freesolo-chalk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
chalk/__init__.py ADDED
File without changes
chalk/ops/__init__.py ADDED
@@ -0,0 +1,12 @@
1
+ """
2
+ Chalk operators — raw Triton/CUDA kernels and their ``torch.autograd.Function`` wrappers.
3
+
4
+ Mirrors Liger Kernel's layout: ``chalk.ops`` holds the low-level kernel implementations
5
+ (``@triton.jit`` functions, autograd Functions, FP8 GEMM helpers), while ``chalk.transformers``
6
+ holds the model-level installers that monkeypatch these kernels into HuggingFace modules.
7
+
8
+ This namespace starts empty by design — kernels are landed one at a time, each with its own
9
+ benchmark evidence.
10
+ """
11
+
12
+ __all__: list[str] = []
chalk/ops/embedding.py ADDED
@@ -0,0 +1,353 @@
1
+ """Fused gather + layer-0 RMSNorm Triton kernel for the Qwen3.5 LoRA worker.
2
+
3
+ WHAT THIS FUSES
4
+ ---------------
5
+ The very first thing a Qwen3.5 forward does is::
6
+
7
+ inputs_embeds = embed_tokens(input_ids) # [tokens, hidden] gather
8
+ # ... decoder layer 0:
9
+ residual = inputs_embeds
10
+ hidden = input_layernorm(inputs_embeds) # layer-0 RMSNorm
11
+
12
+ i.e. a memory-bound embedding gather immediately followed by the layer-0 RMSNorm.
13
+ The baseline (production) path materializes ``inputs_embeds`` to HBM and reads it
14
+ back for the RMSNorm; this kernel gathers the row, RMS-normalizes it in registers,
15
+ and writes the normalized result in ONE launch, so the layer-0 norm never round-
16
+ trips the embedding through HBM a second time.
17
+
18
+ baseline = F.embedding(ids, table) THEN (Liger/eager) RMSNorm
19
+ fused = gather row table[id] -> fp32 var -> x*rstd -> *(1+w) -> cast -> store
20
+
21
+ RMSNorm SEMANTICS (must match the real model EXACTLY)
22
+ -----------------------------------------------------
23
+ Qwen3.5's ``Qwen3_5RMSNorm.forward`` is::
24
+
25
+ output = x.float() * rsqrt(x.float().pow(2).mean(-1) + eps)
26
+ output = output * (1.0 + self.weight.float()) # fp32 multiply
27
+ return output.type_as(x) # cast to bf16 at the end
28
+
29
+ so the weight carries an implicit ``+1.0`` OFFSET and the weight multiply happens in
30
+ fp32 BEFORE the final cast. This is Liger ``casting_mode="gemma"`` + ``offset=1.0``
31
+ (which is exactly how ``apply_liger_kernel_to_qwen3_5`` patches every ``input_layernorm``
32
+ instance — verified on the pod). NOTE: the earlier ``bench/triton_embedding_rmsnorm.py``
33
+ prototype used ``casting_mode="llama"`` (cast-then-multiply, plain ``weight``); that is
34
+ WRONG for Qwen3.5. The kernel here uses the gemma+offset semantics and is self-tested
35
+ against the eager ``Qwen3_5RMSNorm`` math.
36
+
37
+ WHY A FORWARD-ONLY KERNEL IS CORRECT
38
+ ------------------------------------
39
+ Under LoRA the recipe targets ``all-linear`` with no ``modules_to_save`` /
40
+ ``trainable_token_indices`` and there is no full-fine-tune path, so ``embed_tokens``
41
+ AND ``input_layernorm.weight`` are both FROZEN: the embed -> layer-0-RMSNorm subgraph
42
+ builds NO backward graph. A forward-only fused kernel is therefore production-correct.
43
+ To stay safe even if a future config trained either, the patched RMSNorm takes the
44
+ fused path ONLY when grad is disabled OR the inputs are non-trainable (no backward
45
+ needed); otherwise it falls back to the original (Liger or eager) differentiable path.
46
+
47
+ HONEST IMPACT (see bench/embedding_result.md)
48
+ ---------------------------------------------
49
+ Per-op the fused kernel is ~1.8x vs the Liger two-op path on EVERY arch (it is purely
50
+ memory-bound, so the win is architecture-independent). But it runs ONCE per forward
51
+ (layer 0 only) and is ~0.04% of step time. It is a CORRECT, UNIVERSAL, TINY opt-in
52
+ win. There is no larger embedding-adjacent fusion available on this model: weight tying
53
+ (``tie_word_embeddings=true``) routes the big embedding-table cost through the
54
+ lm_head/cross-entropy GEMM, which Liger's fused-linear-CE already owns.
55
+
56
+ GATING / SAFETY
57
+ ---------------
58
+ Install-on-call (the Liger model): calling ``install_qwen35_fused_embedding()`` IS the opt-in —
59
+ there is no env flag — then it is gated by a live-GPU numeric self-test (ANY
60
+ import/compile/self-test failure leaves the Liger/eager path untouched). The installer
61
+ patches ONLY layer-0's ``input_layernorm`` + arranges the gather, no-ops safely if the
62
+ model shape, the embedding module, or the Liger RMSNorm offset/casting does not match
63
+ what this kernel implements. Import-safe on a CPU control plane (triton/torch imported
64
+ lazily).
65
+ """
66
+
67
+ from __future__ import annotations
68
+
69
+ import contextlib
70
+
71
+
72
+ def _build_kernel():
73
+ """Import torch/triton and define the fused gather+RMSNorm kernel. Returns
74
+ ``fused_gather_rmsnorm`` or raises on any import/compile problem (the caller treats a
75
+ raise as "keep the Liger/eager path")."""
76
+ import torch
77
+ import triton
78
+ import triton.language as tl
79
+
80
+ @triton.jit
81
+ def _fused_gather_rmsnorm_kernel(
82
+ ids_ptr, # [n_tokens] int (flattened)
83
+ table_ptr, # [vocab, hidden] embedding table (== lm_head when tied)
84
+ w_ptr, # [hidden] RMSNorm weight (effective weight is 1+w)
85
+ out_ptr, # [n_tokens, hidden] normalized output
86
+ hidden,
87
+ eps,
88
+ OFFSET: tl.constexpr, # 1.0 for Qwen3.5 (weight is zeros-init, effective 1+w)
89
+ BLOCK_H: tl.constexpr,
90
+ ):
91
+ row = tl.program_id(0)
92
+ tok = tl.load(ids_ptr + row)
93
+ h_off = tl.arange(0, BLOCK_H)
94
+ mask = h_off < hidden
95
+ x = tl.load(table_ptr + tok * hidden + h_off, mask=mask, other=0.0)
96
+ xf = x.to(tl.float32)
97
+ var = tl.sum(xf * xf, axis=0) / hidden
98
+ rstd = 1.0 / tl.sqrt(var + eps)
99
+ # gemma casting_mode: the (offset+weight) multiply stays in fp32, cast LAST.
100
+ normed = xf * rstd
101
+ w = tl.load(w_ptr + h_off, mask=mask, other=0.0).to(tl.float32) + OFFSET
102
+ y = (normed * w).to(out_ptr.dtype.element_ty)
103
+ tl.store(out_ptr + row * hidden + h_off, y, mask=mask)
104
+
105
+ def fused_gather_rmsnorm(ids, table, weight, eps=1e-6, offset=1.0):
106
+ """Gather ``table[ids]`` and apply Qwen3.5 RMSNorm (gemma casting, ``1+weight``)
107
+ in one kernel.
108
+
109
+ ids: [n_tokens] int (any shape, flattened); table: [vocab, hidden]; weight:
110
+ [hidden]. Returns [n_tokens, hidden] in ``table.dtype``. Inputs are made
111
+ contiguous defensively (production buffers already are)."""
112
+ assert table.is_cuda
113
+ assert ids.is_cuda
114
+ assert table.ndim == 2
115
+ if not ids.is_contiguous():
116
+ ids = ids.contiguous()
117
+ ids = ids.view(-1)
118
+ if not table.is_contiguous():
119
+ table = table.contiguous()
120
+ if not weight.is_contiguous():
121
+ weight = weight.contiguous()
122
+ n_tokens, hidden = ids.numel(), table.shape[1]
123
+ out = torch.empty((n_tokens, hidden), device=table.device, dtype=table.dtype)
124
+ BLOCK_H = triton.next_power_of_2(hidden)
125
+ _fused_gather_rmsnorm_kernel[(n_tokens,)](
126
+ ids, table, weight, out, hidden, eps, OFFSET=float(offset), BLOCK_H=BLOCK_H, num_warps=8
127
+ )
128
+ return out
129
+
130
+ return fused_gather_rmsnorm
131
+
132
+
133
+ def _self_test(fused_gather_rmsnorm) -> None:
134
+ """Live-GPU numeric self-test vs the EXACT eager ``Qwen3_5RMSNorm`` math
135
+ (fp32 var, fp32 ``*(1+w)``, cast last). Raises on mismatch so the caller keeps the
136
+ Liger/eager path."""
137
+ import torch
138
+ import torch.nn.functional as F
139
+
140
+ torch.manual_seed(0)
141
+ dev, vocab, hidden, eps = "cuda", 4096, 2560, 1e-6
142
+ table = torch.randn(vocab, hidden, device=dev, dtype=torch.bfloat16)
143
+ # weight is zeros-init in Qwen3.5; use a small perturbation (effective 1+w).
144
+ weight = 0.02 * torch.randn(hidden, device=dev, dtype=torch.bfloat16)
145
+ for n in (256, 2048):
146
+ ids = torch.randint(0, vocab, (n,), device=dev)
147
+ emb = F.embedding(ids, table)
148
+ xf = emb.float()
149
+ normed = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + eps)
150
+ ref = (normed * (1.0 + weight.float())).to(emb.dtype).float()
151
+ got = fused_gather_rmsnorm(ids, table, weight, eps=eps, offset=1.0).float()
152
+ rel = (got - ref).abs().max().item() / (ref.abs().max().item() + 1e-9)
153
+ if not (rel < 2e-2):
154
+ raise RuntimeError(f"triton_embedding self-test failed at n={n}: rel={rel:.2e}")
155
+
156
+
157
+ def load_fused_embedding():
158
+ """Return ``fused_gather_rmsnorm`` if the kernel builds and passes its live-GPU
159
+ self-test; otherwise return ``None`` (keep the Liger/eager path). Never raises —
160
+ any failure (no torch/triton, no CUDA, compile/self-test error) -> ``None``."""
161
+ try:
162
+ import torch
163
+
164
+ if not torch.cuda.is_available():
165
+ return None
166
+ fn = _build_kernel()
167
+ _self_test(fn)
168
+ print("[embed] fused Triton gather+RMSNorm kernel enabled (self-test passed)", flush=True)
169
+ return fn
170
+ except Exception as e: # pragma: no cover - defensive: any failure keeps baseline
171
+ print(f"[embed] fused Triton embedding kernel disabled (build/self-test failed): {e}", flush=True)
172
+ return None
173
+
174
+
175
+ def _resolve_text_model(model):
176
+ """Return the inner Qwen3.5 text model (the one that owns ``embed_tokens`` + the
177
+ decoder ``layers``), or None if the structure does not match. Handles the
178
+ ForConditionalGeneration (``model.model.language_model``), ForCausalLM
179
+ (``model.model``), and bare-TextModel layouts, and unwraps a PEFT ``PeftModel``
180
+ (``trainer.model`` is the PEFT wrapper during training) via ``get_base_model()``."""
181
+ roots = [model]
182
+ # PEFT wrapper: the real HF model is under get_base_model() (the bare LoraModel
183
+ # walk does NOT expose embed_tokens/layers at the expected paths).
184
+ getter = getattr(model, "get_base_model", None)
185
+ if callable(getter):
186
+ with contextlib.suppress(Exception):
187
+ roots.append(getter())
188
+ for root in roots:
189
+ for path in (
190
+ ("model", "language_model"),
191
+ ("model",),
192
+ (),
193
+ ):
194
+ obj = root
195
+ ok = True
196
+ for attr in path:
197
+ if not hasattr(obj, attr):
198
+ ok = False
199
+ break
200
+ obj = getattr(obj, attr)
201
+ if ok and hasattr(obj, "embed_tokens") and hasattr(obj, "layers"):
202
+ return obj
203
+ return None
204
+
205
+
206
+ def install_qwen35_fused_embedding(model) -> bool:
207
+ """Wire the fused gather + layer-0-RMSNorm kernel onto a loaded Qwen3.5 model — IFF
208
+ the live-GPU self-test passes. Install-on-call: calling this IS the opt-in (the Liger
209
+ model); there is no env flag.
210
+
211
+ Patches ONLY layer-0's ``input_layernorm`` (the single RMSNorm fed by the embedding):
212
+ a stashing ``embed_tokens.forward`` records the ``input_ids`` of the current step, and
213
+ the patched ``input_layernorm`` re-runs the gather fused with the RMSNorm so the layer-0
214
+ norm avoids the embedding HBM round-trip. The raw embedding (needed for the residual)
215
+ is still produced by the normal ``embed_tokens`` gather, so the residual path is exact.
216
+
217
+ SAFE NO-OP CONDITIONS (any -> return False, leave the model untouched):
218
+ * kernel disabled / build / self-test failure;
219
+ * model structure (text model / embed_tokens / layers / layer-0 input_layernorm)
220
+ does not match;
221
+ * the layer-0 RMSNorm offset/casting does not match what the kernel implements
222
+ (Liger present but NOT patched with offset=1.0/gemma, OR an unexpected weight
223
+ shape) — verified by a per-instance numeric check against the module's OWN
224
+ forward before swapping it in.
225
+
226
+ CORRECTNESS GATE: the fused path is taken only when no backward graph needs to flow
227
+ through the layer-0 norm (grad disabled, or both the embedding table and the norm
228
+ weight are frozen) — the production LoRA recipe freezes both. Otherwise it falls back
229
+ to the original ``input_layernorm`` forward. Never raises: any failure keeps the
230
+ original path. Returns True iff the patch was installed."""
231
+ fn = load_fused_embedding()
232
+ if fn is None:
233
+ return False
234
+ try:
235
+ import torch
236
+
237
+ tm = _resolve_text_model(model)
238
+ if tm is None:
239
+ print("[embed] no matching Qwen3.5 text model (embed_tokens/layers); keeping baseline", flush=True)
240
+ return False
241
+ embed = tm.embed_tokens
242
+ layers = tm.layers
243
+ if not hasattr(embed, "weight") or len(layers) == 0:
244
+ print("[embed] embedding/layers shape mismatch; keeping baseline", flush=True)
245
+ return False
246
+ layer0 = layers[0]
247
+ ln = getattr(layer0, "input_layernorm", None)
248
+ if ln is None or not hasattr(ln, "weight"):
249
+ print("[embed] no layer-0 input_layernorm; keeping baseline", flush=True)
250
+ return False
251
+ table = embed.weight
252
+ if table.ndim != 2 or ln.weight.numel() != table.shape[1]:
253
+ print("[embed] embed/norm dim mismatch; keeping baseline", flush=True)
254
+ return False
255
+ if getattr(ln, "_chalk_embed_patched", False):
256
+ return True
257
+
258
+ # Resolve eps + offset. Qwen3.5 norms carry offset=1.0; Liger sets ``offset`` /
259
+ # ``variance_epsilon`` on the instance, eager has ``eps``. If Liger patched it with
260
+ # a NON-1.0 offset or a non-gemma casting mode, we cannot match it -> bail.
261
+ eps = float(getattr(ln, "variance_epsilon", None) or getattr(ln, "eps", 1e-6))
262
+ offset = float(getattr(ln, "offset", 1.0))
263
+ casting = getattr(ln, "casting_mode", "gemma")
264
+ if casting not in ("gemma", None) or offset != 1.0:
265
+ print(
266
+ f"[embed] layer-0 norm casting/offset unsupported (casting={casting}, offset={offset}); keeping baseline",
267
+ flush=True,
268
+ )
269
+ return False
270
+
271
+ _orig_ln_forward = ln.forward
272
+
273
+ # Per-instance numeric check: the fused gather+RMSNorm must match the layer-0
274
+ # norm's OWN forward (Liger or eager) applied to the real embedding, on the real
275
+ # table/weight. This catches any semantic drift (offset/casting/eps) that the
276
+ # attribute check above missed, on the ACTUAL module — before we swap it in.
277
+ try:
278
+ with torch.no_grad():
279
+ n = 64
280
+ dev = table.device
281
+ ids_chk = torch.randint(0, table.shape[0], (n,), device=dev)
282
+ emb_chk = torch.nn.functional.embedding(ids_chk, table)
283
+ ref = _orig_ln_forward(emb_chk).float()
284
+ got = fn(ids_chk, table, ln.weight, eps=eps, offset=offset).float()
285
+ rel = (got - ref).abs().max().item() / (ref.abs().max().item() + 1e-9)
286
+ if not (rel < 2e-2):
287
+ print(f"[embed] per-instance norm check failed (rel={rel:.2e}); keeping baseline", flush=True)
288
+ return False
289
+ except Exception as e:
290
+ print(f"[embed] per-instance norm check errored ({type(e).__name__}: {e}); keeping baseline", flush=True)
291
+ return False
292
+
293
+ # Stash the current step's input_ids on the text model so the patched layer-0 norm
294
+ # can re-gather fused. embed_tokens still produces the raw embedding (residual).
295
+ _orig_embed_forward = embed.forward
296
+
297
+ def _embed_forward(input_ids, *args, **kwargs):
298
+ try:
299
+ tm._chalk_last_ids = input_ids
300
+ except Exception:
301
+ tm._chalk_last_ids = None
302
+ return _orig_embed_forward(input_ids, *args, **kwargs)
303
+
304
+ embed_table = table # capture (tied weight, frozen)
305
+
306
+ def _ln_forward(hidden_states, *args, **kwargs):
307
+ # SINGLE-USE STASH: consume the recorded ids unconditionally at entry (read into
308
+ # a local, immediately clear it on the model). The stash is only valid for the
309
+ # ONE layer-0 norm call that directly follows THIS step's embed_tokens gather; by
310
+ # clearing it before we branch, a forward that supplies inputs_embeds directly
311
+ # (so embed_tokens never ran and the stash is stale from a prior step) can never
312
+ # re-gather table[stale_ids] in place of the real hidden_states, even if its
313
+ # token count coincidentally matches.
314
+ ids = getattr(tm, "_chalk_last_ids", None)
315
+ tm._chalk_last_ids = None
316
+ # Take the fused path only when: we recorded this step's ids; the incoming
317
+ # hidden_states are the raw embedding (same shape as a fresh gather) AND no
318
+ # backward needs to flow through this norm (grad disabled, or both the table
319
+ # and the norm weight are frozen). Anything else -> original differentiable
320
+ # forward (exact, never drops gradients).
321
+ try:
322
+ if (
323
+ ids is not None
324
+ and isinstance(ids, torch.Tensor)
325
+ and ids.is_cuda
326
+ and hidden_states.ndim >= 2
327
+ and hidden_states.shape[-1] == embed_table.shape[1]
328
+ and hidden_states.numel() // hidden_states.shape[-1] == ids.numel()
329
+ and (not torch.is_grad_enabled() or (not embed_table.requires_grad and not ln.weight.requires_grad))
330
+ ):
331
+ out = fn(ids, embed_table, ln.weight, eps=eps, offset=offset)
332
+ return out.reshape(hidden_states.shape)
333
+ except Exception:
334
+ pass # any runtime hiccup -> exact original path below
335
+ return _orig_ln_forward(hidden_states, *args, **kwargs)
336
+
337
+ embed.forward = _embed_forward
338
+ ln.forward = _ln_forward
339
+ ln._chalk_embed_patched = True
340
+ print(
341
+ "[embed] fused Triton gather+RMSNorm installed on layer-0 input_layernorm "
342
+ "(forward-only, frozen-embed path)",
343
+ flush=True,
344
+ )
345
+ return True
346
+ except Exception as e: # pragma: no cover - defensive
347
+ print(f"[embed] install failed ({type(e).__name__}: {e}); keeping baseline", flush=True)
348
+ return False
349
+
350
+
351
+ if __name__ == "__main__": # manual self-test / smoke
352
+ fn = load_fused_embedding()
353
+ print("fused gather+rmsnorm loaded:", fn is not None)