Flash Attention2 논문을 읽다 보니 이해가 안 되는 부분이 많았다. 팀 내 ML 연구원들은 GPU 파트를 이해 못 했는데 난 오히려 GPU 파트 보단 ML 파트가 이해가 안 됐다;;

그 중에서도 Flash Attention 2 의 핵심 아이디어인 softmax 병렬 계산이 도통 이해가 안 돼서 관련 논문을 열심히 읽었다. 인용수가 22밖에 안돼서 그런가 자료가 매우 없어서 아쉬웠고, 그래서 나의 부족한 글이라도 누군가가에게 도움이 돼지 않을까 하여 올려본다.

논문 URL : https://arxiv.org/abs/1805.02867

tl;dr

  • 마지막에 리스케일링을 해주면 softmax 를 병렬로 계산 할 수 있다.

1. Motivation - Attention requires O(NxN) memory

image
수식 1. Attention

이 논문에서는 Attention을 다루진 않지만, 나는 Flash Attention 의 softmax 를 기준으로 이해하기 위해서 보는 것이므로, Flash Attention 논문에서 attention 을 가지고 왔다.

여기서 강조하고 싶은 것은 위의 attention matrix, S가 최근 long sequence가 유행하는 LLM에서는 점 점 커지고 있다는 것이다. 심지어 N x N 이기 때문에 (여기서 N은 인풋 길이, 토큰 수) 선형이 아니라 제곱으로 그 크기가 증가하고 있다.

본 논문에서는, 이러한 많은 메모리 및 연산 시간을 요구하는 비싼 softmax 연산을 어떻게 메모리, 시간 적으로 최적화 할지 다룬다.

2. Potential Problem of Standard(Naive) Softmax

먼저, 기본 Softmax 수식을 살펴보겠다. 아래 수식은 수학적으로는 아무런 문제가 없을 수 있으나, 실제로 컴퓨터에서 동작을 하면 분모가 너무 커져서 오버/언더플로우가 발생 할 수 있다. 따라서 TensorFlow, PyTorch 등은 뒤에서 설명할 Safe Softmax 를 사용한다. (편의상 앞으론 오버/언더플로우를 오버플로우라고만 표현하겠다.)

image
수식 2. (Naive) Softmax

여기서 한 가지 더 살펴볼 점은, 메모리 엑세스 횟수이다. 아래 <코드1> 에서 볼 수 있듯이 standard softmax 는 각각의 값을 구하기 위해서 메모리를 총 3번 엑세스 하는데,

  1. (3번째 줄) dj 를 구할때 모든 exj 에 대해 한 번 씩 (Load)
  2. (6번째 줄, 우항) yi 를 구할때 또 exi 한번씩 (Load)
  3. (6번째 줄, 좌항) 그리고 yi 에 값을 저장할때 또 한 번씩 (Store) 이렇게 softmax 를 구할때 총 O(3V) 번 메모리에 접근한다. (V=N)

image
코드 1. Softmax

이러한 메모리 엑세스 횟수는 상수이기 때문에 굳이 최적화를 해야 하나? 라고 생각 할 수 있다. 하지만 시간에 예민한 프로그램의 경우, O(3N)을 O(2N)으로 낮추면 실제론 1.5배의 성능 향상을 도출해내므로 충분히 살펴볼만한 포인트라고 말 할 수 있다.

3. Safe Softmax

앞서 기본 형태의 softmax는 오버플로우 문제가 있다고 언급했고, 또한 그것을 해결하기 위한 것이 이 safe softmax 라는 것을 말했다. 아래 수식을 보면 Safe softmax 는 오버플로우 문제를 해결하기 위해 단순히 최댓값을 빼주어서 scale down 해준다.

왜 최댓값일까? 사실 난 이 부분이 직관적으로 이해가 되진 않았다. 그런데 후에 곰곰이 생각해보니, 결국 오버플로우가 발생하는 이유는 상대적으로 큰 수치 때문이니까 그럴 것 같다는 생각이 들었다. (누군가에게는 당연 할 수 있지만, 난 헤맸어서 기록해 본다)

