본문 바로가기

논문 리뷰/etc.

PipeDream: Fast and Efficient Pipeline Parallel DNN Training

1F1B 알고리즘을 통해 유휴시간 없는 파이프라인 병렬화

 

[Github]

[arXiv](Current version v1)

 

 

Abstract

DNN에서 파이프라인 병렬화 training system인 PipeDream 제안

 

 

Introduction

파이프라인, 모델 병렬 처리, 데이터 병렬처리의 조합을 pipeline-parallel training이라고 부른다. (PP)

데이터 병렬 훈련보다 worker 간 통신이 95% 적다고 한다.

 

모든 레이어를 몇 개의 stage로 나눈다.

 

PipeDream의 설계:

  • 알고리즘을 통해 잘못된 partitioning으로 stage 간 작업량이 편향되는 것을 방지
  • 모든 작업자를 바쁜 상태로 유지하기 위한 양방향 훈련
  • 특정 미니배치에 대한 역방향 전달이 해당 순방향 전달에 사용된 것보다 더 최신 피라미터를 사용할 수 없도록 버전 유지

 

 

 

Parallel Training in PipeDream

Pipeline Parallelism

4 Stage 예시

(아래 : 통신과 계산의 시간적 중복 강조)

 

1개의 미니배치에 대한 예시: 유휴 시간이 많음

 

유휴 시간을 줄이기 위해 여러 개의 미니배치를 사용한다.

 

PP의 장점:

  • 데이터 병렬에 비해 훨씬 적은 통신 횟수
  • 계산과 통신 시간을 겹침으로써 하드웨어 효율성이 올라감(순전파 계산 중 역전파 데이터가 들어오고 역전파 계산 중 순전파 데이터가 들어옴으로써 통신 오버헤드가 없음.)

 

Partitioning Layers Across Machines

분할 알고리즘은 각 stage의 작업량을 공평하게 하고 통신되는 데이터의 양을 최대한 적게 보장해야 한다.

N개의 layer와 M개의 device가 있을 때, 먼저 단일 device에서 모델을 profiling 한 후 분할 알고리즘 실행.

 

Profiling the DNN Model

각 레이어 l에서 1000개의 미니 배치에 대해 다음을 기록:

  • 총 계산 시간
  • 출력 활성화의 크기와 역전파에서 입력 gradient의 크기
  • 피라미터의 크기

모든 통신은 송신자 GPU → 송신자 CPU → 수신자 CPU → 수신자 GPU 순서로 이루어진다.

 

PipeDream’s Partitioning Algorithm

위에서 계산한 단일 device에 대해 profiling 된 데이터로 stage를 적절히 나눈다.

양심 고백: 내가 이해력이 안 좋은 편이 아닌데 솔직히 이해 못 하겠음... 딥러닝을 사용한 방법도 아니고 이후에 리뷰할 개선 버전에서도 사용되지 않는 알고리즘이므로 그냥 pass~

 

Work Scheduling

Device가 유휴 상태가 되지 않도록 순전파와 역전파를 번갈아 수행하는 1F1B(one-forward-one-backward) 제안.

첫 번째 역전파가 완료된 이후로 순전파와 역전파가 유휴시간 없이 계속된다.

 

실제로는 순전파보다 역전파가 더 오래 걸림.

PipeDream-2BW paper에서 가져옴

 

Effective Learning

예를 들어, 미니배치 5의 순전파는 미니배치 1의 업데이트가 적용된 후에 진행되지만, 역전파는 미니 배치 4의 업데이트가 적용된 후에 진행된다. 이러한 버전 불일치로 수렴이 되지 않을 수 있다. 또한 출력 stage는 순전파 후 바로 역전파를 수행하지만, 다른 stage는 순전파와 역전파 사이에 다른 미니배치에 의한 업데이트가 있고, 이러한 비대칭도 문제가 된다.

 

Weight Stashing

각 device에서 활성 미니배치 당 하나씩 각 버전의 가중치를 저장하여 역전파 시 활용한다. 순전파 시에는 가장 최신 버전의 가중치가 사용된다.

 

Vertical Sync

미니배치 5는 순전파와 역전파 모두에서 5의 입력 시 가장 최신 버전의 가중치였던 미니배치 1에 의해 업데이트된 가중치 W(1)을 사용하며, 역전파 후에는 W(1)을 삭제하고 버퍼에 W(5)를 추가한다.