I’ve recently been digging into the PyTorch FSDP implementation. It’s powerful and highly optimized, which naturally means the codebase is extensive and isn't always straightforward to navigate. In the process, I decided to write a minimal implementation based on my findings, mainly to emphasize and show the different states and pre/post forward/backward of FSDP, all in a single place!
xnan•2h ago
I’ve recently been digging into the PyTorch FSDP implementation. It’s powerful and highly optimized, which naturally means the codebase is extensive and isn't always straightforward to navigate. In the process, I decided to write a minimal implementation based on my findings, mainly to emphasize and show the different states and pre/post forward/backward of FSDP, all in a single place!
Hope this helps others!