CastleJo의 개발일지

Object Detection - Detectron tqdm 연결하기

|

wandb 적용과 같은 방식으로, hook에 접근하는 방식이다.

다만, 기존의 로그와 겹치면 매우… 보기 안좋기때문에 기존의 로그를 삭제하고 덧붙여넣는 식으로 만들었다.

Trainer Override

from detectron2.engine import DefaultTrainer
...

class Trainer(DefaultTrainer):
	...

	def __init__(self, cfg):
    self.showTQDM = tqdm(range(cfg.SOLVER.MAX_ITER))
    super().__init__(cfg)

	#Override
	def build_hooks(self):
      """
      Build a list of default hooks, including timing, evaluation,
      checkpointing, lr scheduling, precise BN, writing events.
      Returns:
          list[HookBase]:
      """
      cfg = self.cfg.clone()
      cfg.defrost()
      cfg.DATALOADER.NUM_WORKERS = 0  # save some memory and time for PreciseBN

      ret = [
          hooks.IterationTimer(),
          hooks.LRScheduler(),
          hooks.PreciseBN(
              # Run at the same freq as (but before) evaluation.
              cfg.TEST.EVAL_PERIOD,
              self.model,
              # Build a new data loader to not affect training
              self.build_train_loader(cfg),
              cfg.TEST.PRECISE_BN.NUM_ITER,
          )
          if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
          else None,
      ]

      # Do PreciseBN before checkpointer, because it updates the model and need to
      # be saved by checkpointer.
      # This is not always the best: if checkpointing has a different frequency,
      # some checkpoints may have more precise statistics than others.
      if comm.is_main_process():
          ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))

      def test_and_save_results():
          self._last_eval_results = self.test(self.cfg, self.model)
          return self._last_eval_results

      # Do evaluation after checkpointer, because then if it fails,
      # we can use the saved checkpoint to debug.
      ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
			
			# -------여기 위로는 건들 필요 없습니다--------
			
      if comm.is_main_process():
          # Here the default print/log frequency of each writer is used.
          # run writers in the end, so that evaluation metrics are written
          writerList = [
                        CustomMetricPrinter(self.showTQDM,self.cfg.SOLVER.MAX_ITER),
                        JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
                        TensorboardXWriter(self.cfg.OUTPUT_DIR),
                        # WandB_Printer(name = self.cfg.OUTPUT_DIR.split("/")[1], project="object-detection",entity="cv4")
                      ]

          ret.append(hooks.PeriodicWriter(writerList, period=10))

      return ret

생성자에서 tqdm을 생성해주고, hook으로 걸어줄 때 writer에 생성자로 tqdm 모듈을 넘겨주는 방식이다.

CustomMetricPrinter 구현

writerList에 있는 CustomMetricPrinter를 구현해보자.

class CustomMetricPrinter(CommonMetricPrinter):
  def __init__(self, tqdmModule ,max_iter: Optional[int] = None, window_size: int = 20):
      super().__init__(max_iter=max_iter, window_size=window_size)
      self.tqdmModule = tqdmModule
  

  def write(self):
    storage = get_event_storage()
    iteration = storage.iter
    if iteration == self._max_iter:
        return

    try:
        lr = "{:.5g}".format(storage.history("lr").latest())
    except KeyError:
        lr = "N/A"


    showDict = {"lr":lr}
    lossTuple = [(k, f"{v.median(self._window_size):.4g}") for k, v in storage.histories().items() if "loss" in k]
    for k,v in lossTuple:
      showDict[k] = v

    self.tqdmModule.set_postfix(showDict)

기존의 CommonMetricPrinter는 훈련할 때 Log를 Logger로 쏴주는 writer 클래스였다.
이를 상속하여 필요한부분만 쏙 빼서 넣어주면 된다.