btorch 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- btorch/__init__.py +15 -0
- btorch/analysis/__init__.py +124 -0
- btorch/analysis/aggregation.py +363 -0
- btorch/analysis/branching.py +315 -0
- btorch/analysis/clustering.py +53 -0
- btorch/analysis/connectivity.py +251 -0
- btorch/analysis/dynamic_tools/README.md +81 -0
- btorch/analysis/dynamic_tools/__init__.py +0 -0
- btorch/analysis/dynamic_tools/attractor_dynamics.py +121 -0
- btorch/analysis/dynamic_tools/complexity.py +324 -0
- btorch/analysis/dynamic_tools/criticality.py +240 -0
- btorch/analysis/dynamic_tools/ei_balance.py +514 -0
- btorch/analysis/dynamic_tools/lyapunov_dynamics.py +120 -0
- btorch/analysis/dynamic_tools/micro_scale.py +192 -0
- btorch/analysis/dynamic_tools/spiking.py +1206 -0
- btorch/analysis/metrics.py +47 -0
- btorch/analysis/spiking.py +1421 -0
- btorch/analysis/statistics.py +954 -0
- btorch/analysis/tuning.py +76 -0
- btorch/analysis/two_compartment_fit.py +1936 -0
- btorch/analysis/voltage.py +56 -0
- btorch/config.py +43 -0
- btorch/connectome/__init__.py +9 -0
- btorch/connectome/augment.py +386 -0
- btorch/connectome/connection.py +1027 -0
- btorch/datasets/__init__.py +75 -0
- btorch/datasets/noise.py +1129 -0
- btorch/datasets/transforms.py +140 -0
- btorch/io/__init__.py +48 -0
- btorch/io/serialization.py +818 -0
- btorch/jit.py +22 -0
- btorch/models/__init__.py +43 -0
- btorch/models/base.py +1075 -0
- btorch/models/bilinear.py +78 -0
- btorch/models/connection_conversion.py +1204 -0
- btorch/models/constrain.py +14 -0
- btorch/models/conv.py +153 -0
- btorch/models/dlif.py +228 -0
- btorch/models/environ.py +139 -0
- btorch/models/functional.py +448 -0
- btorch/models/hex/__init__.py +16 -0
- btorch/models/hex/conv.py +102 -0
- btorch/models/hex/eye.py +376 -0
- btorch/models/history.py +358 -0
- btorch/models/init.py +150 -0
- btorch/models/linear.py +693 -0
- btorch/models/neurons/__init__.py +17 -0
- btorch/models/neurons/alif.py +407 -0
- btorch/models/neurons/glif.py +441 -0
- btorch/models/neurons/izhikevich.py +322 -0
- btorch/models/neurons/lif.py +254 -0
- btorch/models/neurons/mixed.py +193 -0
- btorch/models/neurons/two_compartment.py +407 -0
- btorch/models/ode.py +55 -0
- btorch/models/parametrize.py +262 -0
- btorch/models/regularizer.py +221 -0
- btorch/models/rnn.py +584 -0
- btorch/models/scale.py +295 -0
- btorch/models/shape.py +71 -0
- btorch/models/surrogate/__init__.py +26 -0
- btorch/models/surrogate/atan.py +117 -0
- btorch/models/surrogate/base.py +62 -0
- btorch/models/surrogate/erf.py +95 -0
- btorch/models/surrogate/poisson_random.py +94 -0
- btorch/models/surrogate/sigmoid.py +65 -0
- btorch/models/surrogate/superspike.py +71 -0
- btorch/models/surrogate/triangle.py +73 -0
- btorch/models/synapse.py +1075 -0
- btorch/py.typed +0 -0
- btorch/types.py +5 -0
- btorch/utils/__init__.py +6 -0
- btorch/utils/bench.py +205 -0
- btorch/utils/conf.py +613 -0
- btorch/utils/dict_utils.py +127 -0
- btorch/utils/file.py +160 -0
- btorch/utils/grad_checkpoint/__init__.py +1 -0
- btorch/utils/grad_checkpoint/checkpoint.py +66 -0
- btorch/utils/grad_checkpoint/test_checkpoint.py +143 -0
- btorch/utils/hdf5_utils.py +98 -0
- btorch/utils/hex/__init__.py +181 -0
- btorch/utils/hex/coords.py +152 -0
- btorch/utils/hex/data.py +552 -0
- btorch/utils/hex/distance.py +74 -0
- btorch/utils/hex/doubled.py +240 -0
- btorch/utils/hex/line.py +110 -0
- btorch/utils/hex/neighbor.py +121 -0
- btorch/utils/hex/offset.py +327 -0
- btorch/utils/hex/range.py +138 -0
- btorch/utils/hex/resolve.py +188 -0
- btorch/utils/hex/storage.py +399 -0
- btorch/utils/hex/transform.py +209 -0
- btorch/utils/pandas_utils.py +38 -0
- btorch/utils/yaml_utils.py +54 -0
- btorch/visualisation/__init__.py +114 -0
- btorch/visualisation/aggregation.py +565 -0
- btorch/visualisation/dynamics.py +980 -0
- btorch/visualisation/hex/__init__.py +42 -0
- btorch/visualisation/hex/animate.py +229 -0
- btorch/visualisation/hex/interactive.py +943 -0
- btorch/visualisation/hex/receptive_field.py +251 -0
- btorch/visualisation/hex/static.py +682 -0
- btorch/visualisation/network.py +50 -0
- btorch/visualisation/timeseries.py +2358 -0
- btorch/visualisation/tuning.py +142 -0
- btorch-0.1.0.dist-info/METADATA +416 -0
- btorch-0.1.0.dist-info/RECORD +108 -0
- btorch-0.1.0.dist-info/WHEEL +4 -0
- btorch-0.1.0.dist-info/licenses/LICENSE +201 -0
btorch/__init__.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
from .aggregation import (
|
|
2
|
+
agg_by_neuron,
|
|
3
|
+
agg_by_neuropil,
|
|
4
|
+
agg_conn,
|
|
5
|
+
build_group_frame,
|
|
6
|
+
group_ecdf,
|
|
7
|
+
group_summary,
|
|
8
|
+
group_values,
|
|
9
|
+
)
|
|
10
|
+
from .branching import branching_ratio
|
|
11
|
+
from .connectivity import HopDistanceModel, compute_ie_ratio
|
|
12
|
+
from .metrics import indices_to_mask, select_on_metric
|
|
13
|
+
from .spiking import (
|
|
14
|
+
compute_raster,
|
|
15
|
+
compute_spectrum,
|
|
16
|
+
cv_temporal,
|
|
17
|
+
fano,
|
|
18
|
+
fano_population,
|
|
19
|
+
fano_sweep,
|
|
20
|
+
fano_temporal,
|
|
21
|
+
firing_rate,
|
|
22
|
+
isi_cv,
|
|
23
|
+
isi_cv_population,
|
|
24
|
+
kurtosis,
|
|
25
|
+
kurtosis_population,
|
|
26
|
+
local_variation,
|
|
27
|
+
)
|
|
28
|
+
from .statistics import (
|
|
29
|
+
StatChoice,
|
|
30
|
+
compute_log_hist,
|
|
31
|
+
describe_array,
|
|
32
|
+
use_percentiles,
|
|
33
|
+
use_stats,
|
|
34
|
+
)
|
|
35
|
+
from .two_compartment_fit import (
|
|
36
|
+
DEFAULT_TWO_COMPARTMENT_FIT_STAGES,
|
|
37
|
+
DEFAULT_TWO_COMPARTMENT_PARAM_BOUNDS,
|
|
38
|
+
AllenSweepBatch,
|
|
39
|
+
FitEvaluation,
|
|
40
|
+
TwoCompartmentFitStage,
|
|
41
|
+
choose_current_clamp_sweeps,
|
|
42
|
+
detect_spikes_from_voltage,
|
|
43
|
+
evaluate_fit_across_sweeps,
|
|
44
|
+
evaluate_two_compartment_fit,
|
|
45
|
+
exponential_filter_spike_train,
|
|
46
|
+
filter_mouse_visp_l5_pyramidal_cells,
|
|
47
|
+
fit_two_compartment_model,
|
|
48
|
+
get_cell_types_cache,
|
|
49
|
+
load_allen_sweep,
|
|
50
|
+
mask_post_spike_voltage_samples,
|
|
51
|
+
plot_two_compartment_fit,
|
|
52
|
+
query_mouse_visp_l5_pyramidal_cells,
|
|
53
|
+
resample_trace,
|
|
54
|
+
rollout_two_compartment,
|
|
55
|
+
save_fit_report,
|
|
56
|
+
spike_timing_loss,
|
|
57
|
+
spike_timing_stats,
|
|
58
|
+
two_compartment_loss,
|
|
59
|
+
)
|
|
60
|
+
from .voltage import suggest_skip_timestep, voltage_overshoot
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
__all__ = [
|
|
64
|
+
"agg_by_neuropil",
|
|
65
|
+
"agg_by_neuron",
|
|
66
|
+
"agg_conn",
|
|
67
|
+
"build_group_frame",
|
|
68
|
+
"group_values",
|
|
69
|
+
"group_summary",
|
|
70
|
+
"group_ecdf",
|
|
71
|
+
"branching_ratio",
|
|
72
|
+
"HopDistanceModel",
|
|
73
|
+
"compute_ie_ratio",
|
|
74
|
+
"indices_to_mask",
|
|
75
|
+
"select_on_metric",
|
|
76
|
+
# New simplified API
|
|
77
|
+
"isi_cv",
|
|
78
|
+
"fano",
|
|
79
|
+
"kurtosis",
|
|
80
|
+
"local_variation",
|
|
81
|
+
# Population metrics
|
|
82
|
+
"isi_cv_population",
|
|
83
|
+
"fano_population",
|
|
84
|
+
"kurtosis_population",
|
|
85
|
+
# Temporal variants
|
|
86
|
+
"cv_temporal",
|
|
87
|
+
"fano_temporal",
|
|
88
|
+
# Sweep functions
|
|
89
|
+
"fano_sweep",
|
|
90
|
+
# Utilities
|
|
91
|
+
"firing_rate",
|
|
92
|
+
"compute_raster",
|
|
93
|
+
"compute_log_hist",
|
|
94
|
+
"compute_spectrum",
|
|
95
|
+
"describe_array",
|
|
96
|
+
"AllenSweepBatch",
|
|
97
|
+
"DEFAULT_TWO_COMPARTMENT_FIT_STAGES",
|
|
98
|
+
"DEFAULT_TWO_COMPARTMENT_PARAM_BOUNDS",
|
|
99
|
+
"FitEvaluation",
|
|
100
|
+
"TwoCompartmentFitStage",
|
|
101
|
+
"choose_current_clamp_sweeps",
|
|
102
|
+
"detect_spikes_from_voltage",
|
|
103
|
+
"evaluate_fit_across_sweeps",
|
|
104
|
+
"evaluate_two_compartment_fit",
|
|
105
|
+
"exponential_filter_spike_train",
|
|
106
|
+
"filter_mouse_visp_l5_pyramidal_cells",
|
|
107
|
+
"fit_two_compartment_model",
|
|
108
|
+
"get_cell_types_cache",
|
|
109
|
+
"load_allen_sweep",
|
|
110
|
+
"mask_post_spike_voltage_samples",
|
|
111
|
+
"plot_two_compartment_fit",
|
|
112
|
+
"query_mouse_visp_l5_pyramidal_cells",
|
|
113
|
+
"resample_trace",
|
|
114
|
+
"rollout_two_compartment",
|
|
115
|
+
"save_fit_report",
|
|
116
|
+
"spike_timing_loss",
|
|
117
|
+
"spike_timing_stats",
|
|
118
|
+
"two_compartment_loss",
|
|
119
|
+
"suggest_skip_timestep",
|
|
120
|
+
"voltage_overshoot",
|
|
121
|
+
"StatChoice",
|
|
122
|
+
"use_stats",
|
|
123
|
+
"use_percentiles",
|
|
124
|
+
]
|
|
@@ -0,0 +1,363 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import scipy.sparse
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ..types import TensorLike
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def agg_by_neuron(
|
|
13
|
+
y,
|
|
14
|
+
neurons: pd.DataFrame,
|
|
15
|
+
agg: Literal["mean", "sum", "std"] = "mean",
|
|
16
|
+
neuron_type_column: str = "cell_type",
|
|
17
|
+
**kwargs,
|
|
18
|
+
) -> dict:
|
|
19
|
+
"""Aggregate data by neuron type."""
|
|
20
|
+
agg_func = getattr(np, agg) if isinstance(y, np.ndarray) else getattr(torch, agg)
|
|
21
|
+
ret = {}
|
|
22
|
+
for neuron_type, group in neurons.groupby(
|
|
23
|
+
neuron_type_column, dropna=True, **kwargs
|
|
24
|
+
):
|
|
25
|
+
ret[neuron_type] = agg_func(y[..., group.simple_id.to_numpy()], -1)
|
|
26
|
+
return ret
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def agg_by_neuropil(
|
|
30
|
+
y,
|
|
31
|
+
neurons: pd.DataFrame | None = None,
|
|
32
|
+
connections: pd.DataFrame | None = None,
|
|
33
|
+
mode: Literal["top_innervated", "all_innervated"] = "all_innervated",
|
|
34
|
+
agg: Literal["mean", "sum", "std"] = "mean",
|
|
35
|
+
use_polars: bool = False,
|
|
36
|
+
):
|
|
37
|
+
"""Aggregate activations by neuropil under a validated aggregation mode."""
|
|
38
|
+
agg_func = getattr(np, agg) if isinstance(y, np.ndarray) else getattr(torch, agg)
|
|
39
|
+
if use_polars:
|
|
40
|
+
try:
|
|
41
|
+
import polars as pl
|
|
42
|
+
except ImportError:
|
|
43
|
+
use_polars = False
|
|
44
|
+
|
|
45
|
+
if mode == "top_innervated":
|
|
46
|
+
assert neurons is not None, "neurons must be provided for top_innervated mode"
|
|
47
|
+
tmp = neurons[["group", "simple_id"]].copy()
|
|
48
|
+
tmp = tmp[tmp["simple_id"] < y.shape[-1]]
|
|
49
|
+
pre_ret: dict = {}
|
|
50
|
+
post_ret: dict = {}
|
|
51
|
+
|
|
52
|
+
if use_polars:
|
|
53
|
+
tmp["pre"] = tmp["group"].str.split(".").str[0]
|
|
54
|
+
tmp["post"] = tmp["group"].str.split(".").str[-1]
|
|
55
|
+
ptbl = pl.from_pandas(tmp)
|
|
56
|
+
|
|
57
|
+
pre_groups = ptbl.group_by("pre", maintain_order=True).agg(
|
|
58
|
+
pl.col("simple_id")
|
|
59
|
+
)
|
|
60
|
+
for row in pre_groups.iter_rows(named=True):
|
|
61
|
+
sid = np.asarray(row["simple_id"], dtype=int)
|
|
62
|
+
pre_ret[row["pre"]] = agg_func(y[..., sid], -1)
|
|
63
|
+
|
|
64
|
+
post_groups = ptbl.group_by("post", maintain_order=True).agg(
|
|
65
|
+
pl.col("simple_id")
|
|
66
|
+
)
|
|
67
|
+
for row in post_groups.iter_rows(named=True):
|
|
68
|
+
sid = np.asarray(row["simple_id"], dtype=int)
|
|
69
|
+
post_ret[row["post"]] = agg_func(y[..., sid], -1)
|
|
70
|
+
else:
|
|
71
|
+
tmp["pre"] = tmp["group"].apply(lambda x: x.split(".")[0])
|
|
72
|
+
tmp["post"] = tmp["group"].apply(lambda x: x.split(".")[-1])
|
|
73
|
+
for pre, group in tmp.groupby("pre", dropna=True):
|
|
74
|
+
pre_ret[pre] = agg_func(y[..., group.simple_id], -1)
|
|
75
|
+
for post, group in tmp.groupby("post", dropna=True):
|
|
76
|
+
post_ret[post] = agg_func(y[..., group.simple_id], -1)
|
|
77
|
+
return pre_ret, post_ret
|
|
78
|
+
if mode == "all_innervated":
|
|
79
|
+
assert (
|
|
80
|
+
connections is not None
|
|
81
|
+
), "connections must be provided for all_innervated mode"
|
|
82
|
+
tmp = connections[["pre_simple_id", "post_simple_id", "neuropil"]]
|
|
83
|
+
tmp = tmp[
|
|
84
|
+
(tmp["pre_simple_id"] < y.shape[-1]) & (tmp["post_simple_id"] < y.shape[-1])
|
|
85
|
+
]
|
|
86
|
+
pre_ret: dict = {}
|
|
87
|
+
post_ret: dict = {}
|
|
88
|
+
|
|
89
|
+
if use_polars:
|
|
90
|
+
ptbl = pl.from_pandas(tmp)
|
|
91
|
+
groups = ptbl.group_by("neuropil", maintain_order=True).agg(
|
|
92
|
+
pl.col("pre_simple_id"), pl.col("post_simple_id")
|
|
93
|
+
)
|
|
94
|
+
for row in groups.iter_rows(named=True):
|
|
95
|
+
neuropil = row["neuropil"]
|
|
96
|
+
pre_ids = np.asarray(row["pre_simple_id"], dtype=int)
|
|
97
|
+
post_ids = np.asarray(row["post_simple_id"], dtype=int)
|
|
98
|
+
pre_ret[neuropil] = agg_func(y[..., pre_ids], -1)
|
|
99
|
+
post_ret[neuropil] = agg_func(y[..., post_ids], -1)
|
|
100
|
+
else:
|
|
101
|
+
for neuropil, group in tmp.groupby("neuropil", dropna=True):
|
|
102
|
+
pre_ret[neuropil] = agg_func(y[..., group.pre_simple_id.to_numpy()], -1)
|
|
103
|
+
post_ret[neuropil] = agg_func(
|
|
104
|
+
y[..., group.post_simple_id.to_numpy()], -1
|
|
105
|
+
)
|
|
106
|
+
return pre_ret, post_ret
|
|
107
|
+
|
|
108
|
+
raise ValueError(
|
|
109
|
+
"Invalid `mode` for `agg_by_neuropil`. "
|
|
110
|
+
"Expected one of: {'top_innervated', 'all_innervated'}."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def agg_conn(
|
|
115
|
+
y,
|
|
116
|
+
conn: pd.DataFrame,
|
|
117
|
+
conn_weight: scipy.sparse.sparray | None = None,
|
|
118
|
+
neurons: pd.DataFrame | None = None,
|
|
119
|
+
mode: Literal["neuropil", "neuron"] = "neuron",
|
|
120
|
+
neuron_type_column: str = "cell_type",
|
|
121
|
+
agg: Literal["mean", "sum", "std"] = "mean",
|
|
122
|
+
):
|
|
123
|
+
"""Aggregate connectivity weights by neuropil or neuron-type pairs."""
|
|
124
|
+
if conn_weight is not None:
|
|
125
|
+
conn_weight = conn_weight.tocoo()
|
|
126
|
+
conn = conn.merge(
|
|
127
|
+
pd.DataFrame(
|
|
128
|
+
{
|
|
129
|
+
"pre_simple_id": conn_weight.row,
|
|
130
|
+
"post_simple_id": conn_weight.col,
|
|
131
|
+
"weight": conn_weight.data,
|
|
132
|
+
}
|
|
133
|
+
),
|
|
134
|
+
how="left",
|
|
135
|
+
on=["pre_simple_id", "post_simple_id"],
|
|
136
|
+
)
|
|
137
|
+
if mode == "neuropil":
|
|
138
|
+
return conn.groupby("neuropil")["weight"].agg(agg)
|
|
139
|
+
if mode == "neuron":
|
|
140
|
+
assert neurons is not None, "neurons must be provided for neuron mode"
|
|
141
|
+
conn = conn.merge(
|
|
142
|
+
neurons[["simple_id", neuron_type_column]].rename(
|
|
143
|
+
columns={
|
|
144
|
+
"simple_id": "pre_simple_id",
|
|
145
|
+
neuron_type_column: f"pre_{neuron_type_column}",
|
|
146
|
+
}
|
|
147
|
+
),
|
|
148
|
+
how="left",
|
|
149
|
+
on="pre_simple_id",
|
|
150
|
+
)
|
|
151
|
+
conn = conn.merge(
|
|
152
|
+
neurons[["simple_id", neuron_type_column]].rename(
|
|
153
|
+
columns={
|
|
154
|
+
"simple_id": "post_simple_id",
|
|
155
|
+
neuron_type_column: f"post_{neuron_type_column}",
|
|
156
|
+
}
|
|
157
|
+
),
|
|
158
|
+
how="left",
|
|
159
|
+
on="post_simple_id",
|
|
160
|
+
)
|
|
161
|
+
return conn.groupby(
|
|
162
|
+
[f"pre_{neuron_type_column}", f"post_{neuron_type_column}"]
|
|
163
|
+
)["weight"].agg(agg)
|
|
164
|
+
|
|
165
|
+
raise ValueError(
|
|
166
|
+
"Invalid `mode` for `agg_conn`. Expected one of: {'neuropil', 'neuron'}."
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def build_group_frame(
|
|
171
|
+
values: TensorLike,
|
|
172
|
+
neurons_df: pd.DataFrame,
|
|
173
|
+
group_by: str,
|
|
174
|
+
*,
|
|
175
|
+
simple_id_col: str = "simple_id",
|
|
176
|
+
value_name: str = "value",
|
|
177
|
+
dropna: bool = True,
|
|
178
|
+
) -> pd.DataFrame:
|
|
179
|
+
"""Convert neuron-aligned values to a tidy frame for grouped analyses.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
values: Array/tensor shaped `[N]` or `[..., N]` where the last axis is
|
|
183
|
+
neuron. All leading dimensions are flattened into independent
|
|
184
|
+
samples (e.g., trials, conditions, or time points).
|
|
185
|
+
neurons_df: DataFrame containing at least `simple_id_col` and `group_by`.
|
|
186
|
+
group_by: Column in `neurons_df` used as grouping key.
|
|
187
|
+
simple_id_col: Column mapping rows in `neurons_df` to neuron index in
|
|
188
|
+
`values`.
|
|
189
|
+
value_name: Name for the output value column.
|
|
190
|
+
dropna: Drop missing values in group/value columns when `True`.
|
|
191
|
+
"""
|
|
192
|
+
y = _to_numpy(values)
|
|
193
|
+
if y.ndim < 1:
|
|
194
|
+
raise ValueError("`values` must have at least one dimension.")
|
|
195
|
+
|
|
196
|
+
if simple_id_col not in neurons_df.columns:
|
|
197
|
+
raise ValueError(f"Missing `{simple_id_col}` in `neurons_df`.")
|
|
198
|
+
if group_by not in neurons_df.columns:
|
|
199
|
+
raise ValueError(f"Missing `{group_by}` in `neurons_df`.")
|
|
200
|
+
|
|
201
|
+
metadata = neurons_df.loc[:, [simple_id_col, group_by]].copy()
|
|
202
|
+
if dropna:
|
|
203
|
+
metadata = metadata.dropna(subset=[group_by])
|
|
204
|
+
if metadata.empty:
|
|
205
|
+
raise ValueError("No neuron metadata available after filtering.")
|
|
206
|
+
|
|
207
|
+
if metadata[simple_id_col].duplicated().any():
|
|
208
|
+
raise ValueError(f"`{simple_id_col}` must be unique in `neurons_df`.")
|
|
209
|
+
|
|
210
|
+
try:
|
|
211
|
+
simple_ids = pd.to_numeric(metadata[simple_id_col], errors="raise").to_numpy(
|
|
212
|
+
dtype=np.int64
|
|
213
|
+
)
|
|
214
|
+
except Exception as exc:
|
|
215
|
+
raise ValueError(f"`{simple_id_col}` must be numeric.") from exc
|
|
216
|
+
|
|
217
|
+
n_neurons = y.shape[-1]
|
|
218
|
+
out_of_range = (simple_ids < 0) | (simple_ids >= n_neurons)
|
|
219
|
+
if np.any(out_of_range):
|
|
220
|
+
bad_ids = simple_ids[out_of_range]
|
|
221
|
+
raise ValueError(
|
|
222
|
+
f"Found `{simple_id_col}` outside [0, {n_neurons - 1}]: {bad_ids.tolist()}"
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
selected = y[..., simple_ids]
|
|
226
|
+
n_samples = int(np.prod(selected.shape[:-1], dtype=np.int64))
|
|
227
|
+
n_samples = max(1, n_samples)
|
|
228
|
+
|
|
229
|
+
flattened = selected.reshape(n_samples, len(simple_ids))
|
|
230
|
+
group_labels = metadata[group_by].to_numpy()
|
|
231
|
+
|
|
232
|
+
frame = pd.DataFrame(
|
|
233
|
+
{
|
|
234
|
+
group_by: np.repeat(group_labels, n_samples),
|
|
235
|
+
value_name: flattened.T.reshape(-1),
|
|
236
|
+
}
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
if dropna:
|
|
240
|
+
frame = frame.dropna(subset=[value_name])
|
|
241
|
+
if frame.empty:
|
|
242
|
+
raise ValueError("No values available after filtering.")
|
|
243
|
+
|
|
244
|
+
return frame
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def group_values(
|
|
248
|
+
values: TensorLike,
|
|
249
|
+
neurons_df: pd.DataFrame,
|
|
250
|
+
group_by: str,
|
|
251
|
+
*,
|
|
252
|
+
simple_id_col: str = "simple_id",
|
|
253
|
+
value_name: str = "value",
|
|
254
|
+
group_order: Sequence | None = None,
|
|
255
|
+
dropna: bool = True,
|
|
256
|
+
) -> dict[object, np.ndarray]:
|
|
257
|
+
"""Return grouped value arrays, keyed by group label in plotting order."""
|
|
258
|
+
frame = build_group_frame(
|
|
259
|
+
values,
|
|
260
|
+
neurons_df,
|
|
261
|
+
group_by,
|
|
262
|
+
simple_id_col=simple_id_col,
|
|
263
|
+
value_name=value_name,
|
|
264
|
+
dropna=dropna,
|
|
265
|
+
)
|
|
266
|
+
order = _resolve_group_order(frame, group_by, group_order)
|
|
267
|
+
return {
|
|
268
|
+
group: frame.loc[frame[group_by] == group, value_name].to_numpy(dtype=float)
|
|
269
|
+
for group in order
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def group_summary(
|
|
274
|
+
values: TensorLike,
|
|
275
|
+
neurons_df: pd.DataFrame,
|
|
276
|
+
group_by: str,
|
|
277
|
+
*,
|
|
278
|
+
simple_id_col: str = "simple_id",
|
|
279
|
+
value_name: str = "value",
|
|
280
|
+
group_order: Sequence | None = None,
|
|
281
|
+
dropna: bool = True,
|
|
282
|
+
) -> pd.DataFrame:
|
|
283
|
+
"""Compute per-group summary statistics from neuron-aligned values."""
|
|
284
|
+
grouped = group_values(
|
|
285
|
+
values,
|
|
286
|
+
neurons_df,
|
|
287
|
+
group_by,
|
|
288
|
+
simple_id_col=simple_id_col,
|
|
289
|
+
value_name=value_name,
|
|
290
|
+
group_order=group_order,
|
|
291
|
+
dropna=dropna,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
rows = []
|
|
295
|
+
for group, vals in grouped.items():
|
|
296
|
+
rows.append(
|
|
297
|
+
{
|
|
298
|
+
group_by: group,
|
|
299
|
+
"n": int(vals.size),
|
|
300
|
+
"mean": float(np.mean(vals)),
|
|
301
|
+
"std": float(np.std(vals)),
|
|
302
|
+
"min": float(np.min(vals)),
|
|
303
|
+
"q25": float(np.quantile(vals, 0.25)),
|
|
304
|
+
"median": float(np.median(vals)),
|
|
305
|
+
"q75": float(np.quantile(vals, 0.75)),
|
|
306
|
+
"max": float(np.max(vals)),
|
|
307
|
+
}
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
return pd.DataFrame(rows)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def group_ecdf(
|
|
314
|
+
values: TensorLike,
|
|
315
|
+
neurons_df: pd.DataFrame,
|
|
316
|
+
group_by: str,
|
|
317
|
+
*,
|
|
318
|
+
simple_id_col: str = "simple_id",
|
|
319
|
+
value_name: str = "value",
|
|
320
|
+
group_order: Sequence | None = None,
|
|
321
|
+
dropna: bool = True,
|
|
322
|
+
) -> dict[object, pd.DataFrame]:
|
|
323
|
+
"""Compute grouped ECDF points ready for plotting or analysis."""
|
|
324
|
+
grouped = group_values(
|
|
325
|
+
values,
|
|
326
|
+
neurons_df,
|
|
327
|
+
group_by,
|
|
328
|
+
simple_id_col=simple_id_col,
|
|
329
|
+
value_name=value_name,
|
|
330
|
+
group_order=group_order,
|
|
331
|
+
dropna=dropna,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
ret: dict[object, pd.DataFrame] = {}
|
|
335
|
+
for group, vals in grouped.items():
|
|
336
|
+
x = np.sort(vals)
|
|
337
|
+
y = np.arange(1, len(x) + 1, dtype=float) / len(x)
|
|
338
|
+
ret[group] = pd.DataFrame({value_name: x, "ecdf": y})
|
|
339
|
+
return ret
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def _resolve_group_order(
|
|
343
|
+
frame: pd.DataFrame,
|
|
344
|
+
group_by: str,
|
|
345
|
+
group_order: Sequence | None,
|
|
346
|
+
) -> list[object]:
|
|
347
|
+
if group_order is None:
|
|
348
|
+
return list(pd.unique(frame[group_by]))
|
|
349
|
+
|
|
350
|
+
requested = list(group_order)
|
|
351
|
+
available = set(frame[group_by].tolist())
|
|
352
|
+
missing = [group for group in requested if group not in available]
|
|
353
|
+
if missing:
|
|
354
|
+
raise ValueError(f"`group_order` contains unknown groups: {missing}")
|
|
355
|
+
return requested
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def _to_numpy(values: TensorLike) -> np.ndarray:
|
|
359
|
+
if isinstance(values, torch.Tensor):
|
|
360
|
+
return values.detach().cpu().numpy()
|
|
361
|
+
if isinstance(values, np.ndarray):
|
|
362
|
+
return values
|
|
363
|
+
raise TypeError("`values` must be a numpy array or torch tensor.")
|