PyTorch - Dataset, DataLoader
19 Aug 2021 | Deep LearningDataset
PyTorch에서는 데이터 전치를 위한 데이터 병렬화, 데이터 증식 및 배치 처리 등을 추상화해주는 여러 유틸리티 클래스를 제공한다.
가장 기본적인 map-style Dataset은 기본적으로 3개의 메서드를 오버라이드 하는 것으로 시작한다.
from torch.utils.data import Dataset
class CustomDataset(Dataset):
# 가장 기본적인 생성자 메서드
def __init__(self,*args, **kwargs):
super.__init__(self)
# Dataset의 최대 요소 수를 반환하는 데 사용되며, 인덱스가 적절한 범위 내에 있는지 확인하는 용도로도 호출된다.
def __len__(self):
length : int
return length
# 해당 데이터셋의 {idx}번째 데이터를 반환하는데 사용된다.
def __getitem(self, idx):
return X,y
DataLoader
모델 학습을 위해 Dataset을 불러때 사용되는 라이브러리이다.
데이터를 미니 배치형태로 제공해주는 역할을 한다.
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
인자가 많은데, 하나하나 알아보자.
-
dataset
당연하지만 위에서 만든 Dataset 이 들어가야 한다. 유일한 가변인자로 무조건 입력해주어야 한다. -
batch_size
배치 하나하나당 사이즈를 설정해주는 인자이다. - shuffle
데이터를 섞어서 배치할 지 결정하는 인자로, torch.seed에서 영향을 받는다. - sampler / batch_sampler
index를 컨트롤하는 방법인데, 여기를 참고하자. - num_workers
내장 라이브러리인 threadpool나, golang의 고루틴의 pool을 생각하면 이해가 빠를 것 같다. 스레드의 수라고 생각된다. -
collate_fn
꽤 많이 쓰이는 함수로 zero-padding 등 데이터의 사이즈를 맞추기 위해 사용된다.
예를 들어 하나의 batch에서 동일한 길이를 반환하는 함수를 만들어 적용시킨다고 생각해보자.my_collate_fn(samples): collate_X = [] collate_y = [] for batch in samples: collate_X.append(batch['X']) collate_y.append(batch['y']) length = max([len(x) for x in collate_X]) for i in range(len(collate_X)): while len(collate_X[i]) < length: collate_X[i] = torch.cat([collate_X[i], torch.tensor([0.])]) return {'X':torch.stack(collate_X), 'y': torch.stach(collate_y)}
이렇게 한다면, 데이터 리스트를 받을 때 가장 큰 길이에 맞춰 적은 데이터들에 (0.) 의 데이터들이 추가될것이다.
- pin_memory
Tensor을 CUDA 고정 메모리에 할당하는 방식인데, C언어의register
태그를 생각하면 편할것이다. 잘 사용되지는 않는다. - drop_last
batch를 2 이상의 수로 받아올 때, 마지막에 나머지가 남을수도 있다.
이를 버릴지 말지 결정하는 인자이다. - timeout
말 그대로 데이터를 불러올 때 타임아웃을 지정해준다. - worker_init_fn
worker가 존재할 때, worker들을 컨트롤해주는 함수이다.
return 값
반복 가능한 객체인, 그러니 즉 __iter__(self)
가 구현되어있는 객체가 리턴된다.
그러므로 next(iter({DataLoader}))
와 같은 방식으로 호출할 수 있다.
Pytorch의 여러가지 Dataset 관련 모듈들
torchvision.dataset
: MNIST나 CIFAR-10과 같은 이미지 데이터셋을 제공한다.torchtext.dataset
: IMDb나 AG_NEWS와 같은 텍스트 데이터셋을 제공한다.torchvision.transforms
: 이미지 데이터셋에 쓰이는 필터나 변환 도구들을 제공한다.torchvision.utils
: 데이터를 저장하고 시각화하는 도구가 들어가있다.