freesolo-flash-dev 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. flash/__init__.py +29 -0
  2. flash/_channel.py +23 -0
  3. flash/_fileio.py +35 -0
  4. flash/_logging.py +49 -0
  5. flash/_update_check.py +266 -0
  6. flash/catalog.py +253 -0
  7. flash/cli/__init__.py +1 -0
  8. flash/cli/main/__init__.py +227 -0
  9. flash/cli/main/__main__.py +6 -0
  10. flash/cli/main/commands.py +636 -0
  11. flash/cli/main/envpush.py +317 -0
  12. flash/cli/main/render.py +599 -0
  13. flash/cli/main/training_doc.py +455 -0
  14. flash/client/__init__.py +14 -0
  15. flash/client/config.py +70 -0
  16. flash/client/http.py +372 -0
  17. flash/client/runtime_secrets.py +69 -0
  18. flash/client/specs.py +20 -0
  19. flash/cost/__init__.py +16 -0
  20. flash/cost/analytical.py +175 -0
  21. flash/cost/facts.py +114 -0
  22. flash/cost/spec.py +113 -0
  23. flash/cost/types.py +158 -0
  24. flash/engine/__init__.py +6 -0
  25. flash/engine/accounting.py +36 -0
  26. flash/engine/chalk_kernels.py +116 -0
  27. flash/engine/multiturn_rollout.py +780 -0
  28. flash/engine/recipe.py +86 -0
  29. flash/engine/vram.py +603 -0
  30. flash/engine/worker/__init__.py +2916 -0
  31. flash/engine/worker/__main__.py +4 -0
  32. flash/engine/worker/kernel_warmup.py +400 -0
  33. flash/engine/worker/lora.py +796 -0
  34. flash/engine/worker/packing.py +366 -0
  35. flash/engine/worker/perf.py +1048 -0
  36. flash/envs/__init__.py +10 -0
  37. flash/envs/adapter/__init__.py +883 -0
  38. flash/envs/adapter/rubric.py +222 -0
  39. flash/envs/base.py +52 -0
  40. flash/envs/registry.py +62 -0
  41. flash/mcp/__init__.py +1 -0
  42. flash/mcp/server.py +85 -0
  43. flash/providers/__init__.py +59 -0
  44. flash/providers/_auth.py +24 -0
  45. flash/providers/_http.py +230 -0
  46. flash/providers/_instance.py +416 -0
  47. flash/providers/_instance_bootstrap.py +517 -0
  48. flash/providers/_poll.py +311 -0
  49. flash/providers/allocator.py +193 -0
  50. flash/providers/base.py +431 -0
  51. flash/providers/hyperstack/__init__.py +127 -0
  52. flash/providers/hyperstack/api.py +522 -0
  53. flash/providers/hyperstack/auth.py +17 -0
  54. flash/providers/hyperstack/gpus.py +29 -0
  55. flash/providers/hyperstack/jobs/__init__.py +632 -0
  56. flash/providers/hyperstack/jobs/builders.py +122 -0
  57. flash/providers/hyperstack/preflight.py +23 -0
  58. flash/providers/hyperstack/pricing.py +26 -0
  59. flash/providers/hyperstack/train.py +25 -0
  60. flash/providers/lambdalabs/__init__.py +139 -0
  61. flash/providers/lambdalabs/api.py +261 -0
  62. flash/providers/lambdalabs/auth.py +18 -0
  63. flash/providers/lambdalabs/gpus.py +29 -0
  64. flash/providers/lambdalabs/jobs/__init__.py +724 -0
  65. flash/providers/lambdalabs/jobs/builders.py +118 -0
  66. flash/providers/lambdalabs/preflight.py +27 -0
  67. flash/providers/lambdalabs/pricing.py +51 -0
  68. flash/providers/lambdalabs/train.py +27 -0
  69. flash/providers/preflight.py +55 -0
  70. flash/providers/realized.py +80 -0
  71. flash/providers/runpod/__init__.py +130 -0
  72. flash/providers/runpod/api.py +186 -0
  73. flash/providers/runpod/auth.py +37 -0
  74. flash/providers/runpod/cost.py +57 -0
  75. flash/providers/runpod/gpus.py +46 -0
  76. flash/providers/runpod/jobs.py +956 -0
  77. flash/providers/runpod/keys.py +139 -0
  78. flash/providers/runpod/preflight.py +30 -0
  79. flash/providers/runpod/preload.py +915 -0
  80. flash/providers/runpod/pricing.py +18 -0
  81. flash/providers/runpod/slots.py +79 -0
  82. flash/providers/runpod/train/__init__.py +150 -0
  83. flash/providers/runpod/train/deps.py +395 -0
  84. flash/providers/runpod/train/endpoints.py +820 -0
  85. flash/py.typed +0 -0
  86. flash/runner/__init__.py +686 -0
  87. flash/runner/checkpoints.py +82 -0
  88. flash/runner/deploy.py +422 -0
  89. flash/runner/lifecycle.py +672 -0
  90. flash/schema/__init__.py +375 -0
  91. flash/schema/fields.py +331 -0
  92. flash/serve/__init__.py +1 -0
  93. flash/serve/deploy.py +326 -0
  94. flash/serve/pricing.py +60 -0
  95. flash/server/__init__.py +1 -0
  96. flash/server/__main__.py +20 -0
  97. flash/server/app.py +961 -0
  98. flash/server/auth.py +263 -0
  99. flash/server/billing.py +124 -0
  100. flash/server/checkpoints.py +110 -0
  101. flash/server/db.py +160 -0
  102. flash/server/environment_registry.py +102 -0
  103. flash/server/envs.py +360 -0
  104. flash/server/reconcile.py +163 -0
  105. flash/server/run_registry.py +150 -0
  106. flash/spec.py +333 -0
  107. freesolo_flash_dev-0.2.25.dist-info/METADATA +192 -0
  108. freesolo_flash_dev-0.2.25.dist-info/RECORD +111 -0
  109. freesolo_flash_dev-0.2.25.dist-info/WHEEL +4 -0
  110. freesolo_flash_dev-0.2.25.dist-info/entry_points.txt +3 -0
  111. freesolo_flash_dev-0.2.25.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,366 @@
