datadoom 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datadoom/__init__.py +23 -0
- datadoom/adapters/__init__.py +29 -0
- datadoom/adapters/frameworks.py +94 -0
- datadoom/adapters/loaders.py +72 -0
- datadoom/api/__init__.py +11 -0
- datadoom/api/app.py +109 -0
- datadoom/api/deps.py +30 -0
- datadoom/api/errors.py +89 -0
- datadoom/api/estimate.py +82 -0
- datadoom/api/routes/__init__.py +7 -0
- datadoom/api/routes/artifacts.py +147 -0
- datadoom/api/routes/datasets.py +180 -0
- datadoom/api/routes/meta.py +45 -0
- datadoom/api/routes/plugins.py +22 -0
- datadoom/api/routes/runs.py +144 -0
- datadoom/api/routes/specs.py +73 -0
- datadoom/api/routes/templates.py +30 -0
- datadoom/api/schemas.py +230 -0
- datadoom/api/serializers.py +143 -0
- datadoom/api/state.py +24 -0
- datadoom/api/store_helpers.py +56 -0
- datadoom/api/ws.py +72 -0
- datadoom/cli/__init__.py +1 -0
- datadoom/cli/main.py +313 -0
- datadoom/config.py +108 -0
- datadoom/engine/__init__.py +38 -0
- datadoom/engine/advice.py +289 -0
- datadoom/engine/audit.py +290 -0
- datadoom/engine/causal/__init__.py +15 -0
- datadoom/engine/causal/execute.py +116 -0
- datadoom/engine/causal/functions.py +116 -0
- datadoom/engine/causal/graph.py +54 -0
- datadoom/engine/difficulty/__init__.py +36 -0
- datadoom/engine/difficulty/calibrate.py +235 -0
- datadoom/engine/difficulty/knobs.py +171 -0
- datadoom/engine/difficulty/probes.py +181 -0
- datadoom/engine/dist/__init__.py +35 -0
- datadoom/engine/dist/base.py +46 -0
- datadoom/engine/dist/builtins.py +172 -0
- datadoom/engine/dist/compliance.py +344 -0
- datadoom/engine/dist/providers.py +117 -0
- datadoom/engine/errors.py +32 -0
- datadoom/engine/export/__init__.py +27 -0
- datadoom/engine/export/base.py +49 -0
- datadoom/engine/export/checksums.py +18 -0
- datadoom/engine/export/csv_exporter.py +34 -0
- datadoom/engine/export/json_exporter.py +67 -0
- datadoom/engine/export/metadata.py +58 -0
- datadoom/engine/export/parquet_exporter.py +45 -0
- datadoom/engine/failure/__init__.py +18 -0
- datadoom/engine/failure/apply.py +37 -0
- datadoom/engine/failure/base.py +116 -0
- datadoom/engine/failure/modes.py +442 -0
- datadoom/engine/pipeline.py +418 -0
- datadoom/engine/profile.py +327 -0
- datadoom/engine/progress.py +14 -0
- datadoom/engine/reference.py +338 -0
- datadoom/engine/reports.py +206 -0
- datadoom/engine/rng.py +79 -0
- datadoom/engine/spec/__init__.py +45 -0
- datadoom/engine/spec/hashing.py +57 -0
- datadoom/engine/spec/models.py +238 -0
- datadoom/engine/spec/validate.py +345 -0
- datadoom/engine/timeseries.py +88 -0
- datadoom/jobs/__init__.py +14 -0
- datadoom/jobs/progress.py +155 -0
- datadoom/jobs/worker.py +162 -0
- datadoom/plugin.py +35 -0
- datadoom/plugins/__init__.py +47 -0
- datadoom/plugins/contracts.py +72 -0
- datadoom/plugins/loader.py +125 -0
- datadoom/plugins/registry.py +214 -0
- datadoom/plugins/scaffold.py +434 -0
- datadoom/store/__init__.py +47 -0
- datadoom/store/artifacts.py +67 -0
- datadoom/store/db.py +104 -0
- datadoom/store/migrations/__init__.py +0 -0
- datadoom/store/migrations/env.py +53 -0
- datadoom/store/migrations/script.py.mako +24 -0
- datadoom/store/migrations/versions/0001_init.py +149 -0
- datadoom/store/migrations/versions/0002_report_mutual_information.py +23 -0
- datadoom/store/migrations/versions/0003_run_name.py +23 -0
- datadoom/store/migrations/versions/0004_report_profile.py +24 -0
- datadoom/store/models.py +170 -0
- datadoom/store/repositories.py +279 -0
- datadoom/templates/__init__.py +239 -0
- datadoom/templates/ab_test.datadoom.yaml +46 -0
- datadoom/templates/clinical_deterioration.datadoom.yaml +124 -0
- datadoom/templates/credit_default_challenge.datadoom.yaml +147 -0
- datadoom/templates/customer_churn.datadoom.yaml +60 -0
- datadoom/templates/ecommerce_orders.datadoom.yaml +46 -0
- datadoom/templates/fraud_detection.datadoom.yaml +57 -0
- datadoom/templates/hospital_readmission.datadoom.yaml +61 -0
- datadoom/templates/insurance_claims.datadoom.yaml +43 -0
- datadoom/templates/iot_sensors.datadoom.yaml +44 -0
- datadoom/templates/people_directory.datadoom.yaml +56 -0
- datadoom/templates/predictive_maintenance.datadoom.yaml +107 -0
- datadoom/templates/telecom_churn_challenge.datadoom.yaml +125 -0
- datadoom/version.py +3 -0
- datadoom/webdist/assets/index-V8VAuTJG.js +445 -0
- datadoom/webdist/assets/index-doRjyG5s.css +1 -0
- datadoom/webdist/assets/inter-cyrillic-ext-wght-normal-BOeWTOD4.woff2 +0 -0
- datadoom/webdist/assets/inter-cyrillic-wght-normal-DqGufNeO.woff2 +0 -0
- datadoom/webdist/assets/inter-greek-ext-wght-normal-DlzME5K_.woff2 +0 -0
- datadoom/webdist/assets/inter-greek-wght-normal-CkhJZR-_.woff2 +0 -0
- datadoom/webdist/assets/inter-latin-ext-wght-normal-DO1Apj_S.woff2 +0 -0
- datadoom/webdist/assets/inter-latin-wght-normal-Dx4kXJAl.woff2 +0 -0
- datadoom/webdist/assets/inter-vietnamese-wght-normal-CBcvBZtf.woff2 +0 -0
- datadoom/webdist/assets/jetbrains-mono-cyrillic-wght-normal-D73BlboJ.woff2 +0 -0
- datadoom/webdist/assets/jetbrains-mono-greek-wght-normal-Bw9x6K1M.woff2 +0 -0
- datadoom/webdist/assets/jetbrains-mono-latin-ext-wght-normal-DBQx-q_a.woff2 +0 -0
- datadoom/webdist/assets/jetbrains-mono-latin-wght-normal-B9CIFXIH.woff2 +0 -0
- datadoom/webdist/assets/jetbrains-mono-vietnamese-wght-normal-Bt-aOZkq.woff2 +0 -0
- datadoom/webdist/assets/space-grotesk-latin-ext-wght-normal-D9tNdqV9.woff2 +0 -0
- datadoom/webdist/assets/space-grotesk-latin-wght-normal-BhU9QXUp.woff2 +0 -0
- datadoom/webdist/assets/space-grotesk-vietnamese-wght-normal-D0rl6rjA.woff2 +0 -0
- datadoom/webdist/index.html +15 -0
- datadoom-0.1.0.dev0.dist-info/METADATA +143 -0
- datadoom-0.1.0.dev0.dist-info/RECORD +122 -0
- datadoom-0.1.0.dev0.dist-info/WHEEL +4 -0
- datadoom-0.1.0.dev0.dist-info/entry_points.txt +2 -0
- datadoom-0.1.0.dev0.dist-info/licenses/LICENSE +202 -0
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""SEM execution: topological walk applying structural equations (05 §3).
|
|
2
|
+
|
|
3
|
+
Runs after base (root) features are sampled. For each derived node in
|
|
4
|
+
topological order:
|
|
5
|
+
|
|
6
|
+
v = Σ_e fn_e(parent_e) + ε_v (ε_v ~ noise[v] via RNG(noise:v))
|
|
7
|
+
|
|
8
|
+
Boolean children interpret the summed contribution as a probability and draw a
|
|
9
|
+
Bernoulli outcome from ``RNG(feature:v)`` (05 §3). Interventions ``do(X=x₀)``
|
|
10
|
+
fix a node to a constant and skip its equation; because the walk is topological,
|
|
11
|
+
descendants automatically see the intervened value.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
from typing import TYPE_CHECKING, Any
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from ..errors import SpecValidationError
|
|
21
|
+
from ..spec.models import BooleanFeature, NumericFeature
|
|
22
|
+
from .functions import STRUCTURAL_FNS
|
|
23
|
+
from .graph import CausalDag
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING: # avoid a runtime import cycle with pipeline
|
|
26
|
+
from ..pipeline import RunContext
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def resolve_interventions(interventions: list[dict[str, Any]]) -> dict[str, float]:
|
|
30
|
+
"""Flatten ``[{do: {X: x0}}, ...]`` into ``{X: x0}`` (last wins)."""
|
|
31
|
+
fixed: dict[str, float] = {}
|
|
32
|
+
for item in interventions:
|
|
33
|
+
do = item.get("do", {})
|
|
34
|
+
for node, value in do.items():
|
|
35
|
+
fixed[node] = float(value)
|
|
36
|
+
return fixed
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def build_dag(ctx: RunContext) -> CausalDag:
|
|
40
|
+
assert ctx.spec.causal is not None
|
|
41
|
+
return CausalDag(ctx.spec.causal.edges, list(ctx.spec.features))
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def execute_causal(ctx: RunContext, columns: dict[str, np.ndarray]) -> CausalDag:
|
|
45
|
+
"""Fill derived columns in-place; return the DAG (for the report's true graph)."""
|
|
46
|
+
spec = ctx.spec
|
|
47
|
+
assert spec.causal is not None
|
|
48
|
+
n = spec.rows
|
|
49
|
+
dag = build_dag(ctx)
|
|
50
|
+
interventions = resolve_interventions(spec.causal.interventions)
|
|
51
|
+
|
|
52
|
+
for node in dag.topological_order():
|
|
53
|
+
if node in interventions:
|
|
54
|
+
columns[node] = _materialize_constant(spec.features[node], interventions[node], n)
|
|
55
|
+
continue
|
|
56
|
+
in_edges = dag.in_edges(node)
|
|
57
|
+
if not in_edges:
|
|
58
|
+
continue # root feature — already sampled in base_generation
|
|
59
|
+
|
|
60
|
+
contrib = np.zeros(n, dtype=float)
|
|
61
|
+
for edge in in_edges:
|
|
62
|
+
fn = STRUCTURAL_FNS[edge.fn]
|
|
63
|
+
contrib = contrib + fn.contribution(columns[edge.src], edge)
|
|
64
|
+
|
|
65
|
+
eps = _draw_noise(ctx, node, n)
|
|
66
|
+
if eps is not None:
|
|
67
|
+
contrib = contrib + eps
|
|
68
|
+
|
|
69
|
+
columns[node] = _finalize(ctx, node, spec.features[node], contrib)
|
|
70
|
+
|
|
71
|
+
return dag
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _draw_noise(ctx: RunContext, node: str, n: int) -> np.ndarray | None:
|
|
75
|
+
"""Node noise ε_v from RNG(noise:v); ``None`` when noise is absent/``none``."""
|
|
76
|
+
from ..dist.builtins import REGISTRY
|
|
77
|
+
|
|
78
|
+
assert ctx.spec.causal is not None
|
|
79
|
+
spec = ctx.spec.causal.noise.get(node)
|
|
80
|
+
if spec is None:
|
|
81
|
+
return None
|
|
82
|
+
dist_name = spec.get("dist")
|
|
83
|
+
if dist_name is None or dist_name == "none":
|
|
84
|
+
return None
|
|
85
|
+
dist = REGISTRY[dist_name]
|
|
86
|
+
ctx.used_namespaces.append(f"noise:{node}")
|
|
87
|
+
return dist.sample(ctx.rng.noise(node), n, spec.get("params", {}))
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _finalize(ctx: RunContext, node: str, feat: Any, contrib: np.ndarray) -> np.ndarray:
|
|
91
|
+
if isinstance(feat, BooleanFeature):
|
|
92
|
+
p = np.clip(contrib, 0.0, 1.0)
|
|
93
|
+
ctx.used_namespaces.append(f"feature:{node}")
|
|
94
|
+
return ctx.rng.feature(node).random(size=len(contrib)) < p
|
|
95
|
+
if isinstance(feat, NumericFeature):
|
|
96
|
+
values = contrib
|
|
97
|
+
if feat.min is not None or feat.max is not None:
|
|
98
|
+
lo = -np.inf if feat.min is None else feat.min
|
|
99
|
+
hi = np.inf if feat.max is None else feat.max
|
|
100
|
+
values = np.clip(values, lo, hi)
|
|
101
|
+
if feat.dtype == "int":
|
|
102
|
+
values = np.rint(values).astype("int64")
|
|
103
|
+
return values
|
|
104
|
+
raise SpecValidationError(
|
|
105
|
+
f"feature {node!r} is a causal target but type {feat.type!r} cannot be derived "
|
|
106
|
+
"(only numeric and boolean targets are supported)",
|
|
107
|
+
locator=f"features.{node}",
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _materialize_constant(feat: Any, value: float, n: int) -> np.ndarray:
|
|
112
|
+
if isinstance(feat, BooleanFeature):
|
|
113
|
+
return np.full(n, bool(value))
|
|
114
|
+
if isinstance(feat, NumericFeature) and feat.dtype == "int":
|
|
115
|
+
return np.full(n, int(round(value)), dtype="int64")
|
|
116
|
+
return np.full(n, float(value))
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""Structural functions for SEM edges (05 §3, 04 §5).
|
|
2
|
+
|
|
3
|
+
Each edge carries a structural function ``fn`` that maps a parent's values to a
|
|
4
|
+
numeric *contribution*. A derived node sums the contributions of its incoming
|
|
5
|
+
edges and adds node noise (see ``execute.py``). Functions are pure and operate
|
|
6
|
+
on numpy arrays so they stay deterministic on the pinned path.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from abc import ABC, abstractmethod
|
|
12
|
+
from collections.abc import Mapping
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
from ..errors import SpecValidationError
|
|
17
|
+
from ..spec.models import CausalEdge
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _as_float(parent: np.ndarray) -> np.ndarray:
|
|
21
|
+
"""Coerce a parent column to float (booleans → 0/1)."""
|
|
22
|
+
return np.asarray(parent, dtype=float)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class StructuralFn(ABC):
|
|
26
|
+
"""ABC for an edge's structural function."""
|
|
27
|
+
|
|
28
|
+
name: str
|
|
29
|
+
# Optional JSON-schema fragment for the edge params (09 §6); ``None`` for
|
|
30
|
+
# built-ins (the Graph inspector renders their native controls).
|
|
31
|
+
param_schema: Mapping[str, object] | None = None
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def contribution(self, parent: np.ndarray, edge: CausalEdge) -> np.ndarray:
|
|
35
|
+
"""Return this edge's additive contribution to the child node."""
|
|
36
|
+
|
|
37
|
+
def validate(self, edge: CausalEdge, locator: str) -> None:
|
|
38
|
+
"""Check the edge carries the params this function needs."""
|
|
39
|
+
return None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Linear(StructuralFn):
|
|
43
|
+
name = "linear"
|
|
44
|
+
|
|
45
|
+
def contribution(self, parent, edge):
|
|
46
|
+
bias = edge.bias or 0.0
|
|
47
|
+
return edge.weight * _as_float(parent) + bias
|
|
48
|
+
|
|
49
|
+
def validate(self, edge, locator):
|
|
50
|
+
if edge.weight is None:
|
|
51
|
+
raise SpecValidationError("linear edge requires 'weight'", locator=locator)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class Logistic(StructuralFn):
|
|
55
|
+
name = "logistic"
|
|
56
|
+
|
|
57
|
+
def contribution(self, parent, edge):
|
|
58
|
+
bias = edge.bias or 0.0
|
|
59
|
+
z = edge.weight * _as_float(parent) + bias
|
|
60
|
+
return 1.0 / (1.0 + np.exp(-z))
|
|
61
|
+
|
|
62
|
+
def validate(self, edge, locator):
|
|
63
|
+
if edge.weight is None:
|
|
64
|
+
raise SpecValidationError("logistic edge requires 'weight'", locator=locator)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class Polynomial(StructuralFn):
|
|
68
|
+
name = "polynomial"
|
|
69
|
+
|
|
70
|
+
def contribution(self, parent, edge):
|
|
71
|
+
p = _as_float(parent)
|
|
72
|
+
out = np.zeros_like(p)
|
|
73
|
+
for i, c in enumerate(edge.coeffs or ()):
|
|
74
|
+
out = out + c * (p**i)
|
|
75
|
+
return out
|
|
76
|
+
|
|
77
|
+
def validate(self, edge, locator):
|
|
78
|
+
if not edge.coeffs:
|
|
79
|
+
raise SpecValidationError(
|
|
80
|
+
"polynomial edge requires a non-empty 'coeffs' list", locator=locator
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class Map(StructuralFn):
|
|
85
|
+
name = "map"
|
|
86
|
+
|
|
87
|
+
def contribution(self, parent, edge):
|
|
88
|
+
mapping = edge.mapping or {}
|
|
89
|
+
out = np.empty(len(parent), dtype=float)
|
|
90
|
+
for i, v in enumerate(parent):
|
|
91
|
+
key = str(v)
|
|
92
|
+
if key not in mapping:
|
|
93
|
+
raise SpecValidationError(
|
|
94
|
+
f"map edge has no mapping for category {key!r}",
|
|
95
|
+
locator=f"causal edge {edge.src}->{edge.dst}",
|
|
96
|
+
)
|
|
97
|
+
out[i] = mapping[key]
|
|
98
|
+
return out
|
|
99
|
+
|
|
100
|
+
def validate(self, edge, locator):
|
|
101
|
+
if not edge.mapping:
|
|
102
|
+
raise SpecValidationError(
|
|
103
|
+
"map edge requires a non-empty 'mapping'", locator=locator
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class Identity(StructuralFn):
|
|
108
|
+
name = "identity"
|
|
109
|
+
|
|
110
|
+
def contribution(self, parent, edge):
|
|
111
|
+
return _as_float(parent)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
STRUCTURAL_FNS: dict[str, StructuralFn] = {
|
|
115
|
+
fn.name: fn for fn in (Linear(), Logistic(), Polynomial(), Map(), Identity())
|
|
116
|
+
}
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Causal DAG construction & deterministic traversal (05 §3, 17 step 11).
|
|
2
|
+
|
|
3
|
+
Wraps a ``networkx.DiGraph`` built from the spec's causal edges. The engine is
|
|
4
|
+
pure, so graph operations are deterministic: nodes iterate in **sorted** order
|
|
5
|
+
and the topological walk is **lexicographical**, so the SEM execution order is a
|
|
6
|
+
function of the spec alone (never of dict/hash iteration order).
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import networkx as nx
|
|
12
|
+
|
|
13
|
+
from ..errors import SpecValidationError
|
|
14
|
+
from ..spec.models import CausalEdge
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CausalDag:
|
|
18
|
+
"""A validated causal DAG over the spec's feature names.
|
|
19
|
+
|
|
20
|
+
Cycle detection is defensive — ``validate_spec`` already rejects cycles, but
|
|
21
|
+
constructing the graph re-checks so the engine never executes a bad walk.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, edges: list[CausalEdge], feature_names: list[str]) -> None:
|
|
25
|
+
self._graph: nx.DiGraph = nx.DiGraph()
|
|
26
|
+
self._graph.add_nodes_from(sorted(feature_names))
|
|
27
|
+
# Author order is preserved per destination so summation order (and thus
|
|
28
|
+
# floating-point results) is a stable function of the spec.
|
|
29
|
+
self._in_edges: dict[str, list[CausalEdge]] = {n: [] for n in feature_names}
|
|
30
|
+
for edge in edges:
|
|
31
|
+
self._graph.add_edge(edge.src, edge.dst)
|
|
32
|
+
self._in_edges[edge.dst].append(edge)
|
|
33
|
+
if not nx.is_directed_acyclic_graph(self._graph):
|
|
34
|
+
raise SpecValidationError("causal graph is not acyclic", locator="causal.edges")
|
|
35
|
+
|
|
36
|
+
def topological_order(self) -> list[str]:
|
|
37
|
+
"""Deterministic topological order (lexicographical tie-breaking)."""
|
|
38
|
+
return list(nx.lexicographical_topological_sort(self._graph))
|
|
39
|
+
|
|
40
|
+
def in_edges(self, node: str) -> list[CausalEdge]:
|
|
41
|
+
"""Incoming edges for ``node`` in author order (empty for roots)."""
|
|
42
|
+
return self._in_edges.get(node, [])
|
|
43
|
+
|
|
44
|
+
def parents(self, node: str) -> list[str]:
|
|
45
|
+
"""Sorted parent feature names of ``node``."""
|
|
46
|
+
return sorted(self._graph.predecessors(node))
|
|
47
|
+
|
|
48
|
+
def is_derived(self, node: str) -> bool:
|
|
49
|
+
"""True if ``node`` has at least one incoming edge (computed, not sampled)."""
|
|
50
|
+
return bool(self._in_edges.get(node))
|
|
51
|
+
|
|
52
|
+
def derived_nodes(self) -> set[str]:
|
|
53
|
+
"""All nodes that are causal targets (have ≥1 incoming edge)."""
|
|
54
|
+
return {n for n, e in self._in_edges.items() if e}
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""Difficulty targeting: empirical calibration to a baseline-metric band (05 §5).
|
|
2
|
+
|
|
3
|
+
A dataset's difficulty is defined *operationally* by the score a standard probe
|
|
4
|
+
model achieves on it. ``calibrate_difficulty`` runs an adaptive bisection over a
|
|
5
|
+
single :class:`~datadoom.engine.difficulty.knobs.DifficultyDial` until the probe
|
|
6
|
+
metric lands in the target band, then returns the calibrated frame to ship and a
|
|
7
|
+
report of what was achieved (honestly, including misses).
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from .calibrate import (
|
|
13
|
+
TIER_BANDS,
|
|
14
|
+
DifficultyResult,
|
|
15
|
+
Target,
|
|
16
|
+
calibrate_difficulty,
|
|
17
|
+
resolve_target,
|
|
18
|
+
)
|
|
19
|
+
from .knobs import ACTIVE_KNOBS, DIAL_MAX, DifficultyDial, KnobState
|
|
20
|
+
from .probes import PROBES, ProbeModel, ProbeResult, evaluate
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"ACTIVE_KNOBS",
|
|
24
|
+
"DIAL_MAX",
|
|
25
|
+
"DifficultyDial",
|
|
26
|
+
"DifficultyResult",
|
|
27
|
+
"KnobState",
|
|
28
|
+
"PROBES",
|
|
29
|
+
"ProbeModel",
|
|
30
|
+
"ProbeResult",
|
|
31
|
+
"TIER_BANDS",
|
|
32
|
+
"Target",
|
|
33
|
+
"calibrate_difficulty",
|
|
34
|
+
"evaluate",
|
|
35
|
+
"resolve_target",
|
|
36
|
+
]
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
"""Adaptive difficulty calibration: bisection on the dial (05 §5.2, 17 step 15).
|
|
2
|
+
|
|
3
|
+
The loop measures the probe metric μ on the realized frame and turns the single
|
|
4
|
+
:class:`DifficultyDial` until μ lands in the target band ``[a, b]`` (from an
|
|
5
|
+
explicit band or a named tier). Because μ(d) is monotone non-increasing in the
|
|
6
|
+
dial, the search is a clean bisection:
|
|
7
|
+
|
|
8
|
+
generate → evaluate μ
|
|
9
|
+
μ ∈ [a, b] → success
|
|
10
|
+
μ > b → too easy → turn the dial up (more noise)
|
|
11
|
+
μ < a → too hard → turn the dial down
|
|
12
|
+
|
|
13
|
+
Honest failure (invariant #3): if even the pristine data is already below the
|
|
14
|
+
band (no easing knob exists) or the maximum dial can't push μ down into the band
|
|
15
|
+
(signal too strong for the available knobs), the loop returns the *closest*
|
|
16
|
+
achieved point and flags ``band_met = False`` with a plain-language note — never
|
|
17
|
+
a silent miss.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
from dataclasses import dataclass
|
|
23
|
+
from typing import TYPE_CHECKING, Any
|
|
24
|
+
|
|
25
|
+
import pandas as pd
|
|
26
|
+
|
|
27
|
+
from .knobs import ACTIVE_KNOBS, DIAL_MAX, DifficultyDial
|
|
28
|
+
from .probes import PROBES, ProbeResult, evaluate
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from ..pipeline import RunContext
|
|
32
|
+
|
|
33
|
+
# Named tiers → validated AUROC bands for binary classification (05 §5.3). These
|
|
34
|
+
# are calibration *targets*, kept honest by the calibration test in tests/.
|
|
35
|
+
TIER_BANDS: dict[str, tuple[float, float]] = {
|
|
36
|
+
"beginner": (0.90, 0.99),
|
|
37
|
+
"intermediate": (0.80, 0.90),
|
|
38
|
+
"advanced": (0.72, 0.80),
|
|
39
|
+
"kaggle": (0.62, 0.72),
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class Target:
|
|
45
|
+
task: str
|
|
46
|
+
metric: str
|
|
47
|
+
band: tuple[float, float]
|
|
48
|
+
tier: str | None = None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class DifficultyResult:
|
|
53
|
+
target: dict[str, Any]
|
|
54
|
+
achieved_metric: float
|
|
55
|
+
metric_name: str
|
|
56
|
+
probe: str
|
|
57
|
+
iterations: int
|
|
58
|
+
band_met: bool
|
|
59
|
+
dial: float
|
|
60
|
+
feature_noise: float
|
|
61
|
+
label_flip: float
|
|
62
|
+
knobs_requested: list[str]
|
|
63
|
+
knobs_active: list[str]
|
|
64
|
+
reference: dict[str, Any]
|
|
65
|
+
trace: list[dict[str, float]]
|
|
66
|
+
note: str | None = None
|
|
67
|
+
|
|
68
|
+
def to_dict(self) -> dict[str, Any]:
|
|
69
|
+
return {
|
|
70
|
+
"target": self.target,
|
|
71
|
+
"achieved_metric": self.achieved_metric,
|
|
72
|
+
"metric_name": self.metric_name,
|
|
73
|
+
"probe": self.probe,
|
|
74
|
+
"iterations": self.iterations,
|
|
75
|
+
"band_met": self.band_met,
|
|
76
|
+
"dial": self.dial,
|
|
77
|
+
"feature_noise": self.feature_noise,
|
|
78
|
+
"label_flip": self.label_flip,
|
|
79
|
+
"knobs_requested": self.knobs_requested,
|
|
80
|
+
"knobs_active": self.knobs_active,
|
|
81
|
+
"reference": self.reference,
|
|
82
|
+
"trace": self.trace,
|
|
83
|
+
"note": self.note,
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class _Best:
|
|
89
|
+
"""Tracks the closest-to-band point seen, for honest fallback reporting."""
|
|
90
|
+
|
|
91
|
+
dial: float = 0.0
|
|
92
|
+
result: ProbeResult | None = None
|
|
93
|
+
distance: float = float("inf")
|
|
94
|
+
|
|
95
|
+
def offer(self, dial: float, result: ProbeResult, band: tuple[float, float]) -> None:
|
|
96
|
+
a, b = band
|
|
97
|
+
m = result.metric
|
|
98
|
+
dist = 0.0 if a <= m <= b else min(abs(m - a), abs(m - b))
|
|
99
|
+
if dist < self.distance:
|
|
100
|
+
self.distance, self.dial, self.result = dist, dial, result
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def resolve_target(target: str | dict[str, Any]) -> Target:
|
|
104
|
+
"""Map a tier name or explicit-band dict to a concrete :class:`Target`."""
|
|
105
|
+
if isinstance(target, str):
|
|
106
|
+
band = TIER_BANDS[target] # validated upstream
|
|
107
|
+
return Target(task="classification", metric="auroc", band=band, tier=target)
|
|
108
|
+
band_raw = target["band"]
|
|
109
|
+
band = (float(band_raw[0]), float(band_raw[1]))
|
|
110
|
+
return Target(
|
|
111
|
+
task=str(target.get("task", "classification")),
|
|
112
|
+
metric=str(target.get("metric", "auroc")),
|
|
113
|
+
band=band,
|
|
114
|
+
tier=None,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def calibrate_difficulty(
|
|
119
|
+
ctx: RunContext, base_frame: pd.DataFrame
|
|
120
|
+
) -> tuple[DifficultyResult, pd.DataFrame]:
|
|
121
|
+
"""Run the adaptive loop; return the report + the calibrated frame to ship."""
|
|
122
|
+
spec = ctx.spec
|
|
123
|
+
assert spec.difficulty is not None
|
|
124
|
+
cfg = spec.difficulty
|
|
125
|
+
target = resolve_target(cfg.target)
|
|
126
|
+
band = target.band
|
|
127
|
+
probe = PROBES[cfg.probe]
|
|
128
|
+
|
|
129
|
+
requested = list(cfg.knobs)
|
|
130
|
+
active = [k for k in requested if k in ACTIVE_KNOBS]
|
|
131
|
+
dial_obj = DifficultyDial(base_frame, cfg.label, ctx.rng, active, ctx.used_namespaces)
|
|
132
|
+
|
|
133
|
+
# Probe seeds are derived from the run's RNG so the split/estimator are
|
|
134
|
+
# reproducible without touching the data namespaces.
|
|
135
|
+
split_seed = int(ctx.rng.probe("split").integers(0, 2**31 - 1))
|
|
136
|
+
est_seed = int(ctx.rng.probe("estimator").integers(0, 2**31 - 1))
|
|
137
|
+
ctx.used_namespaces.extend(["probe:split", "probe:estimator"])
|
|
138
|
+
|
|
139
|
+
trace: list[dict[str, float]] = []
|
|
140
|
+
|
|
141
|
+
def mu(dial: float) -> tuple[ProbeResult, pd.DataFrame]:
|
|
142
|
+
frame, _ = dial_obj.realize(dial)
|
|
143
|
+
result = evaluate(probe, frame, cfg.label, split_seed=split_seed, est_seed=est_seed)
|
|
144
|
+
trace.append({"dial": round(dial, 6), "metric": round(result.metric, 6)})
|
|
145
|
+
return result, frame
|
|
146
|
+
|
|
147
|
+
a, b = band
|
|
148
|
+
best = _Best()
|
|
149
|
+
|
|
150
|
+
r0, frame0 = mu(0.0)
|
|
151
|
+
best.offer(0.0, r0, band)
|
|
152
|
+
note: str | None = None
|
|
153
|
+
band_met = False
|
|
154
|
+
final_dial = 0.0
|
|
155
|
+
final_result = r0
|
|
156
|
+
final_frame = frame0
|
|
157
|
+
|
|
158
|
+
if r0.metric <= b:
|
|
159
|
+
# Pristine data is already at or below the upper bound.
|
|
160
|
+
final_dial, final_result, final_frame = 0.0, r0, frame0
|
|
161
|
+
if r0.metric >= a:
|
|
162
|
+
band_met = True
|
|
163
|
+
else:
|
|
164
|
+
note = (
|
|
165
|
+
"clean data is already harder than the target band; v0.1 has no "
|
|
166
|
+
"easing knob, so the pristine dataset is shipped as-is"
|
|
167
|
+
)
|
|
168
|
+
else:
|
|
169
|
+
# Too easy → push the dial up. Probe the hard end to bracket the root.
|
|
170
|
+
rmax, framemax = mu(DIAL_MAX)
|
|
171
|
+
best.offer(DIAL_MAX, rmax, band)
|
|
172
|
+
if rmax.metric > b:
|
|
173
|
+
note = (
|
|
174
|
+
"maximum difficulty is still too easy for the target band; the "
|
|
175
|
+
"label is too separable for the active knobs to obscure"
|
|
176
|
+
)
|
|
177
|
+
final_dial, final_result, final_frame = DIAL_MAX, rmax, framemax
|
|
178
|
+
elif a <= rmax.metric <= b:
|
|
179
|
+
band_met = True
|
|
180
|
+
final_dial, final_result, final_frame = DIAL_MAX, rmax, framemax
|
|
181
|
+
else:
|
|
182
|
+
# Bracketed: μ(0) > b and μ(DIAL_MAX) < a. Bisect for the band.
|
|
183
|
+
lo, hi = 0.0, DIAL_MAX
|
|
184
|
+
for _ in range(cfg.max_iters):
|
|
185
|
+
mid = (lo + hi) / 2.0
|
|
186
|
+
rmid, framemid = mu(mid)
|
|
187
|
+
best.offer(mid, rmid, band)
|
|
188
|
+
if a <= rmid.metric <= b:
|
|
189
|
+
band_met = True
|
|
190
|
+
final_dial, final_result, final_frame = mid, rmid, framemid
|
|
191
|
+
break
|
|
192
|
+
if rmid.metric > b:
|
|
193
|
+
lo = mid # still too easy → harder
|
|
194
|
+
else:
|
|
195
|
+
hi = mid # too hard → easier
|
|
196
|
+
else:
|
|
197
|
+
# Exhausted iterations without landing in the band: ship closest.
|
|
198
|
+
final_dial = best.dial
|
|
199
|
+
final_result = best.result if best.result is not None else r0
|
|
200
|
+
final_frame, _ = dial_obj.realize(final_dial)
|
|
201
|
+
note = (
|
|
202
|
+
f"target band not reached within max_iters={cfg.max_iters}; "
|
|
203
|
+
"shipping the closest achieved difficulty"
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
state = dial_obj.realize(final_dial)[1]
|
|
207
|
+
reference = {
|
|
208
|
+
"linear_separability": final_result.linear_separability,
|
|
209
|
+
"class_balance": final_result.class_balance,
|
|
210
|
+
"noise_to_signal": dial_obj.noise_to_signal(state.feature_noise),
|
|
211
|
+
"probe_features": final_result.n_features,
|
|
212
|
+
"rows": dial_obj.n,
|
|
213
|
+
}
|
|
214
|
+
result = DifficultyResult(
|
|
215
|
+
target={
|
|
216
|
+
"tier": target.tier,
|
|
217
|
+
"task": target.task,
|
|
218
|
+
"metric": target.metric,
|
|
219
|
+
"band": [a, b],
|
|
220
|
+
},
|
|
221
|
+
achieved_metric=final_result.metric,
|
|
222
|
+
metric_name=final_result.metric_name,
|
|
223
|
+
probe=cfg.probe,
|
|
224
|
+
iterations=len(trace),
|
|
225
|
+
band_met=band_met,
|
|
226
|
+
dial=final_dial,
|
|
227
|
+
feature_noise=state.feature_noise,
|
|
228
|
+
label_flip=state.label_flip,
|
|
229
|
+
knobs_requested=requested,
|
|
230
|
+
knobs_active=active,
|
|
231
|
+
reference=reference,
|
|
232
|
+
trace=trace,
|
|
233
|
+
note=note,
|
|
234
|
+
)
|
|
235
|
+
return result, final_frame
|