axobench 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- axobench-0.1.0/PKG-INFO +34 -0
- axobench-0.1.0/README.md +19 -0
- axobench-0.1.0/pyproject.toml +37 -0
- axobench-0.1.0/setup.cfg +4 -0
- axobench-0.1.0/src/axobench/__init__.py +5 -0
- axobench-0.1.0/src/axobench/benchmark/__init__.py +103 -0
- axobench-0.1.0/src/axobench/benchmark/branch_adapter.py +104 -0
- axobench-0.1.0/src/axobench/benchmark/bundle.py +336 -0
- axobench-0.1.0/src/axobench/benchmark/dataset_schema.py +698 -0
- axobench-0.1.0/src/axobench/benchmark/diagnostic_audits.py +507 -0
- axobench-0.1.0/src/axobench/benchmark/diagnostic_rows.py +748 -0
- axobench-0.1.0/src/axobench/benchmark/mechanistic_response.py +360 -0
- axobench-0.1.0/src/axobench/benchmark/morphology_transfer.py +244 -0
- axobench-0.1.0/src/axobench/benchmark/perturbation_stability.py +117 -0
- axobench-0.1.0/src/axobench/benchmark/plots.py +1382 -0
- axobench-0.1.0/src/axobench/benchmark/profiles.py +381 -0
- axobench-0.1.0/src/axobench/benchmark/regime_stratified.py +417 -0
- axobench-0.1.0/src/axobench/benchmark/reports.py +273 -0
- axobench-0.1.0/src/axobench/benchmark/runner.py +288 -0
- axobench-0.1.0/src/axobench/benchmark/selectors.py +83 -0
- axobench-0.1.0/src/axobench/benchmark/state_metrics.py +321 -0
- axobench-0.1.0/src/axobench/benchmark/structured_intervention.py +423 -0
- axobench-0.1.0/src/axobench/benchmark/suite.py +725 -0
- axobench-0.1.0/src/axobench/benchmark/swc_utils.py +250 -0
- axobench-0.1.0/src/axobench/benchmark/trace_shape.py +404 -0
- axobench-0.1.0/src/axobench/cli.py +550 -0
- axobench-0.1.0/src/axobench/data.py +666 -0
- axobench-0.1.0/src/axobench/generation/__init__.py +15 -0
- axobench-0.1.0/src/axobench/generation/arbor_sim.py +1143 -0
- axobench-0.1.0/src/axobench/generation/assets/__init__.py +1 -0
- axobench-0.1.0/src/axobench/generation/assets/allen_l5_template.swc +4855 -0
- axobench-0.1.0/src/axobench/generation/assets/allen_l5_template_fit.json +297 -0
- axobench-0.1.0/src/axobench/generation/generate_coreneuron_hay_dataset.py +439 -0
- axobench-0.1.0/src/axobench/generation/generate_hay_neuron_dataset.py +289 -0
- axobench-0.1.0/src/axobench/generation/mechanisms/__init__.py +1 -0
- axobench-0.1.0/src/axobench/generation/mechanisms/nmda.mod +75 -0
- axobench-0.1.0/src/axobench/generation/prepare_hay_swc_template.py +93 -0
- axobench-0.1.0/src/axobench/generation/rank_allen_l5_m3_candidates.py +197 -0
- axobench-0.1.0/src/axobench/generation/run_coreneuron_event_dropout_pair.py +614 -0
- axobench-0.1.0/src/axobench/generation/run_coreneuron_hay_probe.py +248 -0
- axobench-0.1.0/src/axobench/generation/run_generation_throughput_gate.py +437 -0
- axobench-0.1.0/src/axobench/generation/run_hay_neuron_driven_probe.py +266 -0
- axobench-0.1.0/src/axobench/generation/run_v1_parallel_generation.py +605 -0
- axobench-0.1.0/src/axobench/metrics.py +468 -0
- axobench-0.1.0/src/axobench/neuronio_raw.py +240 -0
- axobench-0.1.0/src/axobench/setup_workflow.py +150 -0
- axobench-0.1.0/src/axobench.egg-info/PKG-INFO +34 -0
- axobench-0.1.0/src/axobench.egg-info/SOURCES.txt +66 -0
- axobench-0.1.0/src/axobench.egg-info/dependency_links.txt +1 -0
- axobench-0.1.0/src/axobench.egg-info/entry_points.txt +2 -0
- axobench-0.1.0/src/axobench.egg-info/requires.txt +11 -0
- axobench-0.1.0/src/axobench.egg-info/top_level.txt +1 -0
- axobench-0.1.0/tests/test_arbor_sim.py +111 -0
- axobench-0.1.0/tests/test_benchmark_bundle.py +162 -0
- axobench-0.1.0/tests/test_benchmark_profiles.py +56 -0
- axobench-0.1.0/tests/test_benchmark_reports.py +93 -0
- axobench-0.1.0/tests/test_benchmark_runner.py +153 -0
- axobench-0.1.0/tests/test_benchmark_selectors.py +44 -0
- axobench-0.1.0/tests/test_benchmark_suite.py +535 -0
- axobench-0.1.0/tests/test_cli.py +505 -0
- axobench-0.1.0/tests/test_data.py +195 -0
- axobench-0.1.0/tests/test_dataset_schema.py +85 -0
- axobench-0.1.0/tests/test_diagnostic_audits.py +153 -0
- axobench-0.1.0/tests/test_diagnostic_rows.py +249 -0
- axobench-0.1.0/tests/test_kaggle_download.py +127 -0
- axobench-0.1.0/tests/test_metrics.py +103 -0
- axobench-0.1.0/tests/test_neuronio_convert.py +117 -0
- axobench-0.1.0/tests/test_setup_workflow.py +50 -0
axobench-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: axobench
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Single-neuron surrogate benchmark and dataset generation workbench.
|
|
5
|
+
Requires-Python: >=3.10
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: matplotlib>=3.8
|
|
8
|
+
Requires-Dist: numpy>=1.24
|
|
9
|
+
Requires-Dist: requests>=2.32
|
|
10
|
+
Provides-Extra: dev
|
|
11
|
+
Requires-Dist: pytest>=8; extra == "dev"
|
|
12
|
+
Provides-Extra: sim
|
|
13
|
+
Requires-Dist: arbor==0.11.0; extra == "sim"
|
|
14
|
+
Provides-Extra: plots
|
|
15
|
+
|
|
16
|
+
# AxoBench
|
|
17
|
+
|
|
18
|
+
AxoBench is the extracted benchmark and dataset-generation workbench for
|
|
19
|
+
single-neuron surrogate evaluation. It owns dataset generation, data/schema
|
|
20
|
+
utilities, diagnostic suites, and benchmark reporting.
|
|
21
|
+
|
|
22
|
+
The package keeps simulator-backed generation code under
|
|
23
|
+
`src/axobench/generation/`: Arbor, NEURON/Hay, CoreNEURON/Hay, SWC-template
|
|
24
|
+
preparation, morphology-candidate ranking, throughput probes, and the v1
|
|
25
|
+
parallel dataset generator. `scripts/` is reserved for thin operational entry
|
|
26
|
+
points such as launch scripts and data-download helpers.
|
|
27
|
+
|
|
28
|
+
Model implementations, training loops, checkpoint formats, and model-specific
|
|
29
|
+
experiments stay in `bnn_sim`. AxoBench evaluates caller-supplied prediction
|
|
30
|
+
functions or stored prediction/diagnostic artifacts; it does not ship BranchELM,
|
|
31
|
+
Mamba, RNN, or training code.
|
|
32
|
+
|
|
33
|
+
The active exploration arc lives in `docs/benchmark/`, which is ignored by Git
|
|
34
|
+
while the benchmark audit is still exploratory.
|
axobench-0.1.0/README.md
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# AxoBench
|
|
2
|
+
|
|
3
|
+
AxoBench is the extracted benchmark and dataset-generation workbench for
|
|
4
|
+
single-neuron surrogate evaluation. It owns dataset generation, data/schema
|
|
5
|
+
utilities, diagnostic suites, and benchmark reporting.
|
|
6
|
+
|
|
7
|
+
The package keeps simulator-backed generation code under
|
|
8
|
+
`src/axobench/generation/`: Arbor, NEURON/Hay, CoreNEURON/Hay, SWC-template
|
|
9
|
+
preparation, morphology-candidate ranking, throughput probes, and the v1
|
|
10
|
+
parallel dataset generator. `scripts/` is reserved for thin operational entry
|
|
11
|
+
points such as launch scripts and data-download helpers.
|
|
12
|
+
|
|
13
|
+
Model implementations, training loops, checkpoint formats, and model-specific
|
|
14
|
+
experiments stay in `bnn_sim`. AxoBench evaluates caller-supplied prediction
|
|
15
|
+
functions or stored prediction/diagnostic artifacts; it does not ship BranchELM,
|
|
16
|
+
Mamba, RNN, or training code.
|
|
17
|
+
|
|
18
|
+
The active exploration arc lives in `docs/benchmark/`, which is ignored by Git
|
|
19
|
+
while the benchmark audit is still exploratory.
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=68"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "axobench"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Single-neuron surrogate benchmark and dataset generation workbench."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.10"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"matplotlib>=3.8",
|
|
13
|
+
"numpy>=1.24",
|
|
14
|
+
"requests>=2.32",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
[project.optional-dependencies]
|
|
18
|
+
dev = [
|
|
19
|
+
"pytest>=8",
|
|
20
|
+
]
|
|
21
|
+
sim = [
|
|
22
|
+
"arbor==0.11.0",
|
|
23
|
+
]
|
|
24
|
+
plots = []
|
|
25
|
+
|
|
26
|
+
[project.scripts]
|
|
27
|
+
axobench = "axobench.cli:main"
|
|
28
|
+
|
|
29
|
+
[tool.setuptools.packages.find]
|
|
30
|
+
where = ["src"]
|
|
31
|
+
|
|
32
|
+
[tool.setuptools.package-data]
|
|
33
|
+
"axobench.generation" = ["assets/*.swc", "assets/*.json", "mechanisms/*.mod"]
|
|
34
|
+
|
|
35
|
+
[tool.pytest.ini_options]
|
|
36
|
+
testpaths = ["tests"]
|
|
37
|
+
pythonpath = ["src"]
|
axobench-0.1.0/setup.cfg
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
"""AxoBench neuron surrogate benchmark utilities.
|
|
2
|
+
|
|
3
|
+
The package exports the public benchmark helpers lazily so lightweight
|
|
4
|
+
dataframe/TCN evaluation can run in environments that intentionally omit
|
|
5
|
+
PyTorch.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
_EXPORTS: dict[str, tuple[str, str]] = {
|
|
14
|
+
"compute_regime_stratified_metrics": ("axobench.benchmark.regime_stratified", "compute_regime_stratified_metrics"),
|
|
15
|
+
"compute_state_conditioned_metrics": ("axobench.benchmark.state_metrics", "compute_state_conditioned_metrics"),
|
|
16
|
+
"build_neuron_state_masks": ("axobench.benchmark.state_metrics", "build_neuron_state_masks"),
|
|
17
|
+
"compute_trace_shape_metrics": ("axobench.benchmark.trace_shape", "compute_trace_shape_metrics"),
|
|
18
|
+
"compute_perturbation_stability": ("axobench.benchmark.perturbation_stability", "compute_perturbation_stability"),
|
|
19
|
+
"compute_paired_event_intervention_metrics": (
|
|
20
|
+
"axobench.benchmark.structured_intervention",
|
|
21
|
+
"compute_paired_event_intervention_metrics",
|
|
22
|
+
),
|
|
23
|
+
"make_event_dropout_inputs": ("axobench.benchmark.structured_intervention", "make_event_dropout_inputs"),
|
|
24
|
+
"make_selective_event_dropout_inputs": (
|
|
25
|
+
"axobench.benchmark.structured_intervention",
|
|
26
|
+
"make_selective_event_dropout_inputs",
|
|
27
|
+
),
|
|
28
|
+
"make_site_silence_inputs": ("axobench.benchmark.structured_intervention", "make_site_silence_inputs"),
|
|
29
|
+
"make_structured_intervention_inputs": (
|
|
30
|
+
"axobench.benchmark.structured_intervention",
|
|
31
|
+
"make_structured_intervention_inputs",
|
|
32
|
+
),
|
|
33
|
+
"make_temporal_jitter_inputs": ("axobench.benchmark.structured_intervention", "make_temporal_jitter_inputs"),
|
|
34
|
+
"MorphologyTransferEvaluator": ("axobench.benchmark.morphology_transfer", "MorphologyTransferEvaluator"),
|
|
35
|
+
"load_available_morphologies": ("axobench.benchmark.morphology_transfer", "load_available_morphologies"),
|
|
36
|
+
"adapt_branch_count": ("axobench.benchmark.branch_adapter", "adapt_branch_count"),
|
|
37
|
+
"evaluate_with_branch_adaptation": ("axobench.benchmark.branch_adapter", "evaluate_with_branch_adaptation"),
|
|
38
|
+
"BenchmarkProfile": ("axobench.benchmark.profiles", "BenchmarkProfile"),
|
|
39
|
+
"BenchmarkCostModel": ("axobench.benchmark.profiles", "BenchmarkCostModel"),
|
|
40
|
+
"LOCAL_CORENEURON_HAY_COST_MODEL": ("axobench.benchmark.profiles", "LOCAL_CORENEURON_HAY_COST_MODEL"),
|
|
41
|
+
"build_axis_plan": ("axobench.benchmark.profiles", "build_axis_plan"),
|
|
42
|
+
"estimate_profile_cost": ("axobench.benchmark.profiles", "estimate_profile_cost"),
|
|
43
|
+
"get_benchmark_profile": ("axobench.benchmark.profiles", "get_benchmark_profile"),
|
|
44
|
+
"list_benchmark_profiles": ("axobench.benchmark.profiles", "list_benchmark_profiles"),
|
|
45
|
+
"DIAGNOSTIC_SUITES": ("axobench.benchmark.dataset_schema", "DIAGNOSTIC_SUITES"),
|
|
46
|
+
"INTERVENTION_CONDITIONS": ("axobench.benchmark.dataset_schema", "INTERVENTION_CONDITIONS"),
|
|
47
|
+
"ORDINARY_SPLITS": ("axobench.benchmark.dataset_schema", "ORDINARY_SPLITS"),
|
|
48
|
+
"V1_MORPHOLOGY_COUNT": ("axobench.benchmark.dataset_schema", "V1_MORPHOLOGY_COUNT"),
|
|
49
|
+
"build_diagnostic_suite_catalog": ("axobench.benchmark.dataset_schema", "build_diagnostic_suite_catalog"),
|
|
50
|
+
"build_v1_dataset_manifest_template": ("axobench.benchmark.dataset_schema", "build_v1_dataset_manifest_template"),
|
|
51
|
+
"default_diagnostic_specs": ("axobench.benchmark.dataset_schema", "default_diagnostic_specs"),
|
|
52
|
+
"default_target_views": ("axobench.benchmark.dataset_schema", "default_target_views"),
|
|
53
|
+
"provisional_v1_morphologies": ("axobench.benchmark.dataset_schema", "provisional_v1_morphologies"),
|
|
54
|
+
"validate_v1_manifest_template": ("axobench.benchmark.dataset_schema", "validate_v1_manifest_template"),
|
|
55
|
+
"DiagnosticContext": ("axobench.benchmark.diagnostic_rows", "DiagnosticContext"),
|
|
56
|
+
"evaluate_model_diagnostic_rows": ("axobench.benchmark.diagnostic_rows", "evaluate_model_diagnostic_rows"),
|
|
57
|
+
"morphology_contrast_diagnostic_rows": ("axobench.benchmark.diagnostic_rows", "morphology_contrast_diagnostic_rows"),
|
|
58
|
+
"morphology_routing_qc_rows": ("axobench.benchmark.diagnostic_rows", "morphology_routing_qc_rows"),
|
|
59
|
+
"paired_intervention_diagnostic_rows": ("axobench.benchmark.diagnostic_rows", "paired_intervention_diagnostic_rows"),
|
|
60
|
+
"teacher_qc_diagnostic_rows": ("axobench.benchmark.diagnostic_rows", "teacher_qc_diagnostic_rows"),
|
|
61
|
+
"validate_diagnostic_rows": ("axobench.benchmark.diagnostic_rows", "validate_diagnostic_rows"),
|
|
62
|
+
"write_diagnostic_rows": ("axobench.benchmark.diagnostic_rows", "write_diagnostic_rows"),
|
|
63
|
+
"DiagnosticAuditSources": ("axobench.benchmark.diagnostic_audits", "DiagnosticAuditSources"),
|
|
64
|
+
"adapter_fairness_diagnostic_rows": ("axobench.benchmark.diagnostic_audits", "adapter_fairness_diagnostic_rows"),
|
|
65
|
+
"diagnostic_audit_rows": ("axobench.benchmark.diagnostic_audits", "diagnostic_audit_rows"),
|
|
66
|
+
"diagnostic_audit_rows_from_sources": ("axobench.benchmark.diagnostic_audits", "diagnostic_audit_rows_from_sources"),
|
|
67
|
+
"load_artifact_rows": ("axobench.benchmark.diagnostic_audits", "load_artifact_rows"),
|
|
68
|
+
"metric_decoupling_diagnostic_rows": ("axobench.benchmark.diagnostic_audits", "metric_decoupling_diagnostic_rows"),
|
|
69
|
+
"protocol_coverage_diagnostic_rows": ("axobench.benchmark.diagnostic_audits", "protocol_coverage_diagnostic_rows"),
|
|
70
|
+
"robustness_directionality_diagnostic_rows": (
|
|
71
|
+
"axobench.benchmark.diagnostic_audits",
|
|
72
|
+
"robustness_directionality_diagnostic_rows",
|
|
73
|
+
),
|
|
74
|
+
"target_view_audit_diagnostic_rows": ("axobench.benchmark.diagnostic_audits", "target_view_audit_diagnostic_rows"),
|
|
75
|
+
"DatasetSelector": ("axobench.benchmark.selectors", "DatasetSelector"),
|
|
76
|
+
"V1DatasetLayout": ("axobench.benchmark.selectors", "V1DatasetLayout"),
|
|
77
|
+
"load_intervention_pair_npz": ("axobench.benchmark.runner", "load_intervention_pair_npz"),
|
|
78
|
+
"load_selector_batch": ("axobench.benchmark.runner", "load_selector_batch"),
|
|
79
|
+
"run_model_diagnostics_on_selector": ("axobench.benchmark.runner", "run_model_diagnostics_on_selector"),
|
|
80
|
+
"run_paired_intervention_on_npz": ("axobench.benchmark.runner", "run_paired_intervention_on_npz"),
|
|
81
|
+
"run_paired_intervention_on_selector": ("axobench.benchmark.runner", "run_paired_intervention_on_selector"),
|
|
82
|
+
"run_teacher_qc_on_selector": ("axobench.benchmark.runner", "run_teacher_qc_on_selector"),
|
|
83
|
+
"summarize_diagnostic_rows": ("axobench.benchmark.reports", "summarize_diagnostic_rows"),
|
|
84
|
+
"core_metric_comparison_rows": ("axobench.benchmark.reports", "core_metric_comparison_rows"),
|
|
85
|
+
"write_diagnostic_summary": ("axobench.benchmark.reports", "write_diagnostic_summary"),
|
|
86
|
+
"write_core_metric_comparison_table": ("axobench.benchmark.reports", "write_core_metric_comparison_table"),
|
|
87
|
+
"run_v1_benchmark_bundle": ("axobench.benchmark.bundle", "run_v1_benchmark_bundle"),
|
|
88
|
+
"infer_pair_metadata": ("axobench.benchmark.bundle", "infer_pair_metadata"),
|
|
89
|
+
"BenchmarkSuite": ("axobench.benchmark.suite", "BenchmarkSuite"),
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
__all__ = sorted(_EXPORTS)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def __getattr__(name: str) -> Any:
|
|
96
|
+
if name not in _EXPORTS:
|
|
97
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
98
|
+
module_name, attribute = _EXPORTS[name]
|
|
99
|
+
from importlib import import_module
|
|
100
|
+
|
|
101
|
+
value = getattr(import_module(module_name), attribute)
|
|
102
|
+
globals()[name] = value
|
|
103
|
+
return value
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""Branch count adaptation for morphology transfer (M3).
|
|
2
|
+
|
|
3
|
+
Handles mismatch between training morphology (e.g., 45 branches)
|
|
4
|
+
and test morphologies (e.g., 24-40 branches).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from axobench.metrics import binary_auc
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def adapt_branch_count(
|
|
15
|
+
routing_matrix: np.ndarray,
|
|
16
|
+
target_n_branches: int,
|
|
17
|
+
strategy: str = "pad",
|
|
18
|
+
) -> np.ndarray:
|
|
19
|
+
"""Adapt a routing matrix to a target branch count.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
routing_matrix: Source routing matrix of shape (n_branches, input_dim).
|
|
23
|
+
target_n_branches: Desired number of branches.
|
|
24
|
+
strategy: Adaptation strategy:
|
|
25
|
+
- "pad": Pad with zero-weight branches (if source < target)
|
|
26
|
+
- "truncate": Truncate excess branches (if source > target)
|
|
27
|
+
- "interpolate": Distribute inputs proportionally (always works)
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Adapted routing matrix of shape (target_n_branches, input_dim).
|
|
31
|
+
"""
|
|
32
|
+
n_source, input_dim = routing_matrix.shape
|
|
33
|
+
|
|
34
|
+
if n_source == target_n_branches:
|
|
35
|
+
return routing_matrix
|
|
36
|
+
|
|
37
|
+
if strategy == "pad":
|
|
38
|
+
if n_source < target_n_branches:
|
|
39
|
+
# Pad with zero-weight branches
|
|
40
|
+
padding = np.zeros((target_n_branches - n_source, input_dim))
|
|
41
|
+
return np.vstack([routing_matrix, padding])
|
|
42
|
+
else:
|
|
43
|
+
# Truncate excess branches
|
|
44
|
+
return routing_matrix[:target_n_branches]
|
|
45
|
+
|
|
46
|
+
elif strategy == "interpolate":
|
|
47
|
+
# Distribute source branches across target branches proportionally
|
|
48
|
+
result = np.zeros((target_n_branches, input_dim))
|
|
49
|
+
for target_b in range(target_n_branches):
|
|
50
|
+
# Map target branch to source branch space
|
|
51
|
+
source_idx = int(target_b * n_source / target_n_branches)
|
|
52
|
+
# Weight by fractional overlap
|
|
53
|
+
weight = min(1.0, n_source / target_n_branches)
|
|
54
|
+
result[target_b] = routing_matrix[source_idx] * weight
|
|
55
|
+
return result
|
|
56
|
+
|
|
57
|
+
else:
|
|
58
|
+
raise ValueError(f"Unknown strategy: {strategy}")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def evaluate_with_branch_adaptation(
|
|
62
|
+
model_fn: callable,
|
|
63
|
+
inputs: np.ndarray,
|
|
64
|
+
targets: np.ndarray,
|
|
65
|
+
source_routing: np.ndarray,
|
|
66
|
+
target_routing: np.ndarray,
|
|
67
|
+
strategy: str = "pad",
|
|
68
|
+
) -> dict:
|
|
69
|
+
"""Evaluate model with branch count adaptation.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
model_fn: Model prediction function.
|
|
73
|
+
inputs: Input array (batch, time, input_dim).
|
|
74
|
+
targets: Target array (batch, time, 2).
|
|
75
|
+
source_routing: Routing matrix used during training.
|
|
76
|
+
target_routing: Routing matrix for test morphology.
|
|
77
|
+
strategy: Adaptation strategy.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
Dictionary with evaluation metrics.
|
|
81
|
+
"""
|
|
82
|
+
# Adapt routing if needed
|
|
83
|
+
n_source = source_routing.shape[0]
|
|
84
|
+
n_target = target_routing.shape[0]
|
|
85
|
+
|
|
86
|
+
adapted_routing = adapt_branch_count(target_routing, n_source, strategy=strategy)
|
|
87
|
+
|
|
88
|
+
# For now, just evaluate directly (assuming model handles routing internally)
|
|
89
|
+
# In practice, would need to rewire model's routing matrix
|
|
90
|
+
preds = model_fn(inputs)
|
|
91
|
+
|
|
92
|
+
rmse_mv = np.sqrt(np.mean((preds[:, :, 1] - targets[:, :, 1]) ** 2)) / 0.1
|
|
93
|
+
|
|
94
|
+
auc = binary_auc(preds[:, :, 0], targets[:, :, 0])
|
|
95
|
+
if np.isnan(auc):
|
|
96
|
+
auc = 0.5
|
|
97
|
+
|
|
98
|
+
return {
|
|
99
|
+
"n_source_branches": n_source,
|
|
100
|
+
"n_target_branches": n_target,
|
|
101
|
+
"adaptation_strategy": strategy,
|
|
102
|
+
"rmse_mv": float(rmse_mv),
|
|
103
|
+
"auc": float(auc),
|
|
104
|
+
}
|
|
@@ -0,0 +1,336 @@
|
|
|
1
|
+
"""Promoted v1 benchmark bundle runner."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable, Iterable
|
|
6
|
+
import json
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
import re
|
|
9
|
+
from typing import Any, Literal
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from axobench.benchmark.dataset_schema import INTERVENTION_CONDITIONS
|
|
14
|
+
from axobench.benchmark.diagnostic_rows import (
|
|
15
|
+
DiagnosticContext,
|
|
16
|
+
morphology_contrast_diagnostic_rows,
|
|
17
|
+
morphology_routing_qc_rows,
|
|
18
|
+
write_diagnostic_rows,
|
|
19
|
+
)
|
|
20
|
+
from axobench.benchmark.reports import write_diagnostic_summary
|
|
21
|
+
from axobench.benchmark.runner import (
|
|
22
|
+
resolve_intervention_pair_paths,
|
|
23
|
+
run_model_diagnostics_on_selector,
|
|
24
|
+
run_paired_intervention_on_npz,
|
|
25
|
+
run_paired_intervention_on_npzs,
|
|
26
|
+
run_paired_intervention_on_selector,
|
|
27
|
+
run_teacher_qc_on_selector,
|
|
28
|
+
)
|
|
29
|
+
from axobench.benchmark.selectors import V1DatasetLayout
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
DEFAULT_MODEL_SUITES = (
|
|
33
|
+
"trace-core",
|
|
34
|
+
"state-dynamics",
|
|
35
|
+
"spike-behavior",
|
|
36
|
+
"feature-fidelity",
|
|
37
|
+
"paper-summary",
|
|
38
|
+
)
|
|
39
|
+
DEFAULT_INTERVENTION_SUITES = ("teacher-qc", "intervention-response", "paper-summary")
|
|
40
|
+
SuiteProfile = Literal["teacher-qc", "paper-summary", "full"]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def run_v1_benchmark_bundle(
|
|
44
|
+
*,
|
|
45
|
+
dataset_root: str | Path,
|
|
46
|
+
output_dir: str | Path,
|
|
47
|
+
base_context: DiagnosticContext | None = None,
|
|
48
|
+
model_fn: Callable[[np.ndarray], np.ndarray] | None = None,
|
|
49
|
+
model_id: str | None = None,
|
|
50
|
+
intervention_pair_paths: Iterable[str | Path] = (),
|
|
51
|
+
suite: SuiteProfile = "paper-summary",
|
|
52
|
+
max_val_samples: int | None = None,
|
|
53
|
+
max_intervention_samples: int | None = None,
|
|
54
|
+
cache_shards: int = 1,
|
|
55
|
+
write_summaries: bool = True,
|
|
56
|
+
) -> dict[str, Any]:
|
|
57
|
+
"""Run the promoted v1 validation plus intervention diagnostic bundle."""
|
|
58
|
+
layout = V1DatasetLayout(dataset_root)
|
|
59
|
+
output = Path(output_dir)
|
|
60
|
+
output.mkdir(parents=True, exist_ok=True)
|
|
61
|
+
base = base_context or DiagnosticContext()
|
|
62
|
+
suite_config = _suite_config(suite)
|
|
63
|
+
artifacts: list[dict[str, Any]] = []
|
|
64
|
+
intervention_row_sets: list[dict[str, Any]] = []
|
|
65
|
+
|
|
66
|
+
teacher_context = _replace_context(
|
|
67
|
+
base,
|
|
68
|
+
suites=suite_config["teacher"],
|
|
69
|
+
model_id="teacher",
|
|
70
|
+
split="val",
|
|
71
|
+
condition="val",
|
|
72
|
+
intervention=None,
|
|
73
|
+
)
|
|
74
|
+
teacher_rows = run_teacher_qc_on_selector(
|
|
75
|
+
layout.default_validation(),
|
|
76
|
+
context=teacher_context,
|
|
77
|
+
max_samples=max_val_samples,
|
|
78
|
+
output=output / "teacher_qc_val.jsonl",
|
|
79
|
+
cache_shards=cache_shards,
|
|
80
|
+
)
|
|
81
|
+
_record_artifact(teacher_rows, output / "teacher_qc_val.jsonl", artifacts, write_summaries=write_summaries)
|
|
82
|
+
|
|
83
|
+
routing_rows = morphology_routing_qc_rows(
|
|
84
|
+
context=_replace_context(base, suites=suite_config["routing"], model_id="teacher")
|
|
85
|
+
)
|
|
86
|
+
write_diagnostic_rows(routing_rows, output / "morphology_routing_qc.jsonl")
|
|
87
|
+
_record_artifact(
|
|
88
|
+
routing_rows,
|
|
89
|
+
output / "morphology_routing_qc.jsonl",
|
|
90
|
+
artifacts,
|
|
91
|
+
write_summaries=write_summaries,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
if model_fn is not None and suite_config["model"] is not None:
|
|
95
|
+
resolved_model_id = model_id or base.model_id
|
|
96
|
+
model_context = _replace_context(
|
|
97
|
+
base,
|
|
98
|
+
suites=suite_config["model"],
|
|
99
|
+
model_id=resolved_model_id,
|
|
100
|
+
split="val",
|
|
101
|
+
condition="val",
|
|
102
|
+
intervention=None,
|
|
103
|
+
)
|
|
104
|
+
model_rows = run_model_diagnostics_on_selector(
|
|
105
|
+
model_fn,
|
|
106
|
+
layout.default_validation(),
|
|
107
|
+
context=model_context,
|
|
108
|
+
max_samples=max_val_samples,
|
|
109
|
+
output=output / f"model_val_{_slug(resolved_model_id)}.jsonl",
|
|
110
|
+
cache_shards=cache_shards,
|
|
111
|
+
)
|
|
112
|
+
_record_artifact(
|
|
113
|
+
model_rows,
|
|
114
|
+
output / f"model_val_{_slug(resolved_model_id)}.jsonl",
|
|
115
|
+
artifacts,
|
|
116
|
+
write_summaries=write_summaries,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
pair_paths = [Path(path) for path in intervention_pair_paths]
|
|
120
|
+
if pair_paths:
|
|
121
|
+
for intervention, morphology_id, grouped_paths in _group_pair_paths_by_morphology(
|
|
122
|
+
pair_paths,
|
|
123
|
+
default_morphology=base.morphology_id,
|
|
124
|
+
):
|
|
125
|
+
rows_path, rows = _run_grouped_intervention(
|
|
126
|
+
grouped_paths,
|
|
127
|
+
output=output,
|
|
128
|
+
base=base,
|
|
129
|
+
suites=suite_config["intervention"],
|
|
130
|
+
model_fn=model_fn,
|
|
131
|
+
model_id=model_id,
|
|
132
|
+
intervention=intervention,
|
|
133
|
+
morphology_id=morphology_id,
|
|
134
|
+
max_samples=max_intervention_samples,
|
|
135
|
+
)
|
|
136
|
+
intervention_row_sets.extend(rows)
|
|
137
|
+
_record_artifact(rows, rows_path, artifacts, write_summaries=write_summaries)
|
|
138
|
+
else:
|
|
139
|
+
for intervention in layout.available_interventions():
|
|
140
|
+
selector = layout.intervention(intervention)
|
|
141
|
+
selector_pair_paths = resolve_intervention_pair_paths(selector)
|
|
142
|
+
if len(selector_pair_paths) == 1:
|
|
143
|
+
intervention_model_id = model_id if model_fn is not None else "teacher"
|
|
144
|
+
context = _replace_context(
|
|
145
|
+
base,
|
|
146
|
+
suites=suite_config["intervention"],
|
|
147
|
+
model_id=intervention_model_id,
|
|
148
|
+
split="",
|
|
149
|
+
condition=f"interventions/{intervention}",
|
|
150
|
+
intervention=intervention,
|
|
151
|
+
)
|
|
152
|
+
filename = f"intervention_{_slug(base.morphology_id)}_{_slug(intervention)}"
|
|
153
|
+
rows = run_paired_intervention_on_selector(
|
|
154
|
+
selector,
|
|
155
|
+
model_fn=model_fn,
|
|
156
|
+
context=context,
|
|
157
|
+
max_samples=max_intervention_samples,
|
|
158
|
+
output=output / f"{filename}.jsonl",
|
|
159
|
+
)
|
|
160
|
+
intervention_row_sets.extend(rows)
|
|
161
|
+
_record_artifact(rows, output / f"{filename}.jsonl", artifacts, write_summaries=write_summaries)
|
|
162
|
+
continue
|
|
163
|
+
for inferred_intervention, morphology_id, grouped_paths in _group_pair_paths_by_morphology(
|
|
164
|
+
selector_pair_paths,
|
|
165
|
+
default_morphology=base.morphology_id,
|
|
166
|
+
):
|
|
167
|
+
rows_path, rows = _run_grouped_intervention(
|
|
168
|
+
grouped_paths,
|
|
169
|
+
output=output,
|
|
170
|
+
base=base,
|
|
171
|
+
suites=suite_config["intervention"],
|
|
172
|
+
model_fn=model_fn,
|
|
173
|
+
model_id=model_id,
|
|
174
|
+
intervention=inferred_intervention,
|
|
175
|
+
morphology_id=morphology_id,
|
|
176
|
+
max_samples=max_intervention_samples,
|
|
177
|
+
)
|
|
178
|
+
intervention_row_sets.extend(rows)
|
|
179
|
+
_record_artifact(rows, rows_path, artifacts, write_summaries=write_summaries)
|
|
180
|
+
|
|
181
|
+
contrast_rows = morphology_contrast_diagnostic_rows(
|
|
182
|
+
intervention_row_sets,
|
|
183
|
+
context=_replace_context(base, suites=suite_config["morphology_contrast"], morphology_id="all_v1"),
|
|
184
|
+
)
|
|
185
|
+
if contrast_rows:
|
|
186
|
+
contrast_path = output / "morphology_contrast_interventions.jsonl"
|
|
187
|
+
write_diagnostic_rows(contrast_rows, contrast_path)
|
|
188
|
+
_record_artifact(
|
|
189
|
+
contrast_rows,
|
|
190
|
+
contrast_path,
|
|
191
|
+
artifacts,
|
|
192
|
+
write_summaries=write_summaries,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
manifest = {
|
|
196
|
+
"dataset_root": str(dataset_root),
|
|
197
|
+
"output_dir": str(output),
|
|
198
|
+
"model_id": model_id,
|
|
199
|
+
"suite": suite,
|
|
200
|
+
"suite_config": {key: list(value) if value is not None else None for key, value in suite_config.items()},
|
|
201
|
+
"max_val_samples": max_val_samples,
|
|
202
|
+
"max_intervention_samples": max_intervention_samples,
|
|
203
|
+
"write_summaries": write_summaries,
|
|
204
|
+
"artifacts": artifacts,
|
|
205
|
+
}
|
|
206
|
+
(output / "manifest.json").write_text(json.dumps(manifest, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
|
207
|
+
return manifest
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _suite_config(suite: SuiteProfile) -> dict[str, tuple[str, ...] | None]:
|
|
211
|
+
if suite == "teacher-qc":
|
|
212
|
+
return {
|
|
213
|
+
"teacher": ("teacher-qc",),
|
|
214
|
+
"routing": ("teacher-qc",),
|
|
215
|
+
"model": None,
|
|
216
|
+
"intervention": ("teacher-qc",),
|
|
217
|
+
"morphology_contrast": ("morphology-contrast",),
|
|
218
|
+
}
|
|
219
|
+
if suite == "paper-summary":
|
|
220
|
+
return {
|
|
221
|
+
"teacher": ("teacher-qc", "paper-summary"),
|
|
222
|
+
"routing": ("teacher-qc", "paper-summary"),
|
|
223
|
+
"model": DEFAULT_MODEL_SUITES,
|
|
224
|
+
"intervention": DEFAULT_INTERVENTION_SUITES,
|
|
225
|
+
"morphology_contrast": ("morphology-contrast", "paper-summary"),
|
|
226
|
+
}
|
|
227
|
+
if suite == "full":
|
|
228
|
+
return {
|
|
229
|
+
"teacher": ("teacher-qc", "paper-summary", "full"),
|
|
230
|
+
"routing": ("teacher-qc", "paper-summary", "full"),
|
|
231
|
+
"model": (*DEFAULT_MODEL_SUITES, "full"),
|
|
232
|
+
"intervention": ("teacher-qc", "intervention-response", "paper-summary", "full"),
|
|
233
|
+
"morphology_contrast": ("morphology-contrast", "paper-summary", "full"),
|
|
234
|
+
}
|
|
235
|
+
raise ValueError(f"unknown v1 benchmark suite profile: {suite}")
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def infer_pair_metadata(path: str | Path, *, default_morphology: str = "unknown") -> tuple[str, str]:
|
|
239
|
+
"""Infer intervention and morphology labels from a pilot pair-file path."""
|
|
240
|
+
path = Path(path)
|
|
241
|
+
parts = path.parts
|
|
242
|
+
intervention = None
|
|
243
|
+
if "interventions" in parts:
|
|
244
|
+
index = parts.index("interventions")
|
|
245
|
+
if index + 1 < len(parts):
|
|
246
|
+
intervention = parts[index + 1]
|
|
247
|
+
text = str(path)
|
|
248
|
+
if intervention is None:
|
|
249
|
+
intervention = next((name for name in INTERVENTION_CONDITIONS if name in text), "event_dropout")
|
|
250
|
+
match = re.search(r"specimen_(\d+)", text)
|
|
251
|
+
morphology_id = f"specimen_{match.group(1)}" if match else default_morphology
|
|
252
|
+
return intervention, morphology_id
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _group_pair_paths_by_morphology(
|
|
256
|
+
pair_paths: Iterable[Path],
|
|
257
|
+
*,
|
|
258
|
+
default_morphology: str,
|
|
259
|
+
) -> list[tuple[str, str, list[Path]]]:
|
|
260
|
+
grouped: dict[tuple[str, str], list[Path]] = {}
|
|
261
|
+
for pair_path in pair_paths:
|
|
262
|
+
intervention, morphology_id = infer_pair_metadata(pair_path, default_morphology=default_morphology)
|
|
263
|
+
grouped.setdefault((intervention, morphology_id), []).append(pair_path)
|
|
264
|
+
return [
|
|
265
|
+
(intervention, morphology_id, sorted(paths))
|
|
266
|
+
for (intervention, morphology_id), paths in sorted(grouped.items())
|
|
267
|
+
]
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def _run_grouped_intervention(
|
|
271
|
+
pair_paths: list[Path],
|
|
272
|
+
*,
|
|
273
|
+
output: Path,
|
|
274
|
+
base: DiagnosticContext,
|
|
275
|
+
suites: tuple[str, ...],
|
|
276
|
+
model_fn: Callable[[np.ndarray], np.ndarray] | None,
|
|
277
|
+
model_id: str | None,
|
|
278
|
+
intervention: str,
|
|
279
|
+
morphology_id: str,
|
|
280
|
+
max_samples: int | None,
|
|
281
|
+
) -> tuple[Path, list[dict[str, Any]]]:
|
|
282
|
+
intervention_model_id = model_id if model_fn is not None else "teacher"
|
|
283
|
+
context = _replace_context(
|
|
284
|
+
base,
|
|
285
|
+
suites=suites,
|
|
286
|
+
model_id=intervention_model_id,
|
|
287
|
+
morphology_id=morphology_id,
|
|
288
|
+
split="",
|
|
289
|
+
condition=f"interventions/{intervention}",
|
|
290
|
+
intervention=intervention,
|
|
291
|
+
)
|
|
292
|
+
rows_path = output / f"intervention_{_slug(morphology_id)}_{_slug(intervention)}.jsonl"
|
|
293
|
+
if len(pair_paths) == 1:
|
|
294
|
+
rows = run_paired_intervention_on_npz(
|
|
295
|
+
pair_paths[0],
|
|
296
|
+
model_fn=model_fn,
|
|
297
|
+
context=context,
|
|
298
|
+
max_samples=max_samples,
|
|
299
|
+
output=rows_path,
|
|
300
|
+
)
|
|
301
|
+
else:
|
|
302
|
+
rows = run_paired_intervention_on_npzs(
|
|
303
|
+
pair_paths,
|
|
304
|
+
model_fn=model_fn,
|
|
305
|
+
context=context,
|
|
306
|
+
max_samples=max_samples,
|
|
307
|
+
output=rows_path,
|
|
308
|
+
)
|
|
309
|
+
return rows_path, rows
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def _record_artifact(
|
|
313
|
+
rows: list[dict[str, Any]],
|
|
314
|
+
rows_path: Path,
|
|
315
|
+
artifacts: list[dict[str, Any]],
|
|
316
|
+
*,
|
|
317
|
+
write_summaries: bool,
|
|
318
|
+
) -> None:
|
|
319
|
+
artifact = {
|
|
320
|
+
"rows": len(rows),
|
|
321
|
+
"rows_path": str(rows_path),
|
|
322
|
+
}
|
|
323
|
+
if write_summaries:
|
|
324
|
+
summary_path = rows_path.with_suffix(".md")
|
|
325
|
+
write_diagnostic_summary(rows, summary_path)
|
|
326
|
+
artifact["summary_path"] = str(summary_path)
|
|
327
|
+
artifacts.append(artifact)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def _replace_context(context: DiagnosticContext, **updates: Any) -> DiagnosticContext:
|
|
331
|
+
data = {**context.__dict__, **updates}
|
|
332
|
+
return DiagnosticContext(**data)
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def _slug(value: str) -> str:
|
|
336
|
+
return re.sub(r"[^A-Za-z0-9_.-]+", "_", value).strip("_") or "unknown"
|