axobench 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. axobench-0.1.0/PKG-INFO +34 -0
  2. axobench-0.1.0/README.md +19 -0
  3. axobench-0.1.0/pyproject.toml +37 -0
  4. axobench-0.1.0/setup.cfg +4 -0
  5. axobench-0.1.0/src/axobench/__init__.py +5 -0
  6. axobench-0.1.0/src/axobench/benchmark/__init__.py +103 -0
  7. axobench-0.1.0/src/axobench/benchmark/branch_adapter.py +104 -0
  8. axobench-0.1.0/src/axobench/benchmark/bundle.py +336 -0
  9. axobench-0.1.0/src/axobench/benchmark/dataset_schema.py +698 -0
  10. axobench-0.1.0/src/axobench/benchmark/diagnostic_audits.py +507 -0
  11. axobench-0.1.0/src/axobench/benchmark/diagnostic_rows.py +748 -0
  12. axobench-0.1.0/src/axobench/benchmark/mechanistic_response.py +360 -0
  13. axobench-0.1.0/src/axobench/benchmark/morphology_transfer.py +244 -0
  14. axobench-0.1.0/src/axobench/benchmark/perturbation_stability.py +117 -0
  15. axobench-0.1.0/src/axobench/benchmark/plots.py +1382 -0
  16. axobench-0.1.0/src/axobench/benchmark/profiles.py +381 -0
  17. axobench-0.1.0/src/axobench/benchmark/regime_stratified.py +417 -0
  18. axobench-0.1.0/src/axobench/benchmark/reports.py +273 -0
  19. axobench-0.1.0/src/axobench/benchmark/runner.py +288 -0
  20. axobench-0.1.0/src/axobench/benchmark/selectors.py +83 -0
  21. axobench-0.1.0/src/axobench/benchmark/state_metrics.py +321 -0
  22. axobench-0.1.0/src/axobench/benchmark/structured_intervention.py +423 -0
  23. axobench-0.1.0/src/axobench/benchmark/suite.py +725 -0
  24. axobench-0.1.0/src/axobench/benchmark/swc_utils.py +250 -0
  25. axobench-0.1.0/src/axobench/benchmark/trace_shape.py +404 -0
  26. axobench-0.1.0/src/axobench/cli.py +550 -0
  27. axobench-0.1.0/src/axobench/data.py +666 -0
  28. axobench-0.1.0/src/axobench/generation/__init__.py +15 -0
  29. axobench-0.1.0/src/axobench/generation/arbor_sim.py +1143 -0
  30. axobench-0.1.0/src/axobench/generation/assets/__init__.py +1 -0
  31. axobench-0.1.0/src/axobench/generation/assets/allen_l5_template.swc +4855 -0
  32. axobench-0.1.0/src/axobench/generation/assets/allen_l5_template_fit.json +297 -0
  33. axobench-0.1.0/src/axobench/generation/generate_coreneuron_hay_dataset.py +439 -0
  34. axobench-0.1.0/src/axobench/generation/generate_hay_neuron_dataset.py +289 -0
  35. axobench-0.1.0/src/axobench/generation/mechanisms/__init__.py +1 -0
  36. axobench-0.1.0/src/axobench/generation/mechanisms/nmda.mod +75 -0
  37. axobench-0.1.0/src/axobench/generation/prepare_hay_swc_template.py +93 -0
  38. axobench-0.1.0/src/axobench/generation/rank_allen_l5_m3_candidates.py +197 -0
  39. axobench-0.1.0/src/axobench/generation/run_coreneuron_event_dropout_pair.py +614 -0
  40. axobench-0.1.0/src/axobench/generation/run_coreneuron_hay_probe.py +248 -0
  41. axobench-0.1.0/src/axobench/generation/run_generation_throughput_gate.py +437 -0
  42. axobench-0.1.0/src/axobench/generation/run_hay_neuron_driven_probe.py +266 -0
  43. axobench-0.1.0/src/axobench/generation/run_v1_parallel_generation.py +605 -0
  44. axobench-0.1.0/src/axobench/metrics.py +468 -0
  45. axobench-0.1.0/src/axobench/neuronio_raw.py +240 -0
  46. axobench-0.1.0/src/axobench/setup_workflow.py +150 -0
  47. axobench-0.1.0/src/axobench.egg-info/PKG-INFO +34 -0
  48. axobench-0.1.0/src/axobench.egg-info/SOURCES.txt +66 -0
  49. axobench-0.1.0/src/axobench.egg-info/dependency_links.txt +1 -0
  50. axobench-0.1.0/src/axobench.egg-info/entry_points.txt +2 -0
  51. axobench-0.1.0/src/axobench.egg-info/requires.txt +11 -0
  52. axobench-0.1.0/src/axobench.egg-info/top_level.txt +1 -0
  53. axobench-0.1.0/tests/test_arbor_sim.py +111 -0
  54. axobench-0.1.0/tests/test_benchmark_bundle.py +162 -0
  55. axobench-0.1.0/tests/test_benchmark_profiles.py +56 -0
  56. axobench-0.1.0/tests/test_benchmark_reports.py +93 -0
  57. axobench-0.1.0/tests/test_benchmark_runner.py +153 -0
  58. axobench-0.1.0/tests/test_benchmark_selectors.py +44 -0
  59. axobench-0.1.0/tests/test_benchmark_suite.py +535 -0
  60. axobench-0.1.0/tests/test_cli.py +505 -0
  61. axobench-0.1.0/tests/test_data.py +195 -0
  62. axobench-0.1.0/tests/test_dataset_schema.py +85 -0
  63. axobench-0.1.0/tests/test_diagnostic_audits.py +153 -0
  64. axobench-0.1.0/tests/test_diagnostic_rows.py +249 -0
  65. axobench-0.1.0/tests/test_kaggle_download.py +127 -0
  66. axobench-0.1.0/tests/test_metrics.py +103 -0
  67. axobench-0.1.0/tests/test_neuronio_convert.py +117 -0
  68. axobench-0.1.0/tests/test_setup_workflow.py +50 -0
