Object Detection - Detectron 구조 파헤쳐보기
28 Sep 2021 | Deep Learning자고로 프로그래머라면 자신이 사용할 코드는 이해하고있어야되지 않을까.
그래서 베이스라인 코드를 직접 파헤쳐보는 시간을 가졌다.
Register Dataset
일단 기본적으로 Dataset을 준비해야한다.
coco 형식의 json 값이 있다고 가정하자.
try:
register_coco_instances(customDict.name.train_dataset, {}, customDict.path.train_json, customDict.path.image_root)
except AssertionError:
pass
try:
register_coco_instances(customDict.name.test_dataset, {}, customDict.path.test_json, customDict.path.image_root)
except AssertionError:
pass
MetadataCatalog.get(customDict.name.train_dataset).thing_classes = ["여러가지 오브젝트들"]
이러면 Dataset은 끝이다. 놀랍게도.
Load Config
내가 만든건 Custom Config, 즉 Arg 값이고 Detectron도 기본적으로 config값으로 모든 인자 값을 저장한다.
그래서 뭐 열심히 붙여주는 클래스 만들어서 붙여줬다.
# Set Detectron Config by Arg
parser = ConfigArgParser(customDict)
Augmentation
import detectron2.data.transforms as T
def MyMapper(dataset_dict):
dataset_dict = copy.deepcopy(dataset_dict)
image = utils.read_image(dataset_dict['file_name'], format='BGR')
transform_list = [
T.RandomFlip(prob=0.5, horizontal=False, vertical=True),
]
image, transforms = T.apply_transform_gens(transform_list, image)
dataset_dict['image'] = torch.as_tensor(image.transpose(2,0,1).astype('float32'))
annos = [
utils.transform_instance_annotations(obj, transforms, image.shape[:2])
for obj in dataset_dict.pop('annotations')
if obj.get('iscrowd', 0) == 0
]
instances = utils.annotations_to_instances(annos, image.shape[:2])
dataset_dict['instances'] = utils.filter_empty_instances(instances)
return dataset_dict
뭐 요로코롬 넣어줄 수 있다.
dataset_dict
를 반환해주는 형식으로 만들어 덕타이핑해준다.
Trainer
class MyTrainer(DefaultTrainer):
@classmethod
def build_train_loader(cls, cfg, sampler=None):
return build_detection_train_loader(
cfg, mapper = MyMapper, sampler = sampler
)
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
if output_folder is None:
os.makedirs(customDict.path.output_eval_dir, exist_ok = True)
output_folder = customDict.path.output_eval_dir
return COCOEvaluator(dataset_name, cfg, False, output_folder)
DefaultTrainer
를 상속해줘야 한다.
Train
# train
os.makedirs(cfg.OUTPUT_DIR, exist_ok = True)
trainer = MyTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()
다음 train만 해주면 된다. 참 쉽다!