btorch 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. btorch/__init__.py +15 -0
  2. btorch/analysis/__init__.py +124 -0
  3. btorch/analysis/aggregation.py +363 -0
  4. btorch/analysis/branching.py +315 -0
  5. btorch/analysis/clustering.py +53 -0
  6. btorch/analysis/connectivity.py +251 -0
  7. btorch/analysis/dynamic_tools/README.md +81 -0
  8. btorch/analysis/dynamic_tools/__init__.py +0 -0
  9. btorch/analysis/dynamic_tools/attractor_dynamics.py +121 -0
  10. btorch/analysis/dynamic_tools/complexity.py +324 -0
  11. btorch/analysis/dynamic_tools/criticality.py +240 -0
  12. btorch/analysis/dynamic_tools/ei_balance.py +514 -0
  13. btorch/analysis/dynamic_tools/lyapunov_dynamics.py +120 -0
  14. btorch/analysis/dynamic_tools/micro_scale.py +192 -0
  15. btorch/analysis/dynamic_tools/spiking.py +1206 -0
  16. btorch/analysis/metrics.py +47 -0
  17. btorch/analysis/spiking.py +1421 -0
  18. btorch/analysis/statistics.py +954 -0
  19. btorch/analysis/tuning.py +76 -0
  20. btorch/analysis/two_compartment_fit.py +1936 -0
  21. btorch/analysis/voltage.py +56 -0
  22. btorch/config.py +43 -0
  23. btorch/connectome/__init__.py +9 -0
  24. btorch/connectome/augment.py +386 -0
  25. btorch/connectome/connection.py +1027 -0
  26. btorch/datasets/__init__.py +75 -0
  27. btorch/datasets/noise.py +1129 -0
  28. btorch/datasets/transforms.py +140 -0
  29. btorch/io/__init__.py +48 -0
  30. btorch/io/serialization.py +818 -0
  31. btorch/jit.py +22 -0
  32. btorch/models/__init__.py +43 -0
  33. btorch/models/base.py +1075 -0
  34. btorch/models/bilinear.py +78 -0
  35. btorch/models/connection_conversion.py +1204 -0
  36. btorch/models/constrain.py +14 -0
  37. btorch/models/conv.py +153 -0
  38. btorch/models/dlif.py +228 -0
  39. btorch/models/environ.py +139 -0
  40. btorch/models/functional.py +448 -0
  41. btorch/models/hex/__init__.py +16 -0
  42. btorch/models/hex/conv.py +102 -0
  43. btorch/models/hex/eye.py +376 -0
  44. btorch/models/history.py +358 -0
  45. btorch/models/init.py +150 -0
  46. btorch/models/linear.py +693 -0
  47. btorch/models/neurons/__init__.py +17 -0
  48. btorch/models/neurons/alif.py +407 -0
  49. btorch/models/neurons/glif.py +441 -0
  50. btorch/models/neurons/izhikevich.py +322 -0
  51. btorch/models/neurons/lif.py +254 -0
  52. btorch/models/neurons/mixed.py +193 -0
  53. btorch/models/neurons/two_compartment.py +407 -0
  54. btorch/models/ode.py +55 -0
  55. btorch/models/parametrize.py +262 -0
  56. btorch/models/regularizer.py +221 -0
  57. btorch/models/rnn.py +584 -0
  58. btorch/models/scale.py +295 -0
  59. btorch/models/shape.py +71 -0
  60. btorch/models/surrogate/__init__.py +26 -0
  61. btorch/models/surrogate/atan.py +117 -0
  62. btorch/models/surrogate/base.py +62 -0
  63. btorch/models/surrogate/erf.py +95 -0
  64. btorch/models/surrogate/poisson_random.py +94 -0
  65. btorch/models/surrogate/sigmoid.py +65 -0
  66. btorch/models/surrogate/superspike.py +71 -0
  67. btorch/models/surrogate/triangle.py +73 -0
  68. btorch/models/synapse.py +1075 -0
  69. btorch/py.typed +0 -0
  70. btorch/types.py +5 -0
  71. btorch/utils/__init__.py +6 -0
  72. btorch/utils/bench.py +205 -0
  73. btorch/utils/conf.py +613 -0
  74. btorch/utils/dict_utils.py +127 -0
  75. btorch/utils/file.py +160 -0
  76. btorch/utils/grad_checkpoint/__init__.py +1 -0
  77. btorch/utils/grad_checkpoint/checkpoint.py +66 -0
  78. btorch/utils/grad_checkpoint/test_checkpoint.py +143 -0
  79. btorch/utils/hdf5_utils.py +98 -0
  80. btorch/utils/hex/__init__.py +181 -0
  81. btorch/utils/hex/coords.py +152 -0
  82. btorch/utils/hex/data.py +552 -0
  83. btorch/utils/hex/distance.py +74 -0
  84. btorch/utils/hex/doubled.py +240 -0
  85. btorch/utils/hex/line.py +110 -0
  86. btorch/utils/hex/neighbor.py +121 -0
  87. btorch/utils/hex/offset.py +327 -0
  88. btorch/utils/hex/range.py +138 -0
  89. btorch/utils/hex/resolve.py +188 -0
  90. btorch/utils/hex/storage.py +399 -0
  91. btorch/utils/hex/transform.py +209 -0
  92. btorch/utils/pandas_utils.py +38 -0
  93. btorch/utils/yaml_utils.py +54 -0
  94. btorch/visualisation/__init__.py +114 -0
  95. btorch/visualisation/aggregation.py +565 -0
  96. btorch/visualisation/dynamics.py +980 -0
  97. btorch/visualisation/hex/__init__.py +42 -0
  98. btorch/visualisation/hex/animate.py +229 -0
  99. btorch/visualisation/hex/interactive.py +943 -0
  100. btorch/visualisation/hex/receptive_field.py +251 -0
  101. btorch/visualisation/hex/static.py +682 -0
  102. btorch/visualisation/network.py +50 -0
  103. btorch/visualisation/timeseries.py +2358 -0
  104. btorch/visualisation/tuning.py +142 -0
  105. btorch-0.1.0.dist-info/METADATA +416 -0
  106. btorch-0.1.0.dist-info/RECORD +108 -0
  107. btorch-0.1.0.dist-info/WHEEL +4 -0
  108. btorch-0.1.0.dist-info/licenses/LICENSE +201 -0
