앞선 글들에서 GPU 구조최적화, 소프트맥스 병렬화 등을 다루어 보았다. 이러한 글들을 다루게 된 계기는 여럿 있었지만 그 중 하나는 GPU-aware 한 딥러닝 최적화 논문들을 리뷰하기 위함이었다.

GPU-aware한 최적화 논문들은 GPU도 알아야 하고, 딥러닝도 알아야 해서 복잡한 편이다. 그래서 그런지 전체를 아우르는 설명은 찾기 힘들고, 더더욱이나 한글로된 자료는 없었다.

본 글에서는 앞서 정리한 GPU 및 딥러닝 관련 글을 기반으로, GPU-aware 딥러닝 최적화 논문 리뷰를 해보겠다.

오늘 리뷰할 논문은, 아주 유명한 논문 중 하나 인 FlashAttention 1, 2 이다.

논문의 코드나 그래프를 하나하나 설명하기 보단, 주요 아이디어 들을 공유하는 방식으로 서술하였다.

FlashAttention 1 (22.06.24)

TMI

FlashAttention 1 의 풀네임은 “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness” 이다. 지금은 프린스턴 대학의 조교수로 취임한 Tri Dao가 박사일때 2022 년에 낸 논문이며 500회 이상의 인용수를 기록하고 있다. (내 전체 논문 인용수를 합쳐야 500이 넘는데..눈물..)

논문 자체는 ICML workshop 등, 워크샵에만 등재되어 있지만 실제 코드를 구현하여 대부분의 상황에서 잘 작동하게 만들었기 때문에 다양하게 활용되며 입소문이 퍼지게 된 케이스이다. 요즘 MIT의 song han 교수도 그렇고 tri dao도 github 에 코드를 공개 [Link] [Link] 하고 활동을 열심히 하는 모습을 볼 수 있다. 내 전공인 시스템 분야와는 다르게 ‘내부적으로 어떻게 돌아가는지는 모르겠고, 그래서 내가 써 볼 수 있어? 쉽게?’가 중요한 분야 같다.

사설이 길었는데, 이 논문은 제목에서 알 수 있다시피 IO, read/write 관련 최적화 논문이며, 그 중에서도 GPU의 메모리 구조를 파악하고 최적화 한 논문이다. 앞선 글 에서 설명했듯이 GPU의 메모리는 (단순화하면) 작지만 빠른 on-chip과 느리지만 큰 off-chip 으로 구성되어 있다. 저자는 이러한 구조를 이해하지 않고 GPU 프로그래밍을 하면 최대 성능을 이끌어 낼 수 없다고 하며, 이러한 구조에 최적화된 형태의 알고리즘을 제안한다.

문제

  • 트랜스포머 기반 LLM 모델의 경우 문자열의 길이가 매우 제한적이다.
  • 문자열의 길이가 제한적이라 더 긴 문장이나 이미지 등을 학습 할 수 없다.

원인

  • 어텐션 행렬의 시간, 공간 복잡도가 N2 (N = 토큰 갯 수) 라서, 문자열의 길이를 늘리기 어렵다. → Long latency, OOM

Pasted image 20240309102820

해결

  • Tiling (소프트맥스 병렬화를 통한 속도 향상)
    • SRAM 의 사이즈에 맞게 어텐션 행렬을 자른 후 여러 개의 쓰레드 블럭으로 병렬 수행
    • 각 블럭 병렬 수행 후 리스케일링을 통해 정확한 소프트맥스 값 도출
    • 이러한 방식을 통해 HBM 접근 횟수를 최소화하고, 여러 개의 GPU 코어를 최대한 활용

Pasted image 20240309105731

  • 아래 그래프는 블럭 사이즈(갯수 같음)를 늘리며 HBM 접근 횟수가 얼마나 줄어드는가를 실험한 그래프
    • 블럭 갯수를 늘리면, SM 활용도가 올라가고, SRAM 활용도도 같이 올라감.
    • 하지만 SM이 갯수가 제한적이고, 결국 리스케일링 커뮤니케이션이 필요하므로 특정 수 이상에선 성능이 좋아지진 않음. Pasted image 20240313174625
  • Recomputation (어탠션 행렬을 저장하지 않고, backward 일때 다시 계산)
    • We store the softmax normalization factor from the forward pass to quickly recompute attention on-chip in the backward pass, which is faster than the standard approach of reading the intermediate attention matrix from HBM
  • One more, kernel fusion
    • 첫 번째로 지적하는 것은, performance breakdown
    • 두 번째는 위의 알고리즘이 커널 퓨전을 가능하게 하여 최적화

Pasted image 20240309103011

결과

  • 스피드업
  • 메모리 세이빙 → 더 긴 문자열
  • 등등 Pasted image 20240309110044 Pasted image 20240309110117 Pasted image 20240309110131 Pasted image 20240309110150 Pasted image 20240309110203 Pasted image 20240309110209 Pasted image 20240309110230

참고

  • https://www.youtube.com/watch?v=FThvfkXWqtE
  • https://www.youtube.com/watch?v=gMOAud7hZg4

FlashAttention 2 (23.07.18)

문제

  • FlashAttention 이 아직 최적화가 덜 되었다.
    • 아직 GPU 를 많이 못 쓰고 있다.
      • Forward pass 이론 최대 FLOPs/s의 30-50%, backward pass는 25-35%

원인

  • 알고리즘적, 하드웨어(GPU)적 최적화가 덜 진행 됨
  • 자세한 건 아래에서 설명
  • 매우 큰 A100의 matmul/non-matmul 성능 차이
    • 312 TFLOPS/s of FP16/BE16 matmul
    • 19.5 TFLOPS/s of non-matmul FP32

해결

1. Tweak algorithm

non-matmul 연산 최소화 Pasted image 20240313204524 첫 번째 tweak FlashAttention 1

Pasted image 20240313202855

FlashAttention 2

Pasted image 20240313202910

두 번째 tweak

Pasted image 20240313211551

2. Parallelism

image

3. Work Partitioning Between Warps

  • 간단하지만 효과적인 최적화 방법
  • Warp 1 입장에서 담당한 row에 대해서 소프트맥스 값을 구하고 V를 곱해야 하는데, 기존 방식(왼쪽)은 warp 1 이 연산을 끝냈더라도, warp 2,3,4 가 모두 K에 대한 연산을 끝내야 하므로 기다렸어야 함
  • 하지만 오른쪽은 아 Pasted image 20240313211646

결과

  • A100, H100에서 실험
    • H100은 코드 변경 없이 그냥 수행만
    • GPU는 각 1개씩 쓴듯
  • Triton 은 Nvidia Triton 이 아님.
    • Harvard에서 나온 OpenAI 꺼
  • Causal mask 유무, head dimension 수 차이에 따른 실험함

Graph reading guide

  • X축은 문자열 길이
  • Y축은 학습 성능. 클 수록 좋음

    A100

Pasted image 20240314082805 image

H100

Pasted image 20240314082823

결론

  • GPU의 on-chip 메모리를 최대한 활용해야 한다.
  • non-matmul 연산 줄어야 한다.
  • 아직도 할 거 많다.

참고 https://www.youtube.com/watch?v=IoMSGuiwV3g https://www.youtube.com/watch?v=foG0ebzuw34&t=4080s https://www.youtube.com/watch?v=IoMSGuiwV3g

댓글남기기