pfnstudio-core 0.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. pfnstudio_core-0.7.0/.gitignore +68 -0
  2. pfnstudio_core-0.7.0/PKG-INFO +56 -0
  3. pfnstudio_core-0.7.0/README.md +28 -0
  4. pfnstudio_core-0.7.0/pfnstudio_core/__init__.py +68 -0
  5. pfnstudio_core-0.7.0/pfnstudio_core/axes/__init__.py +34 -0
  6. pfnstudio_core-0.7.0/pfnstudio_core/axes/base.py +111 -0
  7. pfnstudio_core-0.7.0/pfnstudio_core/axes/builtins/__init__.py +5 -0
  8. pfnstudio_core-0.7.0/pfnstudio_core/axes/builtins/monotonicity.py +25 -0
  9. pfnstudio_core-0.7.0/pfnstudio_core/axes/detector.py +260 -0
  10. pfnstudio_core-0.7.0/pfnstudio_core/axes/encoding.py +119 -0
  11. pfnstudio_core-0.7.0/pfnstudio_core/axes/honoring.py +147 -0
  12. pfnstudio_core-0.7.0/pfnstudio_core/blocks/__init__.py +3 -0
  13. pfnstudio_core-0.7.0/pfnstudio_core/blocks/heads.py +56 -0
  14. pfnstudio_core-0.7.0/pfnstudio_core/blocks/tabular.py +23 -0
  15. pfnstudio_core-0.7.0/pfnstudio_core/blocks/transformer.py +138 -0
  16. pfnstudio_core-0.7.0/pfnstudio_core/datasets.py +226 -0
  17. pfnstudio_core-0.7.0/pfnstudio_core/eval.py +49 -0
  18. pfnstudio_core-0.7.0/pfnstudio_core/loaders.py +44 -0
  19. pfnstudio_core-0.7.0/pfnstudio_core/model.py +57 -0
  20. pfnstudio_core-0.7.0/pfnstudio_core/prior.py +110 -0
  21. pfnstudio_core-0.7.0/pfnstudio_core/registry.py +98 -0
  22. pfnstudio_core-0.7.0/pfnstudio_core/run.py +63 -0
  23. pfnstudio_core-0.7.0/pfnstudio_core/scorers/__init__.py +48 -0
  24. pfnstudio_core-0.7.0/pfnstudio_core/scorers/base.py +51 -0
  25. pfnstudio_core-0.7.0/pfnstudio_core/scorers/breast_cancer_vs_logreg.py +251 -0
  26. pfnstudio_core-0.7.0/pfnstudio_core/scorers/closed_form_baseline.py +221 -0
  27. pfnstudio_core-0.7.0/pfnstudio_core/scorers/in_context_regression_ols.py +170 -0
  28. pfnstudio_core-0.7.0/pfnstudio_core/scorers/kinpfn_real_fpt_ks.py +220 -0
  29. pfnstudio_core-0.7.0/pfnstudio_core/scorers/m4_monthly_forecast.py +203 -0
  30. pfnstudio_core-0.7.0/pfnstudio_core/scorers/synthetic_classification_bce.py +243 -0
  31. pfnstudio_core-0.7.0/pfnstudio_core/scorers/synthetic_regression_mse.py +194 -0
  32. pfnstudio_core-0.7.0/pfnstudio_core/training/__init__.py +5 -0
  33. pfnstudio_core-0.7.0/pfnstudio_core/training/loop.py +734 -0
  34. pfnstudio_core-0.7.0/pfnstudio_core/training/predict.py +604 -0
  35. pfnstudio_core-0.7.0/pyproject.toml +44 -0
