본문 바로가기

논문 리뷰/Language Model

Self-attention Does Not Need O(n^2) Memory

[arXiv](Current version v3)

 

참고: 다른 분의 논문 리뷰

(저분의 리뷰에 중요한 부분이 빠져 있어서 요약할 겸 작성함)

 

Abstract

Self-attention의 메모리 복잡도 줄이기

 

 

Algorithm

기존 attention:

길이가 n인 스퀀스에서 단일 쿼리에 대해 모든 si를 계산하고 기억해야 하므로 시간 및 메모리 복잡도는 O(n).

Self-attention은 모든 쿼리에 대해 계산되므로 O(n2).

 

단일 쿼리에 대한 알고리즘을 다음과 같이 변경한다.

그리고 한 번의 QKV 연산이 진행될 때마다 v*, s*를 업데이트한 후 다른 중간 계산값은 다 버린다.

모든 연산이 다 진행된 후 v*/s*를 구하기만 하면 된다. 이렇게 하면 두 개의 상수 메모리 밖에 소모되지 않는다.

 

 

 

Numerical Stability

보통 지수 함수의 불안정성 때문에 실제로는 original softmax가 아닌 safe softmax(이 파트 이해하려면 무조건 봐야 함) 사용한다. 그러나 중간 값을 다 버리기 때문에 max값을 계산할 수가 없다.

 

따라서 max값을 저장하는 변수 m을 추가하고

QKV 연산마다 다음과 같이 v*, s*, m*를 업데이트한다.

 

 

 

Experiments

계산 과정을 건드리지 않았기 때문에 당연하지만 성능 변화 없음.