btorch/__init__.py ADDED
@@ -0,0 +1,15 @@
1
+ """Public package entrypoint for Btorch."""
2
+
3
+ import importlib.metadata
4
+
5
+ from btorch import config, jit
6
+
7
+
8
+ __version__ = importlib.metadata.version(__name__)
9
+
10
+
11
+ __all__ = [
12
+ "__version__",
13
+ "config",
14
+ "jit",
15
+ ]
@@ -0,0 +1,124 @@
1
+ from .aggregation import (
2
+ agg_by_neuron,
3
+ agg_by_neuropil,
4
+ agg_conn,
5
+ build_group_frame,
6
+ group_ecdf,
7
+ group_summary,
8
+ group_values,
9
+ )
10
+ from .branching import branching_ratio
11
+ from .connectivity import HopDistanceModel, compute_ie_ratio
12
+ from .metrics import indices_to_mask, select_on_metric
13
+ from .spiking import (
14
+ compute_raster,
15
+ compute_spectrum,
16
+ cv_temporal,
17
+ fano,
18
+ fano_population,
19
+ fano_sweep,
20
+ fano_temporal,
21
+ firing_rate,
22
+ isi_cv,
23
+ isi_cv_population,
24
+ kurtosis,
25
+ kurtosis_population,
26
+ local_variation,
27
+ )
28
+ from .statistics import (
29
+ StatChoice,
30
+ compute_log_hist,
31
+ describe_array,
32
+ use_percentiles,
33
+ use_stats,
34
+ )
35
+ from .two_compartment_fit import (
36
+ DEFAULT_TWO_COMPARTMENT_FIT_STAGES,
37
+ DEFAULT_TWO_COMPARTMENT_PARAM_BOUNDS,
38
+ AllenSweepBatch,
39
+ FitEvaluation,
40
+ TwoCompartmentFitStage,
41
+ choose_current_clamp_sweeps,
42
+ detect_spikes_from_voltage,
43
+ evaluate_fit_across_sweeps,
44
+ evaluate_two_compartment_fit,
45
+ exponential_filter_spike_train,
46
+ filter_mouse_visp_l5_pyramidal_cells,
47
+ fit_two_compartment_model,
48
+ get_cell_types_cache,
49
+ load_allen_sweep,
50
+ mask_post_spike_voltage_samples,
51
+ plot_two_compartment_fit,
52
+ query_mouse_visp_l5_pyramidal_cells,
53
+ resample_trace,
54
+ rollout_two_compartment,
55
+ save_fit_report,
56
+ spike_timing_loss,
57
+ spike_timing_stats,
58
+ two_compartment_loss,
59
+ )
60
+ from .voltage import suggest_skip_timestep, voltage_overshoot
61
+
62
+
63
+ __all__ = [
64
+ "agg_by_neuropil",
65
+ "agg_by_neuron",
66
+ "agg_conn",
67
+ "build_group_frame",
68
+ "group_values",
69
+ "group_summary",
70
+ "group_ecdf",
71
+ "branching_ratio",
72
+ "HopDistanceModel",
73
+ "compute_ie_ratio",
74
+ "indices_to_mask",
75
+ "select_on_metric",
76
+ # New simplified API
77
+ "isi_cv",
78
+ "fano",
79
+ "kurtosis",
80
+ "local_variation",
81
+ # Population metrics
82
+ "isi_cv_population",
83
+ "fano_population",
84
+ "kurtosis_population",
85
+ # Temporal variants
86
+ "cv_temporal",
87
+ "fano_temporal",
88
+ # Sweep functions
89
+ "fano_sweep",
90
+ # Utilities
91
+ "firing_rate",
92
+ "compute_raster",
93
+ "compute_log_hist",
94
+ "compute_spectrum",
95
+ "describe_array",
96
+ "AllenSweepBatch",
97
+ "DEFAULT_TWO_COMPARTMENT_FIT_STAGES",
98
+ "DEFAULT_TWO_COMPARTMENT_PARAM_BOUNDS",
99
+ "FitEvaluation",
100
+ "TwoCompartmentFitStage",
101
+ "choose_current_clamp_sweeps",
102
+ "detect_spikes_from_voltage",
103
+ "evaluate_fit_across_sweeps",
104
+ "evaluate_two_compartment_fit",
105
+ "exponential_filter_spike_train",
106
+ "filter_mouse_visp_l5_pyramidal_cells",
107
+ "fit_two_compartment_model",
108
+ "get_cell_types_cache",
109
+ "load_allen_sweep",
110
+ "mask_post_spike_voltage_samples",
111
+ "plot_two_compartment_fit",
112
+ "query_mouse_visp_l5_pyramidal_cells",
113
+ "resample_trace",
114
+ "rollout_two_compartment",
115
+ "save_fit_report",
116
+ "spike_timing_loss",
117
+ "spike_timing_stats",
118
+ "two_compartment_loss",
119
+ "suggest_skip_timestep",
120
+ "voltage_overshoot",
121
+ "StatChoice",
122
+ "use_stats",
123
+ "use_percentiles",
124
+ ]
@@ -0,0 +1,363 @@
1
+ from collections.abc import Sequence
2
+ from typing import Literal
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import scipy.sparse
7
+ import torch
8
+
9
+ from ..types import TensorLike
10
+
11
+
12
+ def agg_by_neuron(
13
+ y,
14
+ neurons: pd.DataFrame,
15
+ agg: Literal["mean", "sum", "std"] = "mean",
16
+ neuron_type_column: str = "cell_type",
17
+ **kwargs,
18
+ ) -> dict:
19
+ """Aggregate data by neuron type."""
20
+ agg_func = getattr(np, agg) if isinstance(y, np.ndarray) else getattr(torch, agg)
21
+ ret = {}
22
+ for neuron_type, group in neurons.groupby(
23
+ neuron_type_column, dropna=True, **kwargs
24
+ ):
25
+ ret[neuron_type] = agg_func(y[..., group.simple_id.to_numpy()], -1)
26
+ return ret
27
+
28
+
29
+ def agg_by_neuropil(
30
+ y,
31
+ neurons: pd.DataFrame | None = None,
32
+ connections: pd.DataFrame | None = None,
33
+ mode: Literal["top_innervated", "all_innervated"] = "all_innervated",
34
+ agg: Literal["mean", "sum", "std"] = "mean",
35
+ use_polars: bool = False,
36
+ ):
37
+ """Aggregate activations by neuropil under a validated aggregation mode."""
38
+ agg_func = getattr(np, agg) if isinstance(y, np.ndarray) else getattr(torch, agg)
39
+ if use_polars:
40
+ try:
41
+ import polars as pl
42
+ except ImportError:
43
+ use_polars = False
44
+
45
+ if mode == "top_innervated":
46
+ assert neurons is not None, "neurons must be provided for top_innervated mode"
47
+ tmp = neurons[["group", "simple_id"]].copy()
48
+ tmp = tmp[tmp["simple_id"] < y.shape[-1]]
49
+ pre_ret: dict = {}
50
+ post_ret: dict = {}
51
+
52
+ if use_polars:
53
+ tmp["pre"] = tmp["group"].str.split(".").str[0]
54
+ tmp["post"] = tmp["group"].str.split(".").str[-1]
55
+ ptbl = pl.from_pandas(tmp)
56
+
57
+ pre_groups = ptbl.group_by("pre", maintain_order=True).agg(
58
+ pl.col("simple_id")
59
+ )
60
+ for row in pre_groups.iter_rows(named=True):
61
+ sid = np.asarray(row["simple_id"], dtype=int)
62
+ pre_ret[row["pre"]] = agg_func(y[..., sid], -1)
63
+
64
+ post_groups = ptbl.group_by("post", maintain_order=True).agg(
65
+ pl.col("simple_id")
66
+ )
67
+ for row in post_groups.iter_rows(named=True):
68
+ sid = np.asarray(row["simple_id"], dtype=int)
69
+ post_ret[row["post"]] = agg_func(y[..., sid], -1)
70
+ else:
71
+ tmp["pre"] = tmp["group"].apply(lambda x: x.split(".")[0])
72
+ tmp["post"] = tmp["group"].apply(lambda x: x.split(".")[-1])
73
+ for pre, group in tmp.groupby("pre", dropna=True):
74
+ pre_ret[pre] = agg_func(y[..., group.simple_id], -1)
75
+ for post, group in tmp.groupby("post", dropna=True):
76
+ post_ret[post] = agg_func(y[..., group.simple_id], -1)
77
+ return pre_ret, post_ret
78
+ if mode == "all_innervated":
79
+ assert (
80
+ connections is not None
81
+ ), "connections must be provided for all_innervated mode"
82
+ tmp = connections[["pre_simple_id", "post_simple_id", "neuropil"]]
83
+ tmp = tmp[
84
+ (tmp["pre_simple_id"] < y.shape[-1]) & (tmp["post_simple_id"] < y.shape[-1])
85
+ ]
86
+ pre_ret: dict = {}
87
+ post_ret: dict = {}
88
+
89
+ if use_polars:
90
+ ptbl = pl.from_pandas(tmp)
91
+ groups = ptbl.group_by("neuropil", maintain_order=True).agg(
92
+ pl.col("pre_simple_id"), pl.col("post_simple_id")
93
+ )
94
+ for row in groups.iter_rows(named=True):
95
+ neuropil = row["neuropil"]
96
+ pre_ids = np.asarray(row["pre_simple_id"], dtype=int)
97
+ post_ids = np.asarray(row["post_simple_id"], dtype=int)
98
+ pre_ret[neuropil] = agg_func(y[..., pre_ids], -1)
99
+ post_ret[neuropil] = agg_func(y[..., post_ids], -1)
100
+ else:
101
+ for neuropil, group in tmp.groupby("neuropil", dropna=True):
102
+ pre_ret[neuropil] = agg_func(y[..., group.pre_simple_id.to_numpy()], -1)
103
+ post_ret[neuropil] = agg_func(
104
+ y[..., group.post_simple_id.to_numpy()], -1
105
+ )
106
+ return pre_ret, post_ret
107
+
108
+ raise ValueError(
109
+ "Invalid `mode` for `agg_by_neuropil`. "
110
+ "Expected one of: {'top_innervated', 'all_innervated'}."
111
+ )
112
+
113
+
114
+ def agg_conn(
115
+ y,
116
+ conn: pd.DataFrame,
117
+ conn_weight: scipy.sparse.sparray | None = None,
118
+ neurons: pd.DataFrame | None = None,
119
+ mode: Literal["neuropil", "neuron"] = "neuron",
120
+ neuron_type_column: str = "cell_type",
121
+ agg: Literal["mean", "sum", "std"] = "mean",
122
+ ):
123
+ """Aggregate connectivity weights by neuropil or neuron-type pairs."""
124
+ if conn_weight is not None:
125
+ conn_weight = conn_weight.tocoo()
126
+ conn = conn.merge(
127
+ pd.DataFrame(
128
+ {
129
+ "pre_simple_id": conn_weight.row,
130
+ "post_simple_id": conn_weight.col,
131
+ "weight": conn_weight.data,
132
+ }
133
+ ),
134
+ how="left",
135
+ on=["pre_simple_id", "post_simple_id"],
136
+ )
137
+ if mode == "neuropil":
138
+ return conn.groupby("neuropil")["weight"].agg(agg)
139
+ if mode == "neuron":
140
+ assert neurons is not None, "neurons must be provided for neuron mode"
141
+ conn = conn.merge(
142
+ neurons[["simple_id", neuron_type_column]].rename(
143
+ columns={
144
+ "simple_id": "pre_simple_id",
145
+ neuron_type_column: f"pre_{neuron_type_column}",
146
+ }
147
+ ),
148
+ how="left",
149
+ on="pre_simple_id",
150
+ )
151
+ conn = conn.merge(
152
+ neurons[["simple_id", neuron_type_column]].rename(
153
+ columns={
154
+ "simple_id": "post_simple_id",
155
+ neuron_type_column: f"post_{neuron_type_column}",
156
+ }
157
+ ),
158
+ how="left",
159
+ on="post_simple_id",
160
+ )
161
+ return conn.groupby(
162
+ [f"pre_{neuron_type_column}", f"post_{neuron_type_column}"]
163
+ )["weight"].agg(agg)
164
+
165
+ raise ValueError(
166
+ "Invalid `mode` for `agg_conn`. Expected one of: {'neuropil', 'neuron'}."
167
+ )
168
+
169
+
170
+ def build_group_frame(
171
+ values: TensorLike,
172
+ neurons_df: pd.DataFrame,
173
+ group_by: str,
174
+ *,
175
+ simple_id_col: str = "simple_id",
176
+ value_name: str = "value",
177
+ dropna: bool = True,
178
+ ) -> pd.DataFrame:
179
+ """Convert neuron-aligned values to a tidy frame for grouped analyses.
180
+
181
+ Args:
182
+ values: Array/tensor shaped `[N]` or `[..., N]` where the last axis is
183
+ neuron. All leading dimensions are flattened into independent
184
+ samples (e.g., trials, conditions, or time points).
185
+ neurons_df: DataFrame containing at least `simple_id_col` and `group_by`.
186
+ group_by: Column in `neurons_df` used as grouping key.
187
+ simple_id_col: Column mapping rows in `neurons_df` to neuron index in
188
+ `values`.
189
+ value_name: Name for the output value column.
190
+ dropna: Drop missing values in group/value columns when `True`.
191
+ """
192
+ y = _to_numpy(values)
193
+ if y.ndim < 1:
194
+ raise ValueError("`values` must have at least one dimension.")
195
+
196
+ if simple_id_col not in neurons_df.columns:
197
+ raise ValueError(f"Missing `{simple_id_col}` in `neurons_df`.")
198
+ if group_by not in neurons_df.columns:
199
+ raise ValueError(f"Missing `{group_by}` in `neurons_df`.")
200
+
201
+ metadata = neurons_df.loc[:, [simple_id_col, group_by]].copy()
202
+ if dropna:
203
+ metadata = metadata.dropna(subset=[group_by])
204
+ if metadata.empty:
205
+ raise ValueError("No neuron metadata available after filtering.")
206
+
207
+ if metadata[simple_id_col].duplicated().any():
208
+ raise ValueError(f"`{simple_id_col}` must be unique in `neurons_df`.")
209
+
210
+ try:
211
+ simple_ids = pd.to_numeric(metadata[simple_id_col], errors="raise").to_numpy(
212
+ dtype=np.int64
213
+ )
214
+ except Exception as exc:
215
+ raise ValueError(f"`{simple_id_col}` must be numeric.") from exc
216
+
217
+ n_neurons = y.shape[-1]
218
+ out_of_range = (simple_ids < 0) | (simple_ids >= n_neurons)
219
+ if np.any(out_of_range):
220
+ bad_ids = simple_ids[out_of_range]
221
+ raise ValueError(
222
+ f"Found `{simple_id_col}` outside [0, {n_neurons - 1}]: {bad_ids.tolist()}"
223
+ )
224
+
225
+ selected = y[..., simple_ids]
226
+ n_samples = int(np.prod(selected.shape[:-1], dtype=np.int64))
227
+ n_samples = max(1, n_samples)
228
+
229
+ flattened = selected.reshape(n_samples, len(simple_ids))
230
+ group_labels = metadata[group_by].to_numpy()
231
+
232
+ frame = pd.DataFrame(
233
+ {
234
+ group_by: np.repeat(group_labels, n_samples),
235
+ value_name: flattened.T.reshape(-1),
236
+ }
237
+ )
238
+
239
+ if dropna:
240
+ frame = frame.dropna(subset=[value_name])
241
+ if frame.empty:
242
+ raise ValueError("No values available after filtering.")
243
+
244
+ return frame
245
+
246
+
247
+ def group_values(
248
+ values: TensorLike,
249
+ neurons_df: pd.DataFrame,
250
+ group_by: str,
251
+ *,
252
+ simple_id_col: str = "simple_id",
253
+ value_name: str = "value",
254
+ group_order: Sequence | None = None,
255
+ dropna: bool = True,
256
+ ) -> dict[object, np.ndarray]:
257
+ """Return grouped value arrays, keyed by group label in plotting order."""
258
+ frame = build_group_frame(
259
+ values,
260
+ neurons_df,
261
+ group_by,
262
+ simple_id_col=simple_id_col,
263
+ value_name=value_name,
264
+ dropna=dropna,
265
+ )
266
+ order = _resolve_group_order(frame, group_by, group_order)
267
+ return {
268
+ group: frame.loc[frame[group_by] == group, value_name].to_numpy(dtype=float)
269
+ for group in order
270
+ }
271
+
272
+
273
+ def group_summary(
274
+ values: TensorLike,
275
+ neurons_df: pd.DataFrame,
276
+ group_by: str,
277
+ *,
278
+ simple_id_col: str = "simple_id",
279
+ value_name: str = "value",
280
+ group_order: Sequence | None = None,
281
+ dropna: bool = True,
282
+ ) -> pd.DataFrame:
283
+ """Compute per-group summary statistics from neuron-aligned values."""
284
+ grouped = group_values(
285
+ values,
286
+ neurons_df,
287
+ group_by,
288
+ simple_id_col=simple_id_col,
289
+ value_name=value_name,
290
+ group_order=group_order,
291
+ dropna=dropna,
292
+ )
293
+
294
+ rows = []
295
+ for group, vals in grouped.items():
296
+ rows.append(
297
+ {
298
+ group_by: group,
299
+ "n": int(vals.size),
300
+ "mean": float(np.mean(vals)),
301
+ "std": float(np.std(vals)),
302
+ "min": float(np.min(vals)),
303
+ "q25": float(np.quantile(vals, 0.25)),
304
+ "median": float(np.median(vals)),
305
+ "q75": float(np.quantile(vals, 0.75)),
306
+ "max": float(np.max(vals)),
307
+ }
308
+ )
309
+
310
+ return pd.DataFrame(rows)
311
+
312
+
313
+ def group_ecdf(
314
+ values: TensorLike,
315
+ neurons_df: pd.DataFrame,
316
+ group_by: str,
317
+ *,
318
+ simple_id_col: str = "simple_id",
319
+ value_name: str = "value",
320
+ group_order: Sequence | None = None,
321
+ dropna: bool = True,
322
+ ) -> dict[object, pd.DataFrame]:
323
+ """Compute grouped ECDF points ready for plotting or analysis."""
324
+ grouped = group_values(
325
+ values,
326
+ neurons_df,
327
+ group_by,
328
+ simple_id_col=simple_id_col,
329
+ value_name=value_name,
330
+ group_order=group_order,
331
+ dropna=dropna,
332
+ )
333
+
334
+ ret: dict[object, pd.DataFrame] = {}
335
+ for group, vals in grouped.items():
336
+ x = np.sort(vals)
337
+ y = np.arange(1, len(x) + 1, dtype=float) / len(x)
338
+ ret[group] = pd.DataFrame({value_name: x, "ecdf": y})
339
+ return ret
340
+
341
+
342
+ def _resolve_group_order(
343
+ frame: pd.DataFrame,
344
+ group_by: str,
345
+ group_order: Sequence | None,
346
+ ) -> list[object]:
347
+ if group_order is None:
348
+ return list(pd.unique(frame[group_by]))
349
+
350
+ requested = list(group_order)
351
+ available = set(frame[group_by].tolist())
352
+ missing = [group for group in requested if group not in available]
353
+ if missing:
354
+ raise ValueError(f"`group_order` contains unknown groups: {missing}")
355
+ return requested
356
+
357
+
358
+ def _to_numpy(values: TensorLike) -> np.ndarray:
359
+ if isinstance(values, torch.Tensor):
360
+ return values.detach().cpu().numpy()
361
+ if isinstance(values, np.ndarray):
362
+ return values
363
+ raise TypeError("`values` must be a numpy array or torch tensor.")