본문 바로가기

논문 리뷰/etc.

Is Flash Attention Stable?

Abstract

훈련 불안정의 잠재적인 원인인 수치 편차(Numeric Deviation)를 정량화하는 방법을 제안하고 flash attention을 분석

제목 어그로인 것으로 밝혀져...

 

[arXiv](2024/05/05 version v1)

 

 

 

Background

Flash-Attention 논문 리뷰:

시퀀스를 일정한 tile로 나누고 online-softmax trick을 사용하여 전체 행렬을 메모리에 올리지 않는다.

Online softmax를 수행하기 위한 재조정 인자가 필요하다.

Flash attention 논문의 그림
본문의 그림

 

 

 

Experimental Methodology

  • Attention 호출 시 기존 attention과 flash attention을 모두 계산하고 출력을 비교한다.
  • 각 모델을 독립적으로 훈련하여 모델 가중치 차이를 정량화한다.

 

 

 

Quantifying Numeric Deviation Through Microbenchmark

Sweep Numerical Precision

비트 수가 많을수록 편차가 적다.

 

FP64에서 기존 attention의 출력을 'Golden Value'로 설정하고 이 값과 각 attention 출력의 비교.

원본이 화질구지임

 

시퀀스 길이가 길수록 재조정 인자가 더 많이 필요하기에 (더 많은 online-softmax 계산) 수치 편차가 늘어난다.


Sweep Algorithm Changes

큰 tile을 사용하면 재조정 인자의 추가 계산이 줄어드므로 수치 편차가 줄어든다.

 

 

 

Contextualizing Numeric Deviation via Weight Differences

Wasserstein Distance를 통해 가중치의 차이를 측정한다.

 

사실 flash attention의 수치적 차이(빨강)는 무작위 모델 초기화로 인한 편차(파랑)와 비슷하며, 저정밀도 사용으로 인한 편차보다 훨씬 작다. 제목 어그로;