smp에서 swin transformer 사용하기
22 Jan 2022 | Pytorch개요
해당 포스트는 SMP의 인코더 등록하기를 참고하여 작성되었습니다.
Swin Transformer는 대부분 mmDet, mmSeg에서 지원하는 모델을 이용해 사용합니다.
하지만 SMP에서도 인코더를 등록하는 방식으로 사용할 수 있습니다.
이번 포스트에서는 Swin Transformer를 SMP의 인코더로 등록하는 방법을 적겠습니다.
하지만 Dilated convolution 연산이 포함되어있는 Deeplab v3와 같은 디코더의 경우 이전 레이어의 연산값이 포함되어야 하는데, 이를 구현하지 못하여 현재는 PAN 전용으로 만들었습니다.
전체 코드는 여기 있습니다.
하나하나 차근차근 접근해보겠습니다.
인코더 정의하기
공식 코드를 그대로 사용했습니다.
해당 코드를 잘 복붙해옵시다.
다음은 인코더를 정의해야합니다.
# Custom SwinEncoder 정의
class SwinEncoder(torch.nn.Module, EncoderMixin):
def __init__(self, **kwargs):
super().__init__()
# A number of channels for each encoder feature tensor, list of integers
self._out_channels: List[int] = [128, 256, 512, 1024]
# A number of stages in decoder (in other words number of downsampling operations), integer
# use in in forward pass to reduce number of returning features
self._depth: int = 3
# RGB 채널 (다르다면 수정하기)
self._in_channels: int = 3
kwargs.pop('depth')
self.model = SwinTransformer(**kwargs)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
outs = self.model(x)
return list(outs)
Encoder로 사용되려면 EncoderMixin
을 상속해야 합니다.
다음과 같은 구조로 정의해줍니다.
인코더 등록하기
이제는 만들어진 인코더를 등록하는 과정입니다.
# Swin을 smp의 encoder로 사용할 수 있게 등록
def register_encoder():
smp.encoders.encoders["swin_encoder"] = {
"encoder": SwinEncoder, # encoder class here
"pretrained_settings": { # pretrained 값 설정
"imagenet": {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"url": "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth",
"input_space": "RGB",
"input_range": [0, 1],
},
},
"params": { # 기본 파라미터
"pretrain_img_size": 384,
"embed_dim": 128,
"depths": [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32],
"window_size": 12,
"drop_path_rate": 0.3,
}
}
등록하는 과정은 smp.encoder.encoders
라는 Dictionary에 “호출할 이름”으로 정의해줍니다.
정의할 Dictionary 안에는 encoder
, pretrained_settings
,params
가 포함된 Dictionary가 들어가야합니다.
encoder
에는 미리 정의한 SwinEncoder 클래스를 인자로 넣어줍니다.pretrained_settings
는 내부에서도 검색할 키워드와, 그 키워드의 값들을 적어줍니다.
url의 경우 https://github.com/microsoft/Swin-Transformer 내부의 url을 사용했습니다.params
는SwinEncoder
에 인자로 전달할**kwargs
이므로 Dictionary 형태로 전달해줍니다.
눈치채셨듯이, 만약 다른 크기의 모델(Swin-L, Swin-B 등등)을 사용하고 싶다면 url과 params를 configs에 맞춰 정의해주시면 됩니다.
모델 생성하기
다 끝났습니다.
register_encoder()
model = smp.PAN(
encoder_name="swin_encoder",
encoder_weights="imagenet",
encoder_output_stride=32,
in_channels=3,
classes=n_classes
)
인코더를 등록한 다음 이름과 pretrained를 잘 맞춰 넣어주면 됩니다.