Hacker News

I rebuilt FlashAttention in Triton to understand the performance archaeology

80 points by amindiro ago | 17 comments

amindiro |next [-]

I’ve spent the last few weeks deconstructing FlashAttention. While the original paper is brilliant, I found that just reading it didn't give me a "gut feeling" for why certain engineering choices were made (the transition from v1 to v2).

I decided to rebuild it from scratch using Triton. This post is a chronicle of that journey—moving beyond the high-level algorithm and into the "performance archaeology" of the GPU:

- Profiling with Nsight Compute to find the real bottlenecks.

- Looking at the generated PTX and SASS code.

- Debugging shared memory bank conflicts and MIO bottlenecks.

- Iterating through the logic to see why tiling and online softmax are hardware-necessitated, not just mathematical tricks.

I’ve tried to keep it in the spirit of Simon Boehm’s matmul deep dive. Would love to hear from any GPU engineers on whether my interpretations of the SASS/bank conflict behavior match what you've seen in production.

liuliu |root |parent [-]

I hope you finish this one though. It starts strong (I particularly liked how you looked into ncu and shows what each recommendation means, this is very helpful for beginners), but ends with something not satisfying. You didn't explore tensor core (particularly, fp16 / tf32 / bf16), and swizzling (which is the right way to solve the K transpose issue, especially giving Triton itself provides a few ways to do this), and / or async loading (pipelining).

Do you have problem to access H100 or similar chips? Wondering if there anything can help to finish this write-up.

hyperbovine |next |previous [-]

I still don't understand why certain performance aspects of the CUDA platform are so poorly documented. Why is successfully pushing the hw to its performance envelope considered a novel research result? Shouldn't I be able to look this stuff up on the Nvidia website?

amindiro |root |parent [-]

One reason is clearly the fast past at which nvidia is evolving the hardware. I would consider cuda a very well documented platform in general. What they lack is low level tutorials, but this is where posts like this one can be a good resource

sheepscreek |next |previous [-]

What’s with GPU engineers using such unreadable variable names (to anyone outside the immediate domain)?

It’s the equivalent of doing this for compound interest rate calculation:

# A = P * (1 + r/n)^(nt) P = 10000 r = 0.06 n = 12 t = 5 A = P (1 + r / n) * (n * t)

Compared to this:

principal = 10_000 annual_interest_rate = 0.06 compounds_per_year = 12 years = 5

future_value = principal * (1 + annual_interest_rate / compounds_per_year) * (compounds_per_year * years)

My question is partly rhetorical - I know the answer lies with the tight research and mathematical origins. But that makes it research code IMO, not what I would consider high quality software code.

tornikeo |root |parent |next [-]

I think it's a combination of multiple factors. I worked with GPU kernel codes before and the code that you write has a tendency of never being updated or modified. once it works it works perfectly and you do not change it. if you get new hardware you're going to fully rewrite it. so, typically readability is just not useful. also, you're never working with variables that make sense to humans. it's never something tangible. it's always tiles, offsets, indices. i do not think, at least when I was writing the code for GPUS to waste space visual space on better variable naming was worthwhile.

ljlolel |root |parent |next |previous [-]

PhD dropout here: When you’re implementing a math algorithm you can’t really self document. So you have the pdf of the paper and a clear formula, then best to link to that and just implement the formula exactly with same variables.

fny |root |parent |next |previous [-]

I'm a former Ruby guy who ended up in stats/ML for a time. I think it's all about information density.

Let's use your example of `A = P (1 + r / n) * (n * t)` -- I can immediately see the shape of the function and how all the variables interrelated. If I'm comfortable in the domain, I also know what the variables mean. Finally, this maps perfectly to how the math is written.

If you look at everything in the post, all of the above apply. Every one in the domain has seen Q = query, K = key, V = value a billion times, and some variation of (B, N_h, T, D_h). Frankly, I've had enough exposure that after I see (B, N_h, T, D_h) once, I can parse (32, 8, 16, 16) without thinking.

I like you found this insane when I started studying stats, but overtime I realized there a lot to be gained once you've trained yourself to speak the language.

pryelluw |root |parent |previous [-]

Bad programmers. Researchers usually (though sometimes not) are bad at programming. Hence why I don’t do projects for academia.

fancy_pantser |next |previous [-]

When OpenAI announced the Triton language, I was worried I'd be confused one day while reading something because of Nvidia's open-source Triton inference server. I made it quite a long time, but it finally happened today! I was so intrigued for the first few pages and then deeply confused.

rishabhaiover |next |previous [-]

I did an experiment on FlashAttention in Triton to measure the impact of caching tiles in the Shared Memory. Surprisingly, it had a non-monotonic relationship with prefetching these tiles and it was kernel dependent. Attention kernel benefits from prefetching caches while MLP W1 doesn't.

amindiro |root |parent [-]

Very interesting and Would love to see the experiments. Quick question: what do you mean about kernel dependent ?

rishabhaiover |root |parent [-]

Sorry for not being clear. We had two different CUDA functions, one was for Attention and one was for the MLP. Here's the kernel code: https://github.com/sankirthk/GPT2-Kernel-Fusion/blob/main/ke...

We saw different results of pipelining with the Attention kernel vs the MLP kernel (since MLP W1 has to project the attention results into a much higher dimension, the arithmetic intensity shifts towards compute bound characteristics)

amindiro |root |parent [-]

Agreed, this observation holds true for both decode and prefill. Thanks for sharing the code

raphaelty |next |previous [-]

Very interesting, wondering if there are other heavily used algorithm which could benefit a lot from a "Flash" version but don't have one today

npalli |previous [-]

Seems very detailed and comprehensive. Did I miss it, but was there a performance comparison to the PyTorch version at the top?

amindiro |root |parent [-]

Hi thanks for feedback! That’s a good point I did compare to torch but at a high enough sequence length (~1024) torch version starts OOM because it has to materialize the S^2 in global mem. On small sequence length, torch does win solely on optimised cublas matmuls