1
+ """True token packing with a block-diagonal SDPA attention mask.
2
+
3
+ Concatenate short SFT examples into ``max_length`` blocks and feed the trainer a 4D
4
+ **block-diagonal causal** attention mask so packed examples never attend across their
5
+ boundaries. Crucially this is boundary-correct under PLAIN SDPA — it needs neither
6
+ ``flash_attn`` (no prebuilt wheel for torch 2.10 / no sm120 kernel) nor ``flex_attention``
7
+ (unsupported on the Qwen3.5/3.6 arch). It is exactly what lets packing run on flash's DEFAULT
8
+ RTX 5090 (sm120), where the FA2/FA3 varlen path the worker otherwise relies on is unavailable,
9
+ and on any arch whose flash-attn build did not land.
10
+
11
+ Why packing is a win: instruction targets are far shorter than ``max_seq_len``, so an unpacked
12
+ batch spends most of its FLOPs on padding. Concatenating examples into full blocks removes that
13
+ waste (PR #174 measured 4.4-10.7x on the FA2 path; the SDPA-mask path keeps the same packing win
14
+ minus the block-sparse-attention speedup FA2 varlen gives, so ~1.5-2x in practice). The dense
15
+ [T,T] mask is O(T^2) memory, but attention is a small fraction of total FLOPs for these models,
16
+ so the masked-attention overhead is dwarfed by the pad-removal win.
17
+
18
+ GATING — pure full-attention only. A 4D mask isolates examples only in layers that READ the
19
+ attention mask. Hybrid GatedDeltaNet models (Qwen3.5/3.6) interleave linear-attention layers
20
+ whose recurrence + short causal conv1d carry state ACROSS example boundaries regardless of any
21
+ mask — their boundaries reset only via the ``fla`` kernel's ``cu_seq_lens_q/k`` and
22
+ ``causal_conv1d``'s ``seq_idx``. So a pure full-attention arch (``model_is_pure_attention``) packs with the 4D mask
23
+ alone, while a GDN hybrid ALSO needs the varlen path: ``BlockDiagonalCollator(emit_varlen=True)``
24
+ emits ``cu_seq_lens_q/k`` + ``seq_idx``, gated on both kernels being importable + arch-correct
25
+ (``gdn_packing_available`` + ``model_is_gdn_hybrid``). Without those kernels the hybrid tier stays
26
+ unpacked.
27
+
28
+ This is a leaf module: torch is imported lazily inside the collator so it stays CPU-importable
29
+ (the arch probe needs only ``transformers.AutoConfig``). ``flash.engine.worker`` re-exports the
30
+ public names.
31
+ """
32
+
33
+ from __future__ import annotations
34
+
35
+ from dataclasses import dataclass
36
+
37
+
38
+ def _text_config(cfg):
39
+ """The decoder/LM sub-config. Multimodal checkpoints (Qwen3.5-VL) keep the LM dims under
40
+ ``text_config``; read it when present so the layer-type probe sees the real decoder."""
41
+ return getattr(cfg, "text_config", None) or cfg
42
+
43
+
44
+ def model_is_pure_attention(model_id: str) -> bool:
45
+ """True only when EVERY decoder layer is full softmax attention, so a 4D block-diagonal mask
46
+ fully isolates packed examples under SDPA. Config-only probe (no weights, no CUDA). Returns
47
+ safe-False on any error or on a hybrid / linear-attention / sliding-window arch.
48
+
49
+ Excluded (return False):
50
+ * GatedDeltaNet hybrids (Qwen3.5/3.6): ``layer_types`` contains ``"linear_attention"`` (their
51
+ recurrence/conv cross boundaries a mask can't reset), or the config declares linear-attn
52
+ dims directly.
53
+ * Sliding-window models (e.g. Gemma): a layer typed ``"sliding_attention"`` applies a window
54
+ the model builds itself — passing a pre-built 4D mask BYPASSES that window (wrong
55
+ semantics), so exclude them too. Only ``"full_attention"`` everywhere is safe.
56
+
57
+ Included (return True): standard dense decoders (Llama/MiniCPM5, Qwen2/Qwen3) that expose no
58
+ per-layer ``layer_types`` and no linear-attn dims — every layer reads the mask.
59
+ """
60
+ try:
61
+ from transformers import AutoConfig
62
+
63
+ cfg = _text_config(AutoConfig.from_pretrained(model_id, trust_remote_code=True))
64
+ layer_types = getattr(cfg, "layer_types", None)
65
+ if layer_types:
66
+ return all(t == "full_attention" for t in layer_types)
67
+ # No per-layer types: still exclude anything that advertises a linear-attention (DeltaNet)
68
+ # block via its dims — a hybrid arch can omit layer_types but always sets these.
69
+ for attr in ("linear_num_key_heads", "linear_key_head_dim", "linear_conv_kernel_dim"):
70
+ if getattr(cfg, attr, None):
71
+ return False
72
+ # A GLOBALLY sliding-window model (no per-layer layer_types, e.g. Mistral / Qwen2 configs)
73
+ # builds its own LOCAL-attention causal mask; a pre-built full block-diagonal mask would
74
+ # BYPASS the window and train with global attention instead of the checkpoint's intended
75
+ # local attention. Exclude when a window is configured AND active: honor use_sliding_window
76
+ # when the config exposes it (Qwen2.5 ships a sliding_window value but DISABLES it via
77
+ # use_sliding_window=False -> still packs), else assume a configured window is active
78
+ # (Mistral-style configs have no such flag).
79
+ sliding = getattr(cfg, "sliding_window", None)
80
+ return not (sliding and getattr(cfg, "use_sliding_window", True))
81
+ except Exception as e: # network/parse/arch failure -> do NOT pack (boundary-safe default)
82
+ print(f"[pack] pure-attention probe failed for {model_id!r} (treating as NOT pure): {e}")
83
+ return False
84
+
85
+
86
+ def model_is_gdn_hybrid(model_id: str) -> bool:
87
+ """True for a GatedDeltaNet *hybrid* (Qwen3.5/3.6): the config interleaves ``"linear_attention"``
88
+ layers with full attention. These need the varlen GDN path (cu_seqlens + seq_idx) to pack
89
+ boundary-correctly — a 4D mask alone can't reset their recurrent/conv state. Distinct from the
90
+ sliding-window case (also non-pure, but NOT packable this way). Config-only; safe-False on error.
91
+ """
92
+ try:
93
+ from transformers import AutoConfig
94
+
95
+ cfg = _text_config(AutoConfig.from_pretrained(model_id, trust_remote_code=True))
96
+ layer_types = getattr(cfg, "layer_types", None)
97
+ if layer_types and any(t == "linear_attention" for t in layer_types):
98
+ return True
99
+ # No layer_types but linear-attn dims declared -> still a GDN hybrid.
100
+ return any(
101
+ getattr(cfg, a, None)
102
+ for a in ("linear_num_key_heads", "linear_key_head_dim", "linear_conv_kernel_dim")
103
+ )
104
+ except Exception as e:
105
+ print(f"[pack] gdn-hybrid probe failed for {model_id!r} (treating as NOT gdn): {e}")
106
+ return False
107
+
108
+
109
+ def _gdn_forward_threads_reset_kwargs(model_id: str | None) -> bool:
110
+ """Does THIS model's GatedDeltaNet forward actually thread cu_seq_lens_q AND seq_idx? Different GDN
111
+ families live in different modeling modules (qwen3_5 -> modeling_qwen3_5.Qwen3_5GatedDeltaNet, a
112
+ future qwen3_6 -> modeling_qwen3_6.Qwen3_6GatedDeltaNet), so resolve the ACTUAL arch from the
113
+ model's config and probe ITS DeltaNet class — a hardcoded qwen3_5 probe would wrongly pass for an
114
+ arch that drops the kwargs (or whose layer hard-codes seq_idx=None on an older transformers).
115
+ Falls back to qwen3_5 when no model_id is given. Safe-False on any failure."""
116
+ try:
117
+ import importlib
118
+ import inspect
119
+
120
+ model_type = "qwen3_5"
121
+ if model_id:
122
+ from transformers import AutoConfig
123
+
124
+ cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
125
+ model_type = getattr(cfg, "model_type", None) or model_type
126
+ mod = importlib.import_module(f"transformers.models.{model_type}.modeling_{model_type}")
127
+ gdn_cls = next(
128
+ (c for n, c in vars(mod).items()
129
+ if isinstance(c, type) and n.endswith("GatedDeltaNet")),
130
+ None,
131
+ )
132
+ if gdn_cls is None:
133
+ return False
134
+ fwd = inspect.getsource(gdn_cls.forward)
135
+ return ("cu_seq_lens_q" in fwd) and ("seq_idx" in fwd)
136
+ except Exception:
137
+ return False
138
+
139
+
140
+ def gdn_packing_available(model_id: str | None = None) -> bool:
141
+ """True only when BOTH varlen kernels a GatedDeltaNet hybrid needs to pack boundary-correctly are
142
+ importable: ``flash-linear-attention`` (resets the DeltaNet recurrence via ``cu_seq_lens_q/k`` — the
143
+ pure-torch fallback IGNORES it) AND ``causal_conv1d`` (resets the short causal conv via
144
+ ``seq_idx``). Without both, a packed GDN run would cross-contaminate across example boundaries,
145
+ so packing must stay off. GPU-validated (RTX 5090, Qwen3.5-0.8B): with both present, a packed
146
+ example's outputs are byte-identical regardless of its neighbors' content (zero information
147
+ leakage); the only difference vs unpacked is benign bf16 kernel-tiling numerics (~0.3 on logits,
148
+ the same order as flash-attn-vs-SDPA drift).
149
+
150
+ Two guards beyond the find_spec probes: (a) REALLY import causal_conv1d — its availability check
151
+ is find_spec-based, so a built-but-broken wheel (ABI/symbol mismatch) would pass it and then crash
152
+ at model load; (b) verify the INSTALLED Qwen3.5 DeltaNet forward actually threads cu_seq_lens_q AND
153
+ seq_idx — transformers 5.6-5.8 hard-coded seq_idx=None / dropped cu_seq_lens_q, so on those builds
154
+ the collator's reset kwargs are silently ignored and packed examples would still leak. Either
155
+ guard failing -> packing stays off (the model trains unpacked, safely)."""
156
+ try:
157
+ import importlib
158
+
159
+ from transformers.utils.import_utils import (
160
+ is_causal_conv1d_available,
161
+ is_flash_linear_attention_available,
162
+ )
163
+
164
+ if not (is_flash_linear_attention_available() and is_causal_conv1d_available()):
165
+ return False
166
+ importlib.import_module("causal_conv1d") # (a) fail a built-but-broken wheel here, not at load
167
+ if not _gdn_forward_threads_reset_kwargs(model_id): # (b) version/API gate, per ACTUAL arch
168
+ return False
169
+ # (c) RUN the conv kernel on the LIVE GPU: a causal_conv1d wheel compiled WITHOUT this device's
170
+ # arch imports fine but raises "CUDA error: no kernel image is available for execution on the
171
+ # device" at the FIRST forward — which would crash the run mid-train. A tiny conv here surfaces
172
+ # that now so packing stays off and the model trains unpacked instead. (fla's kernels are
173
+ # Triton-JIT — always compiled for the present arch — so they need no such smoke.)
174
+ import torch
175
+
176
+ if torch.cuda.is_available():
177
+ from causal_conv1d import causal_conv1d_fn
178
+
179
+ _x = torch.zeros(1, 4, 8, device="cuda", dtype=torch.bfloat16)
180
+ _w = torch.zeros(4, 3, device="cuda", dtype=torch.bfloat16)
181
+ causal_conv1d_fn(_x, _w)
182
+ torch.cuda.synchronize()
183
+ return True
184
+ except Exception:
185
+ return False
186
+
187
+
188
+ def pack_token_ids(sequences: list[list[int]], max_length: int) -> list[dict]:
189
+ """Greedily bin-pack tokenized examples into blocks of at most ``max_length`` tokens WITHOUT
190
+ splitting an example (first-fit-decreasing, like TRL's ``bfd``: tighter blocks = less padding).
191
+
192
+ An example longer than ``max_length`` is truncated to a single full-length block (matches the
193
+ unpacked trainer's right-truncation). Empty sequences are dropped. Returns rows shaped
194
+ ``{"input_ids": [...], "seq_lengths": [l1, l2, ...]}`` where ``sum(seq_lengths) == len(input_ids)``
195
+ — the collator turns ``seq_lengths`` into the block-diagonal mask + per-example position_ids.
196
+ """
197
+ if max_length <= 0:
198
+ raise ValueError(f"max_length must be positive, got {max_length}")
199
+ seqs = [s[:max_length] for s in sequences if s]
200
+ # First-fit-decreasing: place the longest examples first so the small ones fill the gaps.
201
+ order = sorted(range(len(seqs)), key=lambda i: len(seqs[i]), reverse=True)
202
+ bins: list[dict] = [] # each: {"input_ids": [...], "seq_lengths": [...], "free": int}
203
+ for i in order:
204
+ s = seqs[i]
205
+ need = len(s)
206
+ for b in bins:
207
+ if b["free"] >= need:
208
+ b["input_ids"].extend(s)
209
+ b["seq_lengths"].append(need)
210
+ b["free"] -= need
211
+ break
212
+ else: # no open bin fits -> start a new one
213
+ bins.append({"input_ids": list(s), "seq_lengths": [need], "free": max_length - need})
214
+ return [{"input_ids": b["input_ids"], "seq_lengths": b["seq_lengths"]} for b in bins]
215
+
216
+
217
+ def packing_efficiency(rows: list[dict], max_length: int) -> float:
218
+ """Fraction of block capacity filled with real tokens (1.0 = no padding). Diagnostic only."""
219
+ if not rows or max_length <= 0:
220
+ return 0.0
221
+ real = sum(sum(r["seq_lengths"]) for r in rows)
222
+ return real / (len(rows) * max_length)
223
+
224
+
225
+ def tokenize_for_packing(texts: list[str], tokenizer, max_length: int) -> list[list[int]]:
226
+ """Tokenize chat-templated ``text`` rows for packing, MATCHING TRL's non-packed SFT prep EXACTLY
227
+ so a packed run trains on the SAME token sequences as the unpacked/FA2 path (no quality drift):
228
+ * append the EOS token to any row that doesn't already end with it — TRL's add_eos step does
229
+ this for the language-modeling ``text`` case, and skipping it would stop teaching the model
230
+ the final stop token (it'd never learn to halt);
231
+ * tokenize with the tokenizer's DEFAULT add_special_tokens — TRL's ``_tokenize`` for a non-
232
+ conversational ``text`` field calls ``processing_class(text=input)`` with no override, so for
233
+ Llama-family tokenizers (e.g. the MiniCPM pure-attention tier) it prepends BOS. Forcing
234
+ add_special_tokens=False here would drop that BOS and diverge from the unpacked path. (Qwen
235
+ tokenizers have no BOS, so the Qwen3.x / GDN tier is unaffected either way.)
236
+ * truncate to ``max_length`` (same cap pack_token_ids would apply) so a pathological long row
237
+ never materializes a huge id list; batched (one call) for speed.
238
+ """
239
+ eos = tokenizer.eos_token or ""
240
+ rows = [t if (eos and t.endswith(eos)) else t + eos for t in texts]
241
+ enc = tokenizer(rows, truncation=True, max_length=max_length) # default add_special_tokens (TRL parity)
242
+ return enc["input_ids"]
243
+
244
+
245
+ # Process-local cache of the lower-triangular causal matrix: the collator runs on every batch, and
246
+ # torch.tril(torch.ones(T, T)) is a non-trivial CPU alloc at T=2048+. Keep the LARGEST one seen and
247
+ # slice it for smaller T (it's read-only). Dataloader workers are separate processes, so each holds
248
+ # its own copy — no cross-thread race.
249
+ _CAUSAL_TRIL: dict = {}
250
+
251
+
252
+ def _causal_lower_triangular(total: int, torch):
253
+ cached = _CAUSAL_TRIL.get("m")
254
+ if cached is None or cached.shape[0] < total:
255
+ cached = torch.tril(torch.ones(total, total, dtype=torch.bool))
256
+ _CAUSAL_TRIL["m"] = cached
257
+ return cached[:total, :total]
258
+
259
+
260
+ @dataclass
261
+ class BlockDiagonalCollator:
262
+ """Collate pre-packed rows (from :func:`pack_token_ids`) into a batch whose 4D **block-diagonal
263
+ causal** attention mask keeps packed examples from attending across their boundaries under
264
+ PLAIN SDPA — no flash-attn, no flex_attention.
265
+
266
+ Emits per batch:
267
+ * ``input_ids`` ``[B, T]`` (right-padded with ``pad_token_id``)
268
+ * ``attention_mask`` ``[B, 1, T, T]`` BOOL — ``True`` = query may attend key. Block-diagonal
269
+ (same example) AND causal (key <= query). A bool mask is dtype-agnostic, so it composes
270
+ with bf16/fp16 runs without an ``-inf`` dtype mismatch. Pad tokens form their own segment
271
+ so no query row is all-False (which would NaN the softmax); pad rows never contribute to
272
+ loss (their labels are -100) and real tokens never attend pad keys.
273
+ * ``position_ids`` ``[B, T]`` reset to 0 at each example start (RoPE per example)
274
+ * ``labels`` ``[B, T]`` = ``input_ids`` for real tokens, with each example's FIRST
275
+ token set to -100 (so the cross-boundary next-token pair is never scored — matches the
276
+ unpacked trainer, whose first token is also never a target after HF's internal shift) and
277
+ pad set to -100.
278
+
279
+ ``pad_to_multiple_of`` rounds T up (tensor-core friendliness); the extra positions are pad.
280
+
281
+ ``emit_varlen`` (GatedDeltaNet hybrids, e.g. Qwen3.5/3.6): additionally emit ``cu_seq_lens_q/k``
282
+ (resets the DeltaNet recurrence per example in the fla kernel) and ``seq_idx`` (resets the causal
283
+ conv in causal_conv1d) so the LINEAR-attention layers are boundary-correct too — the 4D mask only
284
+ fixes the full-attention layers. This path requires ``per_device_train_batch_size == 1`` (one
285
+ packed block per step; cu_seqlens spans that block) and does NOT pad (cu_seqlens must cover the
286
+ whole row), so set ``pad_to_multiple_of`` irrelevant here.
287
+ """
288
+
289
+ pad_token_id: int
290
+ label_pad_token_id: int = -100
291
+ pad_to_multiple_of: int = 8
292
+ emit_varlen: bool = False
293
+
294
+ def __call__(self, features: list[dict]) -> dict:
295
+ import torch
296
+
297
+ rows = [list(f["input_ids"]) for f in features]
298
+ seglens = [list(f["seq_lengths"]) for f in features]
299
+ bsz = len(rows)
300
+ if self.emit_varlen and bsz != 1:
301
+ raise ValueError("emit_varlen packing requires per_device_train_batch_size == 1")
302
+ # Fail fast on a broken row rather than silently mis-tag tokens as pad (or vice versa): the
303
+ # whole mask/labels/cu_seqlens construction assumes sum(seq_lengths) == len(input_ids).
304
+ for ids, lens in zip(rows, seglens, strict=True):
305
+ if sum(lens) != len(ids):
306
+ raise ValueError(
307
+ f"packed row invariant broken: sum(seq_lengths)={sum(lens)} != "
308
+ f"len(input_ids)={len(ids)} (rows must come from pack_token_ids)"
309
+ )
310
+ longest = max((len(r) for r in rows), default=0)
311
+ m = self.pad_to_multiple_of
312
+ # No padding on the varlen path: cu_seqlens must cover the whole sequence (a trailing pad
313
+ # region not spanned by cu_seqlens would break the fla varlen kernel).
314
+ total = longest if self.emit_varlen else (((longest + m - 1) // m) * m if m and m > 1 else longest)
315
+ total = max(total, 1)
316
+
317
+ input_ids = torch.full((bsz, total), self.pad_token_id, dtype=torch.long)
318
+ position_ids = torch.zeros((bsz, total), dtype=torch.long)
319
+ # segment id per token: 0..k-1 for the k examples in the block, -1 for trailing pad.
320
+ seg = torch.full((bsz, total), -1, dtype=torch.long)
321
+
322
+ for b, (ids, lens) in enumerate(zip(rows, seglens, strict=True)):
323
+ n = len(ids)
324
+ input_ids[b, :n] = torch.tensor(ids, dtype=torch.long)
325
+ start = 0
326
+ for ex_idx, length in enumerate(lens):
327
+ end = start + length
328
+ position_ids[b, start:end] = torch.arange(length)
329
+ seg[b, start:end] = ex_idx
330
+ start = end
331
+
332
+ # Block-diagonal causal mask, fully vectorized:
333
+ # same-example: seg[q] == seg[k] (pad shares segment -1, so pad rows attend pad -> no
334
+ # all-False row; real tokens never attend pad because real seg != -1)
335
+ # causal: k <= q
336
+ same = seg.unsqueeze(2) == seg.unsqueeze(1) # [B, T, T]
337
+ causal = _causal_lower_triangular(total, torch) # cached + sliced (not rebuilt per batch)
338
+ attention_mask = (same & causal).unsqueeze(1) # [B, 1, T, T]
339
+
340
+ # Labels: real tokens predict their own continuation; first token of each example (and all
341
+ # pad) -> ignore. position_ids == 0 marks exactly each example's first token (pad is 0 too,
342
+ # and pad is already excluded below), so the boundary next-token pair is never scored.
343
+ labels = input_ids.clone()
344
+ labels[seg < 0] = self.label_pad_token_id
345
+ labels[position_ids == 0] = self.label_pad_token_id
346
+
347
+ batch = {
348
+ "input_ids": input_ids,
349
+ "attention_mask": attention_mask,
350
+ "position_ids": position_ids,
351
+ "labels": labels,
352
+ }
353
+ if self.emit_varlen:
354
+ # bsz == 1 (asserted above): cu_seqlens covers this one block's examples, and seq_idx is
355
+ # the per-token segment id (no pad on this path, so seg has no -1). These reach the
356
+ # linear-attention layers via model(**batch) -> the fla chunk kernel (cu_seq_lens_q) and
357
+ # causal_conv1d (seq_idx), resetting their state at each example boundary.
358
+ lens = seglens[0]
359
+ cu = torch.zeros(len(lens) + 1, dtype=torch.int32)
360
+ cu[1:] = torch.tensor(lens, dtype=torch.int32).cumsum(0)
361
+ batch["cu_seq_lens_q"] = cu
362
+ batch["cu_seq_lens_k"] = cu
363
+ batch["max_length_q"] = int(max(lens))
364
+ batch["max_length_k"] = int(max(lens))
365
+ batch["seq_idx"] = seg.to(torch.int32) # [1, T], non-negative (no pad on this path)
366
+ return batch