Object Detection - Detectron wandb 연결하기
05 Oct 2021 | Deep Learningmmdetection이나 yolo와는 달리 detectron2에는 api를 제공해주지 않는 것 같다.(아니면 죄송합니다)
그래서 직접 구현했다. trainer 부분을 override해주는 방식으로 접근하면 감이 잡힌다.
DefaultTrainer 구조 탐험
여기서 구현부를 볼 수 있다.
다 이해할 필요는 없고, pytorch 기반으로 되어 있으니 결국 hook base로 로그와 같은 것들이 돌아가고 있음을 볼 수 있다.
build_hooks(self)
에서 hook들을 등록하는 것으로 볼 수 있는데, 함수의 중간부분만 바꿔야되기때문에 전체를 다 긁어온 다음 override 했다.
Override
from detectron2.engine import DefaultTrainer
...
class Trainer(DefaultTrainer):
...
#Override
def build_hooks(self):
"""
Build a list of default hooks, including timing, evaluation,
checkpointing, lr scheduling, precise BN, writing events.
Returns:
list[HookBase]:
"""
cfg = self.cfg.clone()
cfg.defrost()
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
ret = [
hooks.IterationTimer(),
hooks.LRScheduler(),
hooks.PreciseBN(
# Run at the same freq as (but before) evaluation.
cfg.TEST.EVAL_PERIOD,
self.model,
# Build a new data loader to not affect training
self.build_train_loader(cfg),
cfg.TEST.PRECISE_BN.NUM_ITER,
)
if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
else None,
]
# Do PreciseBN before checkpointer, because it updates the model and need to
# be saved by checkpointer.
# This is not always the best: if checkpointing has a different frequency,
# some checkpoints may have more precise statistics than others.
if comm.is_main_process():
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
def test_and_save_results():
self._last_eval_results = self.test(self.cfg, self.model)
return self._last_eval_results
# Do evaluation after checkpointer, because then if it fails,
# we can use the saved checkpoint to debug.
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
# -------여기 위로는 건들 필요 없습니다--------
if comm.is_main_process():
# Here the default print/log frequency of each writer is used.
# run writers in the end, so that evaluation metrics are written
writerList = [
CustomMetricPrinter(self.showTQDM,self.cfg.SOLVER.MAX_ITER),
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
TensorboardXWriter(self.cfg.OUTPUT_DIR),
# WandB_Printer(name = self.cfg.OUTPUT_DIR.split("/")[1], project="object-detection",entity="cv4")
]
ret.append(hooks.PeriodicWriter(writerList, period=10))
ret.append(hooks.PeriodicWriter([WandB_Printer(name = self.cfg.OUTPUT_DIR.split("/")[1], project="TODO!!",entity="TODO!!")],period=1))
return ret
마지막쯤 WanB_Printer라는 writer를 넣어준다.
trainer
는 hook들을 등록할 수 있는데, hook 안에는 writer과 같은 event들이 들어갈 수 있다.
WandB Printer 구현
그럼 이제 WandB_Printer를 구현해보자.
EventWriter들은 write()
함수를 필수적으로 가지고있어야한다.
from detectron2.utils.events import EventWriter, get_event_storage
class WandB_Printer(EventWriter):
def __init__(self, name, project, entity) -> None:
self._window_size=20
# wandb.ai/authorize 에서 key를 받을 수 있습니다
wandb.login(key="TODO!!")
self.wandb = wandb.init(project=project,entity=entity,name=name)
def write(self):
storage = get_event_storage()
sendDict = self._makeStorageDict(storage)
self.wandb.log(sendDict)
def _makeStorageDict(self,storage):
storageDict = {}
for k,v in [(k, f"{v.median(self._window_size):.4g}") for k, v in storage.histories().items()]:
if "AP" in k:
# AP to mAP
storageDict[k] = float(v) * 0.01
else:
storageDict[k] = float(v)
return storageDict
이렇게한다면 보낼수 있고, 아래는 mmdetection과 협업하기 위하여 커스텀한 전체코드이다.
class WandB_Printer(EventWriter):
def __init__(self, name, project, entity) -> None:
self._window_size=20
self._matchDictList = self._makeMatchDictList()
wandb.login(key="-----")
self.wandb = wandb.init(project=project,entity=entity,name=name)
def write(self):
storage = get_event_storage()
sendDict = self.makeSendDict(storage)
self.wandb.log(sendDict)
def makeSendDict(self, storage):
sendDict = {}
storageDict = self._makeStorageDict(storage)
for item in self._matchDictList:
sendDict = self._findValue(storageDict,item["key"],sendDict,item["customKey"],item["subKey"])
return sendDict
def _makeStorageDict(self,storage):
storageDict = {}
for k,v in [(k, f"{v.median(self._window_size):.4g}") for k, v in storage.histories().items()]:
if "AP" in k:
# AP to mAP
storageDict[k] = float(v) * 0.01
else:
storageDict[k] = float(v)
return storageDict
def _findValue(self,storageDict,key, retDict, customKey, subKey=None):
if key in storageDict:
if subKey is None:
retDict[customKey] = storageDict[key]
else:
retDict["/".join([subKey,customKey])] = storageDict[key]
return retDict
def _makeMatchDictList(self):
matchDictList = []
matchDictList.append(self._makeMatchDict(key="lr",customKey="detectron_learning_rate",subKey=None))
matchDictList.append(self._makeMatchDict(key="total_loss",customKey="loss",subKey="train"))
matchDictList.append(self._makeMatchDict(key="loss_box_reg",customKey="loss_rpn_bbox",subKey="train")) #이거아니면 loss_rpn_loc 임
matchDictList.append(self._makeMatchDict(key="loss_rpn_cls",customKey="loss_rpn_cls",subKey="train"))
matchDictList.append(self._makeMatchDict(key="bbox/AP",customKey="bbox_mAP",subKey="val"))
matchDictList.append(self._makeMatchDict(key="bbox/AP50",customKey="bbox_mAP_50",subKey="val"))
matchDictList.append(self._makeMatchDict(key="bbox/AP75",customKey="bbox_mAP_75",subKey="val"))
matchDictList.append(self._makeMatchDict(key="bbox/APl",customKey="bbox_mAP_l",subKey="val"))
matchDictList.append(self._makeMatchDict(key="bbox/APm",customKey="bbox_mAP_m",subKey="val"))
matchDictList.append(self._makeMatchDict(key="bbox/APs",customKey="bbox_mAP_s",subKey="val"))
#mmdetection엔 없는것들
matchDictList.append(self._makeMatchDict(key="bbox/AP-General trash",customKey="bbox_mAP_General trash",subKey="val"))
matchDictList.append(self._makeMatchDict(key="bbox/AP-Paper",customKey="bbox_mAP_Paper",subKey="val"))
matchDictList.append(self._makeMatchDict(key="bbox/AP-Paper pack",customKey="bbox_mAP_Paper pack",subKey="val"))
matchDictList.append(self._makeMatchDict(key="bbox/AP-Metal",customKey="bbox_mAP_Metal",subKey="val"))
matchDictList.append(self._makeMatchDict(key="bbox/AP-Glass",customKey="bbox_mAP_Glass",subKey="val"))
matchDictList.append(self._makeMatchDict(key="bbox/AP-Plastic",customKey="bbox_mAP_Plastic",subKey="val"))
matchDictList.append(self._makeMatchDict(key="bbox/AP-Styrofoam",customKey="bbox_mAP_Styrofoam",subKey="val"))
matchDictList.append(self._makeMatchDict(key="bbox/AP-Plastic bag",customKey="bbox_mAP_Plastic bag",subKey="val"))
matchDictList.append(self._makeMatchDict(key="bbox/AP-Battery",customKey="bbox_mAP_Battery",subKey="val"))
matchDictList.append(self._makeMatchDict(key="bbox/AP-Clothing",customKey="bbox_mAP_Clothing",subKey="val"))
return matchDictList
def _makeMatchDict(self,key,customKey,subKey):
return {"key":key, "customKey":customKey, "subKey":subKey}