Object Detection - Detectron tqdm 연결하기
06 Oct 2021 | Deep Learningwandb 적용과 같은 방식으로, hook에 접근하는 방식이다.
다만, 기존의 로그와 겹치면 매우… 보기 안좋기때문에 기존의 로그를 삭제하고 덧붙여넣는 식으로 만들었다.
Trainer Override
from detectron2.engine import DefaultTrainer
...
class Trainer(DefaultTrainer):
...
def __init__(self, cfg):
self.showTQDM = tqdm(range(cfg.SOLVER.MAX_ITER))
super().__init__(cfg)
#Override
def build_hooks(self):
"""
Build a list of default hooks, including timing, evaluation,
checkpointing, lr scheduling, precise BN, writing events.
Returns:
list[HookBase]:
"""
cfg = self.cfg.clone()
cfg.defrost()
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
ret = [
hooks.IterationTimer(),
hooks.LRScheduler(),
hooks.PreciseBN(
# Run at the same freq as (but before) evaluation.
cfg.TEST.EVAL_PERIOD,
self.model,
# Build a new data loader to not affect training
self.build_train_loader(cfg),
cfg.TEST.PRECISE_BN.NUM_ITER,
)
if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
else None,
]
# Do PreciseBN before checkpointer, because it updates the model and need to
# be saved by checkpointer.
# This is not always the best: if checkpointing has a different frequency,
# some checkpoints may have more precise statistics than others.
if comm.is_main_process():
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
def test_and_save_results():
self._last_eval_results = self.test(self.cfg, self.model)
return self._last_eval_results
# Do evaluation after checkpointer, because then if it fails,
# we can use the saved checkpoint to debug.
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
# -------여기 위로는 건들 필요 없습니다--------
if comm.is_main_process():
# Here the default print/log frequency of each writer is used.
# run writers in the end, so that evaluation metrics are written
writerList = [
CustomMetricPrinter(self.showTQDM,self.cfg.SOLVER.MAX_ITER),
JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
TensorboardXWriter(self.cfg.OUTPUT_DIR),
# WandB_Printer(name = self.cfg.OUTPUT_DIR.split("/")[1], project="object-detection",entity="cv4")
]
ret.append(hooks.PeriodicWriter(writerList, period=10))
return ret
생성자에서 tqdm을 생성해주고, hook으로 걸어줄 때 writer에 생성자로 tqdm 모듈을 넘겨주는 방식이다.
CustomMetricPrinter 구현
writerList에 있는 CustomMetricPrinter
를 구현해보자.
class CustomMetricPrinter(CommonMetricPrinter):
def __init__(self, tqdmModule ,max_iter: Optional[int] = None, window_size: int = 20):
super().__init__(max_iter=max_iter, window_size=window_size)
self.tqdmModule = tqdmModule
def write(self):
storage = get_event_storage()
iteration = storage.iter
if iteration == self._max_iter:
return
try:
lr = "{:.5g}".format(storage.history("lr").latest())
except KeyError:
lr = "N/A"
showDict = {"lr":lr}
lossTuple = [(k, f"{v.median(self._window_size):.4g}") for k, v in storage.histories().items() if "loss" in k]
for k,v in lossTuple:
showDict[k] = v
self.tqdmModule.set_postfix(showDict)
기존의 CommonMetricPrinter
는 훈련할 때 Log를 Logger로 쏴주는 writer 클래스였다.
이를 상속하여 필요한부분만 쏙 빼서 넣어주면 된다.