CastleJo의 개발일지

Object Detection - Detectron wandb 연결하기

|

mmdetection이나 yolo와는 달리 detectron2에는 api를 제공해주지 않는 것 같다.(아니면 죄송합니다)

그래서 직접 구현했다. trainer 부분을 override해주는 방식으로 접근하면 감이 잡힌다.

DefaultTrainer 구조 탐험

여기서 구현부를 볼 수 있다.

다 이해할 필요는 없고, pytorch 기반으로 되어 있으니 결국 hook base로 로그와 같은 것들이 돌아가고 있음을 볼 수 있다.

build_hooks(self) 에서 hook들을 등록하는 것으로 볼 수 있는데, 함수의 중간부분만 바꿔야되기때문에 전체를 다 긁어온 다음 override 했다.

Override

from detectron2.engine import DefaultTrainer
...

class Trainer(DefaultTrainer):
	...

	#Override
	def build_hooks(self):
      """
      Build a list of default hooks, including timing, evaluation,
      checkpointing, lr scheduling, precise BN, writing events.
      Returns:
          list[HookBase]:
      """
      cfg = self.cfg.clone()
      cfg.defrost()
      cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN

      ret = [
          hooks.IterationTimer(),
          hooks.LRScheduler(),
          hooks.PreciseBN(
              # Run at the same freq as (but before) evaluation.
              cfg.TEST.EVAL_PERIOD,
              self.model,
              # Build a new data loader to not affect training
              self.build_train_loader(cfg),
              cfg.TEST.PRECISE_BN.NUM_ITER,
          )
          if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
          else None,
      ]

      # Do PreciseBN before checkpointer, because it updates the model and need to
      # be saved by checkpointer.
      # This is not always the best: if checkpointing has a different frequency,
      # some checkpoints may have more precise statistics than others.
      if comm.is_main_process():
          ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))

      def test_and_save_results():
          self._last_eval_results = self.test(self.cfg, self.model)
          return self._last_eval_results

      # Do evaluation after checkpointer, because then if it fails,
      # we can use the saved checkpoint to debug.
      ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
			
			# -------여기 위로는 건들 필요 없습니다--------
			
      if comm.is_main_process():
          # Here the default print/log frequency of each writer is used.
          # run writers in the end, so that evaluation metrics are written
          writerList = [
                        CustomMetricPrinter(self.showTQDM,self.cfg.SOLVER.MAX_ITER),
                        JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
                        TensorboardXWriter(self.cfg.OUTPUT_DIR),
                        # WandB_Printer(name = self.cfg.OUTPUT_DIR.split("/")[1], project="object-detection",entity="cv4")
                      ]

          ret.append(hooks.PeriodicWriter(writerList, period=10))
          ret.append(hooks.PeriodicWriter([WandB_Printer(name = self.cfg.OUTPUT_DIR.split("/")[1], project="TODO!!",entity="TODO!!")],period=1))

      return ret

마지막쯤 WanB_Printer라는 writer를 넣어준다.

trainer는 hook들을 등록할 수 있는데, hook 안에는 writer과 같은 event들이 들어갈 수 있다.

WandB Printer 구현

그럼 이제 WandB_Printer를 구현해보자.
EventWriter들은 write() 함수를 필수적으로 가지고있어야한다.

from detectron2.utils.events import EventWriter, get_event_storage

class WandB_Printer(EventWriter):
  def __init__(self, name, project, entity) -> None:
    self._window_size=20

		# wandb.ai/authorize 에서 key를 받을 수 있습니다
    wandb.login(key="TODO!!")
    self.wandb = wandb.init(project=project,entity=entity,name=name)
    
  def write(self):
    storage = get_event_storage()

    sendDict = self._makeStorageDict(storage)
    self.wandb.log(sendDict)
  
  def _makeStorageDict(self,storage):
    storageDict = {}
    for k,v in [(k, f"{v.median(self._window_size):.4g}") for k, v in storage.histories().items()]:
      if "AP" in k:
        # AP to mAP
        storageDict[k] = float(v) * 0.01
      else:
        storageDict[k] = float(v)

    return storageDict

이렇게한다면 보낼수 있고, 아래는 mmdetection과 협업하기 위하여 커스텀한 전체코드이다.

