CastleJo의 개발일지

PyTorch - nn.module

|

nn.module 공식 문서

기본 구조

nn.Module.forward

  • return 값으로 모델의 연산 결과를 반환한다.
  • 모델(value) 와 같이 함수로 출력하면, forward 함수를 거친다.
  • 멤버 Tensor에 자동으로 backward 함수를 계산하게 저장된다.

Backward

  • Layer에 있는 Parameter들의 미분을 수행
  • Forward의 결괏값과 실제값간의 차이에 대해 미분을 수행(Autograd)
  • 해당 값으로 Parameter 업데이트
  • 직접 구현할 일은 많이 없으나 순서는 이해하는게 좋다.

Container

  • nn.Sequential
    Module을 선형적으로 연산해주는 파이프라인을 만들 수 있다.
  • nn.ModuleList
    Module들을 Python의 List와 같이 담을 수 있다.
  • nn.ModuleDict
    비슷하게 Dict타입으로 담을 수 있다.

Hook

forward, backward 전,후에 데코레이터패턴과 같은 개념으로 추가적인 함수를 실행한다.

  • module.register_forward_pre_hook(func)
  • module.register_forward_hook(func)
  • module.register_full_backward_hook(func)

full backward경우는 추후 grad_fn으로만 대체한다는 말이 있다.

Apply

상당히 많이 쓰이는 기술이다. 큰 모델 중 특정 서브 모델에만 특정 함수를 적용하고 싶을 때가 있는데, 그 때 사용한다.
일반적으로 가중치를 초기화시켜줄 때 많이 사용된다고 한다.

알아두면 좋은 것들

Linear vs LazyLinear

늦은 초기화를 사용하여 사용 시점에 생성을 해 메모리를 절약할 수 있는 기술이다.

Python List vs PyTorch ModuleList

PythonList로 연산을 담으면 모듈들이 출력되지 않는다.
하지만 Pytorch List로 담으면 자동으로 __repr__이 호출되어 이름이 붙는다.

Tensor vs Parameter vs Buffer

Module 안의 가중치값을 저장할 땐 Parameter를 기본적으로 사용한다.
Tensor도 같은 역할을 할 수 있지만, Backward 함수를 계산할 수 있는 grad_fn 값을 저장하지 않는다.
Buffer는 좀 다른 개념인데, 모델 저장시 값이 저장되지만 따로 연산을 수행하진 않는다.

module.get_submodule()

모듈 안의 서브모듈들을 이름을 기준으로 검색할 수 있다.

Docstring

파이썬 Class를 만들 때, 맨 앞에 ‘’’ 설명 ‘’’ 으로 표시되어있는 부분이다.

module.extra_repr

모듈을 print할 때 모듈을 설명해줄 수 있는 이름을 정할 수 있다.