[Github]
[arXiv](2024/03/06 version v1)
data:image/s3,"s3://crabby-images/ba955/ba955ce70c70f03a9e5f3e704087463dcdbdacd0" alt=""
Abstract
Gradient를 low-rank로 투영하여 메모리 집약적인 계산을 수행하는, LoRA 보다 메모리 효율적인 GaLore (Gradient Low-Rank Projection) 제안
GaLore: Gradient Low-Rank Projection
- Background
- Low-Rank Property of Weight Gradient
- Gradient Low-rank Projection (GaLore)
이 챕터 선 한 줄 요약: 훈련이 진행될수록 gradient의 rank가 낮아지며, 이를 이용해 메모리 집약적인 계산을 low-rank에서 수행한다.
Background
Regular full-rank training
Timestep t, optimizer (e.g. Adam) p, 역전파 행렬 G에 대해
data:image/s3,"s3://crabby-images/232c3/232c368899e2ca661b61ee9b5f05a5f3d4cc15b1" alt=""
가중치 업데이트는 다음과 같다.
data:image/s3,"s3://crabby-images/84e21/84e21cc2afbcff81ddc2db0caa14d9482d0cda1d" alt=""
이때 p의 state는 메모리 집약적일 수 있으며 Adam의 경우 G̃는 다음과 같이 계산된다.
data:image/s3,"s3://crabby-images/589ed/589eda780629e81fe9ae925210de79f8d0977c4c" alt=""
data:image/s3,"s3://crabby-images/fd0a5/fd0a5ab7a17eec8e31f230865ab5ef1b21417f75" alt=""
Low-Rank Property of Weight Gradient
Lemma 3.1 훈련이 진행될수록 gradient의 rank가 낮아진다.
Gradient update를 다음과 같은 형태로 가정하면, (B, C는 PSD matrix이다.)
data:image/s3,"s3://crabby-images/75ccf/75ccf7af86a6e4620adcc5f80b75980814e908d8" alt=""
t가 증가함에 따라 G는 rank-1로 수렴한다.
data:image/s3,"s3://crabby-images/9b1aa/9b1aaa102f6dbb44e9d6db344ead97025a4e1a8d" alt=""
Theorem 3.2 가역 신경망의 gradient form
선형 레이어와 일부 활성화함수는 가역적이다.
다음과 같은 L2 목표를 갖는 가역 신경망 N에 대해
data:image/s3,"s3://crabby-images/9302d/9302d612b31971f4f47f280dadfe9cdf28a96f1b" alt=""
배치 크기 1에서 layer l의 가중치 행렬 W는 다음과 같은 형태의 gradient를 가진다.
data:image/s3,"s3://crabby-images/fedb4/fedb4b36cfd78e12887eeb43b3d959422015fa4f" alt=""
data:image/s3,"s3://crabby-images/03d94/03d94fe73599da8d07d7a25a733cde9fbddbb061" alt=""
Lemma 3.3 Softmax loss를 사용하는 가역 신경망 또한 같은 형태의 gradient form을 갖는다.
data:image/s3,"s3://crabby-images/d5f5a/d5f5a431caecfab4bd7ff537443ff29faefa072b" alt=""
Gradient Low-rank Projection (GaLore)
Definition 3.4 Gradient Low-rank Projection (GaLore)
G를 row-rank로 투영하고 p를 처리한 뒤 원래 차원으로 복구한다.
data:image/s3,"s3://crabby-images/862a3/862a3f1da43e4afaf58d92db8fb741026831527d" alt=""
data:image/s3,"s3://crabby-images/0057d/0057df07f559ba7a1c960f9a01e968065e70491a" alt=""
Definition 3.5 L-continuity (Lipschitz continuity)
data:image/s3,"s3://crabby-images/e8a32/e8a32d18dad9b513660d05652e0aa667877a8878" alt=""
이를 만족한다는 것은 안정적인 수렴이 보장된다는 것을 의미한다.
Theorem 3.6 Fixed projection을 이용한 GaLore의 수렴
배치 크기가 1보다 클 때의 G:
data:image/s3,"s3://crabby-images/d5116/d5116decd826804d433e96c14edf8b9c4910a48f" alt=""
다음과 같은 조건에서 A, B, C에 대한 L-continuity를 만족한다고 한다.
data:image/s3,"s3://crabby-images/20a46/20a46d648ffdb80c94ef7ed184c25fdef8e2446a" alt=""
각 변수 정의:
data:image/s3,"s3://crabby-images/f719e/f719ef9ed8e082b218f786834e21f15e63864d5c" alt=""
data:image/s3,"s3://crabby-images/15013/15013c12b4ef0ff3e5e1f520f258c13bd6375446" alt=""
λmin은 해당 행렬이 가진 가장 작은 고윳값을 말한다.
그러므로 조건을 만족하기 위해서는
data:image/s3,"s3://crabby-images/d9e6d/d9e6dec061c639ef4033dd92cc070a680bb9babe" alt=""
는 B, C의 가장 큰 몇 개의 고유 부분공간만을 포함하는 것이 좋다.
한 가지 쉬운 방법은 G에 대한 특이값 분해를 사용하는 것이다.
data:image/s3,"s3://crabby-images/adb47/adb4719f49be065cd56be95c19f9bfd17482b7d6" alt=""
GaLore for Memory-Efficient Training
LLM training과 같은 복잡한 최적화 문제의 경우 단일 subspace로 전체 gradient 궤적을 포착하는 것이 불가능하다.
Composition of Low-Rank Subspaces
훈련 도중 일정 시점에서 현재 G에 대한 특이값 분해를 수행하여 P, Q를 초기화함으로써 다른 subspace로의 전환을 허용한다.
data:image/s3,"s3://crabby-images/787bc/787bc707974e5e61ef86200feb45249ce85e8d58" alt=""
data:image/s3,"s3://crabby-images/29a40/29a40f7cb2b9400199cff4d79945e2c5cc6d9be4" alt=""
Memory-Efficient Optimization
메모리와 성능 간의 균형을 위해 실제로는 PGQ를 사용하지 않고 PG 또는 GQ로만 투영한다.
data:image/s3,"s3://crabby-images/0f790/0f790b35e1db2377f4607953ead0ea8e37c751fd" alt=""
Combining with Existing Techniques
추가적으로 QLoRA, Per-layer weight update를 채택.
data:image/s3,"s3://crabby-images/b269b/b269bba1a17ab7a67f51991b13aa196aeab043c1" alt=""
Experiments
Perplexity
data:image/s3,"s3://crabby-images/e076a/e076a3916588fd4484e8ae20276e1d2316550897" alt=""
data:image/s3,"s3://crabby-images/42204/42204ecc30854990de199e7ac8a3da2e975ba6cd" alt=""
Memory
data:image/s3,"s3://crabby-images/fd68b/fd68b999499e2e3c9a75e485b6657851ced8365d" alt=""