We saw different results of pipelining with the Attention kernel vs the MLP kernel (since MLP W1 has to project the attention results into a much higher dimension, the arithmetic intensity shifts towards compute bound characteristics)
It’s the equivalent of doing this for compound interest rate calculation:
# A = P * (1 + r/n)^(nt) P = 10000 r = 0.06 n = 12 t = 5 A = P (1 + r / n) * (n * t)
Compared to this:
principal = 10_000 annual_interest_rate = 0.06 compounds_per_year = 12 years = 5
future_value = principal * (1 + annual_interest_rate / compounds_per_year) * (compounds_per_year * years)
My question is partly rhetorical - I know the answer lies with the tight research and mathematical origins. But that makes it research code IMO, not what I would consider high quality software code.
Let's use your example of `A = P (1 + r / n) * (n * t)` -- I can immediately see the shape of the function and how all the variables interrelated. If I'm comfortable in the domain, I also know what the variables mean. Finally, this maps perfectly to how the math is written.
If you look at everything in the post, all of the above apply. Every one in the domain has seen Q = query, K = key, V = value a billion times, and some variation of (B, N_h, T, D_h). Frankly, I've had enough exposure that after I see (B, N_h, T, D_h) once, I can parse (32, 8, 16, 16) without thinking.
I like you found this insane when I started studying stats, but overtime I realized there a lot to be gained once you've trained yourself to speak the language.
B, T, E = x.size() # batch size, sequence length, embedding dimensionality
q, k, v = self.qkv(x).split(self.embedding, dim=-1)
q, k, v = map(lambda y: y.view(B, T, self.heads, E // self.heads).transpose(1, 2))
attention = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
...
vs B, T, E = bteX.size()
iHeadSize = E // self.heads
bteQ, bteK, bteV = self.qkv_E_3E(bteX).split(E, dim=-1)
bhtiQ, bhtiK, bhtiV = map(lambda y: y.view(B, T, self.heads, iHeadSize).transpose(1, 2))
bhttAttention = (bhtiQ @ bthiK.transpose(-2, -1)) * (1.0 / iHeadSize)
Looks uglier but might be easier to reason about.
amindiro•1mo ago
I decided to rebuild it from scratch using Triton. This post is a chronicle of that journey—moving beyond the high-level algorithm and into the "performance archaeology" of the GPU:
- Profiling with Nsight Compute to find the real bottlenecks.
- Looking at the generated PTX and SASS code.
- Debugging shared memory bank conflicts and MIO bottlenecks.
- Iterating through the logic to see why tiling and online softmax are hardware-necessitated, not just mathematical tricks.
I’ve tried to keep it in the spirit of Simon Boehm’s matmul deep dive. Would love to hear from any GPU engineers on whether my interpretations of the SASS/bank conflict behavior match what you've seen in production.
liuliu•1mo ago
Do you have problem to access H100 or similar chips? Wondering if there anything can help to finish this write-up.
amindiro•1mo ago
You've hit the nail on the head regarding the missing pieces. I actually hit a bit of a wall with my current hardware; using an RTX 2070 made it difficult to meaningfully explore the async loading (TMA) and pipelining optimizations that were used in FA3 and FA4. I also felt the write-up was already pushing the limits of a single post's length, so I decided to "ship it" as a first part.
I would love to dive into TMA for Part 2. If I can get my hands on an H100 (or even an A100), that's highly appreciatediated on my end! If you have any leads on hardware access, please let me know—I’d love to finish the story!