무제한 길이 토큰을 처리할 수 있는 LLM framework
Abstract
LLM이 무제한 길이 context를 처리할 수 있도록 확장한 LongMem 프레임워크 제안.
Introduction
GPT-3은 GPT-2에서 입력 토큰 수를 2k로 늘림. 하지만 dense attention으로 계산 복잡성 증가.
따라서 sparse attention을 사용하는 연구가 있었음.
MemTRM에서는 메모리에서 검색된 memorized token과 in-context token 사이의 dense attention을 통해 계산.
하지만 단일 모델을 사용하여 메모리에 캐시 된 이전 표현이 현재 모델에 문제를 일으킬 수 있는 메모리 부실 문제가 있음.
LongMem에서는 이전 context를 메모리 뱅크에 캐싱하고 메모리 부실 문제를 해결하기 위해 메모리 모듈을 분리하여 잔차 side network(SideNet)를 설계.
메모리 뱅크에 이전 context의 key, value를 저장하고 sidenet에서 현재 context와 융합.
LongMem의 이점:
- 메모리 검색, 융합 과정을 백본과 분리하여 LLM은 장기 context 인코더로만 작동함.
- LLM을 동결하여 사전 훈련된 지식을 활용하며 치명적인 망각을 피할 수 있음.
Methods
Language Models Augmented with Long-Term Memory
Transformer 기반의 LongMem은 LLM, SideNet, memory bank로 구성.
이전 입력의 경우 LLM의 m번째 레이어의 self attention의 key, value 값이 메모리 뱅크에 저장되고
현재 입력은 LLM 통과 후 SideNet에 전송되어 융합.
융합 후 메모리 뱅크의 가장 오래된 캐시를 제거하고 현재 시퀀스 추가.
SideNet은 메모리 뱅크, 현재 입력에 대한 LLM의 모든 은닉 상태 출력 {HLLMl'}l'=1L'와 embedding layer의 출력 H0를 입력으로 받고(LLM의 은닉 상태는 나중에 잔차 연결에 쓰임)
(L-1) 일반 디코더 레이어, 하나의 special memory-augmented decoder layer로 구성.
최종 토큰 확률은 LLM과 SideNet이 공유하는 임베딩 가중치 W와 SideNet의 최종 은닉 상태 HL을 사용하여 계산.
해당 논문과 같이 자기회귀적 언어 모델링을 목표로 하고 pre-training text corpus에서 무작위로 샘플링한 데이터로 학습.
Residual SideNet
SideNet Architecture and Initialization
SideNet의 레이어 수는 LLM의 절반.(L' = 2L)
동일한 깊이의 LLM 레이어 가중치로 초기화.
모델 헤드 레이어도 LLM의 것을 재사용.
Cross-Network Residual Connections
LLM의 2l번째 레이어와 (2l-2) 레이어 출력의 차이를 SideNet의 l번째 출력에 대한 잔차 연결로 추가.
Memory Retrieval and Fusion
Token-to-Chunk Memory Retrieval
가속 및 무결성을 위해 token-to-token 대신 token-to-chunk 검색 수행.
Text chunk는 csz 크기의 n-gram 구조를 나타냄.
메모리 뱅크를 csz로 나누고 평균 풀링하여 key vector를 얻고 현재 context의 attention query와의 내적으로 상위 chunk 선정. 해당 chunk는 token-level로 차원을 압축하고 평면화됨.
중요한 것은 SideNet의 (L-1) 출력의 각 토큰에 대해 각각 다른 chunk가 선택된다는 것.
Memory Fusion
(L-1) 출력과 검색된 key-value 쌍 K̃, Ṽ에 대한 joint attention.
Q, K, V는 이전 출력에 대한 self attention, g는 각 헤드에 대한 학습 가능한 벡터.
Experiments