freesolo-chalk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chalk/__init__.py +0 -0
- chalk/ops/__init__.py +12 -0
- chalk/ops/embedding.py +353 -0
- chalk/ops/fp8_base.py +349 -0
- chalk/ops/lora.py +608 -0
- chalk/ops/mlp.py +947 -0
- chalk/ops/qkv.py +636 -0
- chalk/ops/rope.py +455 -0
- chalk/transformers/__init__.py +38 -0
- chalk/transformers/apply.py +160 -0
- chalk/transformers/embedding.py +5 -0
- chalk/transformers/fp8_base.py +5 -0
- chalk/transformers/lora.py +5 -0
- chalk/transformers/mlp.py +6 -0
- chalk/transformers/qkv.py +5 -0
- chalk/transformers/rope.py +5 -0
- chalk/utils.py +35 -0
- freesolo_chalk-0.1.0.dist-info/METADATA +104 -0
- freesolo_chalk-0.1.0.dist-info/RECORD +23 -0
- freesolo_chalk-0.1.0.dist-info/WHEEL +5 -0
- freesolo_chalk-0.1.0.dist-info/licenses/LICENSE +25 -0
- freesolo_chalk-0.1.0.dist-info/licenses/NOTICE +26 -0
- freesolo_chalk-0.1.0.dist-info/top_level.txt +1 -0
chalk/__init__.py
ADDED
|
File without changes
|
chalk/ops/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Chalk operators — raw Triton/CUDA kernels and their ``torch.autograd.Function`` wrappers.
|
|
3
|
+
|
|
4
|
+
Mirrors Liger Kernel's layout: ``chalk.ops`` holds the low-level kernel implementations
|
|
5
|
+
(``@triton.jit`` functions, autograd Functions, FP8 GEMM helpers), while ``chalk.transformers``
|
|
6
|
+
holds the model-level installers that monkeypatch these kernels into HuggingFace modules.
|
|
7
|
+
|
|
8
|
+
This namespace starts empty by design — kernels are landed one at a time, each with its own
|
|
9
|
+
benchmark evidence.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
__all__: list[str] = []
|
chalk/ops/embedding.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
1
|
+
"""Fused gather + layer-0 RMSNorm Triton kernel for the Qwen3.5 LoRA worker.
|
|
2
|
+
|
|
3
|
+
WHAT THIS FUSES
|
|
4
|
+
---------------
|
|
5
|
+
The very first thing a Qwen3.5 forward does is::
|
|
6
|
+
|
|
7
|
+
inputs_embeds = embed_tokens(input_ids) # [tokens, hidden] gather
|
|
8
|
+
# ... decoder layer 0:
|
|
9
|
+
residual = inputs_embeds
|
|
10
|
+
hidden = input_layernorm(inputs_embeds) # layer-0 RMSNorm
|
|
11
|
+
|
|
12
|
+
i.e. a memory-bound embedding gather immediately followed by the layer-0 RMSNorm.
|
|
13
|
+
The baseline (production) path materializes ``inputs_embeds`` to HBM and reads it
|
|
14
|
+
back for the RMSNorm; this kernel gathers the row, RMS-normalizes it in registers,
|
|
15
|
+
and writes the normalized result in ONE launch, so the layer-0 norm never round-
|
|
16
|
+
trips the embedding through HBM a second time.
|
|
17
|
+
|
|
18
|
+
baseline = F.embedding(ids, table) THEN (Liger/eager) RMSNorm
|
|
19
|
+
fused = gather row table[id] -> fp32 var -> x*rstd -> *(1+w) -> cast -> store
|
|
20
|
+
|
|
21
|
+
RMSNorm SEMANTICS (must match the real model EXACTLY)
|
|
22
|
+
-----------------------------------------------------
|
|
23
|
+
Qwen3.5's ``Qwen3_5RMSNorm.forward`` is::
|
|
24
|
+
|
|
25
|
+
output = x.float() * rsqrt(x.float().pow(2).mean(-1) + eps)
|
|
26
|
+
output = output * (1.0 + self.weight.float()) # fp32 multiply
|
|
27
|
+
return output.type_as(x) # cast to bf16 at the end
|
|
28
|
+
|
|
29
|
+
so the weight carries an implicit ``+1.0`` OFFSET and the weight multiply happens in
|
|
30
|
+
fp32 BEFORE the final cast. This is Liger ``casting_mode="gemma"`` + ``offset=1.0``
|
|
31
|
+
(which is exactly how ``apply_liger_kernel_to_qwen3_5`` patches every ``input_layernorm``
|
|
32
|
+
instance — verified on the pod). NOTE: the earlier ``bench/triton_embedding_rmsnorm.py``
|
|
33
|
+
prototype used ``casting_mode="llama"`` (cast-then-multiply, plain ``weight``); that is
|
|
34
|
+
WRONG for Qwen3.5. The kernel here uses the gemma+offset semantics and is self-tested
|
|
35
|
+
against the eager ``Qwen3_5RMSNorm`` math.
|
|
36
|
+
|
|
37
|
+
WHY A FORWARD-ONLY KERNEL IS CORRECT
|
|
38
|
+
------------------------------------
|
|
39
|
+
Under LoRA the recipe targets ``all-linear`` with no ``modules_to_save`` /
|
|
40
|
+
``trainable_token_indices`` and there is no full-fine-tune path, so ``embed_tokens``
|
|
41
|
+
AND ``input_layernorm.weight`` are both FROZEN: the embed -> layer-0-RMSNorm subgraph
|
|
42
|
+
builds NO backward graph. A forward-only fused kernel is therefore production-correct.
|
|
43
|
+
To stay safe even if a future config trained either, the patched RMSNorm takes the
|
|
44
|
+
fused path ONLY when grad is disabled OR the inputs are non-trainable (no backward
|
|
45
|
+
needed); otherwise it falls back to the original (Liger or eager) differentiable path.
|
|
46
|
+
|
|
47
|
+
HONEST IMPACT (see bench/embedding_result.md)
|
|
48
|
+
---------------------------------------------
|
|
49
|
+
Per-op the fused kernel is ~1.8x vs the Liger two-op path on EVERY arch (it is purely
|
|
50
|
+
memory-bound, so the win is architecture-independent). But it runs ONCE per forward
|
|
51
|
+
(layer 0 only) and is ~0.04% of step time. It is a CORRECT, UNIVERSAL, TINY opt-in
|
|
52
|
+
win. There is no larger embedding-adjacent fusion available on this model: weight tying
|
|
53
|
+
(``tie_word_embeddings=true``) routes the big embedding-table cost through the
|
|
54
|
+
lm_head/cross-entropy GEMM, which Liger's fused-linear-CE already owns.
|
|
55
|
+
|
|
56
|
+
GATING / SAFETY
|
|
57
|
+
---------------
|
|
58
|
+
Install-on-call (the Liger model): calling ``install_qwen35_fused_embedding()`` IS the opt-in —
|
|
59
|
+
there is no env flag — then it is gated by a live-GPU numeric self-test (ANY
|
|
60
|
+
import/compile/self-test failure leaves the Liger/eager path untouched). The installer
|
|
61
|
+
patches ONLY layer-0's ``input_layernorm`` + arranges the gather, no-ops safely if the
|
|
62
|
+
model shape, the embedding module, or the Liger RMSNorm offset/casting does not match
|
|
63
|
+
what this kernel implements. Import-safe on a CPU control plane (triton/torch imported
|
|
64
|
+
lazily).
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
from __future__ import annotations
|
|
68
|
+
|
|
69
|
+
import contextlib
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _build_kernel():
|
|
73
|
+
"""Import torch/triton and define the fused gather+RMSNorm kernel. Returns
|
|
74
|
+
``fused_gather_rmsnorm`` or raises on any import/compile problem (the caller treats a
|
|
75
|
+
raise as "keep the Liger/eager path")."""
|
|
76
|
+
import torch
|
|
77
|
+
import triton
|
|
78
|
+
import triton.language as tl
|
|
79
|
+
|
|
80
|
+
@triton.jit
|
|
81
|
+
def _fused_gather_rmsnorm_kernel(
|
|
82
|
+
ids_ptr, # [n_tokens] int (flattened)
|
|
83
|
+
table_ptr, # [vocab, hidden] embedding table (== lm_head when tied)
|
|
84
|
+
w_ptr, # [hidden] RMSNorm weight (effective weight is 1+w)
|
|
85
|
+
out_ptr, # [n_tokens, hidden] normalized output
|
|
86
|
+
hidden,
|
|
87
|
+
eps,
|
|
88
|
+
OFFSET: tl.constexpr, # 1.0 for Qwen3.5 (weight is zeros-init, effective 1+w)
|
|
89
|
+
BLOCK_H: tl.constexpr,
|
|
90
|
+
):
|
|
91
|
+
row = tl.program_id(0)
|
|
92
|
+
tok = tl.load(ids_ptr + row)
|
|
93
|
+
h_off = tl.arange(0, BLOCK_H)
|
|
94
|
+
mask = h_off < hidden
|
|
95
|
+
x = tl.load(table_ptr + tok * hidden + h_off, mask=mask, other=0.0)
|
|
96
|
+
xf = x.to(tl.float32)
|
|
97
|
+
var = tl.sum(xf * xf, axis=0) / hidden
|
|
98
|
+
rstd = 1.0 / tl.sqrt(var + eps)
|
|
99
|
+
# gemma casting_mode: the (offset+weight) multiply stays in fp32, cast LAST.
|
|
100
|
+
normed = xf * rstd
|
|
101
|
+
w = tl.load(w_ptr + h_off, mask=mask, other=0.0).to(tl.float32) + OFFSET
|
|
102
|
+
y = (normed * w).to(out_ptr.dtype.element_ty)
|
|
103
|
+
tl.store(out_ptr + row * hidden + h_off, y, mask=mask)
|
|
104
|
+
|
|
105
|
+
def fused_gather_rmsnorm(ids, table, weight, eps=1e-6, offset=1.0):
|
|
106
|
+
"""Gather ``table[ids]`` and apply Qwen3.5 RMSNorm (gemma casting, ``1+weight``)
|
|
107
|
+
in one kernel.
|
|
108
|
+
|
|
109
|
+
ids: [n_tokens] int (any shape, flattened); table: [vocab, hidden]; weight:
|
|
110
|
+
[hidden]. Returns [n_tokens, hidden] in ``table.dtype``. Inputs are made
|
|
111
|
+
contiguous defensively (production buffers already are)."""
|
|
112
|
+
assert table.is_cuda
|
|
113
|
+
assert ids.is_cuda
|
|
114
|
+
assert table.ndim == 2
|
|
115
|
+
if not ids.is_contiguous():
|
|
116
|
+
ids = ids.contiguous()
|
|
117
|
+
ids = ids.view(-1)
|
|
118
|
+
if not table.is_contiguous():
|
|
119
|
+
table = table.contiguous()
|
|
120
|
+
if not weight.is_contiguous():
|
|
121
|
+
weight = weight.contiguous()
|
|
122
|
+
n_tokens, hidden = ids.numel(), table.shape[1]
|
|
123
|
+
out = torch.empty((n_tokens, hidden), device=table.device, dtype=table.dtype)
|
|
124
|
+
BLOCK_H = triton.next_power_of_2(hidden)
|
|
125
|
+
_fused_gather_rmsnorm_kernel[(n_tokens,)](
|
|
126
|
+
ids, table, weight, out, hidden, eps, OFFSET=float(offset), BLOCK_H=BLOCK_H, num_warps=8
|
|
127
|
+
)
|
|
128
|
+
return out
|
|
129
|
+
|
|
130
|
+
return fused_gather_rmsnorm
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _self_test(fused_gather_rmsnorm) -> None:
|
|
134
|
+
"""Live-GPU numeric self-test vs the EXACT eager ``Qwen3_5RMSNorm`` math
|
|
135
|
+
(fp32 var, fp32 ``*(1+w)``, cast last). Raises on mismatch so the caller keeps the
|
|
136
|
+
Liger/eager path."""
|
|
137
|
+
import torch
|
|
138
|
+
import torch.nn.functional as F
|
|
139
|
+
|
|
140
|
+
torch.manual_seed(0)
|
|
141
|
+
dev, vocab, hidden, eps = "cuda", 4096, 2560, 1e-6
|
|
142
|
+
table = torch.randn(vocab, hidden, device=dev, dtype=torch.bfloat16)
|
|
143
|
+
# weight is zeros-init in Qwen3.5; use a small perturbation (effective 1+w).
|
|
144
|
+
weight = 0.02 * torch.randn(hidden, device=dev, dtype=torch.bfloat16)
|
|
145
|
+
for n in (256, 2048):
|
|
146
|
+
ids = torch.randint(0, vocab, (n,), device=dev)
|
|
147
|
+
emb = F.embedding(ids, table)
|
|
148
|
+
xf = emb.float()
|
|
149
|
+
normed = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + eps)
|
|
150
|
+
ref = (normed * (1.0 + weight.float())).to(emb.dtype).float()
|
|
151
|
+
got = fused_gather_rmsnorm(ids, table, weight, eps=eps, offset=1.0).float()
|
|
152
|
+
rel = (got - ref).abs().max().item() / (ref.abs().max().item() + 1e-9)
|
|
153
|
+
if not (rel < 2e-2):
|
|
154
|
+
raise RuntimeError(f"triton_embedding self-test failed at n={n}: rel={rel:.2e}")
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def load_fused_embedding():
|
|
158
|
+
"""Return ``fused_gather_rmsnorm`` if the kernel builds and passes its live-GPU
|
|
159
|
+
self-test; otherwise return ``None`` (keep the Liger/eager path). Never raises —
|
|
160
|
+
any failure (no torch/triton, no CUDA, compile/self-test error) -> ``None``."""
|
|
161
|
+
try:
|
|
162
|
+
import torch
|
|
163
|
+
|
|
164
|
+
if not torch.cuda.is_available():
|
|
165
|
+
return None
|
|
166
|
+
fn = _build_kernel()
|
|
167
|
+
_self_test(fn)
|
|
168
|
+
print("[embed] fused Triton gather+RMSNorm kernel enabled (self-test passed)", flush=True)
|
|
169
|
+
return fn
|
|
170
|
+
except Exception as e: # pragma: no cover - defensive: any failure keeps baseline
|
|
171
|
+
print(f"[embed] fused Triton embedding kernel disabled (build/self-test failed): {e}", flush=True)
|
|
172
|
+
return None
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _resolve_text_model(model):
|
|
176
|
+
"""Return the inner Qwen3.5 text model (the one that owns ``embed_tokens`` + the
|
|
177
|
+
decoder ``layers``), or None if the structure does not match. Handles the
|
|
178
|
+
ForConditionalGeneration (``model.model.language_model``), ForCausalLM
|
|
179
|
+
(``model.model``), and bare-TextModel layouts, and unwraps a PEFT ``PeftModel``
|
|
180
|
+
(``trainer.model`` is the PEFT wrapper during training) via ``get_base_model()``."""
|
|
181
|
+
roots = [model]
|
|
182
|
+
# PEFT wrapper: the real HF model is under get_base_model() (the bare LoraModel
|
|
183
|
+
# walk does NOT expose embed_tokens/layers at the expected paths).
|
|
184
|
+
getter = getattr(model, "get_base_model", None)
|
|
185
|
+
if callable(getter):
|
|
186
|
+
with contextlib.suppress(Exception):
|
|
187
|
+
roots.append(getter())
|
|
188
|
+
for root in roots:
|
|
189
|
+
for path in (
|
|
190
|
+
("model", "language_model"),
|
|
191
|
+
("model",),
|
|
192
|
+
(),
|
|
193
|
+
):
|
|
194
|
+
obj = root
|
|
195
|
+
ok = True
|
|
196
|
+
for attr in path:
|
|
197
|
+
if not hasattr(obj, attr):
|
|
198
|
+
ok = False
|
|
199
|
+
break
|
|
200
|
+
obj = getattr(obj, attr)
|
|
201
|
+
if ok and hasattr(obj, "embed_tokens") and hasattr(obj, "layers"):
|
|
202
|
+
return obj
|
|
203
|
+
return None
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def install_qwen35_fused_embedding(model) -> bool:
|
|
207
|
+
"""Wire the fused gather + layer-0-RMSNorm kernel onto a loaded Qwen3.5 model — IFF
|
|
208
|
+
the live-GPU self-test passes. Install-on-call: calling this IS the opt-in (the Liger
|
|
209
|
+
model); there is no env flag.
|
|
210
|
+
|
|
211
|
+
Patches ONLY layer-0's ``input_layernorm`` (the single RMSNorm fed by the embedding):
|
|
212
|
+
a stashing ``embed_tokens.forward`` records the ``input_ids`` of the current step, and
|
|
213
|
+
the patched ``input_layernorm`` re-runs the gather fused with the RMSNorm so the layer-0
|
|
214
|
+
norm avoids the embedding HBM round-trip. The raw embedding (needed for the residual)
|
|
215
|
+
is still produced by the normal ``embed_tokens`` gather, so the residual path is exact.
|
|
216
|
+
|
|
217
|
+
SAFE NO-OP CONDITIONS (any -> return False, leave the model untouched):
|
|
218
|
+
* kernel disabled / build / self-test failure;
|
|
219
|
+
* model structure (text model / embed_tokens / layers / layer-0 input_layernorm)
|
|
220
|
+
does not match;
|
|
221
|
+
* the layer-0 RMSNorm offset/casting does not match what the kernel implements
|
|
222
|
+
(Liger present but NOT patched with offset=1.0/gemma, OR an unexpected weight
|
|
223
|
+
shape) — verified by a per-instance numeric check against the module's OWN
|
|
224
|
+
forward before swapping it in.
|
|
225
|
+
|
|
226
|
+
CORRECTNESS GATE: the fused path is taken only when no backward graph needs to flow
|
|
227
|
+
through the layer-0 norm (grad disabled, or both the embedding table and the norm
|
|
228
|
+
weight are frozen) — the production LoRA recipe freezes both. Otherwise it falls back
|
|
229
|
+
to the original ``input_layernorm`` forward. Never raises: any failure keeps the
|
|
230
|
+
original path. Returns True iff the patch was installed."""
|
|
231
|
+
fn = load_fused_embedding()
|
|
232
|
+
if fn is None:
|
|
233
|
+
return False
|
|
234
|
+
try:
|
|
235
|
+
import torch
|
|
236
|
+
|
|
237
|
+
tm = _resolve_text_model(model)
|
|
238
|
+
if tm is None:
|
|
239
|
+
print("[embed] no matching Qwen3.5 text model (embed_tokens/layers); keeping baseline", flush=True)
|
|
240
|
+
return False
|
|
241
|
+
embed = tm.embed_tokens
|
|
242
|
+
layers = tm.layers
|
|
243
|
+
if not hasattr(embed, "weight") or len(layers) == 0:
|
|
244
|
+
print("[embed] embedding/layers shape mismatch; keeping baseline", flush=True)
|
|
245
|
+
return False
|
|
246
|
+
layer0 = layers[0]
|
|
247
|
+
ln = getattr(layer0, "input_layernorm", None)
|
|
248
|
+
if ln is None or not hasattr(ln, "weight"):
|
|
249
|
+
print("[embed] no layer-0 input_layernorm; keeping baseline", flush=True)
|
|
250
|
+
return False
|
|
251
|
+
table = embed.weight
|
|
252
|
+
if table.ndim != 2 or ln.weight.numel() != table.shape[1]:
|
|
253
|
+
print("[embed] embed/norm dim mismatch; keeping baseline", flush=True)
|
|
254
|
+
return False
|
|
255
|
+
if getattr(ln, "_chalk_embed_patched", False):
|
|
256
|
+
return True
|
|
257
|
+
|
|
258
|
+
# Resolve eps + offset. Qwen3.5 norms carry offset=1.0; Liger sets ``offset`` /
|
|
259
|
+
# ``variance_epsilon`` on the instance, eager has ``eps``. If Liger patched it with
|
|
260
|
+
# a NON-1.0 offset or a non-gemma casting mode, we cannot match it -> bail.
|
|
261
|
+
eps = float(getattr(ln, "variance_epsilon", None) or getattr(ln, "eps", 1e-6))
|
|
262
|
+
offset = float(getattr(ln, "offset", 1.0))
|
|
263
|
+
casting = getattr(ln, "casting_mode", "gemma")
|
|
264
|
+
if casting not in ("gemma", None) or offset != 1.0:
|
|
265
|
+
print(
|
|
266
|
+
f"[embed] layer-0 norm casting/offset unsupported (casting={casting}, offset={offset}); keeping baseline",
|
|
267
|
+
flush=True,
|
|
268
|
+
)
|
|
269
|
+
return False
|
|
270
|
+
|
|
271
|
+
_orig_ln_forward = ln.forward
|
|
272
|
+
|
|
273
|
+
# Per-instance numeric check: the fused gather+RMSNorm must match the layer-0
|
|
274
|
+
# norm's OWN forward (Liger or eager) applied to the real embedding, on the real
|
|
275
|
+
# table/weight. This catches any semantic drift (offset/casting/eps) that the
|
|
276
|
+
# attribute check above missed, on the ACTUAL module — before we swap it in.
|
|
277
|
+
try:
|
|
278
|
+
with torch.no_grad():
|
|
279
|
+
n = 64
|
|
280
|
+
dev = table.device
|
|
281
|
+
ids_chk = torch.randint(0, table.shape[0], (n,), device=dev)
|
|
282
|
+
emb_chk = torch.nn.functional.embedding(ids_chk, table)
|
|
283
|
+
ref = _orig_ln_forward(emb_chk).float()
|
|
284
|
+
got = fn(ids_chk, table, ln.weight, eps=eps, offset=offset).float()
|
|
285
|
+
rel = (got - ref).abs().max().item() / (ref.abs().max().item() + 1e-9)
|
|
286
|
+
if not (rel < 2e-2):
|
|
287
|
+
print(f"[embed] per-instance norm check failed (rel={rel:.2e}); keeping baseline", flush=True)
|
|
288
|
+
return False
|
|
289
|
+
except Exception as e:
|
|
290
|
+
print(f"[embed] per-instance norm check errored ({type(e).__name__}: {e}); keeping baseline", flush=True)
|
|
291
|
+
return False
|
|
292
|
+
|
|
293
|
+
# Stash the current step's input_ids on the text model so the patched layer-0 norm
|
|
294
|
+
# can re-gather fused. embed_tokens still produces the raw embedding (residual).
|
|
295
|
+
_orig_embed_forward = embed.forward
|
|
296
|
+
|
|
297
|
+
def _embed_forward(input_ids, *args, **kwargs):
|
|
298
|
+
try:
|
|
299
|
+
tm._chalk_last_ids = input_ids
|
|
300
|
+
except Exception:
|
|
301
|
+
tm._chalk_last_ids = None
|
|
302
|
+
return _orig_embed_forward(input_ids, *args, **kwargs)
|
|
303
|
+
|
|
304
|
+
embed_table = table # capture (tied weight, frozen)
|
|
305
|
+
|
|
306
|
+
def _ln_forward(hidden_states, *args, **kwargs):
|
|
307
|
+
# SINGLE-USE STASH: consume the recorded ids unconditionally at entry (read into
|
|
308
|
+
# a local, immediately clear it on the model). The stash is only valid for the
|
|
309
|
+
# ONE layer-0 norm call that directly follows THIS step's embed_tokens gather; by
|
|
310
|
+
# clearing it before we branch, a forward that supplies inputs_embeds directly
|
|
311
|
+
# (so embed_tokens never ran and the stash is stale from a prior step) can never
|
|
312
|
+
# re-gather table[stale_ids] in place of the real hidden_states, even if its
|
|
313
|
+
# token count coincidentally matches.
|
|
314
|
+
ids = getattr(tm, "_chalk_last_ids", None)
|
|
315
|
+
tm._chalk_last_ids = None
|
|
316
|
+
# Take the fused path only when: we recorded this step's ids; the incoming
|
|
317
|
+
# hidden_states are the raw embedding (same shape as a fresh gather) AND no
|
|
318
|
+
# backward needs to flow through this norm (grad disabled, or both the table
|
|
319
|
+
# and the norm weight are frozen). Anything else -> original differentiable
|
|
320
|
+
# forward (exact, never drops gradients).
|
|
321
|
+
try:
|
|
322
|
+
if (
|
|
323
|
+
ids is not None
|
|
324
|
+
and isinstance(ids, torch.Tensor)
|
|
325
|
+
and ids.is_cuda
|
|
326
|
+
and hidden_states.ndim >= 2
|
|
327
|
+
and hidden_states.shape[-1] == embed_table.shape[1]
|
|
328
|
+
and hidden_states.numel() // hidden_states.shape[-1] == ids.numel()
|
|
329
|
+
and (not torch.is_grad_enabled() or (not embed_table.requires_grad and not ln.weight.requires_grad))
|
|
330
|
+
):
|
|
331
|
+
out = fn(ids, embed_table, ln.weight, eps=eps, offset=offset)
|
|
332
|
+
return out.reshape(hidden_states.shape)
|
|
333
|
+
except Exception:
|
|
334
|
+
pass # any runtime hiccup -> exact original path below
|
|
335
|
+
return _orig_ln_forward(hidden_states, *args, **kwargs)
|
|
336
|
+
|
|
337
|
+
embed.forward = _embed_forward
|
|
338
|
+
ln.forward = _ln_forward
|
|
339
|
+
ln._chalk_embed_patched = True
|
|
340
|
+
print(
|
|
341
|
+
"[embed] fused Triton gather+RMSNorm installed on layer-0 input_layernorm "
|
|
342
|
+
"(forward-only, frozen-embed path)",
|
|
343
|
+
flush=True,
|
|
344
|
+
)
|
|
345
|
+
return True
|
|
346
|
+
except Exception as e: # pragma: no cover - defensive
|
|
347
|
+
print(f"[embed] install failed ({type(e).__name__}: {e}); keeping baseline", flush=True)
|
|
348
|
+
return False
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
if __name__ == "__main__": # manual self-test / smoke
|
|
352
|
+
fn = load_fused_embedding()
|
|
353
|
+
print("fused gather+rmsnorm loaded:", fn is not None)
|