CastleJo의 개발일지

PyTorch - Dataset, DataLoader

|

Dataset

PyTorch에서는 데이터 전치를 위한 데이터 병렬화, 데이터 증식 및 배치 처리 등을 추상화해주는 여러 유틸리티 클래스를 제공한다.

가장 기본적인 map-style Dataset은 기본적으로 3개의 메서드를 오버라이드 하는 것으로 시작한다.

from torch.utils.data import Dataset

class CustomDataset(Dataset):
    # 가장 기본적인 생성자 메서드
    def __init__(self,*args, **kwargs):
        super.__init__(self)
    # Dataset의 최대 요소 수를 반환하는 데 사용되며, 인덱스가 적절한 범위 내에 있는지 확인하는 용도로도 호출된다.  
    def __len__(self):
        length : int
        return length
    # 해당 데이터셋의 {idx}번째 데이터를 반환하는데 사용된다.  
    def __getitem(self, idx):
        return X,y

DataLoader

모델 학습을 위해 Dataset을 불러때 사용되는 라이브러리이다.
데이터를 미니 배치형태로 제공해주는 역할을 한다.

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

인자가 많은데, 하나하나 알아보자.

  • dataset
    당연하지만 위에서 만든 Dataset 이 들어가야 한다. 유일한 가변인자로 무조건 입력해주어야 한다.

  • batch_size
    배치 하나하나당 사이즈를 설정해주는 인자이다.

  • shuffle
    데이터를 섞어서 배치할 지 결정하는 인자로, torch.seed에서 영향을 받는다.
  • sampler / batch_sampler
    index를 컨트롤하는 방법인데, 여기를 참고하자.
  • num_workers
    내장 라이브러리인 threadpool나, golang의 고루틴의 pool을 생각하면 이해가 빠를 것 같다. 스레드의 수라고 생각된다.
  • collate_fn
    꽤 많이 쓰이는 함수로 zero-padding 등 데이터의 사이즈를 맞추기 위해 사용된다.
    예를 들어 하나의 batch에서 동일한 길이를 반환하는 함수를 만들어 적용시킨다고 생각해보자.

      my_collate_fn(samples):
          collate_X = []
          collate_y = []
    
          for batch in samples:
              collate_X.append(batch['X'])
              collate_y.append(batch['y'])
          length = max([len(x) for x in collate_X])
          for i in range(len(collate_X)):
              while len(collate_X[i]) < length:
                  collate_X[i] = torch.cat([collate_X[i], torch.tensor([0.])])
            
          return {'X':torch.stack(collate_X), 'y': torch.stach(collate_y)}  
    

    이렇게 한다면, 데이터 리스트를 받을 때 가장 큰 길이에 맞춰 적은 데이터들에 (0.) 의 데이터들이 추가될것이다.

  • pin_memory
    Tensor을 CUDA 고정 메모리에 할당하는 방식인데, C언어의 register 태그를 생각하면 편할것이다. 잘 사용되지는 않는다.
  • drop_last
    batch를 2 이상의 수로 받아올 때, 마지막에 나머지가 남을수도 있다.
    이를 버릴지 말지 결정하는 인자이다.
  • timeout
    말 그대로 데이터를 불러올 때 타임아웃을 지정해준다.
  • worker_init_fn
    worker가 존재할 때, worker들을 컨트롤해주는 함수이다.

return 값

반복 가능한 객체인, 그러니 즉 __iter__(self) 가 구현되어있는 객체가 리턴된다.

그러므로 next(iter({DataLoader})) 와 같은 방식으로 호출할 수 있다.

Pytorch의 여러가지 Dataset 관련 모듈들

  • torchvision.dataset : MNIST나 CIFAR-10과 같은 이미지 데이터셋을 제공한다.
  • torchtext.dataset : IMDb나 AG_NEWS와 같은 텍스트 데이터셋을 제공한다.
  • torchvision.transforms : 이미지 데이터셋에 쓰이는 필터나 변환 도구들을 제공한다.
  • torchvision.utils : 데이터를 저장하고 시각화하는 도구가 들어가있다.