image
수식 3. Safe Softmax

여기서 주목해야 할 점은 최댓값을 구해줘야 함으로써 메모리 접근 수가 한 번 더 늘었다는 것이다. 아래 <코드2> 에서 볼 수 있듯이이 3번째 줄에서 최댓값을 구하기 위해 O(V)만큼 메모리를 더 접근해야 한다. 따라서 하나의 softmax 값을 구하기 위해 이전에 3번만 접근하면 됐다면 이젠 4번 접근을 해야 하는 상황인 것이다.

image
코드 2. Safe Softmax

앞서 설명했듯이 메모리 접근 횟수가 한 번 늘어난 상황이지만 (3번에서 4번), 성능은 25% 저하 됐다. 특히나 GPU 에서는 이러한 성능 저하가 더욱 두드러진다.

아래 <그림1>은 FlashAttention 논문에서 가져온 그림인데, GPU는 작지만 빠른 SRAM크지만 느린 HBM 으로 메모리가 구성되어있다. 이때, 우리가 다루고 있는 softmax matrix는 너무 커서 SRAM에 모두 두기 힘들다. 따라서 크지만 느린 HBM 으로의 메모리 연산이 하나 더 증가하는 것은 전체 성능에 많은 영향을 끼친다. 참고로 HBM은 SRAM 보다 대역폭은 크지만 접근 속도는 약 8배 드리다. (Hopper 기준)

image
그림 1. GPU/CPU 메모리 계층 및 대역폭, 사이즈

4. Online Safe Softmax

이러한 문제를 해결하기 위해 제안 한 것이 저자들의 Online Safe Softmax 이다. (이 코드 처음 이해 했을때 저자들 진짜 천재 같았음..)

image
코드 3. Online Safe Softmax

위의 <코드3>에 본 논문의 핵심 아이디어가 서술 되어있다. 주목 할 부분은 3-6번째 줄인데, <코드2>의 2-8줄과 비교해보면 for loop 이 하나 없다는 걸 파악 할 수 있다.

기존 <코드2>에서는 max 값을 구하고, 그 값을 이용해서 dj 를 구했다면, 여기서는 실시간으로 현재 최댓값을 이용해서 dj 를 구하여 오버플로우를 억제해주고, 나중에 그 값을 정상화(리스케일링) 해준다.

이렇게 해주면 메모리 접근 횟수가 한 번 줄어들게 된다. 좀 더 자세히 설명하자면 기존 <코드2>에서는 max 값을 구하려고 모든 xj 에 접근하면서 메모리 접근 한 번씩, 그리고 dj 를 구하느라 다시 한 번 xj에 접근하는데, 위의 <코드3>에서는 xj 에 한 번 엑세스하고 dj 를 구해 버리므로 메모리 접근 횟수가 한 번 줄어든다. (총 4에서 3으로 감소)

반면 연산 자체는 증가한 것 아닌가? 이런식으로 생각 할 수 있다. 하지만 GPU는 HBM 메모리 접근보다 연산이 싸다!

이 알고리즘에서의 핵심은, 모든 xmax 값을 구하지 않고 dj 값을 구하는 것인데, 어떻게 가능한지 아래에서 설명하겠다.

<수식3> 으로 돌아가보자. <수식3>의 분모를 l 이라하면, 아래와 같다. (편의상 ji로, x의 최댓값을 max(x) 로 표현하였다.)

\[l = \displaystyle\sum_{i=1}^N e^{x_i-max(x)} = \displaystyle\sum_{i=1}^N {e^{x_i} \over e^{max(x)}}\] \[= {e^{x_1} \over e^{max(x)} } + {e^{x_2} \over e^{max(x)} } + ... + {e^{x_N} \over e^{max(x)} } = { e^{x_1} + e^{x_2} + ... + e^{x_N} \over e^{max(x)}}\]

수식4. Safe softmax의 분모 풀어쓰기

