I've been brainstorming ideas recently, and one paper that caught my attention was Yann LeCunn's leJEPA paper. It claims to solve a large host of problems with joint embedding model training, and it had me thinking...
What if you simply replace the discrete tokenizer used by LLMs with joint embeddings, and make your autoregressive language model, a "predict the next latent embedding"
For example:
- Write some software to convert text to images where every 8x8 block (or maybe 16x16?) contains a character or whitespace. Can incorporate augmentations like jitter and font changes. - Train a leJEPA VIT model on generated text "images" using SSL to create embeddings from these "images"
- Freeze the leJEPA trained VIT embedding model, and use it as a frozen embedding layer for an autoregressive transformer based model that "predicts the next embedding"
- With the embedding model and the autoregressive latent predictor frozen, train a decoder that translates embeddings into discrete tokenized text.
I can see the following benefits:
- No discrete tokenizer for input
- Autoregressive latent predictor model quickly outputs full image scale concepts rather than individual discrete tokens and can be run asynchronously very quickly compared to the embedding -> discrete text model
- Cohesive multimodality built in... text-free images are still images that can result in latents, perhaps with finetuning on pure image datasets.
In my mind this would be more akin to how humans think - with far superior image recall than text sequence recall and thinking abstractly before speaking or typing language.