본문 바로가기

논문 리뷰/Mamba

Efficiently Modeling Long Sequences with Structured State Spaces (S4)

[arXiv](2022/08/05 version v3)

 

영어 잘하시면 이거 보세요.

https://srush.github.io/annotated-s4/ 근데 솔직히 원어민도 이거 보고 이해 못 할 듯;

 

The Annotated S4

 

srush.github.io

 

시작하기 전에 말하자면 이 논문에 관련된 모든 수식을 이해하는 것은 저로서는 불가능한 일이었습니다...

그래서 최대한 수학을 빼고 개념적으로 설명해 보겠습니다. (그래도 많아요)

 

Abstract

State Space Model(SSM)에 대한 새로운 피라미터화를 기반으로 이전 접근법보다 효율적인 Structured State Space Sequence model(S4) 제안

 

 

 

Background: State Spaces

일단 state space에 대해 모른 다면 이 영상을 꼭 보고 와주세요.

 

우리가 알 수 있는 것은 연속 시간에서 SSM을 다음과 같이 정의할 수 있다는 것입니다.

 

Addressing Long-Range Dependencies with HiPPO

다른 분의 HiPPO 논문 리뷰

 

HiPPO는 모든 시간 t에 대해 0~t 시간에 대한 정보의 종합이라고 볼 수 있는 ct를 유지하기 위한 프레임워크이다.

ct와 t 시점의 출력 ft의 선형 조합을 통해 ct+1을 생성한다.

 

HiPPO를 행렬의 형태로 구현하여 SSM에 적용하면 SSM의 성능을 향상할 수 있다. (A = HiPPO matrix)

 

Discrete-time SSM: The Recurrent Representation

이산 시간에 대한 SSM.

 

수식은 몰라도 된다. 중요한 것은 이산 시간 간격 ∆에 대한 A를 Ā로 표기한다는 것이다.

D는 잔차 연결로 볼 수 있고 계산하기 쉽기 때문에 계속 생략.

 

주의할 점은, 연속 시간에서는 현재 상태 x(t)를 사용하여 y를 계산했지만 

이산 시간에서는 x를 업데이트한 후에 업데이트된 x를 사용하여 y를 계산한다는 점이다.

 

Training SSMs: The Convolutional Representation

초기 state를 0으로 가정하면 다음과 같다.

 

이 과정은 convolution kernel K̄로 벡터화할 수 있다.

 

 

 

Method: Structured State Spaces (S4)

결국 이 논문의 목적은 S4 피라미터화를 통해 Ā, B̄, C̄, K̄를 효율적으로 계산하는 것이다.

 

Motivation: Diagonalization

한 가지 문제는 행렬 A에 대한 반복 곱셈으로 인해 연산 비용이 늘어난다는 것이다.

 

Lemma 3.1

다음을 가정할 때,

 

다음과 같이 SSM의 구성 요소인 행렬들을 켤레화(conjugation) 해도, 동일한 입력 u에 대해 동일한 출력 y를 갖는다.

 

이는 적절한 기저 변환을 통해 A를 더 계산 효율적인 형태로 변환하여 연산을 수행할 수 있음을 의미한다.

 

But, HiPPO matrix의 대각화는 이론적으로는 가능하지만 실제 계산에서 문제가 발생할 수 있다.

 

The S4 Parameterization: Normal Plus Low-Rank

그래서 아무튼 어떤 기술들을 활용하여...

 

NPLR : A를 다루기 쉬운 정규 행렬과 row-rank 행렬의 합의 형태로 표현할 수 있다.

V = unitary matrix, Λ = diagonal matrix, P&Q = row-rank matrix

 

HiPPO matrix의 NPLR 형태의 경우 row-rank = 1로 표현 가능하다.

 

S4 Algorithms and Computational Complexity

Theorem 2 (S4 Recurrence)

그래서 아무튼 A의 NPLR 형태를 활용하면 Ā, B̄를 A, B에 대한 식으로 나타낼 수 있다.

 

Theorem 3 (S4 Convolution)

흠... 아무튼 K̄를 구할 수 있다...

 

Architecture Details of the Deep S4 Layer

S4는 피라미터화된 대각 행렬 Λ, 벡터 P, Q, B, C를 포함한 네트워크이며, 입력 u에 대한 출력 y를 생성한다.