Standard exact Cross-Entropy instantly OOMs on 16GB GPUs at that scale.
To bypass this, I implemented MAXIS Loss. It uses a "Ghost Logit" to mathematically simulate the missing probability mass of unsampled tokens, rather than materializing the full 262k-wide matrix.
Benchmarks on a 16GB VRAM card (T4):
17.5x faster in the loss layer compared to the Triton-optimized Liger Kernel.
~39% VRAM reduction in the objective calculation. Includes RandNLA Attention, which uses Causal Kronecker Sketching to keep memory flat as sequence length grows.
I’ve included technical reports with the formal math in the repository. I would love any technical feedback on the partition function simulation or the sketching approach.