– Core Mamba2 block with LM (Mamba2ForCausalLM) and time-series (Mamba2Forecaster) heads – Pure JAX/Flax (no Triton/custom CUDA), runs on CPU / CUDA / TPU via standard JAX backends – Small CPU-only parity test vs mamba2-torch: similar loss curves, final MSE diff ≈ 0.012, prediction correlation ≈ 0.99; after JIT warmup JAX was ≈ 2× faster per step
I’d really appreciate feedback on: – API design, especially for streaming/stateful inference – Performance gotchas you hit if you try it – Any hooks you’d want exposed for research use
PyPI: https://pypi.org/project/mamba2-jax/
Thanks, Cosmo