GPU 최적화를 TPU에 강제로 적용해보고 깨달은 하드웨어와 컴파일러의 본질
최근 Hacker News에서 꽤 흥미로운 글을 하나 읽었습니다. Archer Zhang이라는 엔지니어가 작성한 ‘Forcing Flash Attention onto a TPU and Learning the Hard Way’라는 포스트입니다. 재미있게도 HN 댓글란은 “Claude가 쓴 글 같다”, “마크다운이 이상해서 읽기 힘들다”며 문체를 물어뜯기 바빴습니다. 솔직히 요즘 AI로 글을 다듬는 게 무슨 큰 죄도 아니고, 기술 블로그에서 문학적 완성도를 찾는 HN 특유의 엘리트주의가 또 발동했구나 싶었습니다.
하지만 문체에 대한 불평을 걷어내고 보면, 이 글은 근래 제가 읽은 가장 훌륭한 엔지니어링 포스트모템 중 하나입니다. CUDA와 Triton으로 GPU 최적화에 익숙해진 엔지니어가 TPU와 JAX라는 완전히 다른 패러다임에 부딪히며 겪은 처절한 삽질기이자, 하드웨어 아키텍처에 대한 깊은 통찰을 담고 있기 때문입니다.
오늘은 이 글을 바탕으로, 우리가 왜 맹목적으로 ‘Flash Attention’이라는 은통알을 쫓으면 안 되는지, 그리고 JAX/XLA 환경에서 컴파일러와 소통하는 방법이 어떻게 달라야 하는지 제 개인적인 경험을 곁들여 파헤쳐 보겠습니다.
착각: GPU 커널을 JAX로 번역하면 끝이겠지?
GPU에서 Triton으로 Flash Attention 커널을 짜본 사람이라면 누구나 비슷한 유혹에 빠집니다. 알고리즘의 동작 원리도 완벽히 이해했고, 메모리 타일링(Tiling)도 해봤습니다. JAX는 그저 ‘컴파일되는 Numpy’일 뿐이니, Triton 코드를 JAX 배열 연산으로 번역해서 TPU에 올리면 당연히 빠를 것이라고 생각합니다.
저자 역시 동일한 접근을 했습니다. GPU에서 하던 방식 그대로, JAX의 jax.lax.fori_loop를 이용해 Q 블록과 KV 블록을 순회하며 Online Softmax 상태(running_max, running_sum)를 업데이트하는 코드를 짰습니다.
결과는 어땠을까요? 처참했습니다. n=4096(Sequence Length) 기준으로, 저자가 직접 짠 Flash Attention은 XLA가 알아서 최적화한 기본 Attention(Fused Standard Attention)보다 무려 35배나 느렸습니다. 최적화를 하겠다고 짠 코드가 오히려 시스템을 망가뜨린 전형적인 안티 패턴입니다.
왜 실패했는가: 프로그래머 vs 컴파일러
이 실패의 근본적인 원인은 프로그래밍 모델의 차이에 있습니다.
GPU 환경(Triton/CUDA)에서 프로그래머는 사실상 컴파일러입니다. 어떤 데이터를 언제 SRAM에 올리고, 언제 HBM으로 내릴지 포인터 연산을 통해 직접 통제합니다. 하지만 JAX 환경에서 프로그래머는 코드를 실행하는 주체가 아닙니다. XLA(Accelerated Linear Algebra) 컴파일러에게 ‘내가 무엇을 하고 싶은지’ 의도를 전달하는 명세서를 작성할 뿐입니다.
저자가 사용한 fori_loop는 XLA에게 다음과 같은 잘못된 시그널을 주었습니다.
- “이 루프는 상태를 들고 다니는 순차적인(Sequential) 작업이야.”
사실 Q 블록들은 서로 아무런 데이터 의존성이 없습니다. 각각 독립적으로 K/V를 읽고 자신의 결과를 출력하면 됩니다. 하지만 fori_loop라는 불투명한 껍데기에 갇힌 탓에, XLA는 이 루프들이 독립적이라는 사실을 추론하지 못했습니다. 결과적으로 메모리 로드 파이프라이닝도, 병렬 처리도 불가능해졌죠.
해결책: vmap을 통한 의도 전달
해결책은 더 복잡한 알고리즘이 아니었습니다. 단지 컴파일러에게 ‘이 작업들은 독립적이다’라고 알려주는 것뿐이었습니다. 저자는 바깥쪽 Q 루프를 jax.vmap으로 교체했습니다.
# 기존 방식: XLA의 눈을 가리는 순차적 루프
running_max, running_sum, acc = jax.lax.fori_loop(
0, num_k_blocks, k_body, (running_max, running_sum, acc)
)
# 개선된 방식: vmap을 통해 독립성을 명시
def one_q_block(q_tile: jax.Array, q_start: jax.Array) -> jax.Array:
# ... 독립적인 Q 블록 처리 로직 ...
return out_tile
# 컴파일러가 전체 배치를 한 번에 보고 최적화 수행
all_tiles = jax.vmap(one_q_block)(q_blocks, q_starts)
이 단순한 변경 하나로 성능은 45배나 뛰어올랐습니다. 똑같은 알고리즘, 똑같은 타일링, 똑같은 수학 연산이었습니다. 달라진 건 단 하나, JAX의 vmap 추상화를 통해 XLA 컴파일러와 제대로 소통했다는 점입니다.
저는 시니어 엔지니어들의 코드 리뷰를 할 때 이런 상황을 자주 봅니다. 스마트한 컴파일러(LLVM, XLA 등)를 믿지 못하고 C 시절의 버릇처럼 루프를 강제하거나 imperative한 최적화를 시도하다가, 오히려 컴파일러의 Dataflow Analysis를 방해해 성능을 깎아먹는 경우 말입니다. JAX를 쓴다면 XLA와 싸우지 말고, XLA가 일하기 좋은 형태로 의도를 던져주어야 합니다.
하드웨어의 현실: VMEM vs SRAM
하지만 이 스토리의 진짜 클라이맥스는 벤치마크 결과에 숨어있습니다. vmap으로 최적화한 Flash Attention조차도, n=4096까지는 XLA가 알아서 융합(auto-fuse)해버린 기본 Attention 로직을 이기지 못했습니다. n=8192가 되어서야 비로소 Flash Attention이 역전하기 시작하죠.
왜 그럴까요? 바로 하드웨어 아키텍처, 구체적으로는 온칩 메모리(On-chip Memory)의 크기 때문입니다.
- NVIDIA A100 GPU: SM당 Shared Memory가 약 164KB 수준입니다.
- Google TPU v5e: 칩 전체가 공유하는 VMEM이 무려 약 128MB에 달합니다.
GPU에서는 n*n 크기의 Attention Score 행렬이 SRAM에 절대 들어가지 않습니다. HBM(VRAM)으로 쫓겨나는 순간 엄청난 병목이 발생하기 때문에, 메모리를 쪼개서 계산하는 Flash Attention이 ‘필수 생존 기법’입니다.
반면 TPU는 다릅니다. VMEM이 128MB나 되기 때문에, n=4096 수준의 Score 행렬(약 64MB)은 그냥 칩 내부에 넉넉하게 들어갑니다. 하드웨어 공간이 텅텅 비어있는데 굳이 소프트웨어 레벨에서 타일링을 하며 오버헤드를 만들 이유가 없는 것입니다. n=8192(Score 행렬 256MB)가 되어서야 VMEM 용량을 초과하게 되고, 이때부터 타일링(Flash Attention)이 빛을 발합니다.
또한 TPU의 핵심 연산 유닛인 MXU(Matrix Multiply Unit)는 그 자체로 128x128 크기의 거대한 Systolic Array입니다. 가중치를 고정해두고 데이터를 흘려보내는(Weight-stationary) 구조이기 때문에, 타일 크기가 작으면(예: Flash Attention의 내부 연산) 파이프라인이 충분히 채워지지 않아 활용률(Utilization)이 20%대까지 떨어집니다. 반면 통짜 행렬을 때려 넣는 기본 Attention은 94%의 활용률을 보여줍니다.
결론: 맹목적인 최적화는 독이다
이 에피소드가 우리에게 주는 교훈은 명확합니다.
첫째, 하드웨어가 다르면 최적화의 상식도 바뀝니다. GPU에서 세상을 구원한 기법이 TPU에서는 불필요한 오버헤드일 수 있습니다. ‘Flash Attention은 무조건 빠르다’라는 맹목적인 믿음 대신, 내가 타겟팅하는 하드웨어의 L1/L2 캐시 크기와 레지스터 구조를 먼저 파악해야 합니다.
둘째, 현대의 컴파일러는 여러분의 생각보다 훨씬 똑똑합니다. JAX나 PyTorch 2.0(TorchDynamo) 같은 프레임워크를 사용할 때는, 로우레벨 제어권을 쥐려고 안달하기보다는 ‘어떻게 하면 이 연산 그래프의 병렬성과 독립성을 컴파일러에게 잘 떠먹여 줄 수 있을까’를 고민해야 합니다.
만약 TPU에서 정말로 극한의 최적화가 필요하다면, JAX 레벨이 아니라 Google의 Pallas(TPU용 Triton)를 사용해 DMA 더블 버퍼링과 비동기 파이프라이닝을 직접 제어해야 합니다. 하지만 99%의 프로덕션 환경에서는? 그냥 깔끔하게 수학적 의도를 코드로 작성하고, 나머지는 XLA에게 맡기는 것이 맞습니다.
HN 커뮤니티가 이 글의 문체를 비판한 것은 그들의 자유지만, 이 글에 담긴 아키텍처 레벨의 디버깅 과정은 시니어 엔지니어라면 반드시 곱씹어볼 만한 가치가 있습니다. 코드를 짜기 전에 하드웨어 스펙 시트부터 읽는 습관, 저부터 다시 다잡아야겠습니다.