Gemma-4 spec decoding on Apple Silicon
Qwen ships an MTP head inside every Qwen3.5 and Qwen3.6 checkpoint. OptIQ uses it to get 1.20x to 1.40x decode speedups (see last week's MTP post). Gemma-4 does not ship an MTP head. Google instead publishes a separate small model alongside the target, with the suffix -assistant, that you load as a draft model.
The first time we tried to wire it up we got nowhere and parked it. This time we kept going. The end state is a generic spec runtime in optiq/runtime/spec/, a working -assistant drafter for Gemma-4 E4B, and a "Spec drafter" picker in the OptIQ Lab Server page that flips it on. Greedy γ=1, 1.18x decode geomean across five prompt categories, 31% acceptance.
This is the writeup of how we got there, including the two RMSNorm-shaped bugs that took us from 0% to 33% acceptance and the bf16 numeric artifact we still cannot route around.
Why the first attempt did not work
When we first looked at Gemma spec decoding, we tried to make it fit the same shape as Qwen MTP. That meant treating the drafter as a head we could attach to the target and calling mlx-lm's built-in speculative path. Neither premise survived contact with the model architecture.
The Gemma -assistant drafter is not a head bolted onto the target. It is a 4-layer Q-only transformer. It has its own attention layers, its own MLP layers, its own RMSNorms, its own RoPE, its own output projection. The thing that makes it Q-only is that those attention layers compute only the Q projection. K and V come from the target's KV cache, two specific donor layers, one sliding-window and one full-attention, picked by Google during training. The drafter is an EAGLE-family model in spirit, not an MTP one.
This means three things had to be built fresh:
- A loader for the drafter weights that knows about the centroid-clustering output head, the per-block RoPE variants, and the Gemma-specific normalization.
- A KV viewer for the target that can produce typed K and V tensors in chronological order, accounting for the fact that one donor uses a
RotatingKVCache(ring buffer, sliding window) and the other uses a plainKVCache(append-only, with stale tail). - A spec loop that knows to feed the drafter the last emitted token's embedding scaled by
sqrt(hidden_size), plus the target's last hidden state, plus the shared KV.
mlx-lm does not provide any of these out of the box. We checked. There is an open issue tracking general drafter support but no implementation. Ollama merged a Gemma -assistant spec PR in February; their code is MIT-licensed and we used it as the algorithmic reference.
Building the drafter
The drafter weights load from the HuggingFace repo mlx-community/gemma-4-E4B-it-assistant-bf16 as a plain safetensors file. The model class lives in optiq/runtime/spec/drafters/gemma_assistant.py and reproduces the structure exactly: four blocks, three with sliding-window attention, one with full attention; an embedding-projection head that maps the target hidden into the drafter's smaller hidden; a centroid output head that splits the vocab into 2048 centroids of 128 tokens each and emits over top-K=32 centroids.
The first version loaded the weights and emitted tokens. Acceptance was 0%. Not "low", actually zero. Across thousands of drafted tokens, none matched the target's argmax.
Bug one: RMSNorm formula
Gemma 1, Gemma 2, and Gemma 3 use the "scale plus one" RMSNorm convention. The norm operator is x_normed * (1 + weight), where weight is a learned per-channel parameter that drifts from zero. This is what mlx-lm's older Gemma classes do.
Gemma 4 changes this. The weight is stored as-is, no plus-one. The norm operator is x_normed * weight, full stop. We checked mlx-lm's gemma4_text.py after our forward passes were producing nonsense activations. The Gemma 4 weights are stored at typical scale, ranging from about 0.5 to 64, not the near-zero values that the plus-one convention assumes. Applying plus-one would multiply by 1.5 to 65 in places where the model expects 0.5 to 64. We instrumented the per-layer norm magnitudes and watched activations explode from 512 to 2181 across the four blocks.
Replacing the formula with mx.fast.rms_norm(x, self.weight, self.eps) directly took us from 0% acceptance to 3%. Not great, but no longer zero, which meant the rest of the math was at least plausibly right.
Bug two: layer scalar formula
Gemma 4 has a per-block scalar applied as a residual gate. The block returns residual + h * layer_scalar. The drafter's checkpoint stores layer_scalar as a single learned scalar per block.
We had it as h * (1 + layer_scalar), by analogy with the plus-one RMSNorm convention we just fixed. We were wrong twice in a row. The actual operation in mlx-lm's gemma4_text.py:386-387 is plain h * scalar.
Fixing this took us from 3% acceptance to 33%. Drafter output started looking like real Gemma output. Speedup landed at 1.21x on the math prompt.
Bug three (still open): bf16 precision drift in multi-token verify
Spec decoding's correctness check is byte-identity. If you turn off the drafter, the output of the spec loop has to be identical, token for token, to a normal greedy run. We benchmarked five prompts and reported the prefix length over which our spec output matched the baseline output:
| Prompt | Speedup | Acceptance | Prefix match |
|---|---|---|---|
| math | 1.29x | 37.5% | 26 / 200 |
| code | 1.25x | 34.0% | 45 / 200 |
| prose | 1.18x | 30.3% | 24 / 200 |
| dialogue | 1.11x | 29.5% | 69 / 200 |
| reasoning | 1.06x | 25.5% | 200 / 200 |
| Geomean | 1.18x | 31.4% |
The reasoning prompt matched all 200 tokens. The others drifted partway through. We sat on this for a while because if our spec loop was correct, the matching prefix should be 200 for every prompt, not just one.
The root cause is in mlx-lm itself. When you feed two tokens to a Gemma 4 forward pass at once and look at the logits for the second position, you do not get exactly the same logits you would have gotten by feeding the two tokens one at a time. We instrumented this with a four-line probe:
import mlx.core as mx from mlx_lm import load model, tok = load("mlx-community/gemma-4-E4B-it-4bit") prompt_ids = mx.array(tok.encode("Tell me about Apple Silicon."))[None] # Sequential: feed A, get logits at A; then feed B, get logits at B. cache_s = make_prompt_cache(model) _ = model(prompt_ids, cache=cache_s) log_A_s = model(mx.array([[100]]), cache=cache_s) log_B_s = model(mx.array([[200]]), cache=cache_s) # Together: feed [A, B] in one shot. cache_t = make_prompt_cache(model) _ = model(prompt_ids, cache=cache_t) log_AB_t = model(mx.array([[100, 200]]), cache=cache_t) print("diff at B:", mx.max(mx.abs(log_AB_t[0, 1] - log_B_s[0, 0])))
The output: diff at B: 0.679688. That is the maximum absolute difference between two logit vectors that should be identical. In bf16 magnitudes (top values around 51), 0.68 corresponds to about 1.3% relative drift. Most of the time the argmax still matches. Occasionally it does not.
This is bf16 attention precision interacting with how mlx-lm blocks the multi-token forward pass. We have not traced exactly which intermediate is causing it (probably the softmax-over-attention numerator, but we did not confirm). What we know is that any greedy spec decoder built on mlx-lm's current Gemma 4 path will inherit this artifact. Our spec loop is algorithmically correct, as the reasoning prompt's 200-of-200 match demonstrates. The other four prompts hit a position where the verify's argmax disagreed with what the sequential baseline would have produced and from there they drift independently.
This is a real correctness story to tell users honestly. Output stays effectively identical for tens of tokens, then can branch. For chat and code generation this is invisible. For exact-reproducibility workflows it is not. We label this greedy spec with bf16-precision drift, not lossless.
What we shipped
The runtime lives at optiq/runtime/spec/. It is independent of the existing optiq/runtime/mtp/ path, by design. Qwen MTP is mature and shipped; we did not want to refactor it under a new abstraction just to add Gemma support. The new module hosts the generic outer loop (draft K tokens, target verifies K+1, accept prefix, commit) plus per-architecture drafter adapters; for now the only adapter is GemmaAssistantDrafter. Future adapters can move in alongside.
From the command line, the drafter routes through OptIQ Serve via the new install_assistant_drafter hook:
from optiq.serve import install_assistant_drafter install_assistant_drafter( target_model_path="mlx-community/gemma-4-E4B-it-4bit", drafter_id="mlx-community/gemma-4-E4B-it-assistant-bf16", ) # then mlx_lm.server.main(); every /v1/chat/completions request now # routes through spec_generate transparently.
From the OptIQ Lab Server page, the "Spec drafter (Gemma-4 family)" picker drops down with the published drafter. Pick the target, pick the drafter, Apply. The supervisor swaps the running model and the chat page picks up the spec path automatically. MTP and the spec drafter are mutually exclusive per loaded model; the UI greys out the alternate when one is selected.
If you want to call it directly without going through Serve, here is the minimal Python path:
from mlx_lm.utils import load_model, load_tokenizer from optiq.runtime.spec import GemmaAssistantDrafter, spec_generate, SpecConfig target, _ = load_model("mlx-community/gemma-4-E4B-it-4bit", lazy=False) tokenizer = load_tokenizer("mlx-community/gemma-4-E4B-it-4bit") drafter = GemmaAssistantDrafter.from_pretrained( "mlx-community/gemma-4-E4B-it-assistant-bf16") prompt = tokenizer.apply_chat_template( [{"role":"user", "content":"Explain spec decoding."}], tokenize=False, add_generation_prompt=True, ) for ev in spec_generate(target, drafter, tokenizer, prompt, SpecConfig(gamma=1)): if ev.kind == "token": print(ev.text, end="", flush=True)
What 1.18x buys you
The number is a geomean across five prompt categories on M4 Pro 24 GB with the 4-bit OptIQ quant. Math and code get the highest speedups, reasoning gets the lowest. The pattern matches what Qwen MTP shows: structured outputs with predictable continuations (operators, brackets, common code idioms) draft well; flowing prose drifts more.
For a 30 token/s baseline (Gemma-4 E4B on M4 Pro), 1.18x lifts decoding to roughly 35 token/s. That is the difference between "barely faster than reading speed" and "comfortably faster than reading speed" for interactive chat. It also moves a 200-token completion from 6.6 seconds to 5.7 seconds, which compounds over an agentic loop with multiple tool calls.
The gap to Ollama's reported 80% acceptance on the same drafter is mostly the bf16 verify artifact, plus the fact that they verify against the fp16 target while we verify against the 4-bit OptIQ target. Quantization adds its own small drift that lowers acceptance further.
Where this leaves Gemma serving
Today, on M4 Pro 24 GB, with the 4-bit OptIQ quant: Gemma-4 E4B serving runs at 1.18x decode geomean with the -assistant drafter enabled. The drafter loads in about 3 seconds and adds about 700 MB to the model footprint. Output is greedy and very nearly byte-identical for tens of tokens, then drifts within the bounds of bf16 attention precision in mlx-lm. The Server page picks it up automatically.
γ>1 multi-token drafting is implemented (the runtime does the chained draft + batched verify + cache rollback dance) but ships defaulted to γ=1 because that is what wins on Metal. We measured γ=1 at 1.34x, γ=2 at 1.27x, γ=3 at 0.96x on the math prompt above. The K-token verify forward scales near-linearly with K on Apple Silicon, so the verify cost climbs faster than the accepted-tokens-per-cycle does, and γ=2's modest acceptance gain over γ=1 does not pay for the doubled verify. Same finding as Qwen MTP.
The MTP guide has the full reference table, methodology, and the new compatibility matrix.