When the Mamba-3 paper dropped (currently under ICLR review), I wanted to understand how the new math actually worked. But the official implementations of these architectures usually rely on heavily optimized custom Triton/CUDA kernels. They are incredibly fast for training, but almost impossible to read if you just want to understand the matrix math.
I spent the last few days reverse-engineering the paper to build mamba3-minimal: a pure-PyTorch, single-file implementation that runs natively on Mac (MPS), CPU, and CUDA.
It implements the three core innovations of the paper without any C++:
1. Trapezoidal Discretization & The "Shift" Hack: The new discretization rule introduces a strict sequential dependency across chunk boundaries (the beta term), which breaks standard PyTorch chunking. I managed to solve this by shifting the sequences at the global level before passing them into the chunked State Space Duality (SSD) algorithm.
2. Complex-Valued SSMs (Data-Dependent RoPE): Mamba-2 famously failed at state-tracking (scoring ~50% random guessing on parity tasks). This repo includes a test script proving the RoPE fix mathematically works, hitting 100% accuracy and extrapolating to length 64.
3. MIMO (Multi-Input Multi-Output): Standard decoding is memory-bound. I implemented the rank-expansion formulation from Appendix D, which shifts the state update from a memory-bound outer product to a compute-bound matrix multiplication, all through clean einsum operations.
The repo includes self-tests proving the O(T) chunked training pass produces the exact same logits as the O(1) sequential autoregressive step (max_diff < 1e-6).
If you've been wanting to read the math behind Mamba-3, I built this to be a readable Rosetta Stone. Would love any feedback on the implementation or the PyTorch optimizations!