[arXiv](Current version v2)
Abstract
간단하고 확장 가능한 multi-modal multi-task 훈련 및 모델링 접근 방식인 Integrated Multimodal Perception(IMP) 제안
Introduction
- 기존 데이터셋을 최대한 활용하고
- 작업 또는 손실 함수의 모든 조합에 대해 훈련할 수 있으며
- 새로운 데이터셋, 작업, 손실 함수를 추가해도 속도가 느려지지 않도록 하나의 multi-modal model을 훈련할 수 있는 방법을 탐색한다.
최근에 개발된 JAX primitives를 통해 AGD(Alternating Gradient Descent)와 MoE(Mixture of Experts)를 구현함으로써 유사한 배치 크기에서 2~8배의 계산이 필요한 여러 modality를 추가했음에도 불구하고 기존 계산 비용과 메모리의 일부만을 사용한다.
Method
Alternating Gradient Descent (AGD)
Multi-modal의 핵심 중 하나는 작업 확장성이다. 데이터와 손실 목표의 다른 조합은 교육 전반에 걸쳐 상호 교환 가능해야 하며, 새로운 데이터나 목표를 추가해도 메모리나 계산 오버헤드가 발생하지 않아야 한다.
Accelerated graph compilation API는 분산 환경에서 각 작업에 대한 I/O 서명 비용이 발생한다. 한 가지 해결책은 혼합 배치를 사용하는 것이지만, 작업이 추가됨에 따라 작업당 배치 크기가 감소한다.
jax.jit API를 통해 AGD를 구현하여 해결한다. 해당 논문에서 각각의 최적화 단계가 개별적으로 볼록한 경우 AGD가 수렴으로 이어질 수 있음을 입증했다.
20개의 작업에 대해 컴파일은 총 훈련 시간의 0.34%만 차지했다.
AGD-Specific Efficiency Considerations
jax.checkpoint를 통해 순전파 도중 모든 체크포인트를 제거함으로써 메모리 절약, jax.lax.scan으로 컴파일 시간을 줄이고, jax.pjit을 통해 컴파일 분산.
Objectives
각 modality에서 강력한 것으로 알려진 손실들을 사용하며, AGD를 통해 개별적으로 역전파된다.
Architecture
아키텍처는 Embedder, MoE Encoder, Heads로 구성된다.
Embedder는 사전 훈련된 모델 사용. (VATT, AudioMAE, T5)
각 임베딩은 modality에 관계없이 공유 인코더에 입력되고 인코더는 LIMoE의 설계를 따른다. 해당 설계는 특정 modality를 추가해도 인코더를 변경할 필요 없음.
헤드는 각 작업에 따라 필요한 modality 별 헤드를 사용한다.
Multi-Resolution Training
비디오 데이터에 대한 계산 복잡성을 줄이기 위해 해상도를 줄이고 프레임을 늘리는 방식으로 차원을 인수분해하거나 프레임 당 일정 비율의 토큰을 drop하는 등의 방법을 사용할 수 있다.
이 또한 큰 배치 크기나 작은 배치 크기가 필요한 각 작업에 따라 다르게 적용한다.