본문 바로가기

논문 리뷰/Vision Transformer

Scalable Diffusion Models with Transformers (DiT)

확산 모델에 트랜스포머 백본. JAX로 구현됨.

 

Arxiv

Github

Project Page

 

 

Abstract

확산 모델에서 일반적으로 사용되는 U-Net 백본을 잠재 패치에서 작동하는 트랜스포머로 대체한다. 

트랜스포머의 깊이/폭 증가 또는 입력 토큰의 증가가 더 낮은 FID를 갖는다는 것을 발견하였으며 class-conditional ImageNet 512, 256에서 이전의 모든 확산 모델을 능가한 2.27 FID를 달성하였다.

 

 

 

Introduction

본 연구에서는 U-Net의 유도 편향이 확산 모델의 성능에 중요하지 않으며, 트랜스포머와 같은 표준 설계로 쉽게 대체될 수 있음을 보여준다. 또한 트랜스포머를 기반으로 한 새로운 확산 모델인 Diffusion Transformer(DiT)를 소개한다.

 

LDM(=stable diffusion) 프레임워크 하에서 DiT 설계 공간을 구성하고 벤치마킹함으로써 U-Net을 성공적으로 대체할 수 있다는 것을 보여준다.

 

네트워크 복잡성과 샘플 품질 사이에는 강한 상관관계가 있다.

 

 

 

Diffusion Transformers

Preliminaries

Classifier-free Guidance

Latent Diffusion Model(LDM, Stable Diffusion)

 

Diffusion Transformer Design Space

adaLN을 통한 조건화가 가장 좋았다고 함.

 

Patchify

오토인코더의 출력인 z(32x32x4)를 d차원 토큰으로 변환하고 위치임베딩

 

DiT block design

In-context conditioning : t와 c의 벡터 임베딩을 일반 토큰으로 추가 (맨 오른쪽)

Cross-attention block : LDM을 포함해 조건 입력을 위해 범용적으로 사용되는 cross attention. 15% 오버헤드. (중간)

Adaptive layer norm (adaLN) block : 적응형 정규화. 가장 적은 오버헤드 추가.

 

AdaLN-Zero block:

확산 모델의 U-Net에서 잔차 연결 직전 각 블록의 최종 컨볼루션 레이어를 0으로 초기화한다. 이는 잔차 블록을 ID로 만들고 훈련을 가속화하는 데 효과가 있다. DiT 블록에서 똑같이 함. 또한 잔차 연결 직전 적응형 스케일링 매개변수 α 추가.

모든 α에 대해 영벡터를 출력하도록 MLP 초기화. 이는 DiT 블록이 ID로 초기화되도록 함.

 

Transformer decoder

각 토큰을 p x p x 2C 형태로 디코딩하고 최종 예측 노이즈와 공분산을 얻기 위해 재배열.

 

 

 

Experimental Setup

p = 2,4,8

Adam W

학습 속도 0.0001 고정

0.9999의 감쇠로 EMA유지

배치 256

증강은 hflip만

학습 속도 워밍업이나 특별한 트릭 없이도 안정적으로 훈련됐다고 함

 

ADM의 확산 하이퍼 피라미터 유지

 

JAX로 구현

 

 

Experiments