CastleJo의 개발일지

Ai Math - CNN, RNN 첫 걸음

|

CNN

Convolution

\[h_{i}=\sigma\left(\sum_{j=1}^{k} V_{j} x_{i+j-1}\right)\]
  • 커널을 입력벡터 사이에서 움직여가면서 선형모델과 합성함수가 적용되는 구조
  • 고정된 가중치 행렬 V로, i의 사이즈에 관계가 없음

  • 2차원 Convolution 연산을 하면, 입력 크기 - 커널 크기 + 1 의 출력 크기가 된다.

Convolution 연산의 역전파

\[\begin{aligned} \frac{\partial}{\partial x}[f * g](x) &=\frac{\partial}{\partial x} \int_{\mathbb{R}^{d}} f(y) g(x-y) \mathrm{d} y \\ &=\int_{\mathbb{R}^{d}} f(y) \frac{\partial g}{\partial x}(x-y) \mathrm{d} y \\ &=\left[f * g^{\prime}\right](x) \end{aligned}\]
  • Convolution 연산은 커널이 모든 입력데이터에 공통으로 적용되기 때문에 역전파를 계산할 때도 Convolution 연산이 나오게 된다.

RNN

  • 소리, 문자열, 주가 등의 데이터를 Sequence 데이터로 분류
  • 시퀀스 데이터는 독립동등분포 가정을 잘 위배하기 떄문에 순서를 바꾸거나 과거 정보에 손실이 발생하면 데이터의 확률 분포도 쉽게 바뀌게된다.

  • 이전의 정보를 가지고 앞으로 발생할 데이터의 확률분포를 다루기 위해 조건부확률을 이용할 수 있다.

\(P\left(X_{1}, \ldots, X_{t}\right)=P\left(X_{t} \mid X_{1}, \ldots, X_{t-1}\right) P\left(X_{1}, \ldots, X_{t-1}\right)\) \(=\prod_{s=1}^{t} P\left(X_{S} \mid X_{s-1}, \ldots, X_{1}\right)\)

RNN 이해하기

  • 기본적인 RNN 모형
\[\begin{aligned} &\mathbf{O}=\mathbf{H} \mathbf{W}^{(2)}+\mathbf{b}^{(2)} \\ &\mathbf{H}=\sigma\left(\mathbf{X W}^{(1)}+\mathbf{b}^{(1)}\right) \end{aligned}\] \[\begin{aligned} \mathbf{O}_{t} &=\mathbf{H}_{t} \mathbf{W}^{(2)}+\mathbf{b}^{(2)} \\ \mathbf{H}_{t} &=\sigma\left(\mathbf{X}_{t} \mathbf{W}_{X}^{(1)}+\mathbf{H}_{t-1} \mathbf{W}_{H}^{(1)}+\mathbf{b}^{(1)}\right) \end{aligned}\]
  • 이전 순서의 잠재변수와 현재의 입력을 활용한다.

  • 잠재변수인 $H_t$를 복제해 다음 순서의 잠재변수를 인코딩하는데 사용한다.

  • RNN의 역전파는 잠재변수의 연결그래프에 따라 순차적으로 계산한다. 이를 BPTT(Backpropagation Through Time) 이라고 한다.

BPTT

  • BPTT를 통해 RNN의 가중치행렬의 미분을 계산해보면 아래와 같이 미분의 곱으로 이루어진 항이 계산된다.
\[L\left(x, y, w_{h}, w_{o}\right)=\sum_{t=1}^{T} \ell\left(y_{t}, o_{t}\right)\] \[\partial_{w_{h}} L\left(x, y, w_{h}, w_{o}\right)=\sum_{t=1}^{T} \partial_{w_{h}} \ell\left(y_{t}, o_{t}\right)=\sum_{t=1}^{T} \partial_{o_{t}} \ell\left(y_{t}, o_{t}\right) \partial_{h_{t}} g\left(h_{t}, w_{h}\right)\left[\partial_{w_{h}} h_{t}\right]\]
  • 시퀀스의 길이가 길어질수록 역전파 알고리즘의 계산이 불안정해지므로 길이를 끊는 것이 필요하다.

이런 문제를 해결하기 위해 LSTMGRU 방식이 등장했고, 추후 다뤄볼 예정이다.

피어세션 정리

주요 키워드

  • CNN 역전파 관련 참고 자료 (그래프 수식 多) https://ratsgo.github.io/deep%20learning/2017/04/05/CNNbackprop/

  • RNN 역전파 방법 BPTT의 미분식 더 공부해보고 자료 공유하기로 http://aikorea.org/blog/rnn-tutorial-1/

  • 전날 토의했던 가능도 관련 수식과 라그랑주 승수 계산 관련 자료 https://namyoungkim.github.io/statistics/2017/09/17/probability/

회고

뜻은 전부 이해했는데 이를 수식으로 표현해보려니 너무 어렵다..