pfnstudio-core 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pfnstudio_core/__init__.py +68 -0
- pfnstudio_core/axes/__init__.py +34 -0
- pfnstudio_core/axes/base.py +111 -0
- pfnstudio_core/axes/builtins/__init__.py +5 -0
- pfnstudio_core/axes/builtins/monotonicity.py +25 -0
- pfnstudio_core/axes/detector.py +260 -0
- pfnstudio_core/axes/encoding.py +119 -0
- pfnstudio_core/axes/honoring.py +147 -0
- pfnstudio_core/blocks/__init__.py +3 -0
- pfnstudio_core/blocks/heads.py +56 -0
- pfnstudio_core/blocks/tabular.py +23 -0
- pfnstudio_core/blocks/transformer.py +138 -0
- pfnstudio_core/datasets.py +226 -0
- pfnstudio_core/eval.py +49 -0
- pfnstudio_core/loaders.py +44 -0
- pfnstudio_core/model.py +57 -0
- pfnstudio_core/prior.py +110 -0
- pfnstudio_core/registry.py +98 -0
- pfnstudio_core/run.py +63 -0
- pfnstudio_core/scorers/__init__.py +48 -0
- pfnstudio_core/scorers/base.py +51 -0
- pfnstudio_core/scorers/breast_cancer_vs_logreg.py +251 -0
- pfnstudio_core/scorers/closed_form_baseline.py +221 -0
- pfnstudio_core/scorers/in_context_regression_ols.py +170 -0
- pfnstudio_core/scorers/kinpfn_real_fpt_ks.py +220 -0
- pfnstudio_core/scorers/m4_monthly_forecast.py +203 -0
- pfnstudio_core/scorers/synthetic_classification_bce.py +243 -0
- pfnstudio_core/scorers/synthetic_regression_mse.py +194 -0
- pfnstudio_core/training/__init__.py +5 -0
- pfnstudio_core/training/loop.py +734 -0
- pfnstudio_core/training/predict.py +604 -0
- pfnstudio_core-0.7.0.dist-info/METADATA +56 -0
- pfnstudio_core-0.7.0.dist-info/RECORD +34 -0
- pfnstudio_core-0.7.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""PFN Studio core abstractions."""
|
|
2
|
+
|
|
3
|
+
from . import axes as _axes # registers built-in axes on import
|
|
4
|
+
from . import blocks as _blocks # registers built-in blocks on import
|
|
5
|
+
from .axes import (
|
|
6
|
+
UNKNOWN,
|
|
7
|
+
Axis,
|
|
8
|
+
AxisDetector,
|
|
9
|
+
detect,
|
|
10
|
+
encode_tag,
|
|
11
|
+
get_axis,
|
|
12
|
+
is_unknown,
|
|
13
|
+
list_axes,
|
|
14
|
+
register_axis,
|
|
15
|
+
sample_tag,
|
|
16
|
+
tag_dim,
|
|
17
|
+
train_detector,
|
|
18
|
+
)
|
|
19
|
+
from .datasets import DatasetUnavailable, RegistryDatasetLoader
|
|
20
|
+
from .eval import Eval, EvalSpec
|
|
21
|
+
from .model import Model, ModelSpec
|
|
22
|
+
from .prior import Prior, PriorSpec
|
|
23
|
+
from .registry import (
|
|
24
|
+
get_block,
|
|
25
|
+
get_eval,
|
|
26
|
+
get_prior,
|
|
27
|
+
list_blocks,
|
|
28
|
+
list_evals,
|
|
29
|
+
list_priors,
|
|
30
|
+
register_block,
|
|
31
|
+
register_eval,
|
|
32
|
+
register_prior,
|
|
33
|
+
)
|
|
34
|
+
from .run import Run, RunSpec
|
|
35
|
+
|
|
36
|
+
__version__ = "0.4.0"
|
|
37
|
+
|
|
38
|
+
__all__ = [
|
|
39
|
+
"UNKNOWN",
|
|
40
|
+
"Axis",
|
|
41
|
+
"AxisDetector",
|
|
42
|
+
"Eval",
|
|
43
|
+
"EvalSpec",
|
|
44
|
+
"Model",
|
|
45
|
+
"ModelSpec",
|
|
46
|
+
"Prior",
|
|
47
|
+
"PriorSpec",
|
|
48
|
+
"Run",
|
|
49
|
+
"RunSpec",
|
|
50
|
+
"detect",
|
|
51
|
+
"encode_tag",
|
|
52
|
+
"get_axis",
|
|
53
|
+
"get_block",
|
|
54
|
+
"get_eval",
|
|
55
|
+
"get_prior",
|
|
56
|
+
"is_unknown",
|
|
57
|
+
"list_axes",
|
|
58
|
+
"list_blocks",
|
|
59
|
+
"list_evals",
|
|
60
|
+
"list_priors",
|
|
61
|
+
"register_axis",
|
|
62
|
+
"register_block",
|
|
63
|
+
"register_eval",
|
|
64
|
+
"register_prior",
|
|
65
|
+
"sample_tag",
|
|
66
|
+
"tag_dim",
|
|
67
|
+
"train_detector",
|
|
68
|
+
]
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Axes package — public surface for promptable-prior axes.
|
|
2
|
+
|
|
3
|
+
Importing this module also imports `builtins`, which eagerly
|
|
4
|
+
registers the built-in axes (monotonicity, …) so they're available
|
|
5
|
+
without an extra import in user code."""
|
|
6
|
+
|
|
7
|
+
from . import builtins as _builtins # noqa: F401 (side-effect: registers built-ins)
|
|
8
|
+
from .base import (
|
|
9
|
+
UNKNOWN,
|
|
10
|
+
Axis,
|
|
11
|
+
AxisKind,
|
|
12
|
+
get_axis,
|
|
13
|
+
is_unknown,
|
|
14
|
+
list_axes,
|
|
15
|
+
register_axis,
|
|
16
|
+
)
|
|
17
|
+
from .detector import AxisDetector, detect, train_detector
|
|
18
|
+
from .encoding import encode_tag, sample_tag, tag_dim
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
"UNKNOWN",
|
|
22
|
+
"Axis",
|
|
23
|
+
"AxisDetector",
|
|
24
|
+
"AxisKind",
|
|
25
|
+
"detect",
|
|
26
|
+
"encode_tag",
|
|
27
|
+
"get_axis",
|
|
28
|
+
"is_unknown",
|
|
29
|
+
"list_axes",
|
|
30
|
+
"register_axis",
|
|
31
|
+
"sample_tag",
|
|
32
|
+
"tag_dim",
|
|
33
|
+
"train_detector",
|
|
34
|
+
]
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""Axis abstraction for promptable priors.
|
|
2
|
+
|
|
3
|
+
An Axis declares a *steerable property* of a prior — a knob the trained
|
|
4
|
+
brain learns to honor at inference time. Examples: monotonicity (edge
|
|
5
|
+
signs in an SCM), lag_scale (delay distribution), feedback_allowed
|
|
6
|
+
(DAG vs cyclic), sparsity (mean parents per node).
|
|
7
|
+
|
|
8
|
+
The axis declares the contract (name, kind, value space, default).
|
|
9
|
+
Each Prior that opts into an axis implements its own application
|
|
10
|
+
logic — the same axis can mean different things to different priors
|
|
11
|
+
(e.g. monotonicity on an SCM constrains edge signs; on a regression
|
|
12
|
+
prior it constrains weight signs). The base class doesn't try to
|
|
13
|
+
unify those interpretations.
|
|
14
|
+
|
|
15
|
+
Back-compat invariants this module enforces:
|
|
16
|
+
- The reserved sentinel ``UNKNOWN`` means "skip this axis at sample
|
|
17
|
+
time". A prior given ``tag={axis: UNKNOWN}`` for every axis is
|
|
18
|
+
bit-identical to its pre-axis behavior. This is what guarantees
|
|
19
|
+
adding axes can't regress existing benchmarks.
|
|
20
|
+
- A prior with no declared axes ignores any ``tag=...`` argument
|
|
21
|
+
passed to ``sample()`` — old call sites keep working unchanged.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
from dataclasses import dataclass, field
|
|
27
|
+
from typing import Any, Literal
|
|
28
|
+
|
|
29
|
+
AxisKind = Literal["categorical", "range", "boolean"]
|
|
30
|
+
|
|
31
|
+
# Reserved sentinel — when a tag has this for an axis, the prior must
|
|
32
|
+
# skip the axis hook and behave as if the axis weren't declared.
|
|
33
|
+
# Stored as a string so it round-trips through JSON / YAML cleanly.
|
|
34
|
+
UNKNOWN = "__unknown__"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass(frozen=True)
|
|
38
|
+
class Axis:
|
|
39
|
+
name: str
|
|
40
|
+
kind: AxisKind
|
|
41
|
+
# For categorical: list of label strings. For range: [min, max].
|
|
42
|
+
# For boolean: ignored.
|
|
43
|
+
values: tuple[Any, ...] = field(default_factory=tuple)
|
|
44
|
+
description: str = ""
|
|
45
|
+
# Fraction of training samples that get UNKNOWN for this axis.
|
|
46
|
+
# Default 0.3 — preserves substantial unconditional mass so the
|
|
47
|
+
# brain doesn't lose its pre-axis baseline.
|
|
48
|
+
unknown_mass: float = 0.3
|
|
49
|
+
|
|
50
|
+
def __post_init__(self) -> None:
|
|
51
|
+
if not 0.0 <= self.unknown_mass <= 1.0:
|
|
52
|
+
raise ValueError(
|
|
53
|
+
f"axis {self.name!r}: unknown_mass must be in [0, 1], got {self.unknown_mass}"
|
|
54
|
+
)
|
|
55
|
+
if self.kind == "categorical" and not self.values:
|
|
56
|
+
raise ValueError(f"axis {self.name!r}: categorical axes need at least one value")
|
|
57
|
+
if self.kind == "range" and len(self.values) != 2:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
f"axis {self.name!r}: range axes need exactly 2 values (min, max), got {len(self.values)}"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def sample_value(self, rng: Any) -> Any:
|
|
63
|
+
"""Sample one value for this axis, honoring the unknown_mass invariant."""
|
|
64
|
+
if rng.random() < self.unknown_mass:
|
|
65
|
+
return UNKNOWN
|
|
66
|
+
if self.kind == "categorical":
|
|
67
|
+
return rng.choice(list(self.values))
|
|
68
|
+
if self.kind == "boolean":
|
|
69
|
+
return bool(rng.integers(0, 2))
|
|
70
|
+
if self.kind == "range":
|
|
71
|
+
lo, hi = float(self.values[0]), float(self.values[1])
|
|
72
|
+
return float(rng.uniform(lo, hi))
|
|
73
|
+
raise ValueError(f"unknown axis kind: {self.kind!r}")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# Module-level registry. Resolves axis names → Axis instances at
|
|
77
|
+
# loader time. Built-in axes register themselves on import; custom
|
|
78
|
+
# axes register via @register_axis.
|
|
79
|
+
_AXES: dict[str, Axis] = {}
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def register_axis(axis: Axis) -> Axis:
|
|
83
|
+
"""Register an Axis. Idempotent: re-registering the *same* axis is
|
|
84
|
+
a no-op (useful when modules get imported twice); re-registering
|
|
85
|
+
a *different* axis under the same name raises."""
|
|
86
|
+
existing = _AXES.get(axis.name)
|
|
87
|
+
if existing is not None and existing != axis:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"axis {axis.name!r} already registered with a different definition"
|
|
90
|
+
)
|
|
91
|
+
_AXES[axis.name] = axis
|
|
92
|
+
return axis
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def get_axis(name: str) -> Axis:
|
|
96
|
+
"""Look up a registered axis by name. Raises KeyError if not registered."""
|
|
97
|
+
if name not in _AXES:
|
|
98
|
+
raise KeyError(
|
|
99
|
+
f"axis {name!r} not registered. Known axes: {sorted(_AXES)}"
|
|
100
|
+
)
|
|
101
|
+
return _AXES[name]
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def list_axes() -> list[str]:
|
|
105
|
+
return sorted(_AXES)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def is_unknown(value: Any) -> bool:
|
|
109
|
+
"""Whether a tag value is the UNKNOWN sentinel. Use this in axis
|
|
110
|
+
application code rather than comparing to the string literal."""
|
|
111
|
+
return value == UNKNOWN
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""Built-in `monotonicity` axis.
|
|
2
|
+
|
|
3
|
+
Declares the contract; per-prior application lives in each Prior
|
|
4
|
+
subclass (an SCM's monotonicity constrains edge signs, a regression
|
|
5
|
+
prior's monotonicity constrains weight signs, etc.).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from ..base import Axis, register_axis
|
|
11
|
+
|
|
12
|
+
monotonicity = register_axis(
|
|
13
|
+
Axis(
|
|
14
|
+
name="monotonicity",
|
|
15
|
+
kind="categorical",
|
|
16
|
+
values=("positive", "negative", "mixed"),
|
|
17
|
+
description=(
|
|
18
|
+
"Sign discipline of the prior's causal relationships. "
|
|
19
|
+
"positive = X up implies Y up everywhere; negative = "
|
|
20
|
+
"X up implies Y down; mixed = signs may vary. Unknown "
|
|
21
|
+
"(default) leaves the prior unconstrained."
|
|
22
|
+
),
|
|
23
|
+
unknown_mass=0.3,
|
|
24
|
+
)
|
|
25
|
+
)
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
"""Auto-detector — supervised model that reads context and proposes
|
|
2
|
+
axis values.
|
|
3
|
+
|
|
4
|
+
The pitch in one line: instead of asking the user to fill out an empty
|
|
5
|
+
form of axis chips, the detector reads their data and pre-fills the
|
|
6
|
+
chips with its best guess (and a confidence). The user just confirms,
|
|
7
|
+
overrides, or adds the axes the detector can't see.
|
|
8
|
+
|
|
9
|
+
Training:
|
|
10
|
+
- Each batch, the prior samples a real (non-UNKNOWN) tag value and
|
|
11
|
+
generates data from it. The tag value is the supervised label.
|
|
12
|
+
- The detector reads the data (no tag) and predicts a softmax over
|
|
13
|
+
axis values.
|
|
14
|
+
- Cross-entropy loss.
|
|
15
|
+
|
|
16
|
+
Inference:
|
|
17
|
+
- Same encoder forward, output softmax probabilities.
|
|
18
|
+
- Highest-probability value becomes the proposed chip; that probability
|
|
19
|
+
becomes the displayed confidence.
|
|
20
|
+
|
|
21
|
+
Limitations in v1:
|
|
22
|
+
- Only categorical axes are detected. Range axes need a regression
|
|
23
|
+
head and a different loss; boolean axes need a 2-class head.
|
|
24
|
+
Both are straightforward additions but defer until they're actually
|
|
25
|
+
shipped on a prior.
|
|
26
|
+
- The detector is small (32-d, 2-layer) — same shape as the test
|
|
27
|
+
brain. Scale up when a real prior with many axes ships.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from __future__ import annotations
|
|
31
|
+
|
|
32
|
+
from typing import Any
|
|
33
|
+
|
|
34
|
+
import numpy as np
|
|
35
|
+
|
|
36
|
+
from .base import UNKNOWN, Axis
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class AxisDetector:
|
|
40
|
+
"""A small encoder + per-axis classification head. Holds one
|
|
41
|
+
head per categorical axis; non-categorical axes are silently
|
|
42
|
+
skipped (an empty detector is legal — predicts nothing)."""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
axes: list[Axis],
|
|
47
|
+
d_model: int = 32,
|
|
48
|
+
n_heads: int = 4,
|
|
49
|
+
n_layers: int = 2,
|
|
50
|
+
dropout: float = 0.0,
|
|
51
|
+
):
|
|
52
|
+
try:
|
|
53
|
+
import torch.nn as nn
|
|
54
|
+
except ImportError as e:
|
|
55
|
+
raise ImportError(
|
|
56
|
+
"AxisDetector requires torch. "
|
|
57
|
+
"Install with: pip install pfnstudio-core[torch]"
|
|
58
|
+
) from e
|
|
59
|
+
|
|
60
|
+
# Only categorical axes get heads in v1. Skipped axes are
|
|
61
|
+
# tracked so callers can warn the user "we can't propose for
|
|
62
|
+
# axis X yet."
|
|
63
|
+
self._scored_axes: list[Axis] = [
|
|
64
|
+
a for a in axes if a.kind == "categorical" and len(a.values) >= 2
|
|
65
|
+
]
|
|
66
|
+
self._skipped_axes: list[Axis] = [
|
|
67
|
+
a for a in axes if a not in self._scored_axes
|
|
68
|
+
]
|
|
69
|
+
self.d_model = d_model
|
|
70
|
+
|
|
71
|
+
# Lazy embedder — input feature dim is inferred on first
|
|
72
|
+
# forward. Same pattern the TabularEmbedder uses.
|
|
73
|
+
self.embedder = nn.LazyLinear(d_model)
|
|
74
|
+
encoder_layer = nn.TransformerEncoderLayer(
|
|
75
|
+
d_model=d_model,
|
|
76
|
+
nhead=n_heads,
|
|
77
|
+
dim_feedforward=d_model * 4,
|
|
78
|
+
dropout=dropout,
|
|
79
|
+
batch_first=True,
|
|
80
|
+
norm_first=True,
|
|
81
|
+
)
|
|
82
|
+
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
|
|
83
|
+
# One head per scored axis. Output dim = number of values
|
|
84
|
+
# (no UNKNOWN class — the detector proposes a real value;
|
|
85
|
+
# low confidence across all classes is how it expresses
|
|
86
|
+
# "I'm not sure", not a separate class).
|
|
87
|
+
self.heads = nn.ModuleDict(
|
|
88
|
+
{a.name: nn.Linear(d_model, len(a.values)) for a in self._scored_axes}
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def scored_axes(self) -> list[Axis]:
|
|
93
|
+
"""Axes the detector has heads for (categorical, ≥2 values)."""
|
|
94
|
+
return self._scored_axes
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def skipped_axes(self) -> list[Axis]:
|
|
98
|
+
"""Axes the detector can't predict (boolean, range, single-value)."""
|
|
99
|
+
return self._skipped_axes
|
|
100
|
+
|
|
101
|
+
def parameters(self):
|
|
102
|
+
# nn.ModuleDict + nn.Module attributes — yield everything.
|
|
103
|
+
yield from self.embedder.parameters()
|
|
104
|
+
yield from self.encoder.parameters()
|
|
105
|
+
yield from self.heads.parameters()
|
|
106
|
+
|
|
107
|
+
def state_dict(self) -> dict[str, Any]:
|
|
108
|
+
return {
|
|
109
|
+
"embedder": self.embedder.state_dict(),
|
|
110
|
+
"encoder": self.encoder.state_dict(),
|
|
111
|
+
"heads": self.heads.state_dict(),
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
def __call__(self, x: Any) -> dict[str, Any]:
|
|
115
|
+
"""Forward pass. Returns axis_name → logits tensor of shape
|
|
116
|
+
(B, num_values_for_axis)."""
|
|
117
|
+
emb = self.embedder(x) # (B, N, d_model)
|
|
118
|
+
encoded = self.encoder(emb) # (B, N, d_model)
|
|
119
|
+
# Mean-pool the context tokens to one vector per batch element.
|
|
120
|
+
# (Could swap for an attention pool later; mean is fine for
|
|
121
|
+
# the monotonicity demo and avoids extra parameters.)
|
|
122
|
+
pooled = encoded.mean(dim=1) # (B, d_model)
|
|
123
|
+
return {name: head(pooled) for name, head in self.heads.items()}
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _value_to_index(axis: Axis, value: Any) -> int:
|
|
127
|
+
"""Index of `value` in axis.values. Raises if value is UNKNOWN
|
|
128
|
+
or not a known value — UNKNOWN tags shouldn't be used as training
|
|
129
|
+
labels because the detector is supposed to learn to assign a
|
|
130
|
+
concrete value."""
|
|
131
|
+
if value == UNKNOWN:
|
|
132
|
+
raise ValueError(
|
|
133
|
+
f"axis {axis.name!r}: cannot train detector on UNKNOWN label"
|
|
134
|
+
)
|
|
135
|
+
return list(axis.values).index(value)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def train_detector(
|
|
139
|
+
*,
|
|
140
|
+
detector: AxisDetector,
|
|
141
|
+
prior: Any,
|
|
142
|
+
steps: int = 500,
|
|
143
|
+
batch_size: int = 16,
|
|
144
|
+
lr: float = 1e-3,
|
|
145
|
+
seed: int = 0,
|
|
146
|
+
on_step: Any = None,
|
|
147
|
+
) -> dict[str, Any]:
|
|
148
|
+
"""Train the detector on labeled batches sampled from the prior.
|
|
149
|
+
|
|
150
|
+
Each batch picks a non-UNKNOWN value for each scored axis, samples
|
|
151
|
+
data with that tag, and trains the detector to predict the tag
|
|
152
|
+
from the data. Cross-entropy loss, summed across axes.
|
|
153
|
+
|
|
154
|
+
Returns ``{status, final_loss, mean_acc}`` where ``mean_acc`` is
|
|
155
|
+
the accuracy of the detector on the *last* batch (cheap sanity
|
|
156
|
+
check that training did something)."""
|
|
157
|
+
try:
|
|
158
|
+
import torch
|
|
159
|
+
import torch.nn.functional as F
|
|
160
|
+
except ImportError:
|
|
161
|
+
return {"status": "skipped", "reason": "torch not installed"}
|
|
162
|
+
|
|
163
|
+
if not detector.scored_axes:
|
|
164
|
+
return {"status": "skipped", "reason": "detector has no scored axes"}
|
|
165
|
+
|
|
166
|
+
rng = np.random.default_rng(seed)
|
|
167
|
+
optim = torch.optim.AdamW(list(detector.parameters()), lr=lr)
|
|
168
|
+
|
|
169
|
+
last_acc: dict[str, float] = {}
|
|
170
|
+
final_loss = 0.0
|
|
171
|
+
for step in range(steps):
|
|
172
|
+
# Sample one tag value per axis (random, non-UNKNOWN — that's
|
|
173
|
+
# the supervised label the detector learns to recover).
|
|
174
|
+
tag = {
|
|
175
|
+
a.name: rng.choice(list(a.values)) for a in detector.scored_axes
|
|
176
|
+
}
|
|
177
|
+
batch = prior.sample_batch(
|
|
178
|
+
batch_size=batch_size, seed=seed + step * batch_size, tag=tag
|
|
179
|
+
)
|
|
180
|
+
X = torch.stack([torch.from_numpy(b["X"]).float() for b in batch])
|
|
181
|
+
|
|
182
|
+
logits = detector(X) # dict axis → (B, num_values)
|
|
183
|
+
loss = torch.zeros(1)
|
|
184
|
+
for axis in detector.scored_axes:
|
|
185
|
+
label_idx = _value_to_index(axis, tag[axis.name])
|
|
186
|
+
labels = torch.full((X.shape[0],), label_idx, dtype=torch.long)
|
|
187
|
+
loss = loss + F.cross_entropy(logits[axis.name], labels)
|
|
188
|
+
|
|
189
|
+
optim.zero_grad()
|
|
190
|
+
loss.backward()
|
|
191
|
+
optim.step()
|
|
192
|
+
|
|
193
|
+
final_loss = float(loss.item())
|
|
194
|
+
if on_step is not None:
|
|
195
|
+
on_step(step, final_loss)
|
|
196
|
+
|
|
197
|
+
# Final-batch accuracy (per axis) — cheap and informative.
|
|
198
|
+
if step == steps - 1:
|
|
199
|
+
with torch.no_grad():
|
|
200
|
+
for axis in detector.scored_axes:
|
|
201
|
+
preds = logits[axis.name].argmax(dim=-1)
|
|
202
|
+
label_idx = _value_to_index(axis, tag[axis.name])
|
|
203
|
+
last_acc[axis.name] = float((preds == label_idx).float().mean().item())
|
|
204
|
+
|
|
205
|
+
return {
|
|
206
|
+
"status": "ok",
|
|
207
|
+
"steps": steps,
|
|
208
|
+
"final_loss": final_loss,
|
|
209
|
+
"last_batch_accuracy": last_acc,
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def detect(
|
|
214
|
+
*,
|
|
215
|
+
detector: AxisDetector,
|
|
216
|
+
context: Any,
|
|
217
|
+
) -> dict[str, dict[str, Any]]:
|
|
218
|
+
"""Run the detector on a context tensor and return a tag-shaped
|
|
219
|
+
proposal with per-axis confidence.
|
|
220
|
+
|
|
221
|
+
Returns ``{axis_name: {value, confidence, probs}}`` where:
|
|
222
|
+
- ``value`` is the predicted axis value (argmax)
|
|
223
|
+
- ``confidence`` is the softmax probability of that value
|
|
224
|
+
- ``probs`` is the full {value: probability} dict, for callers
|
|
225
|
+
that want to render uncertainty (e.g. show "60% positive,
|
|
226
|
+
35% mixed, 5% negative" instead of just "positive 60%").
|
|
227
|
+
"""
|
|
228
|
+
try:
|
|
229
|
+
import torch
|
|
230
|
+
import torch.nn.functional as F
|
|
231
|
+
except ImportError:
|
|
232
|
+
return {}
|
|
233
|
+
|
|
234
|
+
if not detector.scored_axes:
|
|
235
|
+
return {}
|
|
236
|
+
|
|
237
|
+
if not torch.is_tensor(context):
|
|
238
|
+
context = torch.from_numpy(np.asarray(context)).float()
|
|
239
|
+
if context.dim() == 2:
|
|
240
|
+
context = context.unsqueeze(0) # add batch dim
|
|
241
|
+
|
|
242
|
+
with torch.no_grad():
|
|
243
|
+
logits = detector(context)
|
|
244
|
+
|
|
245
|
+
out: dict[str, dict[str, Any]] = {}
|
|
246
|
+
for axis in detector.scored_axes:
|
|
247
|
+
probs = F.softmax(logits[axis.name], dim=-1)
|
|
248
|
+
# Mean across batch in case the caller passed multiple contexts —
|
|
249
|
+
# the detector returns one proposal per axis, not per-sample.
|
|
250
|
+
mean_probs = probs.mean(dim=0)
|
|
251
|
+
best_idx = int(mean_probs.argmax().item())
|
|
252
|
+
out[axis.name] = {
|
|
253
|
+
"value": axis.values[best_idx],
|
|
254
|
+
"confidence": float(mean_probs[best_idx].item()),
|
|
255
|
+
"probs": {
|
|
256
|
+
axis.values[i]: float(mean_probs[i].item())
|
|
257
|
+
for i in range(len(axis.values))
|
|
258
|
+
},
|
|
259
|
+
}
|
|
260
|
+
return out
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
"""Tag encoding — fixed-length vector representation of an axis tag.
|
|
2
|
+
|
|
3
|
+
The encoding rules:
|
|
4
|
+
|
|
5
|
+
- **Categorical axes** → one-hot over the declared values, plus one
|
|
6
|
+
extra slot for the UNKNOWN sentinel (always last). So a categorical
|
|
7
|
+
axis with K values contributes K+1 dimensions.
|
|
8
|
+
- **Boolean axes** → 2 dims (false, true) plus 1 for UNKNOWN. Same
|
|
9
|
+
shape as a 2-value categorical.
|
|
10
|
+
- **Range axes** → 2 dims: ``[is_unknown_flag, normalised_value]``.
|
|
11
|
+
When tagged with a real value v in [min, max], normalised_value is
|
|
12
|
+
``(v - min) / (max - min)`` clipped to [0, 1] and is_unknown_flag
|
|
13
|
+
is 0. When UNKNOWN, normalised_value is 0 and is_unknown_flag is 1.
|
|
14
|
+
|
|
15
|
+
The encoding is deterministic and stateless — the same axes list and
|
|
16
|
+
the same tag always produce the same vector. ``tag_dim(axes)`` returns
|
|
17
|
+
the total length so model constructors can size their tag embedder
|
|
18
|
+
once at init time without poking at sample tags.
|
|
19
|
+
|
|
20
|
+
Back-compat invariant: ``encode_tag({}, axes)`` and ``encode_tag(None,
|
|
21
|
+
axes)`` are byte-identical to ``encode_tag({a: UNKNOWN for a in axes},
|
|
22
|
+
axes)``. Used by callers who need a "no constraints" tag vector
|
|
23
|
+
without enumerating all axes.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from __future__ import annotations
|
|
27
|
+
|
|
28
|
+
from typing import Any
|
|
29
|
+
|
|
30
|
+
import numpy as np
|
|
31
|
+
|
|
32
|
+
from .base import UNKNOWN, Axis, is_unknown
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _axis_dim(axis: Axis) -> int:
|
|
36
|
+
if axis.kind == "categorical":
|
|
37
|
+
return len(axis.values) + 1 # values + UNKNOWN slot
|
|
38
|
+
if axis.kind == "boolean":
|
|
39
|
+
return 3 # false, true, UNKNOWN
|
|
40
|
+
if axis.kind == "range":
|
|
41
|
+
return 2 # is_unknown_flag, normalised_value
|
|
42
|
+
raise ValueError(f"unknown axis kind: {axis.kind!r}")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def tag_dim(axes: list[Axis]) -> int:
|
|
46
|
+
"""Total length of the tag vector produced by ``encode_tag`` for
|
|
47
|
+
the given axes (in the same order)."""
|
|
48
|
+
return sum(_axis_dim(a) for a in axes)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _encode_categorical(axis: Axis, value: Any) -> np.ndarray:
|
|
52
|
+
"""One-hot over values + final UNKNOWN slot."""
|
|
53
|
+
dim = _axis_dim(axis)
|
|
54
|
+
out = np.zeros(dim, dtype=np.float32)
|
|
55
|
+
if value is None or is_unknown(value):
|
|
56
|
+
out[-1] = 1.0
|
|
57
|
+
return out
|
|
58
|
+
if value not in axis.values:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
f"axis {axis.name!r}: value {value!r} not in {axis.values}"
|
|
61
|
+
)
|
|
62
|
+
idx = axis.values.index(value)
|
|
63
|
+
out[idx] = 1.0
|
|
64
|
+
return out
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _encode_boolean(axis: Axis, value: Any) -> np.ndarray:
|
|
68
|
+
out = np.zeros(3, dtype=np.float32)
|
|
69
|
+
if value is None or is_unknown(value):
|
|
70
|
+
out[2] = 1.0
|
|
71
|
+
return out
|
|
72
|
+
out[int(bool(value))] = 1.0
|
|
73
|
+
return out
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _encode_range(axis: Axis, value: Any) -> np.ndarray:
|
|
77
|
+
out = np.zeros(2, dtype=np.float32)
|
|
78
|
+
if value is None or is_unknown(value):
|
|
79
|
+
out[0] = 1.0 # is_unknown_flag
|
|
80
|
+
return out
|
|
81
|
+
lo, hi = float(axis.values[0]), float(axis.values[1])
|
|
82
|
+
span = hi - lo
|
|
83
|
+
if span == 0:
|
|
84
|
+
out[1] = 0.0
|
|
85
|
+
else:
|
|
86
|
+
normalised = (float(value) - lo) / span
|
|
87
|
+
out[1] = float(np.clip(normalised, 0.0, 1.0))
|
|
88
|
+
return out
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def encode_tag(tag: dict[str, Any] | None, axes: list[Axis]) -> np.ndarray:
|
|
92
|
+
"""Encode a tag dict into a fixed-length float32 vector for the given axes.
|
|
93
|
+
|
|
94
|
+
Missing axis entries are treated as UNKNOWN — same as explicitly
|
|
95
|
+
passing UNKNOWN. The encoded vector's length matches ``tag_dim(axes)``.
|
|
96
|
+
"""
|
|
97
|
+
tag = tag or {}
|
|
98
|
+
chunks: list[np.ndarray] = []
|
|
99
|
+
for axis in axes:
|
|
100
|
+
value = tag.get(axis.name, UNKNOWN)
|
|
101
|
+
if axis.kind == "categorical":
|
|
102
|
+
chunks.append(_encode_categorical(axis, value))
|
|
103
|
+
elif axis.kind == "boolean":
|
|
104
|
+
chunks.append(_encode_boolean(axis, value))
|
|
105
|
+
elif axis.kind == "range":
|
|
106
|
+
chunks.append(_encode_range(axis, value))
|
|
107
|
+
else:
|
|
108
|
+
raise ValueError(f"unknown axis kind: {axis.kind!r}")
|
|
109
|
+
if not chunks:
|
|
110
|
+
return np.zeros(0, dtype=np.float32)
|
|
111
|
+
return np.concatenate(chunks, axis=0)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def sample_tag(axes: list[Axis], rng: np.random.Generator) -> dict[str, Any]:
|
|
115
|
+
"""Sample one tag dict for all axes, honoring each axis's
|
|
116
|
+
``unknown_mass``. Returns axis_name → sampled value (or UNKNOWN).
|
|
117
|
+
|
|
118
|
+
Used by the training loop to draw a fresh tag per batch."""
|
|
119
|
+
return {axis.name: axis.sample_value(rng) for axis in axes}
|