수학적으로만 본다면 ex1 + ex2 + … + exN 를 모두 더한 뒤에 마지막에 emax(x) 를 빼주어도 된다. 물론 앞서 말했듯이 최댓값을 빼주는 이유가 오버플로우 발생 예방이기 때문에 이렇게 마지막에 빼주려고 한다면 이미 오버플로우는 발생했겠지만 말이다.

다시 말하면, 적당한 숫자로 나눠주면서 오버플로우만 방지한다면, 정확한 값을 얻기 위한 마지막에만 emax(x) 로 나눠주면 된다는 말이다. (아래 수식에선 적당한 숫자를 esome_of(x) 라 표현 함)

\[l = { e^{x_1} + e^{x_2} + ... + e^{x_N} \over e^{max(x)}} = { e^{x_1} + e^{x_2} + ... + e^{x_N} \over e^{some\_of(x)}} { e^{some\_of(x)} \over e^{max(x)}}\]

수식5. 리스케일링

오버플로우를 방지 할 수 있는 적당한 숫자란 무엇일까? 최댓값이다. 단, 전체 x 의 최댓값은 모두 탐색해야 알 수 있으니까, 단순히 여태 까지 본 x 중에 최댓값이면 된다. (<코드3>의 4-5번째 줄). 아래 <수식6>의 마지막 항을 보면, dj-1emj-1 (이전까지 최댓값) 을 곱하여 없애주고 다시 emj (현재까지의 최댓값) 로 분모를 취해주는 것을 볼 수 있다.

\[d_j = d_{j-1} \times e^{ {m_j-1}-{m_j}} + e^{ {x_j}-{m_j}} = d_{j-1} \times {e^{m_j-1} \over {e^{m_j}}} + {e^{x_j} \over e^{m_j}}\]

수식6. 코드3의 5번째 줄 수식화

N이 4이고 x = {3,4,2,5} 일 때의 예제를 아래에 풀어서 설명해놓았다.

최종적으로 d4 가 아래와 같은 형태이면 되는데,

\[d_4 = {e^{3} + e^{4} +e^{2} +e^{5} \over e^{5}}\]

d1은 직관적이니 넘어가고,

\[d_1 = {e^{3} \over e^{3}}\] \[d_2 = {e^{3} \over e^{3}} \times {e^{3} \over e^{4} } + {e^{4} \over e^{4}} = {e^3 + e^{4} \over e^{4}}\] \[d_3 = {e^{3} +e^{4} \over e^{4}} \times {e^{4} \over e^{4} } + {e^{2} \over e^{4}} = {e^3 + e^4 + e^2 \over e^4}\] \[d_4 = {e^{3} +e^{4} +e^{2} \over e^{4}} \times {e^{4} \over e^{5} } + {e^{5} \over e^{5}} = {e^3 + e^4 + e^2 + e^5 \over e^5}\]

d2 의 경우, d1에 이전까지의 최댓값(e3)을 곱해주고, 현재의 최댓값(e4)으로 나누어 준다. 그리고 현재의 최댓값(e4) 분의 현재값(e4)을 더해주면 된다.

예제1. N이 4이고 x = {3,4,2,5} 일 때

이런식으로 수행하면 전체의 max 를 구하지 않고도 위의 d4 를 만들 수 있어서 메모리 엑세스를 줄일 수 있다.

5. Parallel Online Safe Softmax

이를 이용해서 local softmax를 구하고, 마지막에 리스케일링(전체 최댓값으로 나눠주는 행위)를 함으로써 선형 메모리를 사용하며 safe softmax 를 병렬로 진행 할 수 있다.

6. Summary

  • Softmax 연산은 시간, 공간 복잡도가 N의 제곱이다. → 매우 비싸고 느림
  • Online softmax 를 사용하면 softmax 를 병렬화 할 수 있다. → 매우 비싸고 빠름
  • GPU 메모리 엑세스 최적화로 더욱 효율적이게 softmax 를 할 수 있다. → 비싸고 빠름 (매우가 빠짐. 중요)

참고

  • https://www.youtube.com/watch?v=lpBJHUU4w6k 내가 하고 있는 고민은 항상 (인도의) 누군가가 먼저 하고 있다!

댓글남기기