What to Learn? CUDA vs. PyTorch vs. Jax vs. Triton/Pallas
1•bananc•2mo ago
I am a student getting into ML and for unknown reason chose to learn JAX instead of PyTorch. Get to a point where I need custom kernels and thus my only option is Pallas because I can't use Triton in JAX and never used CUDA/C++. Did I made a wrong decision and should I switch to learning PyTorch + Triton or even CUDA? Also what % of ML labs do you think write custom kernels in Triton at all?
Comments
t-vi•2mo ago
- I don't think it hurts to learn PyTorch (and having learned JAX is good, too). I don't know if JAX + triton is as impossible as you make it out, but it seems that PyTorch integration is quite good for many things.
- For pallas, triton and CUDA/C++, you probably want to know a bit about how GPU works. There is the GPU-Mode discord / lectures / ressources if you are looking for material https://github.com/gpu-mode/ .
- In my experience how well Triton works varies depending on what you want to do (depending on the how well the programming model fits the task). If it does, it is quite nice to get something reasonably fast reasonably fast. PyTorch (in the inductor torch.compile backend) has made many things work well, so you could check that out if you run out of examples elsewhere).
t-vi•2mo ago