@@ -0,0 +1,34 @@
1
+ Metadata-Version: 2.4
2
+ Name: axobench
3
+ Version: 0.1.0
4
+ Summary: Single-neuron surrogate benchmark and dataset generation workbench.
5
+ Requires-Python: >=3.10
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: matplotlib>=3.8
8
+ Requires-Dist: numpy>=1.24
9
+ Requires-Dist: requests>=2.32
10
+ Provides-Extra: dev
11
+ Requires-Dist: pytest>=8; extra == "dev"
12
+ Provides-Extra: sim
13
+ Requires-Dist: arbor==0.11.0; extra == "sim"
14
+ Provides-Extra: plots
15
+
16
+ # AxoBench
17
+
18
+ AxoBench is the extracted benchmark and dataset-generation workbench for
19
+ single-neuron surrogate evaluation. It owns dataset generation, data/schema
20
+ utilities, diagnostic suites, and benchmark reporting.
21
+
22
+ The package keeps simulator-backed generation code under
23
+ `src/axobench/generation/`: Arbor, NEURON/Hay, CoreNEURON/Hay, SWC-template
24
+ preparation, morphology-candidate ranking, throughput probes, and the v1
25
+ parallel dataset generator. `scripts/` is reserved for thin operational entry
26
+ points such as launch scripts and data-download helpers.
27
+
28
+ Model implementations, training loops, checkpoint formats, and model-specific
29
+ experiments stay in `bnn_sim`. AxoBench evaluates caller-supplied prediction
30
+ functions or stored prediction/diagnostic artifacts; it does not ship BranchELM,
31
+ Mamba, RNN, or training code.
32
+
33
+ The active exploration arc lives in `docs/benchmark/`, which is ignored by Git
34
+ while the benchmark audit is still exploratory.
@@ -0,0 +1,19 @@
1
+ # AxoBench
2
+
3
+ AxoBench is the extracted benchmark and dataset-generation workbench for
4
+ single-neuron surrogate evaluation. It owns dataset generation, data/schema
5
+ utilities, diagnostic suites, and benchmark reporting.
6
+
7
+ The package keeps simulator-backed generation code under
8
+ `src/axobench/generation/`: Arbor, NEURON/Hay, CoreNEURON/Hay, SWC-template
9
+ preparation, morphology-candidate ranking, throughput probes, and the v1
10
+ parallel dataset generator. `scripts/` is reserved for thin operational entry
11
+ points such as launch scripts and data-download helpers.
12
+
13
+ Model implementations, training loops, checkpoint formats, and model-specific
14
+ experiments stay in `bnn_sim`. AxoBench evaluates caller-supplied prediction
15
+ functions or stored prediction/diagnostic artifacts; it does not ship BranchELM,
16
+ Mamba, RNN, or training code.
17
+
18
+ The active exploration arc lives in `docs/benchmark/`, which is ignored by Git
19
+ while the benchmark audit is still exploratory.
@@ -0,0 +1,37 @@
1
+ [build-system]
2
+ requires = ["setuptools>=68"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "axobench"
7
+ version = "0.1.0"
8
+ description = "Single-neuron surrogate benchmark and dataset generation workbench."
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ dependencies = [
12
+ "matplotlib>=3.8",
13
+ "numpy>=1.24",
14
+ "requests>=2.32",
15
+ ]
16
+
17
+ [project.optional-dependencies]
18
+ dev = [
19
+ "pytest>=8",
20
+ ]
21
+ sim = [
22
+ "arbor==0.11.0",
23
+ ]
24
+ plots = []
25
+
26
+ [project.scripts]
27
+ axobench = "axobench.cli:main"
28
+
29
+ [tool.setuptools.packages.find]
30
+ where = ["src"]
31
+
32
+ [tool.setuptools.package-data]
33
+ "axobench.generation" = ["assets/*.swc", "assets/*.json", "mechanisms/*.mod"]
34
+
35
+ [tool.pytest.ini_options]
36
+ testpaths = ["tests"]
37
+ pythonpath = ["src"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,5 @@
1
+ """Dataset-generation and diagnostic utilities for single-neuron benchmarks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ __all__: list[str] = []
@@ -0,0 +1,103 @@
1
+ """AxoBench neuron surrogate benchmark utilities.
2
+
3
+ The package exports the public benchmark helpers lazily so lightweight
4
+ dataframe/TCN evaluation can run in environments that intentionally omit
5
+ PyTorch.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Any
11
+
12
+
13
+ _EXPORTS: dict[str, tuple[str, str]] = {
14
+ "compute_regime_stratified_metrics": ("axobench.benchmark.regime_stratified", "compute_regime_stratified_metrics"),
15
+ "compute_state_conditioned_metrics": ("axobench.benchmark.state_metrics", "compute_state_conditioned_metrics"),
16
+ "build_neuron_state_masks": ("axobench.benchmark.state_metrics", "build_neuron_state_masks"),
17
+ "compute_trace_shape_metrics": ("axobench.benchmark.trace_shape", "compute_trace_shape_metrics"),
18
+ "compute_perturbation_stability": ("axobench.benchmark.perturbation_stability", "compute_perturbation_stability"),
19
+ "compute_paired_event_intervention_metrics": (
20
+ "axobench.benchmark.structured_intervention",
21
+ "compute_paired_event_intervention_metrics",
22
+ ),
23
+ "make_event_dropout_inputs": ("axobench.benchmark.structured_intervention", "make_event_dropout_inputs"),
24
+ "make_selective_event_dropout_inputs": (
25
+ "axobench.benchmark.structured_intervention",
26
+ "make_selective_event_dropout_inputs",
27
+ ),
28
+ "make_site_silence_inputs": ("axobench.benchmark.structured_intervention", "make_site_silence_inputs"),
29
+ "make_structured_intervention_inputs": (
30
+ "axobench.benchmark.structured_intervention",
31
+ "make_structured_intervention_inputs",
32
+ ),
33
+ "make_temporal_jitter_inputs": ("axobench.benchmark.structured_intervention", "make_temporal_jitter_inputs"),
34
+ "MorphologyTransferEvaluator": ("axobench.benchmark.morphology_transfer", "MorphologyTransferEvaluator"),
35
+ "load_available_morphologies": ("axobench.benchmark.morphology_transfer", "load_available_morphologies"),
36
+ "adapt_branch_count": ("axobench.benchmark.branch_adapter", "adapt_branch_count"),
37
+ "evaluate_with_branch_adaptation": ("axobench.benchmark.branch_adapter", "evaluate_with_branch_adaptation"),
38
+ "BenchmarkProfile": ("axobench.benchmark.profiles", "BenchmarkProfile"),
39
+ "BenchmarkCostModel": ("axobench.benchmark.profiles", "BenchmarkCostModel"),
40
+ "LOCAL_CORENEURON_HAY_COST_MODEL": ("axobench.benchmark.profiles", "LOCAL_CORENEURON_HAY_COST_MODEL"),
41
+ "build_axis_plan": ("axobench.benchmark.profiles", "build_axis_plan"),
42
+ "estimate_profile_cost": ("axobench.benchmark.profiles", "estimate_profile_cost"),
43
+ "get_benchmark_profile": ("axobench.benchmark.profiles", "get_benchmark_profile"),
44
+ "list_benchmark_profiles": ("axobench.benchmark.profiles", "list_benchmark_profiles"),
45
+ "DIAGNOSTIC_SUITES": ("axobench.benchmark.dataset_schema", "DIAGNOSTIC_SUITES"),
46
+ "INTERVENTION_CONDITIONS": ("axobench.benchmark.dataset_schema", "INTERVENTION_CONDITIONS"),
47
+ "ORDINARY_SPLITS": ("axobench.benchmark.dataset_schema", "ORDINARY_SPLITS"),
48
+ "V1_MORPHOLOGY_COUNT": ("axobench.benchmark.dataset_schema", "V1_MORPHOLOGY_COUNT"),
49
+ "build_diagnostic_suite_catalog": ("axobench.benchmark.dataset_schema", "build_diagnostic_suite_catalog"),
50
+ "build_v1_dataset_manifest_template": ("axobench.benchmark.dataset_schema", "build_v1_dataset_manifest_template"),
51
+ "default_diagnostic_specs": ("axobench.benchmark.dataset_schema", "default_diagnostic_specs"),
52
+ "default_target_views": ("axobench.benchmark.dataset_schema", "default_target_views"),
53
+ "provisional_v1_morphologies": ("axobench.benchmark.dataset_schema", "provisional_v1_morphologies"),
54
+ "validate_v1_manifest_template": ("axobench.benchmark.dataset_schema", "validate_v1_manifest_template"),
55
+ "DiagnosticContext": ("axobench.benchmark.diagnostic_rows", "DiagnosticContext"),
56
+ "evaluate_model_diagnostic_rows": ("axobench.benchmark.diagnostic_rows", "evaluate_model_diagnostic_rows"),
57
+ "morphology_contrast_diagnostic_rows": ("axobench.benchmark.diagnostic_rows", "morphology_contrast_diagnostic_rows"),
58
+ "morphology_routing_qc_rows": ("axobench.benchmark.diagnostic_rows", "morphology_routing_qc_rows"),
59
+ "paired_intervention_diagnostic_rows": ("axobench.benchmark.diagnostic_rows", "paired_intervention_diagnostic_rows"),
60
+ "teacher_qc_diagnostic_rows": ("axobench.benchmark.diagnostic_rows", "teacher_qc_diagnostic_rows"),
61
+ "validate_diagnostic_rows": ("axobench.benchmark.diagnostic_rows", "validate_diagnostic_rows"),
62
+ "write_diagnostic_rows": ("axobench.benchmark.diagnostic_rows", "write_diagnostic_rows"),
63
+ "DiagnosticAuditSources": ("axobench.benchmark.diagnostic_audits", "DiagnosticAuditSources"),
64
+ "adapter_fairness_diagnostic_rows": ("axobench.benchmark.diagnostic_audits", "adapter_fairness_diagnostic_rows"),
65
+ "diagnostic_audit_rows": ("axobench.benchmark.diagnostic_audits", "diagnostic_audit_rows"),
66
+ "diagnostic_audit_rows_from_sources": ("axobench.benchmark.diagnostic_audits", "diagnostic_audit_rows_from_sources"),
67
+ "load_artifact_rows": ("axobench.benchmark.diagnostic_audits", "load_artifact_rows"),
68
+ "metric_decoupling_diagnostic_rows": ("axobench.benchmark.diagnostic_audits", "metric_decoupling_diagnostic_rows"),
69
+ "protocol_coverage_diagnostic_rows": ("axobench.benchmark.diagnostic_audits", "protocol_coverage_diagnostic_rows"),
70
+ "robustness_directionality_diagnostic_rows": (
71
+ "axobench.benchmark.diagnostic_audits",
72
+ "robustness_directionality_diagnostic_rows",
73
+ ),
74
+ "target_view_audit_diagnostic_rows": ("axobench.benchmark.diagnostic_audits", "target_view_audit_diagnostic_rows"),
75
+ "DatasetSelector": ("axobench.benchmark.selectors", "DatasetSelector"),
76
+ "V1DatasetLayout": ("axobench.benchmark.selectors", "V1DatasetLayout"),
77
+ "load_intervention_pair_npz": ("axobench.benchmark.runner", "load_intervention_pair_npz"),
78
+ "load_selector_batch": ("axobench.benchmark.runner", "load_selector_batch"),
79
+ "run_model_diagnostics_on_selector": ("axobench.benchmark.runner", "run_model_diagnostics_on_selector"),
80
+ "run_paired_intervention_on_npz": ("axobench.benchmark.runner", "run_paired_intervention_on_npz"),
81
+ "run_paired_intervention_on_selector": ("axobench.benchmark.runner", "run_paired_intervention_on_selector"),
82
+ "run_teacher_qc_on_selector": ("axobench.benchmark.runner", "run_teacher_qc_on_selector"),
83
+ "summarize_diagnostic_rows": ("axobench.benchmark.reports", "summarize_diagnostic_rows"),
84
+ "core_metric_comparison_rows": ("axobench.benchmark.reports", "core_metric_comparison_rows"),
85
+ "write_diagnostic_summary": ("axobench.benchmark.reports", "write_diagnostic_summary"),
86
+ "write_core_metric_comparison_table": ("axobench.benchmark.reports", "write_core_metric_comparison_table"),
87
+ "run_v1_benchmark_bundle": ("axobench.benchmark.bundle", "run_v1_benchmark_bundle"),
88
+ "infer_pair_metadata": ("axobench.benchmark.bundle", "infer_pair_metadata"),
89
+ "BenchmarkSuite": ("axobench.benchmark.suite", "BenchmarkSuite"),
90
+ }
91
+
92
+ __all__ = sorted(_EXPORTS)
93
+
94
+
95
+ def __getattr__(name: str) -> Any:
96
+ if name not in _EXPORTS:
97
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
98
+ module_name, attribute = _EXPORTS[name]
99
+ from importlib import import_module
100
+
101
+ value = getattr(import_module(module_name), attribute)
102
+ globals()[name] = value
103
+ return value
@@ -0,0 +1,104 @@
1
+ """Branch count adaptation for morphology transfer (M3).
2
+
3
+ Handles mismatch between training morphology (e.g., 45 branches)
4
+ and test morphologies (e.g., 24-40 branches).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import numpy as np
10
+
11
+ from axobench.metrics import binary_auc
12
+
13
+
14
+ def adapt_branch_count(
15
+ routing_matrix: np.ndarray,
16
+ target_n_branches: int,
17
+ strategy: str = "pad",
18
+ ) -> np.ndarray:
19
+ """Adapt a routing matrix to a target branch count.
20
+
21
+ Args:
22
+ routing_matrix: Source routing matrix of shape (n_branches, input_dim).
23
+ target_n_branches: Desired number of branches.
24
+ strategy: Adaptation strategy:
25
+ - "pad": Pad with zero-weight branches (if source < target)
26
+ - "truncate": Truncate excess branches (if source > target)
27
+ - "interpolate": Distribute inputs proportionally (always works)
28
+
29
+ Returns:
30
+ Adapted routing matrix of shape (target_n_branches, input_dim).
31
+ """
32
+ n_source, input_dim = routing_matrix.shape
33
+
34
+ if n_source == target_n_branches:
35
+ return routing_matrix
36
+
37
+ if strategy == "pad":
38
+ if n_source < target_n_branches:
39
+ # Pad with zero-weight branches
40
+ padding = np.zeros((target_n_branches - n_source, input_dim))
41
+ return np.vstack([routing_matrix, padding])
42
+ else:
43
+ # Truncate excess branches
44
+ return routing_matrix[:target_n_branches]
45
+
46
+ elif strategy == "interpolate":
47
+ # Distribute source branches across target branches proportionally
48
+ result = np.zeros((target_n_branches, input_dim))
49
+ for target_b in range(target_n_branches):
50
+ # Map target branch to source branch space
51
+ source_idx = int(target_b * n_source / target_n_branches)
52
+ # Weight by fractional overlap
53
+ weight = min(1.0, n_source / target_n_branches)
54
+ result[target_b] = routing_matrix[source_idx] * weight
55
+ return result
56
+
57
+ else:
58
+ raise ValueError(f"Unknown strategy: {strategy}")
59
+
60
+
61
+ def evaluate_with_branch_adaptation(
62
+ model_fn: callable,
63
+ inputs: np.ndarray,
64
+ targets: np.ndarray,
65
+ source_routing: np.ndarray,
66
+ target_routing: np.ndarray,
67
+ strategy: str = "pad",
68
+ ) -> dict:
69
+ """Evaluate model with branch count adaptation.
70
+
71
+ Args:
72
+ model_fn: Model prediction function.
73
+ inputs: Input array (batch, time, input_dim).
74
+ targets: Target array (batch, time, 2).
75
+ source_routing: Routing matrix used during training.
76
+ target_routing: Routing matrix for test morphology.
77
+ strategy: Adaptation strategy.
78
+
79
+ Returns:
80
+ Dictionary with evaluation metrics.
81
+ """
82
+ # Adapt routing if needed
83
+ n_source = source_routing.shape[0]
84
+ n_target = target_routing.shape[0]
85
+
86
+ adapted_routing = adapt_branch_count(target_routing, n_source, strategy=strategy)
87
+
88
+ # For now, just evaluate directly (assuming model handles routing internally)
89
+ # In practice, would need to rewire model's routing matrix
90
+ preds = model_fn(inputs)
91
+
92
+ rmse_mv = np.sqrt(np.mean((preds[:, :, 1] - targets[:, :, 1]) ** 2)) / 0.1
93
+
94
+ auc = binary_auc(preds[:, :, 0], targets[:, :, 0])
95
+ if np.isnan(auc):
96
+ auc = 0.5
97
+
98
+ return {
99
+ "n_source_branches": n_source,
100
+ "n_target_branches": n_target,
101
+ "adaptation_strategy": strategy,
102
+ "rmse_mv": float(rmse_mv),
103
+ "auc": float(auc),
104
+ }
@@ -0,0 +1,336 @@
1
+ """Promoted v1 benchmark bundle runner."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable, Iterable
6
+ import json
7
+ from pathlib import Path
8
+ import re
9
+ from typing import Any, Literal
10
+
11
+ import numpy as np
12
+
13
+ from axobench.benchmark.dataset_schema import INTERVENTION_CONDITIONS
14
+ from axobench.benchmark.diagnostic_rows import (
15
+ DiagnosticContext,
16
+ morphology_contrast_diagnostic_rows,
17
+ morphology_routing_qc_rows,
18
+ write_diagnostic_rows,
19
+ )
20
+ from axobench.benchmark.reports import write_diagnostic_summary
21
+ from axobench.benchmark.runner import (
22
+ resolve_intervention_pair_paths,
23
+ run_model_diagnostics_on_selector,
24
+ run_paired_intervention_on_npz,
25
+ run_paired_intervention_on_npzs,
26
+ run_paired_intervention_on_selector,
27
+ run_teacher_qc_on_selector,
28
+ )
29
+ from axobench.benchmark.selectors import V1DatasetLayout
30
+
31
+
32
+ DEFAULT_MODEL_SUITES = (
33
+ "trace-core",
34
+ "state-dynamics",
35
+ "spike-behavior",
36
+ "feature-fidelity",
37
+ "paper-summary",
38
+ )
39
+ DEFAULT_INTERVENTION_SUITES = ("teacher-qc", "intervention-response", "paper-summary")
40
+ SuiteProfile = Literal["teacher-qc", "paper-summary", "full"]
41
+
42
+
43
+ def run_v1_benchmark_bundle(
44
+ *,
45
+ dataset_root: str | Path,
46
+ output_dir: str | Path,
47
+ base_context: DiagnosticContext | None = None,
48
+ model_fn: Callable[[np.ndarray], np.ndarray] | None = None,
49
+ model_id: str | None = None,
50
+ intervention_pair_paths: Iterable[str | Path] = (),
51
+ suite: SuiteProfile = "paper-summary",
52
+ max_val_samples: int | None = None,
53
+ max_intervention_samples: int | None = None,
54
+ cache_shards: int = 1,
55
+ write_summaries: bool = True,
56
+ ) -> dict[str, Any]:
57
+ """Run the promoted v1 validation plus intervention diagnostic bundle."""
58
+ layout = V1DatasetLayout(dataset_root)
59
+ output = Path(output_dir)
60
+ output.mkdir(parents=True, exist_ok=True)
61
+ base = base_context or DiagnosticContext()
62
+ suite_config = _suite_config(suite)
63
+ artifacts: list[dict[str, Any]] = []
64
+ intervention_row_sets: list[dict[str, Any]] = []
65
+
66
+ teacher_context = _replace_context(
67
+ base,
68
+ suites=suite_config["teacher"],
69
+ model_id="teacher",
70
+ split="val",
71
+ condition="val",
72
+ intervention=None,
73
+ )
74
+ teacher_rows = run_teacher_qc_on_selector(
75
+ layout.default_validation(),
76
+ context=teacher_context,
77
+ max_samples=max_val_samples,
78
+ output=output / "teacher_qc_val.jsonl",
79
+ cache_shards=cache_shards,
80
+ )
81
+ _record_artifact(teacher_rows, output / "teacher_qc_val.jsonl", artifacts, write_summaries=write_summaries)
82
+
83
+ routing_rows = morphology_routing_qc_rows(
84
+ context=_replace_context(base, suites=suite_config["routing"], model_id="teacher")
85
+ )
86
+ write_diagnostic_rows(routing_rows, output / "morphology_routing_qc.jsonl")
87
+ _record_artifact(
88
+ routing_rows,
89
+ output / "morphology_routing_qc.jsonl",
90
+ artifacts,
91
+ write_summaries=write_summaries,
92
+ )
93
+
94
+ if model_fn is not None and suite_config["model"] is not None:
95
+ resolved_model_id = model_id or base.model_id
96
+ model_context = _replace_context(
97
+ base,
98
+ suites=suite_config["model"],
99
+ model_id=resolved_model_id,
100
+ split="val",
101
+ condition="val",
102
+ intervention=None,
103
+ )
104
+ model_rows = run_model_diagnostics_on_selector(
105
+ model_fn,
106
+ layout.default_validation(),
107
+ context=model_context,
108
+ max_samples=max_val_samples,
109
+ output=output / f"model_val_{_slug(resolved_model_id)}.jsonl",
110
+ cache_shards=cache_shards,
111
+ )
112
+ _record_artifact(
113
+ model_rows,
114
+ output / f"model_val_{_slug(resolved_model_id)}.jsonl",
115
+ artifacts,
116
+ write_summaries=write_summaries,
117
+ )
118
+
119
+ pair_paths = [Path(path) for path in intervention_pair_paths]
120
+ if pair_paths:
121
+ for intervention, morphology_id, grouped_paths in _group_pair_paths_by_morphology(
122
+ pair_paths,
123
+ default_morphology=base.morphology_id,
124
+ ):
125
+ rows_path, rows = _run_grouped_intervention(
126
+ grouped_paths,
127
+ output=output,
128
+ base=base,
129
+ suites=suite_config["intervention"],
130
+ model_fn=model_fn,
131
+ model_id=model_id,
132
+ intervention=intervention,
133
+ morphology_id=morphology_id,
134
+ max_samples=max_intervention_samples,
135
+ )
136
+ intervention_row_sets.extend(rows)
137
+ _record_artifact(rows, rows_path, artifacts, write_summaries=write_summaries)
138
+ else:
139
+ for intervention in layout.available_interventions():
140
+ selector = layout.intervention(intervention)
141
+ selector_pair_paths = resolve_intervention_pair_paths(selector)
142
+ if len(selector_pair_paths) == 1:
143
+ intervention_model_id = model_id if model_fn is not None else "teacher"
144
+ context = _replace_context(
145
+ base,
146
+ suites=suite_config["intervention"],
147
+ model_id=intervention_model_id,
148
+ split="",
149
+ condition=f"interventions/{intervention}",
150
+ intervention=intervention,
151
+ )
152
+ filename = f"intervention_{_slug(base.morphology_id)}_{_slug(intervention)}"
153
+ rows = run_paired_intervention_on_selector(
154
+ selector,
155
+ model_fn=model_fn,
156
+ context=context,
157
+ max_samples=max_intervention_samples,
158
+ output=output / f"{filename}.jsonl",
159
+ )
160
+ intervention_row_sets.extend(rows)
161
+ _record_artifact(rows, output / f"{filename}.jsonl", artifacts, write_summaries=write_summaries)
162
+ continue
163
+ for inferred_intervention, morphology_id, grouped_paths in _group_pair_paths_by_morphology(
164
+ selector_pair_paths,
165
+ default_morphology=base.morphology_id,
166
+ ):
167
+ rows_path, rows = _run_grouped_intervention(
168
+ grouped_paths,
169
+ output=output,
170
+ base=base,
171
+ suites=suite_config["intervention"],
172
+ model_fn=model_fn,
173
+ model_id=model_id,
174
+ intervention=inferred_intervention,
175
+ morphology_id=morphology_id,
176
+ max_samples=max_intervention_samples,
177
+ )
178
+ intervention_row_sets.extend(rows)
179
+ _record_artifact(rows, rows_path, artifacts, write_summaries=write_summaries)
180
+
181
+ contrast_rows = morphology_contrast_diagnostic_rows(
182
+ intervention_row_sets,
183
+ context=_replace_context(base, suites=suite_config["morphology_contrast"], morphology_id="all_v1"),
184
+ )
185
+ if contrast_rows:
186
+ contrast_path = output / "morphology_contrast_interventions.jsonl"
187
+ write_diagnostic_rows(contrast_rows, contrast_path)
188
+ _record_artifact(
189
+ contrast_rows,
190
+ contrast_path,
191
+ artifacts,
192
+ write_summaries=write_summaries,
193
+ )
194
+
195
+ manifest = {
196
+ "dataset_root": str(dataset_root),
197
+ "output_dir": str(output),
198
+ "model_id": model_id,
199
+ "suite": suite,
200
+ "suite_config": {key: list(value) if value is not None else None for key, value in suite_config.items()},
201
+ "max_val_samples": max_val_samples,
202
+ "max_intervention_samples": max_intervention_samples,
203
+ "write_summaries": write_summaries,
204
+ "artifacts": artifacts,
205
+ }
206
+ (output / "manifest.json").write_text(json.dumps(manifest, indent=2, sort_keys=True) + "\n", encoding="utf-8")
207
+ return manifest
208
+
209
+
210
+ def _suite_config(suite: SuiteProfile) -> dict[str, tuple[str, ...] | None]:
211
+ if suite == "teacher-qc":
212
+ return {
213
+ "teacher": ("teacher-qc",),
214
+ "routing": ("teacher-qc",),
215
+ "model": None,
216
+ "intervention": ("teacher-qc",),
217
+ "morphology_contrast": ("morphology-contrast",),
218
+ }
219
+ if suite == "paper-summary":
220
+ return {
221
+ "teacher": ("teacher-qc", "paper-summary"),
222
+ "routing": ("teacher-qc", "paper-summary"),
223
+ "model": DEFAULT_MODEL_SUITES,
224
+ "intervention": DEFAULT_INTERVENTION_SUITES,
225
+ "morphology_contrast": ("morphology-contrast", "paper-summary"),
226
+ }
227
+ if suite == "full":
228
+ return {
229
+ "teacher": ("teacher-qc", "paper-summary", "full"),
230
+ "routing": ("teacher-qc", "paper-summary", "full"),
231
+ "model": (*DEFAULT_MODEL_SUITES, "full"),
232
+ "intervention": ("teacher-qc", "intervention-response", "paper-summary", "full"),
233
+ "morphology_contrast": ("morphology-contrast", "paper-summary", "full"),
234
+ }
235
+ raise ValueError(f"unknown v1 benchmark suite profile: {suite}")
236
+
237
+
238
+ def infer_pair_metadata(path: str | Path, *, default_morphology: str = "unknown") -> tuple[str, str]:
239
+ """Infer intervention and morphology labels from a pilot pair-file path."""
240
+ path = Path(path)
241
+ parts = path.parts
242
+ intervention = None
243
+ if "interventions" in parts:
244
+ index = parts.index("interventions")
245
+ if index + 1 < len(parts):
246
+ intervention = parts[index + 1]
247
+ text = str(path)
248
+ if intervention is None:
249
+ intervention = next((name for name in INTERVENTION_CONDITIONS if name in text), "event_dropout")
250
+ match = re.search(r"specimen_(\d+)", text)
251
+ morphology_id = f"specimen_{match.group(1)}" if match else default_morphology
252
+ return intervention, morphology_id
253
+
254
+
255
+ def _group_pair_paths_by_morphology(
256
+ pair_paths: Iterable[Path],
257
+ *,
258
+ default_morphology: str,
259
+ ) -> list[tuple[str, str, list[Path]]]:
260
+ grouped: dict[tuple[str, str], list[Path]] = {}
261
+ for pair_path in pair_paths:
262
+ intervention, morphology_id = infer_pair_metadata(pair_path, default_morphology=default_morphology)
263
+ grouped.setdefault((intervention, morphology_id), []).append(pair_path)
264
+ return [
265
+ (intervention, morphology_id, sorted(paths))
266
+ for (intervention, morphology_id), paths in sorted(grouped.items())
267
+ ]
268
+
269
+
270
+ def _run_grouped_intervention(
271
+ pair_paths: list[Path],
272
+ *,
273
+ output: Path,
274
+ base: DiagnosticContext,
275
+ suites: tuple[str, ...],
276
+ model_fn: Callable[[np.ndarray], np.ndarray] | None,
277
+ model_id: str | None,
278
+ intervention: str,
279
+ morphology_id: str,
280
+ max_samples: int | None,
281
+ ) -> tuple[Path, list[dict[str, Any]]]:
282
+ intervention_model_id = model_id if model_fn is not None else "teacher"
283
+ context = _replace_context(
284
+ base,
285
+ suites=suites,
286
+ model_id=intervention_model_id,
287
+ morphology_id=morphology_id,
288
+ split="",
289
+ condition=f"interventions/{intervention}",
290
+ intervention=intervention,
291
+ )
292
+ rows_path = output / f"intervention_{_slug(morphology_id)}_{_slug(intervention)}.jsonl"
293
+ if len(pair_paths) == 1:
294
+ rows = run_paired_intervention_on_npz(
295
+ pair_paths[0],
296
+ model_fn=model_fn,
297
+ context=context,
298
+ max_samples=max_samples,
299
+ output=rows_path,
300
+ )
301
+ else:
302
+ rows = run_paired_intervention_on_npzs(
303
+ pair_paths,
304
+ model_fn=model_fn,
305
+ context=context,
306
+ max_samples=max_samples,
307
+ output=rows_path,
308
+ )
309
+ return rows_path, rows
310
+
311
+
312
+ def _record_artifact(
313
+ rows: list[dict[str, Any]],
314
+ rows_path: Path,
315
+ artifacts: list[dict[str, Any]],
316
+ *,
317
+ write_summaries: bool,
318
+ ) -> None:
319
+ artifact = {
320
+ "rows": len(rows),
321
+ "rows_path": str(rows_path),
322
+ }
323
+ if write_summaries:
324
+ summary_path = rows_path.with_suffix(".md")
325
+ write_diagnostic_summary(rows, summary_path)
326
+ artifact["summary_path"] = str(summary_path)
327
+ artifacts.append(artifact)
328
+
329
+
330
+ def _replace_context(context: DiagnosticContext, **updates: Any) -> DiagnosticContext:
331
+ data = {**context.__dict__, **updates}
332
+ return DiagnosticContext(**data)
333
+
334
+
335
+ def _slug(value: str) -> str:
336
+ return re.sub(r"[^A-Za-z0-9_.-]+", "_", value).strip("_") or "unknown"