gengeneeval 0.1.1__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geneval/__init__.py +45 -1
- geneval/data/__init__.py +14 -0
- geneval/data/lazy_loader.py +562 -0
- geneval/evaluator.py +4 -0
- geneval/lazy_evaluator.py +424 -0
- geneval/metrics/__init__.py +19 -0
- geneval/metrics/reconstruction.py +243 -0
- {gengeneeval-0.1.1.dist-info → gengeneeval-0.2.1.dist-info}/METADATA +42 -5
- {gengeneeval-0.1.1.dist-info → gengeneeval-0.2.1.dist-info}/RECORD +12 -9
- {gengeneeval-0.1.1.dist-info → gengeneeval-0.2.1.dist-info}/WHEEL +1 -1
- {gengeneeval-0.1.1.dist-info → gengeneeval-0.2.1.dist-info}/entry_points.txt +0 -0
- {gengeneeval-0.1.1.dist-info → gengeneeval-0.2.1.dist-info}/licenses/LICENSE +0 -0
geneval/evaluator.py
CHANGED
|
@@ -27,11 +27,15 @@ from .metrics.distances import (
|
|
|
27
27
|
MultivariateWasserstein,
|
|
28
28
|
MultivariateMMD,
|
|
29
29
|
)
|
|
30
|
+
from .metrics.reconstruction import (
|
|
31
|
+
MSEDistance,
|
|
32
|
+
)
|
|
30
33
|
from .results import EvaluationResult, SplitResult, ConditionResult
|
|
31
34
|
|
|
32
35
|
|
|
33
36
|
# Default metrics to compute
|
|
34
37
|
DEFAULT_METRICS = [
|
|
38
|
+
MSEDistance,
|
|
35
39
|
PearsonCorrelation,
|
|
36
40
|
SpearmanCorrelation,
|
|
37
41
|
MeanPearsonCorrelation,
|
|
@@ -0,0 +1,424 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Memory-efficient evaluator for large-scale gene expression datasets.
|
|
3
|
+
|
|
4
|
+
Uses lazy loading and batched processing to minimize memory footprint.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import Dict, List, Optional, Union, Type, Any, Generator
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
import numpy as np
|
|
11
|
+
import warnings
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
import gc
|
|
14
|
+
|
|
15
|
+
from .data.lazy_loader import (
|
|
16
|
+
LazyGeneExpressionDataLoader,
|
|
17
|
+
load_data_lazy,
|
|
18
|
+
ConditionBatch,
|
|
19
|
+
)
|
|
20
|
+
from .metrics.base_metric import BaseMetric, MetricResult
|
|
21
|
+
from .metrics.correlation import (
|
|
22
|
+
PearsonCorrelation,
|
|
23
|
+
SpearmanCorrelation,
|
|
24
|
+
MeanPearsonCorrelation,
|
|
25
|
+
MeanSpearmanCorrelation,
|
|
26
|
+
)
|
|
27
|
+
from .metrics.distances import (
|
|
28
|
+
Wasserstein1Distance,
|
|
29
|
+
Wasserstein2Distance,
|
|
30
|
+
MMDDistance,
|
|
31
|
+
EnergyDistance,
|
|
32
|
+
)
|
|
33
|
+
from .metrics.reconstruction import (
|
|
34
|
+
MSEDistance,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# These multivariate metrics don't support batched computation
|
|
38
|
+
from .metrics.distances import MultivariateWasserstein, MultivariateMMD
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# Metrics that support incremental/batched computation
|
|
42
|
+
BATCHABLE_METRICS = [
|
|
43
|
+
MSEDistance,
|
|
44
|
+
PearsonCorrelation,
|
|
45
|
+
SpearmanCorrelation,
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
# Metrics that require full data
|
|
49
|
+
NON_BATCHABLE_METRICS = [
|
|
50
|
+
Wasserstein1Distance,
|
|
51
|
+
Wasserstein2Distance,
|
|
52
|
+
MMDDistance,
|
|
53
|
+
EnergyDistance,
|
|
54
|
+
MultivariateWasserstein,
|
|
55
|
+
MultivariateMMD,
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class StreamingMetricAccumulator:
|
|
61
|
+
"""Accumulates values for streaming mean/std computation."""
|
|
62
|
+
n: int = 0
|
|
63
|
+
sum: float = 0.0
|
|
64
|
+
sum_sq: float = 0.0
|
|
65
|
+
|
|
66
|
+
def add(self, value: float, count: int = 1):
|
|
67
|
+
"""Add a value (or batch of values with same value)."""
|
|
68
|
+
self.n += count
|
|
69
|
+
self.sum += value * count
|
|
70
|
+
self.sum_sq += (value ** 2) * count
|
|
71
|
+
|
|
72
|
+
def add_batch(self, values: np.ndarray):
|
|
73
|
+
"""Add multiple values."""
|
|
74
|
+
self.n += len(values)
|
|
75
|
+
self.sum += np.sum(values)
|
|
76
|
+
self.sum_sq += np.sum(values ** 2)
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def mean(self) -> float:
|
|
80
|
+
return self.sum / self.n if self.n > 0 else 0.0
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def std(self) -> float:
|
|
84
|
+
if self.n <= 1:
|
|
85
|
+
return 0.0
|
|
86
|
+
variance = (self.sum_sq / self.n) - (self.mean ** 2)
|
|
87
|
+
return np.sqrt(max(0, variance))
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class StreamingConditionResult:
|
|
92
|
+
"""Lightweight result for a single condition."""
|
|
93
|
+
condition_key: str
|
|
94
|
+
n_real_samples: int = 0
|
|
95
|
+
n_generated_samples: int = 0
|
|
96
|
+
metrics: Dict[str, float] = field(default_factory=dict)
|
|
97
|
+
real_mean: Optional[np.ndarray] = None
|
|
98
|
+
generated_mean: Optional[np.ndarray] = None
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@dataclass
|
|
102
|
+
class StreamingEvaluationResult:
|
|
103
|
+
"""Memory-efficient evaluation result that streams to disk."""
|
|
104
|
+
output_dir: Path
|
|
105
|
+
n_conditions: int = 0
|
|
106
|
+
metric_accumulators: Dict[str, StreamingMetricAccumulator] = field(default_factory=dict)
|
|
107
|
+
condition_keys: List[str] = field(default_factory=list)
|
|
108
|
+
|
|
109
|
+
def add_condition(self, result: StreamingConditionResult):
|
|
110
|
+
"""Add a condition result and update accumulators."""
|
|
111
|
+
self.n_conditions += 1
|
|
112
|
+
self.condition_keys.append(result.condition_key)
|
|
113
|
+
|
|
114
|
+
for metric_name, value in result.metrics.items():
|
|
115
|
+
if metric_name not in self.metric_accumulators:
|
|
116
|
+
self.metric_accumulators[metric_name] = StreamingMetricAccumulator()
|
|
117
|
+
self.metric_accumulators[metric_name].add(value)
|
|
118
|
+
|
|
119
|
+
def get_summary(self) -> Dict[str, Dict[str, float]]:
|
|
120
|
+
"""Get summary statistics."""
|
|
121
|
+
summary = {}
|
|
122
|
+
for name, acc in self.metric_accumulators.items():
|
|
123
|
+
summary[name] = {
|
|
124
|
+
"mean": acc.mean,
|
|
125
|
+
"std": acc.std,
|
|
126
|
+
"n": acc.n,
|
|
127
|
+
}
|
|
128
|
+
return summary
|
|
129
|
+
|
|
130
|
+
def save_summary(self):
|
|
131
|
+
"""Save summary to output directory."""
|
|
132
|
+
import json
|
|
133
|
+
|
|
134
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
135
|
+
|
|
136
|
+
summary = {
|
|
137
|
+
"n_conditions": self.n_conditions,
|
|
138
|
+
"metrics": self.get_summary(),
|
|
139
|
+
"condition_keys": self.condition_keys,
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
with open(self.output_dir / "summary.json", "w") as f:
|
|
143
|
+
json.dump(summary, f, indent=2)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class MemoryEfficientEvaluator:
|
|
147
|
+
"""
|
|
148
|
+
Memory-efficient evaluator using lazy loading and batched processing.
|
|
149
|
+
|
|
150
|
+
Features:
|
|
151
|
+
- Lazy data loading (one condition at a time)
|
|
152
|
+
- Batched processing within conditions
|
|
153
|
+
- Streaming metric accumulation
|
|
154
|
+
- Periodic garbage collection
|
|
155
|
+
- Progress streaming to disk
|
|
156
|
+
|
|
157
|
+
Parameters
|
|
158
|
+
----------
|
|
159
|
+
data_loader : LazyGeneExpressionDataLoader
|
|
160
|
+
Lazy data loader
|
|
161
|
+
metrics : List[BaseMetric], optional
|
|
162
|
+
Metrics to compute. Note: Some metrics (like MMD) may not support
|
|
163
|
+
batched computation and will use full condition data.
|
|
164
|
+
batch_size : int
|
|
165
|
+
Batch size for within-condition processing
|
|
166
|
+
gc_every_n_conditions : int
|
|
167
|
+
Run garbage collection every N conditions
|
|
168
|
+
verbose : bool
|
|
169
|
+
Print progress
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
def __init__(
|
|
173
|
+
self,
|
|
174
|
+
data_loader: LazyGeneExpressionDataLoader,
|
|
175
|
+
metrics: Optional[List[Union[BaseMetric, Type[BaseMetric]]]] = None,
|
|
176
|
+
batch_size: int = 256,
|
|
177
|
+
gc_every_n_conditions: int = 10,
|
|
178
|
+
verbose: bool = True,
|
|
179
|
+
):
|
|
180
|
+
self.data_loader = data_loader
|
|
181
|
+
self.batch_size = batch_size
|
|
182
|
+
self.gc_every_n_conditions = gc_every_n_conditions
|
|
183
|
+
self.verbose = verbose
|
|
184
|
+
|
|
185
|
+
# Initialize metrics
|
|
186
|
+
self.metrics: List[BaseMetric] = []
|
|
187
|
+
metric_classes = metrics or [
|
|
188
|
+
MSEDistance,
|
|
189
|
+
PearsonCorrelation,
|
|
190
|
+
SpearmanCorrelation,
|
|
191
|
+
MeanPearsonCorrelation,
|
|
192
|
+
MeanSpearmanCorrelation,
|
|
193
|
+
]
|
|
194
|
+
|
|
195
|
+
for m in metric_classes:
|
|
196
|
+
if isinstance(m, type):
|
|
197
|
+
self.metrics.append(m())
|
|
198
|
+
else:
|
|
199
|
+
self.metrics.append(m)
|
|
200
|
+
|
|
201
|
+
def _log(self, msg: str):
|
|
202
|
+
if self.verbose:
|
|
203
|
+
print(msg)
|
|
204
|
+
|
|
205
|
+
def evaluate(
|
|
206
|
+
self,
|
|
207
|
+
split: Optional[str] = None,
|
|
208
|
+
output_dir: Optional[Union[str, Path]] = None,
|
|
209
|
+
save_per_condition: bool = False,
|
|
210
|
+
) -> StreamingEvaluationResult:
|
|
211
|
+
"""
|
|
212
|
+
Run memory-efficient evaluation.
|
|
213
|
+
|
|
214
|
+
Parameters
|
|
215
|
+
----------
|
|
216
|
+
split : str, optional
|
|
217
|
+
Split to evaluate
|
|
218
|
+
output_dir : str or Path, optional
|
|
219
|
+
Directory to save results. If provided, results are streamed to disk.
|
|
220
|
+
save_per_condition : bool
|
|
221
|
+
If True, save individual condition results to disk
|
|
222
|
+
|
|
223
|
+
Returns
|
|
224
|
+
-------
|
|
225
|
+
StreamingEvaluationResult
|
|
226
|
+
Evaluation result with aggregated metrics
|
|
227
|
+
"""
|
|
228
|
+
if output_dir is not None:
|
|
229
|
+
output_dir = Path(output_dir)
|
|
230
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
231
|
+
else:
|
|
232
|
+
output_dir = Path(".")
|
|
233
|
+
|
|
234
|
+
result = StreamingEvaluationResult(output_dir=output_dir)
|
|
235
|
+
|
|
236
|
+
# Get conditions
|
|
237
|
+
conditions = self.data_loader.get_common_conditions(split)
|
|
238
|
+
self._log(f"Evaluating {len(conditions)} conditions")
|
|
239
|
+
self._log(f"Memory estimate: {self.data_loader.estimate_memory_usage()}")
|
|
240
|
+
|
|
241
|
+
# Iterate conditions (one at a time in memory)
|
|
242
|
+
for i, (cond_key, real_data, gen_data, cond_info) in enumerate(
|
|
243
|
+
self.data_loader.iterate_conditions(split)
|
|
244
|
+
):
|
|
245
|
+
if self.verbose and (i + 1) % 10 == 0:
|
|
246
|
+
self._log(f" Processing {i + 1}/{len(conditions)}: {cond_key}")
|
|
247
|
+
|
|
248
|
+
# Compute metrics for this condition
|
|
249
|
+
cond_result = self._evaluate_condition(
|
|
250
|
+
cond_key, real_data, gen_data, cond_info
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
# Add to streaming result
|
|
254
|
+
result.add_condition(cond_result)
|
|
255
|
+
|
|
256
|
+
# Optionally save per-condition result
|
|
257
|
+
if save_per_condition and output_dir:
|
|
258
|
+
self._save_condition_result(cond_result, output_dir)
|
|
259
|
+
|
|
260
|
+
# Periodic garbage collection
|
|
261
|
+
if (i + 1) % self.gc_every_n_conditions == 0:
|
|
262
|
+
gc.collect()
|
|
263
|
+
|
|
264
|
+
# Final summary
|
|
265
|
+
result.save_summary()
|
|
266
|
+
|
|
267
|
+
if self.verbose:
|
|
268
|
+
self._print_summary(result)
|
|
269
|
+
|
|
270
|
+
return result
|
|
271
|
+
|
|
272
|
+
def _evaluate_condition(
|
|
273
|
+
self,
|
|
274
|
+
cond_key: str,
|
|
275
|
+
real_data: np.ndarray,
|
|
276
|
+
gen_data: np.ndarray,
|
|
277
|
+
cond_info: Dict[str, str],
|
|
278
|
+
) -> StreamingConditionResult:
|
|
279
|
+
"""Evaluate a single condition."""
|
|
280
|
+
result = StreamingConditionResult(
|
|
281
|
+
condition_key=cond_key,
|
|
282
|
+
n_real_samples=real_data.shape[0],
|
|
283
|
+
n_generated_samples=gen_data.shape[0],
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
# Compute means
|
|
287
|
+
result.real_mean = real_data.mean(axis=0)
|
|
288
|
+
result.generated_mean = gen_data.mean(axis=0)
|
|
289
|
+
|
|
290
|
+
# Compute metrics
|
|
291
|
+
for metric in self.metrics:
|
|
292
|
+
try:
|
|
293
|
+
metric_result = metric.compute(
|
|
294
|
+
real=real_data,
|
|
295
|
+
generated=gen_data,
|
|
296
|
+
gene_names=self.data_loader.gene_names,
|
|
297
|
+
aggregate_method="mean",
|
|
298
|
+
condition=cond_key,
|
|
299
|
+
)
|
|
300
|
+
result.metrics[metric.name] = metric_result.aggregate_value
|
|
301
|
+
except Exception as e:
|
|
302
|
+
warnings.warn(f"Failed to compute {metric.name} for {cond_key}: {e}")
|
|
303
|
+
|
|
304
|
+
return result
|
|
305
|
+
|
|
306
|
+
def _save_condition_result(
|
|
307
|
+
self,
|
|
308
|
+
result: StreamingConditionResult,
|
|
309
|
+
output_dir: Path,
|
|
310
|
+
):
|
|
311
|
+
"""Save a single condition result to disk."""
|
|
312
|
+
import json
|
|
313
|
+
|
|
314
|
+
condition_dir = output_dir / "conditions"
|
|
315
|
+
condition_dir.mkdir(exist_ok=True)
|
|
316
|
+
|
|
317
|
+
# Safe filename
|
|
318
|
+
safe_key = result.condition_key.replace("/", "_").replace("\\", "_")
|
|
319
|
+
|
|
320
|
+
data = {
|
|
321
|
+
"condition_key": result.condition_key,
|
|
322
|
+
"n_real": result.n_real_samples,
|
|
323
|
+
"n_generated": result.n_generated_samples,
|
|
324
|
+
"metrics": result.metrics,
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
with open(condition_dir / f"{safe_key}.json", "w") as f:
|
|
328
|
+
json.dump(data, f, indent=2)
|
|
329
|
+
|
|
330
|
+
def _print_summary(self, result: StreamingEvaluationResult):
|
|
331
|
+
"""Print summary."""
|
|
332
|
+
self._log("\n" + "=" * 60)
|
|
333
|
+
self._log("EVALUATION SUMMARY (Memory-Efficient)")
|
|
334
|
+
self._log("=" * 60)
|
|
335
|
+
self._log(f"Conditions evaluated: {result.n_conditions}")
|
|
336
|
+
self._log("-" * 40)
|
|
337
|
+
|
|
338
|
+
for name, stats in result.get_summary().items():
|
|
339
|
+
self._log(f" {name}: {stats['mean']:.4f} ± {stats['std']:.4f}")
|
|
340
|
+
|
|
341
|
+
self._log("=" * 60)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def evaluate_lazy(
|
|
345
|
+
real_path: Union[str, Path],
|
|
346
|
+
generated_path: Union[str, Path],
|
|
347
|
+
condition_columns: List[str],
|
|
348
|
+
split_column: Optional[str] = None,
|
|
349
|
+
output_dir: Optional[Union[str, Path]] = None,
|
|
350
|
+
batch_size: int = 256,
|
|
351
|
+
use_backed: bool = False,
|
|
352
|
+
metrics: Optional[List[Union[BaseMetric, Type[BaseMetric]]]] = None,
|
|
353
|
+
verbose: bool = True,
|
|
354
|
+
save_per_condition: bool = False,
|
|
355
|
+
**kwargs
|
|
356
|
+
) -> StreamingEvaluationResult:
|
|
357
|
+
"""
|
|
358
|
+
Memory-efficient evaluation using lazy loading.
|
|
359
|
+
|
|
360
|
+
Use this function for large datasets that don't fit in memory.
|
|
361
|
+
|
|
362
|
+
Parameters
|
|
363
|
+
----------
|
|
364
|
+
real_path : str or Path
|
|
365
|
+
Path to real data h5ad file
|
|
366
|
+
generated_path : str or Path
|
|
367
|
+
Path to generated data h5ad file
|
|
368
|
+
condition_columns : List[str]
|
|
369
|
+
Columns to match between datasets
|
|
370
|
+
split_column : str, optional
|
|
371
|
+
Column for train/test split
|
|
372
|
+
output_dir : str or Path, optional
|
|
373
|
+
Directory to save results
|
|
374
|
+
batch_size : int
|
|
375
|
+
Batch size for processing
|
|
376
|
+
use_backed : bool
|
|
377
|
+
Use memory-mapped file access (for very large files)
|
|
378
|
+
metrics : List, optional
|
|
379
|
+
Metrics to compute
|
|
380
|
+
verbose : bool
|
|
381
|
+
Print progress
|
|
382
|
+
save_per_condition : bool
|
|
383
|
+
Save individual condition results
|
|
384
|
+
|
|
385
|
+
Returns
|
|
386
|
+
-------
|
|
387
|
+
StreamingEvaluationResult
|
|
388
|
+
Aggregated evaluation results
|
|
389
|
+
|
|
390
|
+
Examples
|
|
391
|
+
--------
|
|
392
|
+
>>> # For large datasets that don't fit in memory
|
|
393
|
+
>>> results = evaluate_lazy(
|
|
394
|
+
... "real.h5ad",
|
|
395
|
+
... "generated.h5ad",
|
|
396
|
+
... condition_columns=["perturbation"],
|
|
397
|
+
... output_dir="eval_output/",
|
|
398
|
+
... batch_size=256,
|
|
399
|
+
... use_backed=True, # Memory-mapped for very large files
|
|
400
|
+
... )
|
|
401
|
+
>>> print(results.get_summary())
|
|
402
|
+
"""
|
|
403
|
+
# Create lazy loader
|
|
404
|
+
with load_data_lazy(
|
|
405
|
+
real_path=real_path,
|
|
406
|
+
generated_path=generated_path,
|
|
407
|
+
condition_columns=condition_columns,
|
|
408
|
+
split_column=split_column,
|
|
409
|
+
batch_size=batch_size,
|
|
410
|
+
use_backed=use_backed,
|
|
411
|
+
) as loader:
|
|
412
|
+
# Create evaluator
|
|
413
|
+
evaluator = MemoryEfficientEvaluator(
|
|
414
|
+
data_loader=loader,
|
|
415
|
+
metrics=metrics,
|
|
416
|
+
batch_size=batch_size,
|
|
417
|
+
verbose=verbose,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
# Run evaluation
|
|
421
|
+
return evaluator.evaluate(
|
|
422
|
+
output_dir=output_dir,
|
|
423
|
+
save_per_condition=save_per_condition,
|
|
424
|
+
)
|
geneval/metrics/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
Metrics module for gene expression evaluation.
|
|
3
3
|
|
|
4
4
|
Provides per-gene and aggregate metrics for comparing distributions:
|
|
5
|
+
- Reconstruction metrics (MSE, RMSE, MAE, R²)
|
|
5
6
|
- Correlation metrics (Pearson, Spearman)
|
|
6
7
|
- Distribution distances (Wasserstein, MMD, Energy)
|
|
7
8
|
- Multivariate distances
|
|
@@ -27,13 +28,26 @@ from .distances import (
|
|
|
27
28
|
MultivariateWasserstein,
|
|
28
29
|
MultivariateMMD,
|
|
29
30
|
)
|
|
31
|
+
from .reconstruction import (
|
|
32
|
+
MSEDistance,
|
|
33
|
+
RMSEDistance,
|
|
34
|
+
MAEDistance,
|
|
35
|
+
R2Score,
|
|
36
|
+
)
|
|
30
37
|
|
|
31
38
|
# All available metrics
|
|
32
39
|
ALL_METRICS = [
|
|
40
|
+
# Reconstruction
|
|
41
|
+
MSEDistance,
|
|
42
|
+
RMSEDistance,
|
|
43
|
+
MAEDistance,
|
|
44
|
+
R2Score,
|
|
45
|
+
# Correlation
|
|
33
46
|
PearsonCorrelation,
|
|
34
47
|
SpearmanCorrelation,
|
|
35
48
|
MeanPearsonCorrelation,
|
|
36
49
|
MeanSpearmanCorrelation,
|
|
50
|
+
# Distribution
|
|
37
51
|
Wasserstein1Distance,
|
|
38
52
|
Wasserstein2Distance,
|
|
39
53
|
MMDDistance,
|
|
@@ -48,6 +62,11 @@ __all__ = [
|
|
|
48
62
|
"MetricResult",
|
|
49
63
|
"DistributionMetric",
|
|
50
64
|
"CorrelationMetric",
|
|
65
|
+
# Reconstruction metrics
|
|
66
|
+
"MSEDistance",
|
|
67
|
+
"RMSEDistance",
|
|
68
|
+
"MAEDistance",
|
|
69
|
+
"R2Score",
|
|
51
70
|
# Correlation metrics
|
|
52
71
|
"PearsonCorrelation",
|
|
53
72
|
"SpearmanCorrelation",
|