본문 바로가기

논문 리뷰/etc.

Online normalizer calculation for softmax

[Github]

[arXiv](Current version v2)

 

Abstract

더 적은 메모리 액세스로 softmax 계산

 

 

Original softmax

일반적인 softmax:

벡터 당 총 3번의 메모리 액세스가 발생한다.

(정규화 항 dV 계산, 출력값 계산, 저장)

 

하지만 현재 대부분의 딥러닝 작업에서는 지수함수의 오버플로우 위험 때문에 safe softmax를 사용한다.

하지만 safe softmax는 최댓값을 구하는 과정이 추가되어 총 4번의 메모리 액세스가 발생한다.

 

 

 

Online normalizer calculation

온라인으로 최댓값과 정규화 항을 업데이트하는 방식으로 메모리 액세스를 3번으로 줄일 수 있다.

(이 논문에 있는 수학은 증명이 어렵지 않으니 한 번 써 보면서 읽는 걸 추천함)

 

Parallel online normalizer calculation

병렬 스레드에서 online softmax를 계산하는 방법:

 

Softmax and top-k fusion

Beam search와 같이 softmax → TopK 순서로 진행되는 경우, 일반적으로 5번의 메모리 액세스가 필요하다.

 

하지만 online으로 메모리 효율적으로 처리할 수 있다.

 

 

 

Benchmarking