gengeneeval 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geneval/__init__.py +129 -0
- geneval/cli.py +333 -0
- geneval/config.py +141 -0
- geneval/core.py +41 -0
- geneval/data/__init__.py +23 -0
- geneval/data/gene_expression_datamodule.py +211 -0
- geneval/data/loader.py +437 -0
- geneval/evaluator.py +359 -0
- geneval/evaluators/__init__.py +4 -0
- geneval/evaluators/base_evaluator.py +178 -0
- geneval/evaluators/gene_expression_evaluator.py +218 -0
- geneval/metrics/__init__.py +65 -0
- geneval/metrics/base_metric.py +229 -0
- geneval/metrics/correlation.py +232 -0
- geneval/metrics/distances.py +516 -0
- geneval/metrics/metrics.py +134 -0
- geneval/models/__init__.py +1 -0
- geneval/models/base_model.py +53 -0
- geneval/results.py +334 -0
- geneval/testing.py +393 -0
- geneval/utils/__init__.py +1 -0
- geneval/utils/io.py +27 -0
- geneval/utils/preprocessing.py +82 -0
- geneval/visualization/__init__.py +38 -0
- geneval/visualization/plots.py +499 -0
- geneval/visualization/visualizer.py +1096 -0
- gengeneeval-0.1.0.dist-info/METADATA +172 -0
- gengeneeval-0.1.0.dist-info/RECORD +31 -0
- gengeneeval-0.1.0.dist-info/WHEEL +4 -0
- gengeneeval-0.1.0.dist-info/entry_points.txt +3 -0
- gengeneeval-0.1.0.dist-info/licenses/LICENSE +9 -0
geneval/evaluator.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Comprehensive evaluator for gene expression data.
|
|
3
|
+
|
|
4
|
+
Computes all metrics between real and generated data, organized by conditions and splits.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import Dict, List, Optional, Union, Type, Any
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
import numpy as np
|
|
11
|
+
import warnings
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
|
|
14
|
+
from .data.loader import GeneExpressionDataLoader, load_data
|
|
15
|
+
from .metrics.base_metric import BaseMetric, MetricResult
|
|
16
|
+
from .metrics.correlation import (
|
|
17
|
+
PearsonCorrelation,
|
|
18
|
+
SpearmanCorrelation,
|
|
19
|
+
MeanPearsonCorrelation,
|
|
20
|
+
MeanSpearmanCorrelation,
|
|
21
|
+
)
|
|
22
|
+
from .metrics.distances import (
|
|
23
|
+
Wasserstein1Distance,
|
|
24
|
+
Wasserstein2Distance,
|
|
25
|
+
MMDDistance,
|
|
26
|
+
EnergyDistance,
|
|
27
|
+
MultivariateWasserstein,
|
|
28
|
+
MultivariateMMD,
|
|
29
|
+
)
|
|
30
|
+
from .results import EvaluationResult, SplitResult, ConditionResult
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# Default metrics to compute
|
|
34
|
+
DEFAULT_METRICS = [
|
|
35
|
+
PearsonCorrelation,
|
|
36
|
+
SpearmanCorrelation,
|
|
37
|
+
MeanPearsonCorrelation,
|
|
38
|
+
MeanSpearmanCorrelation,
|
|
39
|
+
Wasserstein1Distance,
|
|
40
|
+
Wasserstein2Distance,
|
|
41
|
+
MMDDistance,
|
|
42
|
+
EnergyDistance,
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class GeneEvalEvaluator:
|
|
47
|
+
"""
|
|
48
|
+
Main evaluator class for gene expression data.
|
|
49
|
+
|
|
50
|
+
Computes comprehensive metrics between real and generated datasets,
|
|
51
|
+
supporting multiple conditions, splits, and metric types.
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
data_loader : GeneExpressionDataLoader
|
|
56
|
+
Loaded and aligned data loader
|
|
57
|
+
metrics : List[BaseMetric or Type[BaseMetric]], optional
|
|
58
|
+
Metrics to compute. If None, uses default set.
|
|
59
|
+
aggregate_method : str
|
|
60
|
+
How to aggregate per-gene values (mean, median, etc.)
|
|
61
|
+
include_multivariate : bool
|
|
62
|
+
Whether to include multivariate (whole-space) metrics
|
|
63
|
+
verbose : bool
|
|
64
|
+
Whether to print progress
|
|
65
|
+
|
|
66
|
+
Examples
|
|
67
|
+
--------
|
|
68
|
+
>>> loader = load_data("real.h5ad", "generated.h5ad", ["perturbation"])
|
|
69
|
+
>>> evaluator = GeneEvalEvaluator(loader)
|
|
70
|
+
>>> results = evaluator.evaluate()
|
|
71
|
+
>>> results.save("output/")
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
data_loader: GeneExpressionDataLoader,
|
|
77
|
+
metrics: Optional[List[Union[BaseMetric, Type[BaseMetric]]]] = None,
|
|
78
|
+
aggregate_method: str = "mean",
|
|
79
|
+
include_multivariate: bool = True,
|
|
80
|
+
verbose: bool = True,
|
|
81
|
+
):
|
|
82
|
+
self.data_loader = data_loader
|
|
83
|
+
self.aggregate_method = aggregate_method
|
|
84
|
+
self.include_multivariate = include_multivariate
|
|
85
|
+
self.verbose = verbose
|
|
86
|
+
|
|
87
|
+
# Initialize metrics
|
|
88
|
+
self.metrics: List[BaseMetric] = []
|
|
89
|
+
metric_classes = metrics or DEFAULT_METRICS
|
|
90
|
+
|
|
91
|
+
for m in metric_classes:
|
|
92
|
+
if isinstance(m, type):
|
|
93
|
+
# It's a class, instantiate it
|
|
94
|
+
self.metrics.append(m())
|
|
95
|
+
else:
|
|
96
|
+
# It's already an instance
|
|
97
|
+
self.metrics.append(m)
|
|
98
|
+
|
|
99
|
+
# Add multivariate metrics if requested
|
|
100
|
+
if include_multivariate:
|
|
101
|
+
self.metrics.extend([
|
|
102
|
+
MultivariateWasserstein(),
|
|
103
|
+
MultivariateMMD(),
|
|
104
|
+
])
|
|
105
|
+
|
|
106
|
+
def _log(self, msg: str):
|
|
107
|
+
"""Print message if verbose."""
|
|
108
|
+
if self.verbose:
|
|
109
|
+
print(msg)
|
|
110
|
+
|
|
111
|
+
def evaluate(
|
|
112
|
+
self,
|
|
113
|
+
splits: Optional[List[str]] = None,
|
|
114
|
+
save_dir: Optional[Union[str, Path]] = None,
|
|
115
|
+
) -> EvaluationResult:
|
|
116
|
+
"""
|
|
117
|
+
Run full evaluation on all conditions and splits.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
splits : List[str], optional
|
|
122
|
+
Splits to evaluate. If None, evaluates all available splits.
|
|
123
|
+
save_dir : str or Path, optional
|
|
124
|
+
If provided, save results to this directory
|
|
125
|
+
|
|
126
|
+
Returns
|
|
127
|
+
-------
|
|
128
|
+
EvaluationResult
|
|
129
|
+
Complete evaluation results
|
|
130
|
+
"""
|
|
131
|
+
# Get available splits
|
|
132
|
+
available_splits = self.data_loader.get_splits()
|
|
133
|
+
|
|
134
|
+
if splits is None:
|
|
135
|
+
splits = available_splits
|
|
136
|
+
else:
|
|
137
|
+
# Validate requested splits
|
|
138
|
+
invalid = set(splits) - set(available_splits)
|
|
139
|
+
if invalid:
|
|
140
|
+
warnings.warn(f"Requested splits not found: {invalid}")
|
|
141
|
+
splits = [s for s in splits if s in available_splits]
|
|
142
|
+
|
|
143
|
+
self._log(f"Evaluating {len(splits)} splits: {splits}")
|
|
144
|
+
self._log(f"Using {len(self.metrics)} metrics: {[m.name for m in self.metrics]}")
|
|
145
|
+
|
|
146
|
+
# Create result container
|
|
147
|
+
result = EvaluationResult(
|
|
148
|
+
gene_names=self.data_loader.gene_names,
|
|
149
|
+
condition_columns=self.data_loader.condition_columns,
|
|
150
|
+
metadata={
|
|
151
|
+
"real_path": str(self.data_loader.real_path),
|
|
152
|
+
"generated_path": str(self.data_loader.generated_path),
|
|
153
|
+
"aggregate_method": self.aggregate_method,
|
|
154
|
+
"metric_names": [m.name for m in self.metrics],
|
|
155
|
+
}
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Evaluate each split
|
|
159
|
+
for split in splits:
|
|
160
|
+
split_key = split if split != "all" else None
|
|
161
|
+
split_result = self._evaluate_split(split, split_key)
|
|
162
|
+
result.add_split(split_result)
|
|
163
|
+
|
|
164
|
+
# Compute aggregate metrics
|
|
165
|
+
for split_result in result.splits.values():
|
|
166
|
+
split_result.compute_aggregates()
|
|
167
|
+
|
|
168
|
+
# Print summary
|
|
169
|
+
if self.verbose:
|
|
170
|
+
self._print_summary(result)
|
|
171
|
+
|
|
172
|
+
# Save if requested
|
|
173
|
+
if save_dir is not None:
|
|
174
|
+
result.save(save_dir)
|
|
175
|
+
self._log(f"Results saved to: {save_dir}")
|
|
176
|
+
|
|
177
|
+
return result
|
|
178
|
+
|
|
179
|
+
def _evaluate_split(
|
|
180
|
+
self,
|
|
181
|
+
split_name: str,
|
|
182
|
+
split_filter: Optional[str]
|
|
183
|
+
) -> SplitResult:
|
|
184
|
+
"""Evaluate a single split."""
|
|
185
|
+
split_result = SplitResult(split_name=split_name)
|
|
186
|
+
|
|
187
|
+
conditions = list(self.data_loader.iterate_conditions(split_filter))
|
|
188
|
+
self._log(f"\n Split '{split_name}': {len(conditions)} conditions")
|
|
189
|
+
|
|
190
|
+
for i, (cond_key, real_data, gen_data, cond_info) in enumerate(conditions):
|
|
191
|
+
if self.verbose and (i + 1) % 10 == 0:
|
|
192
|
+
self._log(f" Processing condition {i + 1}/{len(conditions)}")
|
|
193
|
+
|
|
194
|
+
# Create condition result
|
|
195
|
+
cond_result = ConditionResult(
|
|
196
|
+
condition_key=cond_key,
|
|
197
|
+
split=split_name,
|
|
198
|
+
n_real_samples=real_data.shape[0],
|
|
199
|
+
n_generated_samples=gen_data.shape[0],
|
|
200
|
+
n_genes=real_data.shape[1],
|
|
201
|
+
gene_names=self.data_loader.gene_names,
|
|
202
|
+
perturbation=cond_info.get(self.data_loader.condition_columns[0]),
|
|
203
|
+
covariates=cond_info,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# Store mean profiles
|
|
207
|
+
cond_result.real_mean = real_data.mean(axis=0)
|
|
208
|
+
cond_result.generated_mean = gen_data.mean(axis=0)
|
|
209
|
+
|
|
210
|
+
# Compute all metrics
|
|
211
|
+
for metric in self.metrics:
|
|
212
|
+
try:
|
|
213
|
+
metric_result = metric.compute(
|
|
214
|
+
real=real_data,
|
|
215
|
+
generated=gen_data,
|
|
216
|
+
gene_names=self.data_loader.gene_names,
|
|
217
|
+
aggregate_method=self.aggregate_method,
|
|
218
|
+
condition=cond_key,
|
|
219
|
+
split=split_name,
|
|
220
|
+
)
|
|
221
|
+
cond_result.add_metric(metric.name, metric_result)
|
|
222
|
+
except Exception as e:
|
|
223
|
+
warnings.warn(
|
|
224
|
+
f"Failed to compute {metric.name} for {cond_key}: {e}"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
split_result.add_condition(cond_result)
|
|
228
|
+
|
|
229
|
+
return split_result
|
|
230
|
+
|
|
231
|
+
def _print_summary(self, result: EvaluationResult):
|
|
232
|
+
"""Print summary of results."""
|
|
233
|
+
self._log("\n" + "=" * 60)
|
|
234
|
+
self._log("EVALUATION SUMMARY")
|
|
235
|
+
self._log("=" * 60)
|
|
236
|
+
|
|
237
|
+
for split_name, split in result.splits.items():
|
|
238
|
+
self._log(f"\nSplit: {split_name} ({split.n_conditions} conditions)")
|
|
239
|
+
self._log("-" * 40)
|
|
240
|
+
|
|
241
|
+
# Print aggregate metrics
|
|
242
|
+
for key, value in sorted(split.aggregate_metrics.items()):
|
|
243
|
+
if key.endswith("_mean"):
|
|
244
|
+
metric_name = key[:-5]
|
|
245
|
+
std_key = f"{metric_name}_std"
|
|
246
|
+
std = split.aggregate_metrics.get(std_key, 0)
|
|
247
|
+
self._log(f" {metric_name}: {value:.4f} ± {std:.4f}")
|
|
248
|
+
|
|
249
|
+
self._log("=" * 60)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def evaluate(
|
|
253
|
+
real_path: Union[str, Path],
|
|
254
|
+
generated_path: Union[str, Path],
|
|
255
|
+
condition_columns: List[str],
|
|
256
|
+
split_column: Optional[str] = None,
|
|
257
|
+
output_dir: Optional[Union[str, Path]] = None,
|
|
258
|
+
metrics: Optional[List[Union[BaseMetric, Type[BaseMetric]]]] = None,
|
|
259
|
+
include_multivariate: bool = True,
|
|
260
|
+
verbose: bool = True,
|
|
261
|
+
**loader_kwargs
|
|
262
|
+
) -> EvaluationResult:
|
|
263
|
+
"""
|
|
264
|
+
Convenience function to run full evaluation.
|
|
265
|
+
|
|
266
|
+
Parameters
|
|
267
|
+
----------
|
|
268
|
+
real_path : str or Path
|
|
269
|
+
Path to real data h5ad file
|
|
270
|
+
generated_path : str or Path
|
|
271
|
+
Path to generated data h5ad file
|
|
272
|
+
condition_columns : List[str]
|
|
273
|
+
Columns to match between datasets
|
|
274
|
+
split_column : str, optional
|
|
275
|
+
Column indicating train/test split
|
|
276
|
+
output_dir : str or Path, optional
|
|
277
|
+
Directory to save results
|
|
278
|
+
metrics : List, optional
|
|
279
|
+
Metrics to compute
|
|
280
|
+
include_multivariate : bool
|
|
281
|
+
Whether to include multivariate metrics
|
|
282
|
+
verbose : bool
|
|
283
|
+
Print progress
|
|
284
|
+
**loader_kwargs
|
|
285
|
+
Additional arguments for data loader
|
|
286
|
+
|
|
287
|
+
Returns
|
|
288
|
+
-------
|
|
289
|
+
EvaluationResult
|
|
290
|
+
Complete evaluation results
|
|
291
|
+
|
|
292
|
+
Examples
|
|
293
|
+
--------
|
|
294
|
+
>>> results = evaluate(
|
|
295
|
+
... "real.h5ad",
|
|
296
|
+
... "generated.h5ad",
|
|
297
|
+
... condition_columns=["perturbation", "cell_type"],
|
|
298
|
+
... split_column="split",
|
|
299
|
+
... output_dir="evaluation_output/"
|
|
300
|
+
... )
|
|
301
|
+
"""
|
|
302
|
+
# Load data
|
|
303
|
+
loader = load_data(
|
|
304
|
+
real_path=real_path,
|
|
305
|
+
generated_path=generated_path,
|
|
306
|
+
condition_columns=condition_columns,
|
|
307
|
+
split_column=split_column,
|
|
308
|
+
**loader_kwargs
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
# Create evaluator
|
|
312
|
+
evaluator = GeneEvalEvaluator(
|
|
313
|
+
data_loader=loader,
|
|
314
|
+
metrics=metrics,
|
|
315
|
+
include_multivariate=include_multivariate,
|
|
316
|
+
verbose=verbose,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
# Run evaluation
|
|
320
|
+
return evaluator.evaluate(save_dir=output_dir)
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
class MetricRegistry:
|
|
324
|
+
"""
|
|
325
|
+
Registry of available metrics.
|
|
326
|
+
|
|
327
|
+
Allows registration of custom metrics and retrieval by name.
|
|
328
|
+
"""
|
|
329
|
+
|
|
330
|
+
_metrics: Dict[str, Type[BaseMetric]] = {}
|
|
331
|
+
|
|
332
|
+
@classmethod
|
|
333
|
+
def register(cls, metric_class: Type[BaseMetric]):
|
|
334
|
+
"""Register a metric class."""
|
|
335
|
+
instance = metric_class()
|
|
336
|
+
cls._metrics[instance.name] = metric_class
|
|
337
|
+
|
|
338
|
+
@classmethod
|
|
339
|
+
def get(cls, name: str) -> Optional[Type[BaseMetric]]:
|
|
340
|
+
"""Get metric class by name."""
|
|
341
|
+
return cls._metrics.get(name)
|
|
342
|
+
|
|
343
|
+
@classmethod
|
|
344
|
+
def list_all(cls) -> List[str]:
|
|
345
|
+
"""List all registered metric names."""
|
|
346
|
+
return list(cls._metrics.keys())
|
|
347
|
+
|
|
348
|
+
@classmethod
|
|
349
|
+
def get_all(cls) -> List[Type[BaseMetric]]:
|
|
350
|
+
"""Get all registered metric classes."""
|
|
351
|
+
return list(cls._metrics.values())
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
# Register default metrics
|
|
355
|
+
for metric_class in DEFAULT_METRICS:
|
|
356
|
+
MetricRegistry.register(metric_class)
|
|
357
|
+
|
|
358
|
+
MetricRegistry.register(MultivariateWasserstein)
|
|
359
|
+
MetricRegistry.register(MultivariateMMD)
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from typing import Dict, Iterable, List, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from anndata import AnnData
|
|
9
|
+
from scipy import sparse
|
|
10
|
+
|
|
11
|
+
from ..utils.preprocessing import to_dense
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BaseEvaluator(ABC):
|
|
15
|
+
"""
|
|
16
|
+
Base class for evaluation of generated data against real datasets.
|
|
17
|
+
|
|
18
|
+
Provides:
|
|
19
|
+
- Variable/gene alignment between real and generated AnnData objects
|
|
20
|
+
- Computation and application of control baselines per strata
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, data, output: AnnData):
|
|
24
|
+
"""
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
data : object
|
|
28
|
+
An object providing at least:
|
|
29
|
+
- gene_expression_dataset.adata: AnnData
|
|
30
|
+
- perturbation_key: str
|
|
31
|
+
- split_key: str
|
|
32
|
+
- control: str
|
|
33
|
+
- condition_keys: Optional[List[str]]
|
|
34
|
+
output : AnnData
|
|
35
|
+
Generated data to evaluate.
|
|
36
|
+
"""
|
|
37
|
+
self.data = data
|
|
38
|
+
self.output = output
|
|
39
|
+
|
|
40
|
+
# ---------- alignment utilities ----------
|
|
41
|
+
|
|
42
|
+
def _align_varnames_like(self, real: AnnData, generated: AnnData) -> Tuple[AnnData, AnnData]:
|
|
43
|
+
"""
|
|
44
|
+
Align real and generated AnnData to the common set of var_names (genes),
|
|
45
|
+
preserving order based on the real AnnData.
|
|
46
|
+
"""
|
|
47
|
+
real_genes = pd.Index(real.var_names.astype(str))
|
|
48
|
+
gen_genes = pd.Index(generated.var_names.astype(str))
|
|
49
|
+
common = real_genes.intersection(gen_genes)
|
|
50
|
+
if len(common) == 0:
|
|
51
|
+
raise ValueError("No overlapping genes between real and generated AnnData.")
|
|
52
|
+
|
|
53
|
+
# Reindex both adatas to the common genes in the order of real
|
|
54
|
+
real = real[:, real_genes.get_indexer(common)].copy()
|
|
55
|
+
generated = generated[:, generated.var_names.astype(str).isin(common)].copy()
|
|
56
|
+
# Reorder generated to match real
|
|
57
|
+
generated = generated[:, pd.Index(generated.var_names.astype(str)).get_indexer(common)].copy()
|
|
58
|
+
|
|
59
|
+
real.var_names = common
|
|
60
|
+
generated.var_names = common
|
|
61
|
+
return real, generated
|
|
62
|
+
|
|
63
|
+
# ---------- baseline utilities ----------
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def _key_from_values(values: Iterable[object]) -> str:
|
|
67
|
+
# stable string key for strata-tuples
|
|
68
|
+
return "####".join([str(v) for v in values])
|
|
69
|
+
|
|
70
|
+
def _compute_control_means(
|
|
71
|
+
self,
|
|
72
|
+
adata: AnnData,
|
|
73
|
+
perturbation_col: str,
|
|
74
|
+
control_value: str,
|
|
75
|
+
strata_cols: Optional[List[str]] = None,
|
|
76
|
+
) -> Dict[str, np.ndarray]:
|
|
77
|
+
"""
|
|
78
|
+
Compute per-strata control means across genes.
|
|
79
|
+
|
|
80
|
+
Returns a dict mapping a strata-key -> mean vector (n_genes,).
|
|
81
|
+
"""
|
|
82
|
+
strata_cols = strata_cols or []
|
|
83
|
+
obs = adata.obs
|
|
84
|
+
|
|
85
|
+
if perturbation_col not in obs.columns:
|
|
86
|
+
raise KeyError(f"'{perturbation_col}' not found in adata.obs.")
|
|
87
|
+
|
|
88
|
+
is_control = (obs[perturbation_col].astype(str) == str(control_value)).to_numpy()
|
|
89
|
+
if not is_control.any():
|
|
90
|
+
# no controls; return empty means map
|
|
91
|
+
return {}
|
|
92
|
+
|
|
93
|
+
ctrl = adata[is_control]
|
|
94
|
+
if not strata_cols:
|
|
95
|
+
return {self._key_from_values([]): to_dense(ctrl.X).mean(axis=0)}
|
|
96
|
+
|
|
97
|
+
# group by strata columns (as strings to be robust)
|
|
98
|
+
df = ctrl.obs[strata_cols].astype(str)
|
|
99
|
+
means: Dict[str, np.ndarray] = {}
|
|
100
|
+
# compute mean per unique strata combination
|
|
101
|
+
for _, row in df.drop_duplicates().iterrows():
|
|
102
|
+
mask = np.ones(ctrl.n_obs, dtype=bool)
|
|
103
|
+
for c in strata_cols:
|
|
104
|
+
mask &= (df[c].to_numpy() == str(row[c]))
|
|
105
|
+
if not mask.any():
|
|
106
|
+
continue
|
|
107
|
+
key = self._key_from_values([row[c] for c in strata_cols])
|
|
108
|
+
means[key] = to_dense(ctrl.X[mask]).mean(axis=0)
|
|
109
|
+
return means
|
|
110
|
+
|
|
111
|
+
def _apply_baseline_per_strata(
|
|
112
|
+
self,
|
|
113
|
+
X,
|
|
114
|
+
obs: pd.DataFrame,
|
|
115
|
+
baseline: Dict[str, np.ndarray],
|
|
116
|
+
strata_cols: Optional[List[str]] = None,
|
|
117
|
+
mode: str = "subtract",
|
|
118
|
+
):
|
|
119
|
+
"""
|
|
120
|
+
Apply per-strata baseline vectors to rows in X based on obs[strata_cols].
|
|
121
|
+
|
|
122
|
+
mode: 'subtract' or 'add'
|
|
123
|
+
"""
|
|
124
|
+
strata_cols = strata_cols or []
|
|
125
|
+
if mode not in ("subtract", "add"):
|
|
126
|
+
raise ValueError("mode must be 'subtract' or 'add'.")
|
|
127
|
+
|
|
128
|
+
if sparse.issparse(X):
|
|
129
|
+
X = X.tocsr(copy=True)
|
|
130
|
+
to_dense_first = False
|
|
131
|
+
else:
|
|
132
|
+
X = np.array(X, copy=True)
|
|
133
|
+
to_dense_first = True # already dense
|
|
134
|
+
|
|
135
|
+
if not strata_cols:
|
|
136
|
+
key = self._key_from_values([])
|
|
137
|
+
b = baseline.get(key, None)
|
|
138
|
+
if b is None:
|
|
139
|
+
return X
|
|
140
|
+
if sparse.issparse(X):
|
|
141
|
+
# operate dense for simplicity
|
|
142
|
+
X = X.toarray()
|
|
143
|
+
if mode == "subtract":
|
|
144
|
+
X -= b
|
|
145
|
+
else:
|
|
146
|
+
X += b
|
|
147
|
+
return X
|
|
148
|
+
|
|
149
|
+
# Apply per group
|
|
150
|
+
df = obs[strata_cols].astype(str)
|
|
151
|
+
# iterate groups in baseline for efficiency
|
|
152
|
+
for key, b in baseline.items():
|
|
153
|
+
# decode key into tuple of values
|
|
154
|
+
parts = key.split("####") if key else []
|
|
155
|
+
if len(parts) != len(strata_cols):
|
|
156
|
+
# skip mismatched key
|
|
157
|
+
continue
|
|
158
|
+
mask = np.ones(df.shape[0], dtype=bool)
|
|
159
|
+
for col, val in zip(strata_cols, parts):
|
|
160
|
+
mask &= (df[col].to_numpy() == val)
|
|
161
|
+
if not mask.any():
|
|
162
|
+
continue
|
|
163
|
+
|
|
164
|
+
if sparse.issparse(X):
|
|
165
|
+
# operate in dense block then write back
|
|
166
|
+
block = X[mask].toarray()
|
|
167
|
+
if mode == "subtract":
|
|
168
|
+
block -= b
|
|
169
|
+
else:
|
|
170
|
+
block += b
|
|
171
|
+
X[mask] = sparse.csr_matrix(block)
|
|
172
|
+
else:
|
|
173
|
+
if mode == "subtract":
|
|
174
|
+
X[mask] -= b
|
|
175
|
+
else:
|
|
176
|
+
X[mask] += b
|
|
177
|
+
|
|
178
|
+
return X
|