pfnstudio-core 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,68 @@
1
+ """PFN Studio core abstractions."""
2
+
3
+ from . import axes as _axes # registers built-in axes on import
4
+ from . import blocks as _blocks # registers built-in blocks on import
5
+ from .axes import (
6
+ UNKNOWN,
7
+ Axis,
8
+ AxisDetector,
9
+ detect,
10
+ encode_tag,
11
+ get_axis,
12
+ is_unknown,
13
+ list_axes,
14
+ register_axis,
15
+ sample_tag,
16
+ tag_dim,
17
+ train_detector,
18
+ )
19
+ from .datasets import DatasetUnavailable, RegistryDatasetLoader
20
+ from .eval import Eval, EvalSpec
21
+ from .model import Model, ModelSpec
22
+ from .prior import Prior, PriorSpec
23
+ from .registry import (
24
+ get_block,
25
+ get_eval,
26
+ get_prior,
27
+ list_blocks,
28
+ list_evals,
29
+ list_priors,
30
+ register_block,
31
+ register_eval,
32
+ register_prior,
33
+ )
34
+ from .run import Run, RunSpec
35
+
36
+ __version__ = "0.4.0"
37
+
38
+ __all__ = [
39
+ "UNKNOWN",
40
+ "Axis",
41
+ "AxisDetector",
42
+ "Eval",
43
+ "EvalSpec",
44
+ "Model",
45
+ "ModelSpec",
46
+ "Prior",
47
+ "PriorSpec",
48
+ "Run",
49
+ "RunSpec",
50
+ "detect",
51
+ "encode_tag",
52
+ "get_axis",
53
+ "get_block",
54
+ "get_eval",
55
+ "get_prior",
56
+ "is_unknown",
57
+ "list_axes",
58
+ "list_blocks",
59
+ "list_evals",
60
+ "list_priors",
61
+ "register_axis",
62
+ "register_block",
63
+ "register_eval",
64
+ "register_prior",
65
+ "sample_tag",
66
+ "tag_dim",
67
+ "train_detector",
68
+ ]
@@ -0,0 +1,34 @@
1
+ """Axes package — public surface for promptable-prior axes.
2
+
3
+ Importing this module also imports `builtins`, which eagerly
4
+ registers the built-in axes (monotonicity, …) so they're available
5
+ without an extra import in user code."""
6
+
7
+ from . import builtins as _builtins # noqa: F401 (side-effect: registers built-ins)
8
+ from .base import (
9
+ UNKNOWN,
10
+ Axis,
11
+ AxisKind,
12
+ get_axis,
13
+ is_unknown,
14
+ list_axes,
15
+ register_axis,
16
+ )
17
+ from .detector import AxisDetector, detect, train_detector
18
+ from .encoding import encode_tag, sample_tag, tag_dim
19
+
20
+ __all__ = [
21
+ "UNKNOWN",
22
+ "Axis",
23
+ "AxisDetector",
24
+ "AxisKind",
25
+ "detect",
26
+ "encode_tag",
27
+ "get_axis",
28
+ "is_unknown",
29
+ "list_axes",
30
+ "register_axis",
31
+ "sample_tag",
32
+ "tag_dim",
33
+ "train_detector",
34
+ ]
@@ -0,0 +1,111 @@
1
+ """Axis abstraction for promptable priors.
2
+
3
+ An Axis declares a *steerable property* of a prior — a knob the trained
4
+ brain learns to honor at inference time. Examples: monotonicity (edge
5
+ signs in an SCM), lag_scale (delay distribution), feedback_allowed
6
+ (DAG vs cyclic), sparsity (mean parents per node).
7
+
8
+ The axis declares the contract (name, kind, value space, default).
9
+ Each Prior that opts into an axis implements its own application
10
+ logic — the same axis can mean different things to different priors
11
+ (e.g. monotonicity on an SCM constrains edge signs; on a regression
12
+ prior it constrains weight signs). The base class doesn't try to
13
+ unify those interpretations.
14
+
15
+ Back-compat invariants this module enforces:
16
+ - The reserved sentinel ``UNKNOWN`` means "skip this axis at sample
17
+ time". A prior given ``tag={axis: UNKNOWN}`` for every axis is
18
+ bit-identical to its pre-axis behavior. This is what guarantees
19
+ adding axes can't regress existing benchmarks.
20
+ - A prior with no declared axes ignores any ``tag=...`` argument
21
+ passed to ``sample()`` — old call sites keep working unchanged.
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ from dataclasses import dataclass, field
27
+ from typing import Any, Literal
28
+
29
+ AxisKind = Literal["categorical", "range", "boolean"]
30
+
31
+ # Reserved sentinel — when a tag has this for an axis, the prior must
32
+ # skip the axis hook and behave as if the axis weren't declared.
33
+ # Stored as a string so it round-trips through JSON / YAML cleanly.
34
+ UNKNOWN = "__unknown__"
35
+
36
+
37
+ @dataclass(frozen=True)
38
+ class Axis:
39
+ name: str
40
+ kind: AxisKind
41
+ # For categorical: list of label strings. For range: [min, max].
42
+ # For boolean: ignored.
43
+ values: tuple[Any, ...] = field(default_factory=tuple)
44
+ description: str = ""
45
+ # Fraction of training samples that get UNKNOWN for this axis.
46
+ # Default 0.3 — preserves substantial unconditional mass so the
47
+ # brain doesn't lose its pre-axis baseline.
48
+ unknown_mass: float = 0.3
49
+
50
+ def __post_init__(self) -> None:
51
+ if not 0.0 <= self.unknown_mass <= 1.0:
52
+ raise ValueError(
53
+ f"axis {self.name!r}: unknown_mass must be in [0, 1], got {self.unknown_mass}"
54
+ )
55
+ if self.kind == "categorical" and not self.values:
56
+ raise ValueError(f"axis {self.name!r}: categorical axes need at least one value")
57
+ if self.kind == "range" and len(self.values) != 2:
58
+ raise ValueError(
59
+ f"axis {self.name!r}: range axes need exactly 2 values (min, max), got {len(self.values)}"
60
+ )
61
+
62
+ def sample_value(self, rng: Any) -> Any:
63
+ """Sample one value for this axis, honoring the unknown_mass invariant."""
64
+ if rng.random() < self.unknown_mass:
65
+ return UNKNOWN
66
+ if self.kind == "categorical":
67
+ return rng.choice(list(self.values))
68
+ if self.kind == "boolean":
69
+ return bool(rng.integers(0, 2))
70
+ if self.kind == "range":
71
+ lo, hi = float(self.values[0]), float(self.values[1])
72
+ return float(rng.uniform(lo, hi))
73
+ raise ValueError(f"unknown axis kind: {self.kind!r}")
74
+
75
+
76
+ # Module-level registry. Resolves axis names → Axis instances at
77
+ # loader time. Built-in axes register themselves on import; custom
78
+ # axes register via @register_axis.
79
+ _AXES: dict[str, Axis] = {}
80
+
81
+
82
+ def register_axis(axis: Axis) -> Axis:
83
+ """Register an Axis. Idempotent: re-registering the *same* axis is
84
+ a no-op (useful when modules get imported twice); re-registering
85
+ a *different* axis under the same name raises."""
86
+ existing = _AXES.get(axis.name)
87
+ if existing is not None and existing != axis:
88
+ raise ValueError(
89
+ f"axis {axis.name!r} already registered with a different definition"
90
+ )
91
+ _AXES[axis.name] = axis
92
+ return axis
93
+
94
+
95
+ def get_axis(name: str) -> Axis:
96
+ """Look up a registered axis by name. Raises KeyError if not registered."""
97
+ if name not in _AXES:
98
+ raise KeyError(
99
+ f"axis {name!r} not registered. Known axes: {sorted(_AXES)}"
100
+ )
101
+ return _AXES[name]
102
+
103
+
104
+ def list_axes() -> list[str]:
105
+ return sorted(_AXES)
106
+
107
+
108
+ def is_unknown(value: Any) -> bool:
109
+ """Whether a tag value is the UNKNOWN sentinel. Use this in axis
110
+ application code rather than comparing to the string literal."""
111
+ return value == UNKNOWN
@@ -0,0 +1,5 @@
1
+ """Built-in axes. Importing this module registers them with the global
2
+ axis registry. The package ``__init__`` imports this so axes are
3
+ available the moment ``pfnstudio_core`` is imported."""
4
+
5
+ from . import monotonicity # noqa: F401 (registers on import)
@@ -0,0 +1,25 @@
1
+ """Built-in `monotonicity` axis.
2
+
3
+ Declares the contract; per-prior application lives in each Prior
4
+ subclass (an SCM's monotonicity constrains edge signs, a regression
5
+ prior's monotonicity constrains weight signs, etc.).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from ..base import Axis, register_axis
11
+
12
+ monotonicity = register_axis(
13
+ Axis(
14
+ name="monotonicity",
15
+ kind="categorical",
16
+ values=("positive", "negative", "mixed"),
17
+ description=(
18
+ "Sign discipline of the prior's causal relationships. "
19
+ "positive = X up implies Y up everywhere; negative = "
20
+ "X up implies Y down; mixed = signs may vary. Unknown "
21
+ "(default) leaves the prior unconstrained."
22
+ ),
23
+ unknown_mass=0.3,
24
+ )
25
+ )
@@ -0,0 +1,260 @@
1
+ """Auto-detector — supervised model that reads context and proposes
2
+ axis values.
3
+
4
+ The pitch in one line: instead of asking the user to fill out an empty
5
+ form of axis chips, the detector reads their data and pre-fills the
6
+ chips with its best guess (and a confidence). The user just confirms,
7
+ overrides, or adds the axes the detector can't see.
8
+
9
+ Training:
10
+ - Each batch, the prior samples a real (non-UNKNOWN) tag value and
11
+ generates data from it. The tag value is the supervised label.
12
+ - The detector reads the data (no tag) and predicts a softmax over
13
+ axis values.
14
+ - Cross-entropy loss.
15
+
16
+ Inference:
17
+ - Same encoder forward, output softmax probabilities.
18
+ - Highest-probability value becomes the proposed chip; that probability
19
+ becomes the displayed confidence.
20
+
21
+ Limitations in v1:
22
+ - Only categorical axes are detected. Range axes need a regression
23
+ head and a different loss; boolean axes need a 2-class head.
24
+ Both are straightforward additions but defer until they're actually
25
+ shipped on a prior.
26
+ - The detector is small (32-d, 2-layer) — same shape as the test
27
+ brain. Scale up when a real prior with many axes ships.
28
+ """
29
+
30
+ from __future__ import annotations
31
+
32
+ from typing import Any
33
+
34
+ import numpy as np
35
+
36
+ from .base import UNKNOWN, Axis
37
+
38
+
39
+ class AxisDetector:
40
+ """A small encoder + per-axis classification head. Holds one
41
+ head per categorical axis; non-categorical axes are silently
42
+ skipped (an empty detector is legal — predicts nothing)."""
43
+
44
+ def __init__(
45
+ self,
46
+ axes: list[Axis],
47
+ d_model: int = 32,
48
+ n_heads: int = 4,
49
+ n_layers: int = 2,
50
+ dropout: float = 0.0,
51
+ ):
52
+ try:
53
+ import torch.nn as nn
54
+ except ImportError as e:
55
+ raise ImportError(
56
+ "AxisDetector requires torch. "
57
+ "Install with: pip install pfnstudio-core[torch]"
58
+ ) from e
59
+
60
+ # Only categorical axes get heads in v1. Skipped axes are
61
+ # tracked so callers can warn the user "we can't propose for
62
+ # axis X yet."
63
+ self._scored_axes: list[Axis] = [
64
+ a for a in axes if a.kind == "categorical" and len(a.values) >= 2
65
+ ]
66
+ self._skipped_axes: list[Axis] = [
67
+ a for a in axes if a not in self._scored_axes
68
+ ]
69
+ self.d_model = d_model
70
+
71
+ # Lazy embedder — input feature dim is inferred on first
72
+ # forward. Same pattern the TabularEmbedder uses.
73
+ self.embedder = nn.LazyLinear(d_model)
74
+ encoder_layer = nn.TransformerEncoderLayer(
75
+ d_model=d_model,
76
+ nhead=n_heads,
77
+ dim_feedforward=d_model * 4,
78
+ dropout=dropout,
79
+ batch_first=True,
80
+ norm_first=True,
81
+ )
82
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
83
+ # One head per scored axis. Output dim = number of values
84
+ # (no UNKNOWN class — the detector proposes a real value;
85
+ # low confidence across all classes is how it expresses
86
+ # "I'm not sure", not a separate class).
87
+ self.heads = nn.ModuleDict(
88
+ {a.name: nn.Linear(d_model, len(a.values)) for a in self._scored_axes}
89
+ )
90
+
91
+ @property
92
+ def scored_axes(self) -> list[Axis]:
93
+ """Axes the detector has heads for (categorical, ≥2 values)."""
94
+ return self._scored_axes
95
+
96
+ @property
97
+ def skipped_axes(self) -> list[Axis]:
98
+ """Axes the detector can't predict (boolean, range, single-value)."""
99
+ return self._skipped_axes
100
+
101
+ def parameters(self):
102
+ # nn.ModuleDict + nn.Module attributes — yield everything.
103
+ yield from self.embedder.parameters()
104
+ yield from self.encoder.parameters()
105
+ yield from self.heads.parameters()
106
+
107
+ def state_dict(self) -> dict[str, Any]:
108
+ return {
109
+ "embedder": self.embedder.state_dict(),
110
+ "encoder": self.encoder.state_dict(),
111
+ "heads": self.heads.state_dict(),
112
+ }
113
+
114
+ def __call__(self, x: Any) -> dict[str, Any]:
115
+ """Forward pass. Returns axis_name → logits tensor of shape
116
+ (B, num_values_for_axis)."""
117
+ emb = self.embedder(x) # (B, N, d_model)
118
+ encoded = self.encoder(emb) # (B, N, d_model)
119
+ # Mean-pool the context tokens to one vector per batch element.
120
+ # (Could swap for an attention pool later; mean is fine for
121
+ # the monotonicity demo and avoids extra parameters.)
122
+ pooled = encoded.mean(dim=1) # (B, d_model)
123
+ return {name: head(pooled) for name, head in self.heads.items()}
124
+
125
+
126
+ def _value_to_index(axis: Axis, value: Any) -> int:
127
+ """Index of `value` in axis.values. Raises if value is UNKNOWN
128
+ or not a known value — UNKNOWN tags shouldn't be used as training
129
+ labels because the detector is supposed to learn to assign a
130
+ concrete value."""
131
+ if value == UNKNOWN:
132
+ raise ValueError(
133
+ f"axis {axis.name!r}: cannot train detector on UNKNOWN label"
134
+ )
135
+ return list(axis.values).index(value)
136
+
137
+
138
+ def train_detector(
139
+ *,
140
+ detector: AxisDetector,
141
+ prior: Any,
142
+ steps: int = 500,
143
+ batch_size: int = 16,
144
+ lr: float = 1e-3,
145
+ seed: int = 0,
146
+ on_step: Any = None,
147
+ ) -> dict[str, Any]:
148
+ """Train the detector on labeled batches sampled from the prior.
149
+
150
+ Each batch picks a non-UNKNOWN value for each scored axis, samples
151
+ data with that tag, and trains the detector to predict the tag
152
+ from the data. Cross-entropy loss, summed across axes.
153
+
154
+ Returns ``{status, final_loss, mean_acc}`` where ``mean_acc`` is
155
+ the accuracy of the detector on the *last* batch (cheap sanity
156
+ check that training did something)."""
157
+ try:
158
+ import torch
159
+ import torch.nn.functional as F
160
+ except ImportError:
161
+ return {"status": "skipped", "reason": "torch not installed"}
162
+
163
+ if not detector.scored_axes:
164
+ return {"status": "skipped", "reason": "detector has no scored axes"}
165
+
166
+ rng = np.random.default_rng(seed)
167
+ optim = torch.optim.AdamW(list(detector.parameters()), lr=lr)
168
+
169
+ last_acc: dict[str, float] = {}
170
+ final_loss = 0.0
171
+ for step in range(steps):
172
+ # Sample one tag value per axis (random, non-UNKNOWN — that's
173
+ # the supervised label the detector learns to recover).
174
+ tag = {
175
+ a.name: rng.choice(list(a.values)) for a in detector.scored_axes
176
+ }
177
+ batch = prior.sample_batch(
178
+ batch_size=batch_size, seed=seed + step * batch_size, tag=tag
179
+ )
180
+ X = torch.stack([torch.from_numpy(b["X"]).float() for b in batch])
181
+
182
+ logits = detector(X) # dict axis → (B, num_values)
183
+ loss = torch.zeros(1)
184
+ for axis in detector.scored_axes:
185
+ label_idx = _value_to_index(axis, tag[axis.name])
186
+ labels = torch.full((X.shape[0],), label_idx, dtype=torch.long)
187
+ loss = loss + F.cross_entropy(logits[axis.name], labels)
188
+
189
+ optim.zero_grad()
190
+ loss.backward()
191
+ optim.step()
192
+
193
+ final_loss = float(loss.item())
194
+ if on_step is not None:
195
+ on_step(step, final_loss)
196
+
197
+ # Final-batch accuracy (per axis) — cheap and informative.
198
+ if step == steps - 1:
199
+ with torch.no_grad():
200
+ for axis in detector.scored_axes:
201
+ preds = logits[axis.name].argmax(dim=-1)
202
+ label_idx = _value_to_index(axis, tag[axis.name])
203
+ last_acc[axis.name] = float((preds == label_idx).float().mean().item())
204
+
205
+ return {
206
+ "status": "ok",
207
+ "steps": steps,
208
+ "final_loss": final_loss,
209
+ "last_batch_accuracy": last_acc,
210
+ }
211
+
212
+
213
+ def detect(
214
+ *,
215
+ detector: AxisDetector,
216
+ context: Any,
217
+ ) -> dict[str, dict[str, Any]]:
218
+ """Run the detector on a context tensor and return a tag-shaped
219
+ proposal with per-axis confidence.
220
+
221
+ Returns ``{axis_name: {value, confidence, probs}}`` where:
222
+ - ``value`` is the predicted axis value (argmax)
223
+ - ``confidence`` is the softmax probability of that value
224
+ - ``probs`` is the full {value: probability} dict, for callers
225
+ that want to render uncertainty (e.g. show "60% positive,
226
+ 35% mixed, 5% negative" instead of just "positive 60%").
227
+ """
228
+ try:
229
+ import torch
230
+ import torch.nn.functional as F
231
+ except ImportError:
232
+ return {}
233
+
234
+ if not detector.scored_axes:
235
+ return {}
236
+
237
+ if not torch.is_tensor(context):
238
+ context = torch.from_numpy(np.asarray(context)).float()
239
+ if context.dim() == 2:
240
+ context = context.unsqueeze(0) # add batch dim
241
+
242
+ with torch.no_grad():
243
+ logits = detector(context)
244
+
245
+ out: dict[str, dict[str, Any]] = {}
246
+ for axis in detector.scored_axes:
247
+ probs = F.softmax(logits[axis.name], dim=-1)
248
+ # Mean across batch in case the caller passed multiple contexts —
249
+ # the detector returns one proposal per axis, not per-sample.
250
+ mean_probs = probs.mean(dim=0)
251
+ best_idx = int(mean_probs.argmax().item())
252
+ out[axis.name] = {
253
+ "value": axis.values[best_idx],
254
+ "confidence": float(mean_probs[best_idx].item()),
255
+ "probs": {
256
+ axis.values[i]: float(mean_probs[i].item())
257
+ for i in range(len(axis.values))
258
+ },
259
+ }
260
+ return out
@@ -0,0 +1,119 @@
1
+ """Tag encoding — fixed-length vector representation of an axis tag.
2
+
3
+ The encoding rules:
4
+
5
+ - **Categorical axes** → one-hot over the declared values, plus one
6
+ extra slot for the UNKNOWN sentinel (always last). So a categorical
7
+ axis with K values contributes K+1 dimensions.
8
+ - **Boolean axes** → 2 dims (false, true) plus 1 for UNKNOWN. Same
9
+ shape as a 2-value categorical.
10
+ - **Range axes** → 2 dims: ``[is_unknown_flag, normalised_value]``.
11
+ When tagged with a real value v in [min, max], normalised_value is
12
+ ``(v - min) / (max - min)`` clipped to [0, 1] and is_unknown_flag
13
+ is 0. When UNKNOWN, normalised_value is 0 and is_unknown_flag is 1.
14
+
15
+ The encoding is deterministic and stateless — the same axes list and
16
+ the same tag always produce the same vector. ``tag_dim(axes)`` returns
17
+ the total length so model constructors can size their tag embedder
18
+ once at init time without poking at sample tags.
19
+
20
+ Back-compat invariant: ``encode_tag({}, axes)`` and ``encode_tag(None,
21
+ axes)`` are byte-identical to ``encode_tag({a: UNKNOWN for a in axes},
22
+ axes)``. Used by callers who need a "no constraints" tag vector
23
+ without enumerating all axes.
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ from typing import Any
29
+
30
+ import numpy as np
31
+
32
+ from .base import UNKNOWN, Axis, is_unknown
33
+
34
+
35
+ def _axis_dim(axis: Axis) -> int:
36
+ if axis.kind == "categorical":
37
+ return len(axis.values) + 1 # values + UNKNOWN slot
38
+ if axis.kind == "boolean":
39
+ return 3 # false, true, UNKNOWN
40
+ if axis.kind == "range":
41
+ return 2 # is_unknown_flag, normalised_value
42
+ raise ValueError(f"unknown axis kind: {axis.kind!r}")
43
+
44
+
45
+ def tag_dim(axes: list[Axis]) -> int:
46
+ """Total length of the tag vector produced by ``encode_tag`` for
47
+ the given axes (in the same order)."""
48
+ return sum(_axis_dim(a) for a in axes)
49
+
50
+
51
+ def _encode_categorical(axis: Axis, value: Any) -> np.ndarray:
52
+ """One-hot over values + final UNKNOWN slot."""
53
+ dim = _axis_dim(axis)
54
+ out = np.zeros(dim, dtype=np.float32)
55
+ if value is None or is_unknown(value):
56
+ out[-1] = 1.0
57
+ return out
58
+ if value not in axis.values:
59
+ raise ValueError(
60
+ f"axis {axis.name!r}: value {value!r} not in {axis.values}"
61
+ )
62
+ idx = axis.values.index(value)
63
+ out[idx] = 1.0
64
+ return out
65
+
66
+
67
+ def _encode_boolean(axis: Axis, value: Any) -> np.ndarray:
68
+ out = np.zeros(3, dtype=np.float32)
69
+ if value is None or is_unknown(value):
70
+ out[2] = 1.0
71
+ return out
72
+ out[int(bool(value))] = 1.0
73
+ return out
74
+
75
+
76
+ def _encode_range(axis: Axis, value: Any) -> np.ndarray:
77
+ out = np.zeros(2, dtype=np.float32)
78
+ if value is None or is_unknown(value):
79
+ out[0] = 1.0 # is_unknown_flag
80
+ return out
81
+ lo, hi = float(axis.values[0]), float(axis.values[1])
82
+ span = hi - lo
83
+ if span == 0:
84
+ out[1] = 0.0
85
+ else:
86
+ normalised = (float(value) - lo) / span
87
+ out[1] = float(np.clip(normalised, 0.0, 1.0))
88
+ return out
89
+
90
+
91
+ def encode_tag(tag: dict[str, Any] | None, axes: list[Axis]) -> np.ndarray:
92
+ """Encode a tag dict into a fixed-length float32 vector for the given axes.
93
+
94
+ Missing axis entries are treated as UNKNOWN — same as explicitly
95
+ passing UNKNOWN. The encoded vector's length matches ``tag_dim(axes)``.
96
+ """
97
+ tag = tag or {}
98
+ chunks: list[np.ndarray] = []
99
+ for axis in axes:
100
+ value = tag.get(axis.name, UNKNOWN)
101
+ if axis.kind == "categorical":
102
+ chunks.append(_encode_categorical(axis, value))
103
+ elif axis.kind == "boolean":
104
+ chunks.append(_encode_boolean(axis, value))
105
+ elif axis.kind == "range":
106
+ chunks.append(_encode_range(axis, value))
107
+ else:
108
+ raise ValueError(f"unknown axis kind: {axis.kind!r}")
109
+ if not chunks:
110
+ return np.zeros(0, dtype=np.float32)
111
+ return np.concatenate(chunks, axis=0)
112
+
113
+
114
+ def sample_tag(axes: list[Axis], rng: np.random.Generator) -> dict[str, Any]:
115
+ """Sample one tag dict for all axes, honoring each axis's
116
+ ``unknown_mass``. Returns axis_name → sampled value (or UNKNOWN).
117
+
118
+ Used by the training loop to draw a fresh tag per batch."""
119
+ return {axis.name: axis.sample_value(rng) for axis in axes}