class WandB_Printer(EventWriter):
  def __init__(self, name, project, entity) -> None:
    self._window_size=20
    self._matchDictList = self._makeMatchDictList()

    wandb.login(key="-----")
    
    self.wandb = wandb.init(project=project,entity=entity,name=name)
    
  def write(self):
    storage = get_event_storage()

    sendDict = self.makeSendDict(storage)
    self.wandb.log(sendDict)
 

  def makeSendDict(self, storage):
    sendDict = {}

    storageDict = self._makeStorageDict(storage)

    for item in self._matchDictList:
      sendDict = self._findValue(storageDict,item["key"],sendDict,item["customKey"],item["subKey"])

    return sendDict 
  
  def _makeStorageDict(self,storage):
    storageDict = {}
    for k,v in [(k, f"{v.median(self._window_size):.4g}") for k, v in storage.histories().items()]:
      if "AP" in k:
        # AP to mAP
        storageDict[k] = float(v) * 0.01
      else:
        storageDict[k] = float(v)

    return storageDict

  def _findValue(self,storageDict,key, retDict, customKey, subKey=None):
    if key in storageDict:
      if subKey is None:
        retDict[customKey] = storageDict[key]
      else:
        retDict["/".join([subKey,customKey])] = storageDict[key]

    return retDict
  
  def _makeMatchDictList(self):
    matchDictList = []

    matchDictList.append(self._makeMatchDict(key="lr",customKey="detectron_learning_rate",subKey=None))

    matchDictList.append(self._makeMatchDict(key="total_loss",customKey="loss",subKey="train"))
    matchDictList.append(self._makeMatchDict(key="loss_box_reg",customKey="loss_rpn_bbox",subKey="train")) #이거아니면 loss_rpn_loc 임
    matchDictList.append(self._makeMatchDict(key="loss_rpn_cls",customKey="loss_rpn_cls",subKey="train"))

    matchDictList.append(self._makeMatchDict(key="bbox/AP",customKey="bbox_mAP",subKey="val"))
    matchDictList.append(self._makeMatchDict(key="bbox/AP50",customKey="bbox_mAP_50",subKey="val"))
    matchDictList.append(self._makeMatchDict(key="bbox/AP75",customKey="bbox_mAP_75",subKey="val"))
    matchDictList.append(self._makeMatchDict(key="bbox/APl",customKey="bbox_mAP_l",subKey="val"))
    matchDictList.append(self._makeMatchDict(key="bbox/APm",customKey="bbox_mAP_m",subKey="val"))
    matchDictList.append(self._makeMatchDict(key="bbox/APs",customKey="bbox_mAP_s",subKey="val"))

    #mmdetection엔 없는것들
    matchDictList.append(self._makeMatchDict(key="bbox/AP-General trash",customKey="bbox_mAP_General trash",subKey="val")) 
    matchDictList.append(self._makeMatchDict(key="bbox/AP-Paper",customKey="bbox_mAP_Paper",subKey="val"))
    matchDictList.append(self._makeMatchDict(key="bbox/AP-Paper pack",customKey="bbox_mAP_Paper pack",subKey="val"))
    matchDictList.append(self._makeMatchDict(key="bbox/AP-Metal",customKey="bbox_mAP_Metal",subKey="val"))
    matchDictList.append(self._makeMatchDict(key="bbox/AP-Glass",customKey="bbox_mAP_Glass",subKey="val"))
    matchDictList.append(self._makeMatchDict(key="bbox/AP-Plastic",customKey="bbox_mAP_Plastic",subKey="val"))
    matchDictList.append(self._makeMatchDict(key="bbox/AP-Styrofoam",customKey="bbox_mAP_Styrofoam",subKey="val"))
    matchDictList.append(self._makeMatchDict(key="bbox/AP-Plastic bag",customKey="bbox_mAP_Plastic bag",subKey="val"))
    matchDictList.append(self._makeMatchDict(key="bbox/AP-Battery",customKey="bbox_mAP_Battery",subKey="val"))
    matchDictList.append(self._makeMatchDict(key="bbox/AP-Clothing",customKey="bbox_mAP_Clothing",subKey="val"))

    return matchDictList
  
  def _makeMatchDict(self,key,customKey,subKey):
    return {"key":key, "customKey":customKey, "subKey":subKey}