CastleJo의 개발일지

Object Detection - Detectron Repeatfactor Sampler 설정하기

|

Repeatfactor Sampler 란 전달받은 repeat factor를 기준으로 더 많은 빈도로, 혹은 더 적은 빈도로 데이터를 불러오기 위한 샘플링 방법이다.
즉 이런 방식으로 훈련한다면 데이터 불균형 문제를 어느정도 해결할 수 있다.

기존의 sampler와 같이 사용할 수 있지만, 설정해줘야 할 것들이 몇가지 있다.

일단 호출은 좀 복잡하다.

일단 dataset_dict를 불러와야 한다.

from detectron2.data.build import get_detection_dataset_dicts

def getDatasetDicts(cfg):
	return get_detection_dataset_dicts(
			cfg.DATASETS.TRAIN,
			filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
			min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
			if cfg.MODEL.KEYPOINT_ON
			else 0,
			proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
	)

이를 기반으로 repeat_factor를 불러올 수 있다. 또한 threshold를 설정해주자.
thresh는 0~1 사이의 값으로, max(1.0, math.sqrt(repeat_thresh / cat_freq)) 해당 식으로 빈도를 정하게 된다.

from detectron2.data.samplers import RepeatFactorTrainingSampler

datasetDict = getDatasetDict(cfg)
repeat_thresh = 0.5
repeat_factor = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(datasetDict,repeat_thresh)

이제 호출만 하면 된다.

from detectron2.data.samplers import RepeatFactorTrainingSampler
sampler = RepeatFactorTrainingSampler(repeat_factor, shuffle=True, seed=42)