Object Detection - Detectron Repeatfactor Sampler 설정하기
30 Sep 2021 | Deep LearningRepeatfactor Sampler 란 전달받은 repeat factor를 기준으로 더 많은 빈도로, 혹은 더 적은 빈도로 데이터를 불러오기 위한 샘플링 방법이다.
즉 이런 방식으로 훈련한다면 데이터 불균형 문제를 어느정도 해결할 수 있다.
기존의 sampler와 같이 사용할 수 있지만, 설정해줘야 할 것들이 몇가지 있다.
일단 호출은 좀 복잡하다.
일단 dataset_dict를 불러와야 한다.
from detectron2.data.build import get_detection_dataset_dicts
def getDatasetDicts(cfg):
return get_detection_dataset_dicts(
cfg.DATASETS.TRAIN,
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
if cfg.MODEL.KEYPOINT_ON
else 0,
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
)
이를 기반으로 repeat_factor를 불러올 수 있다. 또한 threshold를 설정해주자.
thresh는 0~1 사이의 값으로, max(1.0, math.sqrt(repeat_thresh / cat_freq))
해당 식으로 빈도를 정하게 된다.
from detectron2.data.samplers import RepeatFactorTrainingSampler
datasetDict = getDatasetDict(cfg)
repeat_thresh = 0.5
repeat_factor = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(datasetDict,repeat_thresh)
이제 호출만 하면 된다.
from detectron2.data.samplers import RepeatFactorTrainingSampler
sampler = RepeatFactorTrainingSampler(repeat_factor, shuffle=True, seed=42)