mlx-optiq
Documentation · methodology

How sensitivity works

mlx-optiq's central idea fits in two sentences. First: not every layer in a transformer is equally fragile under quantization — some take a 4-bit hit cleanly, others fall apart. Second: you can measure which is which by perturbing one layer at a time and watching the model's output distribution.

Measure once, allocate everywhere. The same per-layer sensitivity number informs weight bit-width, KV bit-width, and LoRA rank.

The measurement

For each pair (layer L, candidate bits b):

  1. Forward-pass calibration data through the model with all weights in their reference precision. Record the output logits.
  2. Replace just L's weight with a simulate-quantized version at b bits (round-trip: quantize → dequantize). All other layers stay at reference precision.
  3. Forward-pass the same calibration data again. Record the perturbed logits.
  4. Compute KL divergence between the reference and perturbed logit distributions, averaged across calibration samples and tokens.
  5. Restore L to its reference precision and move on to the next layer.

The result is a table: for every layer, the KL cost of dropping it from reference to each candidate bit-width. This is mlx-optiq's per-layer sensitivity signal.

Two reference modes

What does "reference precision" mean concretely? Two options, picked automatically by --reference auto based on whether the bf16 weights fit in your Mac's RAM:

1. bf16 reference — gold standard

Load the original bf16 model into RAM. Each sensitivity probe swaps a single layer between bf16 and a quantized copy. Highest-fidelity measurement. Required RAM ≈ model size in bf16 (≈ 2 × parameters in GB).

Used automatically when bf16 fits in ~70% of available RAM. On a 36 GB Mac, this means models up to ~10 B parameters.

2. uniform_4bit reference — for big models

Build a uniform-4-bit MLX baseline first. Load that as the running model (~25% of bf16 size). Stream bf16 weights off disk, one layer at a time, swapping each in for its sensitivity probe. Slightly weaker signal — you're measuring KL relative to uniform-4 instead of bf16 — but lets 27 B+ models still get a calibration-driven mixed-precision allocation on a 36 GB Mac.

Used automatically when bf16 doesn't fit. The bf16 weights still need to be on disk for the streaming probes; only RAM is the constraint.

Auto-routing in practice

convert.shbash
# auto picks bf16 if it fits, else uniform_4bit
$ optiq convert Qwen/Qwen3.5-9B \
    --target-bpw 4.5 --candidate-bits 4,8 \
    --reference auto

# force bf16 (will OOM if model doesn't fit)
$ optiq convert Qwen/Qwen3.5-9B --reference bf16

# force uniform-4-bit reference (works on 27 B+ on 36 GB)
$ optiq convert Qwen/Qwen3.5-27B --reference uniform_4bit

Calibration data

mlx-optiq uses WikiText-2 validation by default — 32 sequences of 128 tokens each. Generic web text is sufficient because we're measuring relative layer sensitivity, not absolute accuracy. Smaller calibration sets work but give noisier sensitivity estimates; larger sets cost linearly more time without much signal improvement past ~32 samples.

The allocator

Once you have the sensitivity table, allocating per-layer bits is a knapsack. mlx-optiq uses a greedy heuristic that's optimal in practice for the small number of layers and bit-widths involved:

  1. Start every layer at the lowest candidate bit-width (e.g. 4-bit).
  2. Compute the average bit-budget so far.
  3. If under target BPW: find the layer where upgrading by one bit-width tier buys the largest KL reduction per extra bit. Upgrade it.
  4. Repeat until the average BPW reaches the target.

Some layers are protected — they always get the highest bit-width regardless of the knapsack. By default these are lm_head, embed_tokens, the first attention block and the last attention block. They're cheap (small parameter share) and pathological to quantize.

The output

mlx-optiq hands the per-layer bit map to mlx_lm.convert as a quant_predicate. The output is a standard MLX checkpoint indistinguishable from a uniform-quantized one to mlx_lm.load, except that some layers are at 8-bit and others at 4-bit.

Why KL and not perplexity? Perplexity is a scalar — too coarse to discriminate between layers that fail in different ways. KL divergence on the full output distribution captures shifts in which tokens get mass, not just how confident the model is on the chosen one. The KL signal correlates strongly with downstream task accuracy (GSM8K, MMLU) and is much cheaper to compute than running a full eval.

What about MoE models?

Sparse mixture-of-experts (Qwen3.5-35B-A3B, Qwen3.6-35B-A3B, gemma-4-26B-A4B) need special handling: each expert is its own weight tensor and can carry its own bit-width. mlx-optiq walks the MoE structure (different layouts for Gemma's switch_glu vs Qwen's switch_mlp) and treats the fused expert tensor as a single layer for sensitivity purposes — the per-expert bit assignment then comes out of the same knapsack.

Next up: see the algorithm in action in our research write-up, or get started with a model family.