What to Learn? CUDA vs. PyTorch vs. Jax vs. Triton/Pallas
1•bananc•2h ago
I am a student getting into ML and for unknown reason chose to learn JAX instead of PyTorch. Get to a point where I need custom kernels and thus my only option is Pallas because I can't use Triton in JAX and never used CUDA/C++. Did I made a wrong decision and should I switch to learning PyTorch + Triton or even CUDA? Also what % of ML labs do you think write custom kernels in Triton at all?