[논문 리뷰] Characterizing signal propagation to close the performance gap in unnormalized resnets
“Characterizing signal propagation to close the performance gap in unnormalized resnets” 논문을 개인 공부 및 리뷰를 위해 쓴 글입니다.
논문 순서
1. Batch Normalization Biases Residual Blocks towards the Identity Function in Deep Networks, De and Smith, NeurIPS 2020
2. Characterizing Signal Propagation to Close the Performance Gap in Unnormalized ResNets, Brock et al., ICLR 2021
3. High Performance Large-Scale Image Recognition without Normalization, Brock et al., ICMR 2021
논문 링크 : https://arxiv.org/abs/2101.08692
1. Introduction
- 배치 정규화(Batch Normalization)은 딥러닝의 핵심이 되었으며, 거의 모든 SOTA 이미지 분류기에 사용되고 있다.
- 다음은 배치 정규화의 pros and cons 이다.
Pros
- loss landscope를 더 매끄럽게 하여 더 큰 학습률로 학습을 가능하게 해준다.
- 배치 통계의 미니 배치 추정치에서 발생하는 노이즈(noise)는 implicit regularization을 도입한다.
Loss Landscope 이미지 출처 3 - 배치 정규화는 identity skip connection을 가진 deep residual network에서 초기화(initialization)할 때 좋은 신호 전파(signal propagation)를 보장하여, 깊은 레이어로 resnet을 학습할 수 있게 해준다.
Cons
-
장치(device)당 배치 크기가 너무 작거나 너무 클 때 성능이 저하되는 등, 배치 크기에 따라 동작이 크게 달라지며, training 및 추론 시간(inference time) 모델의 동작 사이에 불일치를 초래한다.
- 추론 시간(inference time) : 기계 학습 모델이 새로운 데이터를 처리하고 예측을 하는데 걸리는 시간의 양 4
이미지 출처 5 -
배치 정규화는 메모리 오버헤드가 추가로 발생할 수 있다.
-
다른 하드웨어에서 학습된 배치 정규화 모델을 복제(replication)하는 것은 어렵다.
Contribution
- 그들은 정규화 레이어없이 깊은 ResNet을 학습할 수 있는 방안을 찾아, SOTA와 대적할 수 있는 test accuracy을 찾는 것이다.
- 그들의 contribution은 다음과 같다.
- deep residual network에서 forward pass에서 초기화 시 신호 전파를 검사하는 데 도움이 되는 시각화 세트인 신호 전파 플롯(SPP)을 소개한다.
- ReLU 또는 Swish 활성화 및 Gaussian 가중치를 사용하여 unnormalized ResNets에서 주요 고장(failure) 모드(신호 전파 식별)를 식별한다.
- 이러한 비선형성의 평균 출력은 양수이기 때문에, 각 채널의 숨겨진 활성화의 제곱 평균은 네트워크 깊이가 증가함에 따라 빠르게 증가한다.
- 이를 해결하기 위해, Scaleed Weight Standardization을 제안한다.
Gaussian Distribution = 정규 분포(Normal Distribution)
Alternate normalizers
- unnormalized 네트워크에서, 0으로 초기화된 각 잔차 분기의 끝에 학습 가능한(learnable) 스칼라(scalar)를 도입하여 대체하기도 한다.
- 서로 다른 상황에서 배치 정규화의 한계를 극복하기 위해 다양한 대체 정규화 체계가 제안되었으며, 각각은 hidden activation의 서로 다른 구성 요소에서 작동한다.
- LayerNorm(Ba 등, 2016), InstanceNorm(Ulyanov 등, 2016), GroupNorm(Wu & He, 2018) 등이 그 예
- 이러한 대안은 배치 크기에 대한 의존성을 제거하고 일반적으로 매우 작은 배치 크기에 대해 배치 정규화보다 더 잘 작동하지만, 추론 시간 동안 추가 계산 비용을 도입하는 등 자체적인 한계를 도입하기도 한다.
2. Signal Propagation Plots
- Signal propagation : 파동이 한 지점에서 다른 지점으로 전파의 이동 6
- random Gaussian inputs 또는 real training samples로 조건화될 때 네트워크 내의 다른 지점에서 hidden activations의 통계를 표시하는 것이 매우 유익하다.
- Signal Propagation Plots (SPPs) 도입하는 이유이다.
- Average Channel Squared Mean : NHW 축에 걸친 평균의 제곱으로 계산한 다음 C 축에 걸쳐 평균을 구한다.
- 전파가 좋을수록, 각 채널의 평균 활성화가 0에 가까워짐
- Average Channel Variance : NHW 축에 대한 채널 분산을 취한 다음 C 축에 대한 평균을 구하여 계산된다.
- 신호 폭발(explosion) 또는 감쇠(attenuation)를 명확하게 보여줌
- Residual Average Channel Variance : 잔차 가지(residual branch)의 레이어가 올바르게 초기화되었는지 여부를 평가하는 데 도움이 된다.
3. Normalizer-Free Resnets (NF-RESNETS)
- 이제 SPP를 이용하여 신호 전파가 좋고 안정적인 unnormalized resnet을 개발해본다.
- SPP를 통해 확인한 바로는, 표준 초기화의 경우, 배치 정규화는 각 잔차 블록에 대한 입력을 입력 신호의 표준 편차에 비례하는 인수로 다운스케일링(downscaling;축소)한다. 그리고, 각 잔차 블록은 신호의 분산을 거의 일정한 요인만큼 증가시킨다.
- 이런 효과를 흉내내기 위해서, $x_{l+1}=x_l+\alpha f_l(x_l/\beta_l)$ 형태의 잔차 블록을 사용한다.
- $x_l$ : $l$번째 잔차 블록에 대한 입력
- $f_l(\cdot)$ : $l$번째 잔차 가지(residual branch)
- 잔차 가지에 의해 계산된 함수 $f_l(\cdot)$는 초기화 시 분산 보존으로 매개 변수화된다.
- $Var(f_l(z))=Var(z)$
- 이러한 제약을 통해 네트워크의 신호 성장에 대해 추론하고 분산을 분석적으로 추정할 수 있다.
- $\beta_l$ : $\sqrt{Var(x_l)}$인 고정된 스칼라 값
- 초기화 시 활성화 $x_l$의 예상 경험적 표준 편차
- 이렇게 하면 $f_l(\cdot)$에 대한 입력이 단위 분산(unit variance)이 되게 해준다.
- $\alpha$ : 블록 간의 분산 성장률을 제어하는 스칼라 하이퍼파라미터 값
- $Var(x_l)=Var(x_{l-1}+\alpha^2)$와 초기 값 $Var(x_0)=1$, $\beta=\sqrt{Var(x_l)}$로 설정해 잔차 블록 $l$에서 expected empirical variance를 계산한다.
- normalized ResNet의 신호 분산은 normalized 입력을 수신하는 shortcut convolution으로 인해 각 전이 레이어(transition layer)에서 재설정된다.
- 이 논문에서는 전이 레이어(transition layer)의 shortcut convolution이 $(x_l/\beta_l)$에서 작동하도록 하여 이 재설정(reset)을 모방한다.
- 이것을 Normalizer-Free ResNet (NF-ResNets)이라 부른다.
4.1 ReLU Activations Induce Mean Shifts
- 위에서 초록색 선이 He 초기화를 이용한 가우시안 가중치로 초기화한 NF-ResNet이다.
- 제곱 채널 평균의 평균 값(Avg Squared Channel Mean)은 깊이에 따라 빠르게 증가하여 평균 채널 분산을 초과하는 큰 값을 달성한다.
- 이는 서로 다른 훈련 입력에 대한 hidden activation이 강한 상관관계를 갖는 “mean shift”을 나타낸다.
- ReLU 활성화가 제거되면 Avg Squared Channel Mean은 모든 블록 깊이에 대해 0에 가깝게 유지되며 잔차 가지(residual branch)에 대한 경험적 분산은 1을 중심으로 변동한다.
- 결론은, He-초기화된 가중치를 가진 NF-ResNet 모델은 불안정하며, 깊이가 증가함에 따라 훈련하기가 점점 더 어렵다.
4.2 Scaled Weight Standardization
- mean shift 현상을 막고 잔차 가지 $f_l(\cdot)$이 분산을 보존하기 위해, scaled weight standardization을 제안한다.
- $\hat W_{i,j}=\gamma\cdot \frac{W_{i,j}-\mu w_{i,\cdot}}{\sigma w_{i,\cdot}\sqrt{N}}$
- 평균 $\mu$와 분산 $\sigma$는 컨볼루션 필터의 팬인(fan-in) 정도에 걸쳐 계산된다.
- $W$ : gaussian 가중치로 초기화
- $\gamma$ : 고정 상수
- Scaled WS를 적용한 변형의 출력, $z=\hat W g(x)$는 $\mathbb{E}(z_i)=0$을 가져, mean shift를 제거한다.
- Scaled WS는 학습 중에 저렴하고 추론에서 자유롭고, (활성화보다는 매개 변수의 수에 따라) 잘 조정되며, 배치 요소 간의 의존성과 training 및 test 동작의 불일치를 초래하지 않으며, 분산 훈련에서 구현이 다르지 않다.
- 4.1 그리메서 Scaled WS는 초기화시 avg channel squared mean의 성장을 제거한다.
Summary of NF-ResNet
- 계산 후 예상 신호 분산인 $\beta_l^2$를 전파한다.
- $\beta_l^2$는 각 잔여 블록($\beta_0=1$) 후 $\alpha^2$에 의해 성장.
- $\beta_l$에 의해 각 잔차 가지(residual branch)에 대한 입력 축소(downscale)
- 추가로, $\beta_l$에 의한 전환 블록(transition block)에 있는 스킵 경로의 컨볼루션에 대한 입력 축소하고 전환 블록 뒤에 나오는 $\beta_{l+1}=1+\alpha^2$를 재조정
- 모든 컨볼루션 레이어에 Scaled WS를 사용하여 활성화 함수 g(x)에 특정한 이득(gain)을 계산한다.
5. Experiments
- 다음은 각 ResNet의 변형들에 대한 정확도를 비교한 결과이다.
- 배치 정규화의 한계로 장치당 배치 크기가 작을 때 성능이 하락한다는 것이다.
- 다음은 배치 크기에 따른 NF-ResNet과 BN-ReseNet의 성능 비교이다.
- 다음 그림은 다양한 FLOP budget 범위에서, NF-Nesnet은 ImageNet의 SOTA EfficientNets와 경쟁력 있는 성능을 달성한 것을 보여준다.
댓글남기기