How does the target model validate the draft tokens without running the inference as normal?
Because if it is doing just that, I don't get the point as you can't trust the draft tokens before they are validated, so you're still stuck waiting for the target model.
So your draft model can decode N new tokens, then the real model does one inference pass to score the N new drafted tokens.
Prefill is computation bound whereas decode is bandwidth bound, so in practice doing one prefill over N tokens is cheaper than doing N decode passes.
Say the model so far has "The capital of France". The small model generates "is Paris.", which let's say is 5 tokens.
You feed the large model "The capital of France is Paris." to validate all 5 of those tokens in a single forward pass.
Also, if the small model would be sufficiently more "correct" than "wrong", wouldn't be more efficient to get rid of the large model at this point?
It is about improving quality while allowing for faster speed most of the time. The tradeoff is that you consume more memory from having two models loaded vs one of them exclusively.
If you just focus on one then it would make sense to reduce memory usage by just running the smaller model.
Unsurprisingly gpt-oss has both larger and smaller models that work very similarly! Both model sizes are so similar that even if getting a few wrong would not be slowing down the performance enough to equal the speed of the larger model(which is the worst case with this setup). We want the speed of the smaller model as much as possible. That is all
or is this a scenario where computation is expensive but validation is cheap?
This takes 2 seconds time, assuming 1 second for every pass.
What I instead do is kick off f1(x) in another thread, and then run f2(g1(x)) where g1 is one pass through GPT-nano.
This takes 1 + 0.1 seconds, assuming gpt nano takes 0.1s for every pass. In this 1.1 seconds, the f1(x) that we kicked off in the 2nd thread would have finished (it takes 1 second).
So in 1.1 seconds we have available to us f1(x), f2(g1(x)), and we store the intermediate g1(x) as well
We compare g1(x) and f1(x)
If they were equal, i.e g1(x) = f1(x), then we have our answer = f2(g1(x)) in just 1.1s.
If they were not, we compute f2(output of f1(x) from 2nd thread) which takes 1 further second, bringing our total to 2.1s.
If the small model is equalling the big model in say 2/3 of cases, you will spend 2/3 * 1.1 + 1/3 * 2.1 = 1.433s on average for this computation. Without speculative decoding, it is always 2s.
Now I see they tried to point out the obvious thing which is to predict multiple tokens ahead, not just two as in your example.
It does run the inference as normal, just in parallel with the other inferences
> if it is doing just that, I don't get the point
Running inferences in parallel allows you to only read the model weights out of memory only once for N parallel inferences, as opposed to reading them out of memory N times for N serial inferences. Inference is massively bottlenecked by memory bandwidth to the tune of one or two orders of magnitude compared to compute, so this helps a lot.
Nitpick: it's only bottlenecked by memory bandwidth if the batch size is too low (that is: if you don't have many users calling the same model in parallel).
Speculative decoding is just a way of running a single query as if it was parallel queries.
For home use, Gemma27B QAT is king. Its almost as good as Deepseek R1
tmshapland•1h ago
acters•44m ago