freesolo-flash-dev 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flash/__init__.py +29 -0
- flash/_channel.py +23 -0
- flash/_fileio.py +35 -0
- flash/_logging.py +49 -0
- flash/_update_check.py +266 -0
- flash/catalog.py +253 -0
- flash/cli/__init__.py +1 -0
- flash/cli/main/__init__.py +227 -0
- flash/cli/main/__main__.py +6 -0
- flash/cli/main/commands.py +636 -0
- flash/cli/main/envpush.py +317 -0
- flash/cli/main/render.py +599 -0
- flash/cli/main/training_doc.py +455 -0
- flash/client/__init__.py +14 -0
- flash/client/config.py +70 -0
- flash/client/http.py +372 -0
- flash/client/runtime_secrets.py +69 -0
- flash/client/specs.py +20 -0
- flash/cost/__init__.py +16 -0
- flash/cost/analytical.py +175 -0
- flash/cost/facts.py +114 -0
- flash/cost/spec.py +113 -0
- flash/cost/types.py +158 -0
- flash/engine/__init__.py +6 -0
- flash/engine/accounting.py +36 -0
- flash/engine/chalk_kernels.py +116 -0
- flash/engine/multiturn_rollout.py +780 -0
- flash/engine/recipe.py +86 -0
- flash/engine/vram.py +603 -0
- flash/engine/worker/__init__.py +2916 -0
- flash/engine/worker/__main__.py +4 -0
- flash/engine/worker/kernel_warmup.py +400 -0
- flash/engine/worker/lora.py +796 -0
- flash/engine/worker/packing.py +366 -0
- flash/engine/worker/perf.py +1048 -0
- flash/envs/__init__.py +10 -0
- flash/envs/adapter/__init__.py +883 -0
- flash/envs/adapter/rubric.py +222 -0
- flash/envs/base.py +52 -0
- flash/envs/registry.py +62 -0
- flash/mcp/__init__.py +1 -0
- flash/mcp/server.py +85 -0
- flash/providers/__init__.py +59 -0
- flash/providers/_auth.py +24 -0
- flash/providers/_http.py +230 -0
- flash/providers/_instance.py +416 -0
- flash/providers/_instance_bootstrap.py +517 -0
- flash/providers/_poll.py +311 -0
- flash/providers/allocator.py +193 -0
- flash/providers/base.py +431 -0
- flash/providers/hyperstack/__init__.py +127 -0
- flash/providers/hyperstack/api.py +522 -0
- flash/providers/hyperstack/auth.py +17 -0
- flash/providers/hyperstack/gpus.py +29 -0
- flash/providers/hyperstack/jobs/__init__.py +632 -0
- flash/providers/hyperstack/jobs/builders.py +122 -0
- flash/providers/hyperstack/preflight.py +23 -0
- flash/providers/hyperstack/pricing.py +26 -0
- flash/providers/hyperstack/train.py +25 -0
- flash/providers/lambdalabs/__init__.py +139 -0
- flash/providers/lambdalabs/api.py +261 -0
- flash/providers/lambdalabs/auth.py +18 -0
- flash/providers/lambdalabs/gpus.py +29 -0
- flash/providers/lambdalabs/jobs/__init__.py +724 -0
- flash/providers/lambdalabs/jobs/builders.py +118 -0
- flash/providers/lambdalabs/preflight.py +27 -0
- flash/providers/lambdalabs/pricing.py +51 -0
- flash/providers/lambdalabs/train.py +27 -0
- flash/providers/preflight.py +55 -0
- flash/providers/realized.py +80 -0
- flash/providers/runpod/__init__.py +130 -0
- flash/providers/runpod/api.py +186 -0
- flash/providers/runpod/auth.py +37 -0
- flash/providers/runpod/cost.py +57 -0
- flash/providers/runpod/gpus.py +46 -0
- flash/providers/runpod/jobs.py +956 -0
- flash/providers/runpod/keys.py +139 -0
- flash/providers/runpod/preflight.py +30 -0
- flash/providers/runpod/preload.py +915 -0
- flash/providers/runpod/pricing.py +18 -0
- flash/providers/runpod/slots.py +79 -0
- flash/providers/runpod/train/__init__.py +150 -0
- flash/providers/runpod/train/deps.py +395 -0
- flash/providers/runpod/train/endpoints.py +820 -0
- flash/py.typed +0 -0
- flash/runner/__init__.py +686 -0
- flash/runner/checkpoints.py +82 -0
- flash/runner/deploy.py +422 -0
- flash/runner/lifecycle.py +672 -0
- flash/schema/__init__.py +375 -0
- flash/schema/fields.py +331 -0
- flash/serve/__init__.py +1 -0
- flash/serve/deploy.py +326 -0
- flash/serve/pricing.py +60 -0
- flash/server/__init__.py +1 -0
- flash/server/__main__.py +20 -0
- flash/server/app.py +961 -0
- flash/server/auth.py +263 -0
- flash/server/billing.py +124 -0
- flash/server/checkpoints.py +110 -0
- flash/server/db.py +160 -0
- flash/server/environment_registry.py +102 -0
- flash/server/envs.py +360 -0
- flash/server/reconcile.py +163 -0
- flash/server/run_registry.py +150 -0
- flash/spec.py +333 -0
- freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
- freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
- freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
- freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
- freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
"""True token packing with a block-diagonal SDPA attention mask.
|
|
2
|
+
|
|
3
|
+
Concatenate short SFT examples into ``max_length`` blocks and feed the trainer a 4D
|
|
4
|
+
**block-diagonal causal** attention mask so packed examples never attend across their
|
|
5
|
+
boundaries. Crucially this is boundary-correct under PLAIN SDPA — it needs neither
|
|
6
|
+
``flash_attn`` (no prebuilt wheel for torch 2.10 / no sm120 kernel) nor ``flex_attention``
|
|
7
|
+
(unsupported on the Qwen3.5/3.6 arch). It is exactly what lets packing run on flash's DEFAULT
|
|
8
|
+
RTX 5090 (sm120), where the FA2/FA3 varlen path the worker otherwise relies on is unavailable,
|
|
9
|
+
and on any arch whose flash-attn build did not land.
|
|
10
|
+
|
|
11
|
+
Why packing is a win: instruction targets are far shorter than ``max_seq_len``, so an unpacked
|
|
12
|
+
batch spends most of its FLOPs on padding. Concatenating examples into full blocks removes that
|
|
13
|
+
waste (PR #174 measured 4.4-10.7x on the FA2 path; the SDPA-mask path keeps the same packing win
|
|
14
|
+
minus the block-sparse-attention speedup FA2 varlen gives, so ~1.5-2x in practice). The dense
|
|
15
|
+
[T,T] mask is O(T^2) memory, but attention is a small fraction of total FLOPs for these models,
|
|
16
|
+
so the masked-attention overhead is dwarfed by the pad-removal win.
|
|
17
|
+
|
|
18
|
+
GATING — pure full-attention only. A 4D mask isolates examples only in layers that READ the
|
|
19
|
+
attention mask. Hybrid GatedDeltaNet models (Qwen3.5/3.6) interleave linear-attention layers
|
|
20
|
+
whose recurrence + short causal conv1d carry state ACROSS example boundaries regardless of any
|
|
21
|
+
mask — their boundaries reset only via the ``fla`` kernel's ``cu_seq_lens_q/k`` and
|
|
22
|
+
``causal_conv1d``'s ``seq_idx``. So a pure full-attention arch (``model_is_pure_attention``) packs with the 4D mask
|
|
23
|
+
alone, while a GDN hybrid ALSO needs the varlen path: ``BlockDiagonalCollator(emit_varlen=True)``
|
|
24
|
+
emits ``cu_seq_lens_q/k`` + ``seq_idx``, gated on both kernels being importable + arch-correct
|
|
25
|
+
(``gdn_packing_available`` + ``model_is_gdn_hybrid``). Without those kernels the hybrid tier stays
|
|
26
|
+
unpacked.
|
|
27
|
+
|
|
28
|
+
This is a leaf module: torch is imported lazily inside the collator so it stays CPU-importable
|
|
29
|
+
(the arch probe needs only ``transformers.AutoConfig``). ``flash.engine.worker`` re-exports the
|
|
30
|
+
public names.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
from __future__ import annotations
|
|
34
|
+
|
|
35
|
+
from dataclasses import dataclass
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _text_config(cfg):
|
|
39
|
+
"""The decoder/LM sub-config. Multimodal checkpoints (Qwen3.5-VL) keep the LM dims under
|
|
40
|
+
``text_config``; read it when present so the layer-type probe sees the real decoder."""
|
|
41
|
+
return getattr(cfg, "text_config", None) or cfg
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def model_is_pure_attention(model_id: str) -> bool:
|
|
45
|
+
"""True only when EVERY decoder layer is full softmax attention, so a 4D block-diagonal mask
|
|
46
|
+
fully isolates packed examples under SDPA. Config-only probe (no weights, no CUDA). Returns
|
|
47
|
+
safe-False on any error or on a hybrid / linear-attention / sliding-window arch.
|
|
48
|
+
|
|
49
|
+
Excluded (return False):
|
|
50
|
+
* GatedDeltaNet hybrids (Qwen3.5/3.6): ``layer_types`` contains ``"linear_attention"`` (their
|
|
51
|
+
recurrence/conv cross boundaries a mask can't reset), or the config declares linear-attn
|
|
52
|
+
dims directly.
|
|
53
|
+
* Sliding-window models (e.g. Gemma): a layer typed ``"sliding_attention"`` applies a window
|
|
54
|
+
the model builds itself — passing a pre-built 4D mask BYPASSES that window (wrong
|
|
55
|
+
semantics), so exclude them too. Only ``"full_attention"`` everywhere is safe.
|
|
56
|
+
|
|
57
|
+
Included (return True): standard dense decoders (Llama/MiniCPM5, Qwen2/Qwen3) that expose no
|
|
58
|
+
per-layer ``layer_types`` and no linear-attn dims — every layer reads the mask.
|
|
59
|
+
"""
|
|
60
|
+
try:
|
|
61
|
+
from transformers import AutoConfig
|
|
62
|
+
|
|
63
|
+
cfg = _text_config(AutoConfig.from_pretrained(model_id, trust_remote_code=True))
|
|
64
|
+
layer_types = getattr(cfg, "layer_types", None)
|
|
65
|
+
if layer_types:
|
|
66
|
+
return all(t == "full_attention" for t in layer_types)
|
|
67
|
+
# No per-layer types: still exclude anything that advertises a linear-attention (DeltaNet)
|
|
68
|
+
# block via its dims — a hybrid arch can omit layer_types but always sets these.
|
|
69
|
+
for attr in ("linear_num_key_heads", "linear_key_head_dim", "linear_conv_kernel_dim"):
|
|
70
|
+
if getattr(cfg, attr, None):
|
|
71
|
+
return False
|
|
72
|
+
# A GLOBALLY sliding-window model (no per-layer layer_types, e.g. Mistral / Qwen2 configs)
|
|
73
|
+
# builds its own LOCAL-attention causal mask; a pre-built full block-diagonal mask would
|
|
74
|
+
# BYPASS the window and train with global attention instead of the checkpoint's intended
|
|
75
|
+
# local attention. Exclude when a window is configured AND active: honor use_sliding_window
|
|
76
|
+
# when the config exposes it (Qwen2.5 ships a sliding_window value but DISABLES it via
|
|
77
|
+
# use_sliding_window=False -> still packs), else assume a configured window is active
|
|
78
|
+
# (Mistral-style configs have no such flag).
|
|
79
|
+
sliding = getattr(cfg, "sliding_window", None)
|
|
80
|
+
return not (sliding and getattr(cfg, "use_sliding_window", True))
|
|
81
|
+
except Exception as e: # network/parse/arch failure -> do NOT pack (boundary-safe default)
|
|
82
|
+
print(f"[pack] pure-attention probe failed for {model_id!r} (treating as NOT pure): {e}")
|
|
83
|
+
return False
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def model_is_gdn_hybrid(model_id: str) -> bool:
|
|
87
|
+
"""True for a GatedDeltaNet *hybrid* (Qwen3.5/3.6): the config interleaves ``"linear_attention"``
|
|
88
|
+
layers with full attention. These need the varlen GDN path (cu_seqlens + seq_idx) to pack
|
|
89
|
+
boundary-correctly — a 4D mask alone can't reset their recurrent/conv state. Distinct from the
|
|
90
|
+
sliding-window case (also non-pure, but NOT packable this way). Config-only; safe-False on error.
|
|
91
|
+
"""
|
|
92
|
+
try:
|
|
93
|
+
from transformers import AutoConfig
|
|
94
|
+
|
|
95
|
+
cfg = _text_config(AutoConfig.from_pretrained(model_id, trust_remote_code=True))
|
|
96
|
+
layer_types = getattr(cfg, "layer_types", None)
|
|
97
|
+
if layer_types and any(t == "linear_attention" for t in layer_types):
|
|
98
|
+
return True
|
|
99
|
+
# No layer_types but linear-attn dims declared -> still a GDN hybrid.
|
|
100
|
+
return any(
|
|
101
|
+
getattr(cfg, a, None)
|
|
102
|
+
for a in ("linear_num_key_heads", "linear_key_head_dim", "linear_conv_kernel_dim")
|
|
103
|
+
)
|
|
104
|
+
except Exception as e:
|
|
105
|
+
print(f"[pack] gdn-hybrid probe failed for {model_id!r} (treating as NOT gdn): {e}")
|
|
106
|
+
return False
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _gdn_forward_threads_reset_kwargs(model_id: str | None) -> bool:
|
|
110
|
+
"""Does THIS model's GatedDeltaNet forward actually thread cu_seq_lens_q AND seq_idx? Different GDN
|
|
111
|
+
families live in different modeling modules (qwen3_5 -> modeling_qwen3_5.Qwen3_5GatedDeltaNet, a
|
|
112
|
+
future qwen3_6 -> modeling_qwen3_6.Qwen3_6GatedDeltaNet), so resolve the ACTUAL arch from the
|
|
113
|
+
model's config and probe ITS DeltaNet class — a hardcoded qwen3_5 probe would wrongly pass for an
|
|
114
|
+
arch that drops the kwargs (or whose layer hard-codes seq_idx=None on an older transformers).
|
|
115
|
+
Falls back to qwen3_5 when no model_id is given. Safe-False on any failure."""
|
|
116
|
+
try:
|
|
117
|
+
import importlib
|
|
118
|
+
import inspect
|
|
119
|
+
|
|
120
|
+
model_type = "qwen3_5"
|
|
121
|
+
if model_id:
|
|
122
|
+
from transformers import AutoConfig
|
|
123
|
+
|
|
124
|
+
cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
|
|
125
|
+
model_type = getattr(cfg, "model_type", None) or model_type
|
|
126
|
+
mod = importlib.import_module(f"transformers.models.{model_type}.modeling_{model_type}")
|
|
127
|
+
gdn_cls = next(
|
|
128
|
+
(c for n, c in vars(mod).items()
|
|
129
|
+
if isinstance(c, type) and n.endswith("GatedDeltaNet")),
|
|
130
|
+
None,
|
|
131
|
+
)
|
|
132
|
+
if gdn_cls is None:
|
|
133
|
+
return False
|
|
134
|
+
fwd = inspect.getsource(gdn_cls.forward)
|
|
135
|
+
return ("cu_seq_lens_q" in fwd) and ("seq_idx" in fwd)
|
|
136
|
+
except Exception:
|
|
137
|
+
return False
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def gdn_packing_available(model_id: str | None = None) -> bool:
|
|
141
|
+
"""True only when BOTH varlen kernels a GatedDeltaNet hybrid needs to pack boundary-correctly are
|
|
142
|
+
importable: ``flash-linear-attention`` (resets the DeltaNet recurrence via ``cu_seq_lens_q/k`` — the
|
|
143
|
+
pure-torch fallback IGNORES it) AND ``causal_conv1d`` (resets the short causal conv via
|
|
144
|
+
``seq_idx``). Without both, a packed GDN run would cross-contaminate across example boundaries,
|
|
145
|
+
so packing must stay off. GPU-validated (RTX 5090, Qwen3.5-0.8B): with both present, a packed
|
|
146
|
+
example's outputs are byte-identical regardless of its neighbors' content (zero information
|
|
147
|
+
leakage); the only difference vs unpacked is benign bf16 kernel-tiling numerics (~0.3 on logits,
|
|
148
|
+
the same order as flash-attn-vs-SDPA drift).
|
|
149
|
+
|
|
150
|
+
Two guards beyond the find_spec probes: (a) REALLY import causal_conv1d — its availability check
|
|
151
|
+
is find_spec-based, so a built-but-broken wheel (ABI/symbol mismatch) would pass it and then crash
|
|
152
|
+
at model load; (b) verify the INSTALLED Qwen3.5 DeltaNet forward actually threads cu_seq_lens_q AND
|
|
153
|
+
seq_idx — transformers 5.6-5.8 hard-coded seq_idx=None / dropped cu_seq_lens_q, so on those builds
|
|
154
|
+
the collator's reset kwargs are silently ignored and packed examples would still leak. Either
|
|
155
|
+
guard failing -> packing stays off (the model trains unpacked, safely)."""
|
|
156
|
+
try:
|
|
157
|
+
import importlib
|
|
158
|
+
|
|
159
|
+
from transformers.utils.import_utils import (
|
|
160
|
+
is_causal_conv1d_available,
|
|
161
|
+
is_flash_linear_attention_available,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
if not (is_flash_linear_attention_available() and is_causal_conv1d_available()):
|
|
165
|
+
return False
|
|
166
|
+
importlib.import_module("causal_conv1d") # (a) fail a built-but-broken wheel here, not at load
|
|
167
|
+
if not _gdn_forward_threads_reset_kwargs(model_id): # (b) version/API gate, per ACTUAL arch
|
|
168
|
+
return False
|
|
169
|
+
# (c) RUN the conv kernel on the LIVE GPU: a causal_conv1d wheel compiled WITHOUT this device's
|
|
170
|
+
# arch imports fine but raises "CUDA error: no kernel image is available for execution on the
|
|
171
|
+
# device" at the FIRST forward — which would crash the run mid-train. A tiny conv here surfaces
|
|
172
|
+
# that now so packing stays off and the model trains unpacked instead. (fla's kernels are
|
|
173
|
+
# Triton-JIT — always compiled for the present arch — so they need no such smoke.)
|
|
174
|
+
import torch
|
|
175
|
+
|
|
176
|
+
if torch.cuda.is_available():
|
|
177
|
+
from causal_conv1d import causal_conv1d_fn
|
|
178
|
+
|
|
179
|
+
_x = torch.zeros(1, 4, 8, device="cuda", dtype=torch.bfloat16)
|
|
180
|
+
_w = torch.zeros(4, 3, device="cuda", dtype=torch.bfloat16)
|
|
181
|
+
causal_conv1d_fn(_x, _w)
|
|
182
|
+
torch.cuda.synchronize()
|
|
183
|
+
return True
|
|
184
|
+
except Exception:
|
|
185
|
+
return False
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def pack_token_ids(sequences: list[list[int]], max_length: int) -> list[dict]:
|
|
189
|
+
"""Greedily bin-pack tokenized examples into blocks of at most ``max_length`` tokens WITHOUT
|
|
190
|
+
splitting an example (first-fit-decreasing, like TRL's ``bfd``: tighter blocks = less padding).
|
|
191
|
+
|
|
192
|
+
An example longer than ``max_length`` is truncated to a single full-length block (matches the
|
|
193
|
+
unpacked trainer's right-truncation). Empty sequences are dropped. Returns rows shaped
|
|
194
|
+
``{"input_ids": [...], "seq_lengths": [l1, l2, ...]}`` where ``sum(seq_lengths) == len(input_ids)``
|
|
195
|
+
— the collator turns ``seq_lengths`` into the block-diagonal mask + per-example position_ids.
|
|
196
|
+
"""
|
|
197
|
+
if max_length <= 0:
|
|
198
|
+
raise ValueError(f"max_length must be positive, got {max_length}")
|
|
199
|
+
seqs = [s[:max_length] for s in sequences if s]
|
|
200
|
+
# First-fit-decreasing: place the longest examples first so the small ones fill the gaps.
|
|
201
|
+
order = sorted(range(len(seqs)), key=lambda i: len(seqs[i]), reverse=True)
|
|
202
|
+
bins: list[dict] = [] # each: {"input_ids": [...], "seq_lengths": [...], "free": int}
|
|
203
|
+
for i in order:
|
|
204
|
+
s = seqs[i]
|
|
205
|
+
need = len(s)
|
|
206
|
+
for b in bins:
|
|
207
|
+
if b["free"] >= need:
|
|
208
|
+
b["input_ids"].extend(s)
|
|
209
|
+
b["seq_lengths"].append(need)
|
|
210
|
+
b["free"] -= need
|
|
211
|
+
break
|
|
212
|
+
else: # no open bin fits -> start a new one
|
|
213
|
+
bins.append({"input_ids": list(s), "seq_lengths": [need], "free": max_length - need})
|
|
214
|
+
return [{"input_ids": b["input_ids"], "seq_lengths": b["seq_lengths"]} for b in bins]
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def packing_efficiency(rows: list[dict], max_length: int) -> float:
|
|
218
|
+
"""Fraction of block capacity filled with real tokens (1.0 = no padding). Diagnostic only."""
|
|
219
|
+
if not rows or max_length <= 0:
|
|
220
|
+
return 0.0
|
|
221
|
+
real = sum(sum(r["seq_lengths"]) for r in rows)
|
|
222
|
+
return real / (len(rows) * max_length)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def tokenize_for_packing(texts: list[str], tokenizer, max_length: int) -> list[list[int]]:
|
|
226
|
+
"""Tokenize chat-templated ``text`` rows for packing, MATCHING TRL's non-packed SFT prep EXACTLY
|
|
227
|
+
so a packed run trains on the SAME token sequences as the unpacked/FA2 path (no quality drift):
|
|
228
|
+
* append the EOS token to any row that doesn't already end with it — TRL's add_eos step does
|
|
229
|
+
this for the language-modeling ``text`` case, and skipping it would stop teaching the model
|
|
230
|
+
the final stop token (it'd never learn to halt);
|
|
231
|
+
* tokenize with the tokenizer's DEFAULT add_special_tokens — TRL's ``_tokenize`` for a non-
|
|
232
|
+
conversational ``text`` field calls ``processing_class(text=input)`` with no override, so for
|
|
233
|
+
Llama-family tokenizers (e.g. the MiniCPM pure-attention tier) it prepends BOS. Forcing
|
|
234
|
+
add_special_tokens=False here would drop that BOS and diverge from the unpacked path. (Qwen
|
|
235
|
+
tokenizers have no BOS, so the Qwen3.x / GDN tier is unaffected either way.)
|
|
236
|
+
* truncate to ``max_length`` (same cap pack_token_ids would apply) so a pathological long row
|
|
237
|
+
never materializes a huge id list; batched (one call) for speed.
|
|
238
|
+
"""
|
|
239
|
+
eos = tokenizer.eos_token or ""
|
|
240
|
+
rows = [t if (eos and t.endswith(eos)) else t + eos for t in texts]
|
|
241
|
+
enc = tokenizer(rows, truncation=True, max_length=max_length) # default add_special_tokens (TRL parity)
|
|
242
|
+
return enc["input_ids"]
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
# Process-local cache of the lower-triangular causal matrix: the collator runs on every batch, and
|
|
246
|
+
# torch.tril(torch.ones(T, T)) is a non-trivial CPU alloc at T=2048+. Keep the LARGEST one seen and
|
|
247
|
+
# slice it for smaller T (it's read-only). Dataloader workers are separate processes, so each holds
|
|
248
|
+
# its own copy — no cross-thread race.
|
|
249
|
+
_CAUSAL_TRIL: dict = {}
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def _causal_lower_triangular(total: int, torch):
|
|
253
|
+
cached = _CAUSAL_TRIL.get("m")
|
|
254
|
+
if cached is None or cached.shape[0] < total:
|
|
255
|
+
cached = torch.tril(torch.ones(total, total, dtype=torch.bool))
|
|
256
|
+
_CAUSAL_TRIL["m"] = cached
|
|
257
|
+
return cached[:total, :total]
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
@dataclass
|
|
261
|
+
class BlockDiagonalCollator:
|
|
262
|
+
"""Collate pre-packed rows (from :func:`pack_token_ids`) into a batch whose 4D **block-diagonal
|
|
263
|
+
causal** attention mask keeps packed examples from attending across their boundaries under
|
|
264
|
+
PLAIN SDPA — no flash-attn, no flex_attention.
|
|
265
|
+
|
|
266
|
+
Emits per batch:
|
|
267
|
+
* ``input_ids`` ``[B, T]`` (right-padded with ``pad_token_id``)
|
|
268
|
+
* ``attention_mask`` ``[B, 1, T, T]`` BOOL — ``True`` = query may attend key. Block-diagonal
|
|
269
|
+
(same example) AND causal (key <= query). A bool mask is dtype-agnostic, so it composes
|
|
270
|
+
with bf16/fp16 runs without an ``-inf`` dtype mismatch. Pad tokens form their own segment
|
|
271
|
+
so no query row is all-False (which would NaN the softmax); pad rows never contribute to
|
|
272
|
+
loss (their labels are -100) and real tokens never attend pad keys.
|
|
273
|
+
* ``position_ids`` ``[B, T]`` reset to 0 at each example start (RoPE per example)
|
|
274
|
+
* ``labels`` ``[B, T]`` = ``input_ids`` for real tokens, with each example's FIRST
|
|
275
|
+
token set to -100 (so the cross-boundary next-token pair is never scored — matches the
|
|
276
|
+
unpacked trainer, whose first token is also never a target after HF's internal shift) and
|
|
277
|
+
pad set to -100.
|
|
278
|
+
|
|
279
|
+
``pad_to_multiple_of`` rounds T up (tensor-core friendliness); the extra positions are pad.
|
|
280
|
+
|
|
281
|
+
``emit_varlen`` (GatedDeltaNet hybrids, e.g. Qwen3.5/3.6): additionally emit ``cu_seq_lens_q/k``
|
|
282
|
+
(resets the DeltaNet recurrence per example in the fla kernel) and ``seq_idx`` (resets the causal
|
|
283
|
+
conv in causal_conv1d) so the LINEAR-attention layers are boundary-correct too — the 4D mask only
|
|
284
|
+
fixes the full-attention layers. This path requires ``per_device_train_batch_size == 1`` (one
|
|
285
|
+
packed block per step; cu_seqlens spans that block) and does NOT pad (cu_seqlens must cover the
|
|
286
|
+
whole row), so set ``pad_to_multiple_of`` irrelevant here.
|
|
287
|
+
"""
|
|
288
|
+
|
|
289
|
+
pad_token_id: int
|
|
290
|
+
label_pad_token_id: int = -100
|
|
291
|
+
pad_to_multiple_of: int = 8
|
|
292
|
+
emit_varlen: bool = False
|
|
293
|
+
|
|
294
|
+
def __call__(self, features: list[dict]) -> dict:
|
|
295
|
+
import torch
|
|
296
|
+
|
|
297
|
+
rows = [list(f["input_ids"]) for f in features]
|
|
298
|
+
seglens = [list(f["seq_lengths"]) for f in features]
|
|
299
|
+
bsz = len(rows)
|
|
300
|
+
if self.emit_varlen and bsz != 1:
|
|
301
|
+
raise ValueError("emit_varlen packing requires per_device_train_batch_size == 1")
|
|
302
|
+
# Fail fast on a broken row rather than silently mis-tag tokens as pad (or vice versa): the
|
|
303
|
+
# whole mask/labels/cu_seqlens construction assumes sum(seq_lengths) == len(input_ids).
|
|
304
|
+
for ids, lens in zip(rows, seglens, strict=True):
|
|
305
|
+
if sum(lens) != len(ids):
|
|
306
|
+
raise ValueError(
|
|
307
|
+
f"packed row invariant broken: sum(seq_lengths)={sum(lens)} != "
|
|
308
|
+
f"len(input_ids)={len(ids)} (rows must come from pack_token_ids)"
|
|
309
|
+
)
|
|
310
|
+
longest = max((len(r) for r in rows), default=0)
|
|
311
|
+
m = self.pad_to_multiple_of
|
|
312
|
+
# No padding on the varlen path: cu_seqlens must cover the whole sequence (a trailing pad
|
|
313
|
+
# region not spanned by cu_seqlens would break the fla varlen kernel).
|
|
314
|
+
total = longest if self.emit_varlen else (((longest + m - 1) // m) * m if m and m > 1 else longest)
|
|
315
|
+
total = max(total, 1)
|
|
316
|
+
|
|
317
|
+
input_ids = torch.full((bsz, total), self.pad_token_id, dtype=torch.long)
|
|
318
|
+
position_ids = torch.zeros((bsz, total), dtype=torch.long)
|
|
319
|
+
# segment id per token: 0..k-1 for the k examples in the block, -1 for trailing pad.
|
|
320
|
+
seg = torch.full((bsz, total), -1, dtype=torch.long)
|
|
321
|
+
|
|
322
|
+
for b, (ids, lens) in enumerate(zip(rows, seglens, strict=True)):
|
|
323
|
+
n = len(ids)
|
|
324
|
+
input_ids[b, :n] = torch.tensor(ids, dtype=torch.long)
|
|
325
|
+
start = 0
|
|
326
|
+
for ex_idx, length in enumerate(lens):
|
|
327
|
+
end = start + length
|
|
328
|
+
position_ids[b, start:end] = torch.arange(length)
|
|
329
|
+
seg[b, start:end] = ex_idx
|
|
330
|
+
start = end
|
|
331
|
+
|
|
332
|
+
# Block-diagonal causal mask, fully vectorized:
|
|
333
|
+
# same-example: seg[q] == seg[k] (pad shares segment -1, so pad rows attend pad -> no
|
|
334
|
+
# all-False row; real tokens never attend pad because real seg != -1)
|
|
335
|
+
# causal: k <= q
|
|
336
|
+
same = seg.unsqueeze(2) == seg.unsqueeze(1) # [B, T, T]
|
|
337
|
+
causal = _causal_lower_triangular(total, torch) # cached + sliced (not rebuilt per batch)
|
|
338
|
+
attention_mask = (same & causal).unsqueeze(1) # [B, 1, T, T]
|
|
339
|
+
|
|
340
|
+
# Labels: real tokens predict their own continuation; first token of each example (and all
|
|
341
|
+
# pad) -> ignore. position_ids == 0 marks exactly each example's first token (pad is 0 too,
|
|
342
|
+
# and pad is already excluded below), so the boundary next-token pair is never scored.
|
|
343
|
+
labels = input_ids.clone()
|
|
344
|
+
labels[seg < 0] = self.label_pad_token_id
|
|
345
|
+
labels[position_ids == 0] = self.label_pad_token_id
|
|
346
|
+
|
|
347
|
+
batch = {
|
|
348
|
+
"input_ids": input_ids,
|
|
349
|
+
"attention_mask": attention_mask,
|
|
350
|
+
"position_ids": position_ids,
|
|
351
|
+
"labels": labels,
|
|
352
|
+
}
|
|
353
|
+
if self.emit_varlen:
|
|
354
|
+
# bsz == 1 (asserted above): cu_seqlens covers this one block's examples, and seq_idx is
|
|
355
|
+
# the per-token segment id (no pad on this path, so seg has no -1). These reach the
|
|
356
|
+
# linear-attention layers via model(**batch) -> the fla chunk kernel (cu_seq_lens_q) and
|
|
357
|
+
# causal_conv1d (seq_idx), resetting their state at each example boundary.
|
|
358
|
+
lens = seglens[0]
|
|
359
|
+
cu = torch.zeros(len(lens) + 1, dtype=torch.int32)
|
|
360
|
+
cu[1:] = torch.tensor(lens, dtype=torch.int32).cumsum(0)
|
|
361
|
+
batch["cu_seq_lens_q"] = cu
|
|
362
|
+
batch["cu_seq_lens_k"] = cu
|
|
363
|
+
batch["max_length_q"] = int(max(lens))
|
|
364
|
+
batch["max_length_k"] = int(max(lens))
|
|
365
|
+
batch["seq_idx"] = seg.to(torch.int32) # [1, T], non-negative (no pad on this path)
|
|
366
|
+
return batch
|