@@ -0,0 +1,68 @@
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ share/python-wheels/
20
+ *.egg-info/
21
+ .installed.cfg
22
+ *.egg
23
+ MANIFEST
24
+
25
+ # Virtual environments
26
+ .venv/
27
+ .deps-venv/
28
+ venv/
29
+ ENV/
30
+ env/
31
+
32
+ # Editable-install metadata
33
+ _editable_impl_*.pth
34
+
35
+ # IDE / editor
36
+ .vscode/
37
+ .idea/
38
+ *.swp
39
+ *.swo
40
+ .DS_Store
41
+
42
+ # pytest / coverage / mypy
43
+ .pytest_cache/
44
+ .mypy_cache/
45
+ .ruff_cache/
46
+ .coverage
47
+ .coverage.*
48
+ htmlcov/
49
+ .tox/
50
+ .nox/
51
+
52
+ # Jupyter
53
+ .ipynb_checkpoints/
54
+
55
+ # Build artifacts
56
+ checkpoint/
57
+ runs/*/checkpoint/
58
+ runs/*/results.json
59
+ runs/*/metrics.jsonl
60
+ *.pt
61
+ *.pth
62
+
63
+ # Dataset cache (downloaded benchmark datasets)
64
+ .priorstudio/datasets/
65
+
66
+ # Logs
67
+ logs/
68
+ *.log
@@ -0,0 +1,56 @@
1
+ Metadata-Version: 2.4
2
+ Name: pfnstudio-core
3
+ Version: 0.7.0
4
+ Summary: Core abstractions for PFN Studio — Prior, Model, Eval, Run, and the block registry.
5
+ Project-URL: Homepage, https://pfnstudio.com
6
+ Project-URL: Repository, https://github.com/profitopsai/pfnstudio
7
+ Project-URL: Issues, https://github.com/profitopsai/pfnstudio/issues
8
+ Project-URL: Documentation, https://github.com/profitopsai/pfnstudio/tree/main/docs
9
+ Author: PFN Studio contributors
10
+ License: Apache-2.0
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: License :: OSI Approved :: Apache Software License
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
15
+ Requires-Python: >=3.10
16
+ Requires-Dist: numpy>=1.26
17
+ Requires-Dist: pandas>=2.0
18
+ Requires-Dist: pydantic>=2.6
19
+ Requires-Dist: pyyaml>=6.0
20
+ Requires-Dist: scikit-learn>=1.3
21
+ Provides-Extra: dev
22
+ Requires-Dist: pytest-cov>=5.0; extra == 'dev'
23
+ Requires-Dist: pytest>=8.0; extra == 'dev'
24
+ Requires-Dist: ruff>=0.4; extra == 'dev'
25
+ Provides-Extra: torch
26
+ Requires-Dist: torch>=2.2; extra == 'torch'
27
+ Description-Content-Type: text/markdown
28
+
29
+ # pfnstudio-core
30
+
31
+ The Python contract for PFN Studio FM projects.
32
+
33
+ ```python
34
+ from pfnstudio_core import Prior, Model, Eval, Run, register_block, register_prior
35
+
36
+ @register_prior("my_prior")
37
+ class MyPrior(Prior):
38
+ def sample(self, seed: int): ...
39
+
40
+ @register_block("my_attention")
41
+ class MyAttention:
42
+ def __init__(self, d_model: int, n_heads: int): ...
43
+ ```
44
+
45
+ The CLI discovers anything registered via these decorators and validates `models/*.yaml` references against the registry.
46
+
47
+ ## Layout
48
+
49
+ - `prior.py` — `Prior` ABC and built-in prior loader
50
+ - `model.py` — `Model` config + block-composition
51
+ - `eval.py` — `Eval` config + result schema
52
+ - `run.py` — `Run` manifest + executor protocol
53
+ - `registry.py` — `@register_prior`, `@register_block`, `@register_eval` and discovery
54
+ - `loaders.py` — load YAML artifacts into typed objects
55
+ - `blocks/` — built-in architecture blocks (transformer encoder, causal attention, heads)
56
+ - `training/` — minimal in-process training loop for the `local` compute adapter
@@ -0,0 +1,28 @@
1
+ # pfnstudio-core
2
+
3
+ The Python contract for PFN Studio FM projects.
4
+
5
+ ```python
6
+ from pfnstudio_core import Prior, Model, Eval, Run, register_block, register_prior
7
+
8
+ @register_prior("my_prior")
9
+ class MyPrior(Prior):
10
+ def sample(self, seed: int): ...
11
+
12
+ @register_block("my_attention")
13
+ class MyAttention:
14
+ def __init__(self, d_model: int, n_heads: int): ...
15
+ ```
16
+
17
+ The CLI discovers anything registered via these decorators and validates `models/*.yaml` references against the registry.
18
+
19
+ ## Layout
20
+
21
+ - `prior.py` — `Prior` ABC and built-in prior loader
22
+ - `model.py` — `Model` config + block-composition
23
+ - `eval.py` — `Eval` config + result schema
24
+ - `run.py` — `Run` manifest + executor protocol
25
+ - `registry.py` — `@register_prior`, `@register_block`, `@register_eval` and discovery
26
+ - `loaders.py` — load YAML artifacts into typed objects
27
+ - `blocks/` — built-in architecture blocks (transformer encoder, causal attention, heads)
28
+ - `training/` — minimal in-process training loop for the `local` compute adapter
@@ -0,0 +1,68 @@
1
+ """PFN Studio core abstractions."""
2
+
3
+ from . import axes as _axes # registers built-in axes on import
4
+ from . import blocks as _blocks # registers built-in blocks on import
5
+ from .axes import (
6
+ UNKNOWN,
7
+ Axis,
8
+ AxisDetector,
9
+ detect,
10
+ encode_tag,
11
+ get_axis,
12
+ is_unknown,
13
+ list_axes,
14
+ register_axis,
15
+ sample_tag,
16
+ tag_dim,
17
+ train_detector,
18
+ )
19
+ from .datasets import DatasetUnavailable, RegistryDatasetLoader
20
+ from .eval import Eval, EvalSpec
21
+ from .model import Model, ModelSpec
22
+ from .prior import Prior, PriorSpec
23
+ from .registry import (
24
+ get_block,
25
+ get_eval,
26
+ get_prior,
27
+ list_blocks,
28
+ list_evals,
29
+ list_priors,
30
+ register_block,
31
+ register_eval,
32
+ register_prior,
33
+ )
34
+ from .run import Run, RunSpec
35
+
36
+ __version__ = "0.4.0"
37
+
38
+ __all__ = [
39
+ "UNKNOWN",
40
+ "Axis",
41
+ "AxisDetector",
42
+ "Eval",
43
+ "EvalSpec",
44
+ "Model",
45
+ "ModelSpec",
46
+ "Prior",
47
+ "PriorSpec",
48
+ "Run",
49
+ "RunSpec",
50
+ "detect",
51
+ "encode_tag",
52
+ "get_axis",
53
+ "get_block",
54
+ "get_eval",
55
+ "get_prior",
56
+ "is_unknown",
57
+ "list_axes",
58
+ "list_blocks",
59
+ "list_evals",
60
+ "list_priors",
61
+ "register_axis",
62
+ "register_block",
63
+ "register_eval",
64
+ "register_prior",
65
+ "sample_tag",
66
+ "tag_dim",
67
+ "train_detector",
68
+ ]
@@ -0,0 +1,34 @@
1
+ """Axes package — public surface for promptable-prior axes.
2
+
3
+ Importing this module also imports `builtins`, which eagerly
4
+ registers the built-in axes (monotonicity, …) so they're available
5
+ without an extra import in user code."""
6
+
7
+ from . import builtins as _builtins # noqa: F401 (side-effect: registers built-ins)
8
+ from .base import (
9
+ UNKNOWN,
10
+ Axis,
11
+ AxisKind,
12
+ get_axis,
13
+ is_unknown,
14
+ list_axes,
15
+ register_axis,
16
+ )
17
+ from .detector import AxisDetector, detect, train_detector
18
+ from .encoding import encode_tag, sample_tag, tag_dim
19
+
20
+ __all__ = [
21
+ "UNKNOWN",
22
+ "Axis",
23
+ "AxisDetector",
24
+ "AxisKind",
25
+ "detect",
26
+ "encode_tag",
27
+ "get_axis",
28
+ "is_unknown",
29
+ "list_axes",
30
+ "register_axis",
31
+ "sample_tag",
32
+ "tag_dim",
33
+ "train_detector",
34
+ ]
@@ -0,0 +1,111 @@
1
+ """Axis abstraction for promptable priors.
2
+
3
+ An Axis declares a *steerable property* of a prior — a knob the trained
4
+ brain learns to honor at inference time. Examples: monotonicity (edge
5
+ signs in an SCM), lag_scale (delay distribution), feedback_allowed
6
+ (DAG vs cyclic), sparsity (mean parents per node).
7
+
8
+ The axis declares the contract (name, kind, value space, default).
9
+ Each Prior that opts into an axis implements its own application
10
+ logic — the same axis can mean different things to different priors
11
+ (e.g. monotonicity on an SCM constrains edge signs; on a regression
12
+ prior it constrains weight signs). The base class doesn't try to
13
+ unify those interpretations.
14
+
15
+ Back-compat invariants this module enforces:
16
+ - The reserved sentinel ``UNKNOWN`` means "skip this axis at sample
17
+ time". A prior given ``tag={axis: UNKNOWN}`` for every axis is
18
+ bit-identical to its pre-axis behavior. This is what guarantees
19
+ adding axes can't regress existing benchmarks.
20
+ - A prior with no declared axes ignores any ``tag=...`` argument
21
+ passed to ``sample()`` — old call sites keep working unchanged.
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ from dataclasses import dataclass, field
27
+ from typing import Any, Literal
28
+
29
+ AxisKind = Literal["categorical", "range", "boolean"]
30
+
31
+ # Reserved sentinel — when a tag has this for an axis, the prior must
32
+ # skip the axis hook and behave as if the axis weren't declared.
33
+ # Stored as a string so it round-trips through JSON / YAML cleanly.
34
+ UNKNOWN = "__unknown__"
35
+
36
+
37
+ @dataclass(frozen=True)
38
+ class Axis:
39
+ name: str
40
+ kind: AxisKind
41
+ # For categorical: list of label strings. For range: [min, max].
42
+ # For boolean: ignored.
43
+ values: tuple[Any, ...] = field(default_factory=tuple)
44
+ description: str = ""
45
+ # Fraction of training samples that get UNKNOWN for this axis.
46
+ # Default 0.3 — preserves substantial unconditional mass so the
47
+ # brain doesn't lose its pre-axis baseline.
48
+ unknown_mass: float = 0.3
49
+
50
+ def __post_init__(self) -> None:
51
+ if not 0.0 <= self.unknown_mass <= 1.0:
52
+ raise ValueError(
53
+ f"axis {self.name!r}: unknown_mass must be in [0, 1], got {self.unknown_mass}"
54
+ )
55
+ if self.kind == "categorical" and not self.values:
56
+ raise ValueError(f"axis {self.name!r}: categorical axes need at least one value")
57
+ if self.kind == "range" and len(self.values) != 2:
58
+ raise ValueError(
59
+ f"axis {self.name!r}: range axes need exactly 2 values (min, max), got {len(self.values)}"
60
+ )
61
+
62
+ def sample_value(self, rng: Any) -> Any:
63
+ """Sample one value for this axis, honoring the unknown_mass invariant."""
64
+ if rng.random() < self.unknown_mass:
65
+ return UNKNOWN
66
+ if self.kind == "categorical":
67
+ return rng.choice(list(self.values))
68
+ if self.kind == "boolean":
69
+ return bool(rng.integers(0, 2))
70
+ if self.kind == "range":
71
+ lo, hi = float(self.values[0]), float(self.values[1])
72
+ return float(rng.uniform(lo, hi))
73
+ raise ValueError(f"unknown axis kind: {self.kind!r}")
74
+
75
+
76
+ # Module-level registry. Resolves axis names → Axis instances at
77
+ # loader time. Built-in axes register themselves on import; custom
78
+ # axes register via @register_axis.
79
+ _AXES: dict[str, Axis] = {}
80
+
81
+
82
+ def register_axis(axis: Axis) -> Axis:
83
+ """Register an Axis. Idempotent: re-registering the *same* axis is
84
+ a no-op (useful when modules get imported twice); re-registering
85
+ a *different* axis under the same name raises."""
86
+ existing = _AXES.get(axis.name)
87
+ if existing is not None and existing != axis:
88
+ raise ValueError(
89
+ f"axis {axis.name!r} already registered with a different definition"
90
+ )
91
+ _AXES[axis.name] = axis
92
+ return axis
93
+
94
+
95
+ def get_axis(name: str) -> Axis:
96
+ """Look up a registered axis by name. Raises KeyError if not registered."""
97
+ if name not in _AXES:
98
+ raise KeyError(
99
+ f"axis {name!r} not registered. Known axes: {sorted(_AXES)}"
100
+ )
101
+ return _AXES[name]
102
+
103
+
104
+ def list_axes() -> list[str]:
105
+ return sorted(_AXES)
106
+
107
+
108
+ def is_unknown(value: Any) -> bool:
109
+ """Whether a tag value is the UNKNOWN sentinel. Use this in axis
110
+ application code rather than comparing to the string literal."""
111
+ return value == UNKNOWN
@@ -0,0 +1,5 @@
1
+ """Built-in axes. Importing this module registers them with the global
2
+ axis registry. The package ``__init__`` imports this so axes are
3
+ available the moment ``pfnstudio_core`` is imported."""
4
+
5
+ from . import monotonicity # noqa: F401 (registers on import)
@@ -0,0 +1,25 @@
1
+ """Built-in `monotonicity` axis.
2
+
3
+ Declares the contract; per-prior application lives in each Prior
4
+ subclass (an SCM's monotonicity constrains edge signs, a regression
5
+ prior's monotonicity constrains weight signs, etc.).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from ..base import Axis, register_axis
11
+
12
+ monotonicity = register_axis(
13
+ Axis(
14
+ name="monotonicity",
15
+ kind="categorical",
16
+ values=("positive", "negative", "mixed"),
17
+ description=(
18
+ "Sign discipline of the prior's causal relationships. "
19
+ "positive = X up implies Y up everywhere; negative = "
20
+ "X up implies Y down; mixed = signs may vary. Unknown "
21
+ "(default) leaves the prior unconstrained."
22
+ ),
23
+ unknown_mass=0.3,
24
+ )
25
+ )
@@ -0,0 +1,260 @@
1
+ """Auto-detector — supervised model that reads context and proposes
2
+ axis values.
3
+
4
+ The pitch in one line: instead of asking the user to fill out an empty
5
+ form of axis chips, the detector reads their data and pre-fills the
6
+ chips with its best guess (and a confidence). The user just confirms,
7
+ overrides, or adds the axes the detector can't see.
8
+
9
+ Training:
10
+ - Each batch, the prior samples a real (non-UNKNOWN) tag value and
11
+ generates data from it. The tag value is the supervised label.
12
+ - The detector reads the data (no tag) and predicts a softmax over
13
+ axis values.
14
+ - Cross-entropy loss.
15
+
16
+ Inference:
17
+ - Same encoder forward, output softmax probabilities.
18
+ - Highest-probability value becomes the proposed chip; that probability
19
+ becomes the displayed confidence.
20
+
21
+ Limitations in v1:
22
+ - Only categorical axes are detected. Range axes need a regression
23
+ head and a different loss; boolean axes need a 2-class head.
24
+ Both are straightforward additions but defer until they're actually
25
+ shipped on a prior.
26
+ - The detector is small (32-d, 2-layer) — same shape as the test
27
+ brain. Scale up when a real prior with many axes ships.
28
+ """
29
+
30
+ from __future__ import annotations
31
+
32
+ from typing import Any
33
+
34
+ import numpy as np
35
+
36
+ from .base import UNKNOWN, Axis
37
+
38
+
39
+ class AxisDetector:
40
+ """A small encoder + per-axis classification head. Holds one
41
+ head per categorical axis; non-categorical axes are silently
42
+ skipped (an empty detector is legal — predicts nothing)."""
43
+
44
+ def __init__(
45
+ self,
46
+ axes: list[Axis],
47
+ d_model: int = 32,
48
+ n_heads: int = 4,
49
+ n_layers: int = 2,
50
+ dropout: float = 0.0,
51
+ ):
52
+ try:
53
+ import torch.nn as nn
54
+ except ImportError as e:
55
+ raise ImportError(
56
+ "AxisDetector requires torch. "
57
+ "Install with: pip install pfnstudio-core[torch]"
58
+ ) from e
59
+
60
+ # Only categorical axes get heads in v1. Skipped axes are
61
+ # tracked so callers can warn the user "we can't propose for
62
+ # axis X yet."
63
+ self._scored_axes: list[Axis] = [
64
+ a for a in axes if a.kind == "categorical" and len(a.values) >= 2
65
+ ]
66
+ self._skipped_axes: list[Axis] = [
67
+ a for a in axes if a not in self._scored_axes
68
+ ]
69
+ self.d_model = d_model
70
+
71
+ # Lazy embedder — input feature dim is inferred on first
72
+ # forward. Same pattern the TabularEmbedder uses.
73
+ self.embedder = nn.LazyLinear(d_model)
74
+ encoder_layer = nn.TransformerEncoderLayer(
75
+ d_model=d_model,
76
+ nhead=n_heads,
77
+ dim_feedforward=d_model * 4,
78
+ dropout=dropout,
79
+ batch_first=True,
80
+ norm_first=True,
81
+ )
82
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
83
+ # One head per scored axis. Output dim = number of values
84
+ # (no UNKNOWN class — the detector proposes a real value;
85
+ # low confidence across all classes is how it expresses
86
+ # "I'm not sure", not a separate class).
87
+ self.heads = nn.ModuleDict(
88
+ {a.name: nn.Linear(d_model, len(a.values)) for a in self._scored_axes}
89
+ )
90
+
91
+ @property
92
+ def scored_axes(self) -> list[Axis]:
93
+ """Axes the detector has heads for (categorical, ≥2 values)."""
94
+ return self._scored_axes
95
+
96
+ @property
97
+ def skipped_axes(self) -> list[Axis]:
98
+ """Axes the detector can't predict (boolean, range, single-value)."""
99
+ return self._skipped_axes
100
+
101
+ def parameters(self):
102
+ # nn.ModuleDict + nn.Module attributes — yield everything.
103
+ yield from self.embedder.parameters()
104
+ yield from self.encoder.parameters()
105
+ yield from self.heads.parameters()
106
+
107
+ def state_dict(self) -> dict[str, Any]:
108
+ return {
109
+ "embedder": self.embedder.state_dict(),
110
+ "encoder": self.encoder.state_dict(),
111
+ "heads": self.heads.state_dict(),
112
+ }
113
+
114
+ def __call__(self, x: Any) -> dict[str, Any]:
115
+ """Forward pass. Returns axis_name → logits tensor of shape
116
+ (B, num_values_for_axis)."""
117
+ emb = self.embedder(x) # (B, N, d_model)
118
+ encoded = self.encoder(emb) # (B, N, d_model)
119
+ # Mean-pool the context tokens to one vector per batch element.
120
+ # (Could swap for an attention pool later; mean is fine for
121
+ # the monotonicity demo and avoids extra parameters.)
122
+ pooled = encoded.mean(dim=1) # (B, d_model)
123
+ return {name: head(pooled) for name, head in self.heads.items()}
124
+
125
+
126
+ def _value_to_index(axis: Axis, value: Any) -> int:
127
+ """Index of `value` in axis.values. Raises if value is UNKNOWN
128
+ or not a known value — UNKNOWN tags shouldn't be used as training
129
+ labels because the detector is supposed to learn to assign a
130
+ concrete value."""
131
+ if value == UNKNOWN:
132
+ raise ValueError(
133
+ f"axis {axis.name!r}: cannot train detector on UNKNOWN label"
134
+ )
135
+ return list(axis.values).index(value)
136
+
137
+
138
+ def train_detector(
139
+ *,
140
+ detector: AxisDetector,
141
+ prior: Any,
142
+ steps: int = 500,
143
+ batch_size: int = 16,
144
+ lr: float = 1e-3,
145
+ seed: int = 0,
146
+ on_step: Any = None,
147
+ ) -> dict[str, Any]:
148
+ """Train the detector on labeled batches sampled from the prior.
149
+
150
+ Each batch picks a non-UNKNOWN value for each scored axis, samples
151
+ data with that tag, and trains the detector to predict the tag
152
+ from the data. Cross-entropy loss, summed across axes.
153
+
154
+ Returns ``{status, final_loss, mean_acc}`` where ``mean_acc`` is
155
+ the accuracy of the detector on the *last* batch (cheap sanity
156
+ check that training did something)."""
157
+ try:
158
+ import torch
159
+ import torch.nn.functional as F
160
+ except ImportError:
161
+ return {"status": "skipped", "reason": "torch not installed"}
162
+
163
+ if not detector.scored_axes:
164
+ return {"status": "skipped", "reason": "detector has no scored axes"}
165
+
166
+ rng = np.random.default_rng(seed)
167
+ optim = torch.optim.AdamW(list(detector.parameters()), lr=lr)
168
+
169
+ last_acc: dict[str, float] = {}
170
+ final_loss = 0.0
171
+ for step in range(steps):
172
+ # Sample one tag value per axis (random, non-UNKNOWN — that's
173
+ # the supervised label the detector learns to recover).
174
+ tag = {
175
+ a.name: rng.choice(list(a.values)) for a in detector.scored_axes
176
+ }
177
+ batch = prior.sample_batch(
178
+ batch_size=batch_size, seed=seed + step * batch_size, tag=tag
179
+ )
180
+ X = torch.stack([torch.from_numpy(b["X"]).float() for b in batch])
181
+
182
+ logits = detector(X) # dict axis → (B, num_values)
183
+ loss = torch.zeros(1)
184
+ for axis in detector.scored_axes:
185
+ label_idx = _value_to_index(axis, tag[axis.name])
186
+ labels = torch.full((X.shape[0],), label_idx, dtype=torch.long)
187
+ loss = loss + F.cross_entropy(logits[axis.name], labels)
188
+
189
+ optim.zero_grad()
190
+ loss.backward()
191
+ optim.step()
192
+
193
+ final_loss = float(loss.item())
194
+ if on_step is not None:
195
+ on_step(step, final_loss)
196
+
197
+ # Final-batch accuracy (per axis) — cheap and informative.
198
+ if step == steps - 1:
199
+ with torch.no_grad():
200
+ for axis in detector.scored_axes:
201
+ preds = logits[axis.name].argmax(dim=-1)
202
+ label_idx = _value_to_index(axis, tag[axis.name])
203
+ last_acc[axis.name] = float((preds == label_idx).float().mean().item())
204
+
205
+ return {
206
+ "status": "ok",
207
+ "steps": steps,
208
+ "final_loss": final_loss,
209
+ "last_batch_accuracy": last_acc,
210
+ }
211
+
212
+
213
+ def detect(
214
+ *,
215
+ detector: AxisDetector,
216
+ context: Any,
217
+ ) -> dict[str, dict[str, Any]]:
218
+ """Run the detector on a context tensor and return a tag-shaped
219
+ proposal with per-axis confidence.
220
+
221
+ Returns ``{axis_name: {value, confidence, probs}}`` where:
222
+ - ``value`` is the predicted axis value (argmax)
223
+ - ``confidence`` is the softmax probability of that value
224
+ - ``probs`` is the full {value: probability} dict, for callers
225
+ that want to render uncertainty (e.g. show "60% positive,
226
+ 35% mixed, 5% negative" instead of just "positive 60%").
227
+ """
228
+ try:
229
+ import torch
230
+ import torch.nn.functional as F
231
+ except ImportError:
232
+ return {}
233
+
234
+ if not detector.scored_axes:
235
+ return {}
236
+
237
+ if not torch.is_tensor(context):
238
+ context = torch.from_numpy(np.asarray(context)).float()
239
+ if context.dim() == 2:
240
+ context = context.unsqueeze(0) # add batch dim
241
+
242
+ with torch.no_grad():
243
+ logits = detector(context)
244
+
245
+ out: dict[str, dict[str, Any]] = {}
246
+ for axis in detector.scored_axes:
247
+ probs = F.softmax(logits[axis.name], dim=-1)
248
+ # Mean across batch in case the caller passed multiple contexts —
249
+ # the detector returns one proposal per axis, not per-sample.
250
+ mean_probs = probs.mean(dim=0)
251
+ best_idx = int(mean_probs.argmax().item())
252
+ out[axis.name] = {
253
+ "value": axis.values[best_idx],
254
+ "confidence": float(mean_probs[best_idx].item()),
255
+ "probs": {
256
+ axis.values[i]: float(mean_probs[i].item())
257
+ for i in range(len(axis.values))
258
+ },
259
+ }
260
+ return out