gengeneeval 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geneval/__init__.py CHANGED
@@ -8,6 +8,7 @@ Features:
8
8
  - Multiple distance and correlation metrics (per-gene and aggregate)
9
9
  - Condition-based matching (perturbation, cell type, etc.)
10
10
  - Train/test split support
11
+ - Memory-efficient lazy loading for large datasets
11
12
  - Publication-quality visualizations
12
13
  - Command-line interface
13
14
 
@@ -20,12 +21,22 @@ Quick Start:
20
21
  ... output_dir="output/"
21
22
  ... )
22
23
 
24
+ Memory-Efficient Mode (for large datasets):
25
+ >>> from geneval import evaluate_lazy
26
+ >>> results = evaluate_lazy(
27
+ ... real_path="real.h5ad",
28
+ ... generated_path="generated.h5ad",
29
+ ... condition_columns=["perturbation"],
30
+ ... batch_size=256,
31
+ ... use_backed=True, # Memory-mapped access
32
+ ... )
33
+
23
34
  CLI Usage:
24
35
  $ geneval --real real.h5ad --generated generated.h5ad \\
25
36
  --conditions perturbation cell_type --output results/
26
37
  """
27
38
 
28
- __version__ = "0.2.0"
39
+ __version__ = "0.3.0"
29
40
  __author__ = "GenEval Team"
30
41
 
31
42
  # Main evaluation interface
@@ -35,12 +46,26 @@ from .evaluator import (
35
46
  MetricRegistry,
36
47
  )
37
48
 
49
+ # Memory-efficient evaluation
50
+ from .lazy_evaluator import (
51
+ evaluate_lazy,
52
+ MemoryEfficientEvaluator,
53
+ StreamingEvaluationResult,
54
+ )
55
+
38
56
  # Data loading
39
57
  from .data.loader import (
40
58
  GeneExpressionDataLoader,
41
59
  load_data,
42
60
  )
43
61
 
62
+ # Memory-efficient data loading
63
+ from .data.lazy_loader import (
64
+ LazyGeneExpressionDataLoader,
65
+ load_data_lazy,
66
+ ConditionBatch,
67
+ )
68
+
44
69
  # Results
45
70
  from .results import (
46
71
  EvaluationResult,
@@ -76,6 +101,14 @@ from .metrics.reconstruction import (
76
101
  R2Score,
77
102
  )
78
103
 
104
+ # Accelerated computation
105
+ from .metrics.accelerated import (
106
+ AccelerationConfig,
107
+ ParallelMetricComputer,
108
+ get_available_backends,
109
+ compute_metrics_accelerated,
110
+ )
111
+
79
112
  # Visualization
80
113
  from .visualization.visualizer import (
81
114
  EvaluationVisualizer,
@@ -99,9 +132,17 @@ __all__ = [
99
132
  "evaluate",
100
133
  "GeneEvalEvaluator",
101
134
  "MetricRegistry",
135
+ # Memory-efficient evaluation
136
+ "evaluate_lazy",
137
+ "MemoryEfficientEvaluator",
138
+ "StreamingEvaluationResult",
102
139
  # Data loading
103
140
  "GeneExpressionDataLoader",
104
141
  "load_data",
142
+ # Memory-efficient data loading
143
+ "LazyGeneExpressionDataLoader",
144
+ "load_data_lazy",
145
+ "ConditionBatch",
105
146
  # Results
106
147
  "EvaluationResult",
107
148
  "SplitResult",
@@ -123,6 +164,16 @@ __all__ = [
123
164
  "EnergyDistance",
124
165
  "MultivariateWasserstein",
125
166
  "MultivariateMMD",
167
+ # Reconstruction metrics
168
+ "MSEDistance",
169
+ "RMSEDistance",
170
+ "MAEDistance",
171
+ "R2Score",
172
+ # Acceleration
173
+ "AccelerationConfig",
174
+ "ParallelMetricComputer",
175
+ "get_available_backends",
176
+ "compute_metrics_accelerated",
126
177
  # Visualization
127
178
  "EvaluationVisualizer",
128
179
  "visualize",
geneval/data/__init__.py CHANGED
@@ -2,6 +2,7 @@
2
2
  Data loading module for gene expression evaluation.
3
3
 
4
4
  Provides data loaders for paired real and generated datasets.
5
+ Includes both standard and memory-efficient lazy loading options.
5
6
  """
6
7
 
7
8
  from .loader import (
@@ -9,15 +10,28 @@ from .loader import (
9
10
  load_data,
10
11
  DataLoaderError,
11
12
  )
13
+ from .lazy_loader import (
14
+ LazyGeneExpressionDataLoader,
15
+ load_data_lazy,
16
+ LazyDataLoaderError,
17
+ ConditionBatch,
18
+ )
12
19
  from .gene_expression_datamodule import (
13
20
  GeneExpressionDataModule,
14
21
  DataModuleError,
15
22
  )
16
23
 
17
24
  __all__ = [
25
+ # Standard loader
18
26
  "GeneExpressionDataLoader",
19
27
  "load_data",
20
28
  "DataLoaderError",
29
+ # Lazy loader (memory-efficient)
30
+ "LazyGeneExpressionDataLoader",
31
+ "load_data_lazy",
32
+ "LazyDataLoaderError",
33
+ "ConditionBatch",
34
+ # DataModule
21
35
  "GeneExpressionDataModule",
22
36
  "DataModuleError",
23
37
  ]
@@ -0,0 +1,562 @@
1
+ """
2
+ Memory-efficient lazy data loader for large-scale gene expression datasets.
3
+
4
+ Provides lazy loading and batched iteration over AnnData h5ad files without
5
+ loading entire datasets into memory. Supports backed mode for very large files.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ from typing import Optional, List, Union, Dict, Tuple, Iterator, Generator
10
+ from pathlib import Path
11
+ import warnings
12
+ import numpy as np
13
+ import pandas as pd
14
+ from scipy import sparse
15
+ from dataclasses import dataclass
16
+
17
+ try:
18
+ import anndata as ad
19
+ import scanpy as sc
20
+ except ImportError:
21
+ raise ImportError("anndata and scanpy are required. Install with: pip install anndata scanpy")
22
+
23
+
24
+ class LazyDataLoaderError(Exception):
25
+ """Custom exception for lazy data loading errors."""
26
+ pass
27
+
28
+
29
+ @dataclass
30
+ class ConditionBatch:
31
+ """Container for a batch of samples from a condition."""
32
+ condition_key: str
33
+ condition_info: Dict[str, str]
34
+ real_data: np.ndarray
35
+ generated_data: np.ndarray
36
+ batch_idx: int
37
+ n_batches: int
38
+ is_last_batch: bool
39
+
40
+
41
+ class LazyGeneExpressionDataLoader:
42
+ """
43
+ Memory-efficient lazy data loader for paired gene expression datasets.
44
+
45
+ Unlike GeneExpressionDataLoader, this class:
46
+ - Uses backed mode for h5ad files to avoid loading entire datasets
47
+ - Supports batched iteration over conditions
48
+ - Only loads data into memory when explicitly requested
49
+ - Provides memory usage estimates
50
+
51
+ Parameters
52
+ ----------
53
+ real_path : str or Path
54
+ Path to real data h5ad file
55
+ generated_path : str or Path
56
+ Path to generated data h5ad file
57
+ condition_columns : List[str]
58
+ Columns to match between datasets
59
+ split_column : str, optional
60
+ Column indicating train/test split
61
+ batch_size : int
62
+ Maximum number of samples per batch when iterating
63
+ use_backed : bool
64
+ If True, use backed mode (memory-mapped). May be slower but uses minimal memory.
65
+ If False, loads full file but processes in batches.
66
+ min_samples_per_condition : int
67
+ Minimum samples required per condition to include
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ real_path: Union[str, Path],
73
+ generated_path: Union[str, Path],
74
+ condition_columns: List[str],
75
+ split_column: Optional[str] = None,
76
+ batch_size: int = 256,
77
+ use_backed: bool = False,
78
+ min_samples_per_condition: int = 2,
79
+ ):
80
+ self.real_path = Path(real_path)
81
+ self.generated_path = Path(generated_path)
82
+ self.condition_columns = condition_columns
83
+ self.split_column = split_column
84
+ self.batch_size = batch_size
85
+ self.use_backed = use_backed
86
+ self.min_samples_per_condition = min_samples_per_condition
87
+
88
+ # Lazy-loaded references (backed or full)
89
+ self._real: Optional[ad.AnnData] = None
90
+ self._generated: Optional[ad.AnnData] = None
91
+
92
+ # Metadata (always loaded - lightweight)
93
+ self._real_obs: Optional[pd.DataFrame] = None
94
+ self._generated_obs: Optional[pd.DataFrame] = None
95
+ self._real_var_names: Optional[pd.Index] = None
96
+ self._generated_var_names: Optional[pd.Index] = None
97
+
98
+ # Gene alignment info
99
+ self._common_genes: Optional[List[str]] = None
100
+ self._real_gene_idx: Optional[np.ndarray] = None
101
+ self._gen_gene_idx: Optional[np.ndarray] = None
102
+
103
+ # Pre-computed condition indices for fast access
104
+ self._condition_indices: Optional[Dict[str, Dict[str, np.ndarray]]] = None
105
+
106
+ # State
107
+ self._is_initialized = False
108
+
109
+ def initialize(self) -> "LazyGeneExpressionDataLoader":
110
+ """
111
+ Initialize loader by reading metadata only (not expression data).
112
+
113
+ This loads obs DataFrames and var_names to prepare for iteration,
114
+ but does not load the expression matrices.
115
+
116
+ Returns
117
+ -------
118
+ self
119
+ For method chaining
120
+ """
121
+ if self._is_initialized:
122
+ return self
123
+
124
+ # Validate paths
125
+ if not self.real_path.exists():
126
+ raise LazyDataLoaderError(f"Real data file not found: {self.real_path}")
127
+ if not self.generated_path.exists():
128
+ raise LazyDataLoaderError(f"Generated data file not found: {self.generated_path}")
129
+
130
+ # Load metadata only (obs and var_names)
131
+ # This is much faster and lighter than loading full data
132
+ self._load_metadata()
133
+
134
+ # Validate columns
135
+ self._validate_columns()
136
+
137
+ # Compute gene alignment indices
138
+ self._compute_gene_alignment()
139
+
140
+ # Pre-compute condition indices
141
+ self._precompute_condition_indices()
142
+
143
+ self._is_initialized = True
144
+ return self
145
+
146
+ def _load_metadata(self):
147
+ """Load only metadata (obs, var_names) without expression data."""
148
+ # For backed mode, we open but don't load X
149
+ # For non-backed, we still only read metadata initially
150
+
151
+ # Try backed mode if requested
152
+ if self.use_backed:
153
+ try:
154
+ self._real = sc.read_h5ad(self.real_path, backed='r')
155
+ self._generated = sc.read_h5ad(self.generated_path, backed='r')
156
+ except Exception as e:
157
+ warnings.warn(f"Backed mode failed, falling back to standard loading: {e}")
158
+ self._real = None
159
+ self._generated = None
160
+
161
+ if self._real is None:
162
+ # Load only what we need: obs and var
163
+ # For very large files, read in low-memory mode
164
+ self._real = sc.read_h5ad(self.real_path)
165
+ self._generated = sc.read_h5ad(self.generated_path)
166
+
167
+ # Cache lightweight metadata
168
+ self._real_obs = self._real.obs.copy()
169
+ self._generated_obs = self._generated.obs.copy()
170
+ self._real_var_names = pd.Index(self._real.var_names.astype(str))
171
+ self._generated_var_names = pd.Index(self._generated.var_names.astype(str))
172
+
173
+ def _validate_columns(self):
174
+ """Validate that required columns exist."""
175
+ for col in self.condition_columns:
176
+ if col not in self._real_obs.columns:
177
+ raise LazyDataLoaderError(
178
+ f"Condition column '{col}' not found in real data. "
179
+ f"Available: {list(self._real_obs.columns)}"
180
+ )
181
+ if col not in self._generated_obs.columns:
182
+ raise LazyDataLoaderError(
183
+ f"Condition column '{col}' not found in generated data. "
184
+ f"Available: {list(self._generated_obs.columns)}"
185
+ )
186
+
187
+ def _compute_gene_alignment(self):
188
+ """Pre-compute gene alignment indices."""
189
+ common = self._real_var_names.intersection(self._generated_var_names)
190
+
191
+ if len(common) == 0:
192
+ raise LazyDataLoaderError(
193
+ "No overlapping genes between real and generated data."
194
+ )
195
+
196
+ self._common_genes = common.tolist()
197
+ self._real_gene_idx = self._real_var_names.get_indexer(common)
198
+ self._gen_gene_idx = self._generated_var_names.get_indexer(common)
199
+
200
+ n_real_only = len(self._real_var_names) - len(common)
201
+ n_gen_only = len(self._generated_var_names) - len(common)
202
+
203
+ if n_real_only > 0 or n_gen_only > 0:
204
+ warnings.warn(
205
+ f"Gene alignment: {len(common)} common genes. "
206
+ f"Dropped {n_real_only} from real, {n_gen_only} from generated."
207
+ )
208
+
209
+ def _get_condition_key(self, row: pd.Series) -> str:
210
+ """Generate unique key for a condition combination."""
211
+ return "####".join([str(row[c]) for c in self.condition_columns])
212
+
213
+ def _precompute_condition_indices(self):
214
+ """Pre-compute sample indices for each condition (lightweight)."""
215
+ self._condition_indices = {"real": {}, "generated": {}}
216
+
217
+ # Real data conditions
218
+ real_conditions = self._real_obs[self.condition_columns].astype(str).drop_duplicates()
219
+ for _, row in real_conditions.iterrows():
220
+ key = self._get_condition_key(row)
221
+ mask = np.ones(len(self._real_obs), dtype=bool)
222
+ for col in self.condition_columns:
223
+ mask &= (self._real_obs[col].astype(str) == str(row[col])).values
224
+
225
+ indices = np.where(mask)[0]
226
+ if len(indices) >= self.min_samples_per_condition:
227
+ self._condition_indices["real"][key] = indices
228
+
229
+ # Generated data conditions
230
+ gen_conditions = self._generated_obs[self.condition_columns].astype(str).drop_duplicates()
231
+ for _, row in gen_conditions.iterrows():
232
+ key = self._get_condition_key(row)
233
+ mask = np.ones(len(self._generated_obs), dtype=bool)
234
+ for col in self.condition_columns:
235
+ mask &= (self._generated_obs[col].astype(str) == str(row[col])).values
236
+
237
+ indices = np.where(mask)[0]
238
+ if len(indices) >= self.min_samples_per_condition:
239
+ self._condition_indices["generated"][key] = indices
240
+
241
+ def get_splits(self) -> List[str]:
242
+ """Get available splits."""
243
+ if not self._is_initialized:
244
+ self.initialize()
245
+
246
+ if self.split_column is None or self.split_column not in self._real_obs.columns:
247
+ return ["all"]
248
+
249
+ return list(self._real_obs[self.split_column].astype(str).unique())
250
+
251
+ def get_common_conditions(self, split: Optional[str] = None) -> List[str]:
252
+ """Get conditions present in both real and generated data."""
253
+ if not self._is_initialized:
254
+ self.initialize()
255
+
256
+ real_keys = set(self._condition_indices["real"].keys())
257
+ gen_keys = set(self._condition_indices["generated"].keys())
258
+
259
+ common = real_keys & gen_keys
260
+
261
+ # Filter by split if specified
262
+ if split is not None and split != "all" and self.split_column is not None:
263
+ filtered = set()
264
+ for key in common:
265
+ real_idx = self._condition_indices["real"][key]
266
+ split_vals = self._real_obs.iloc[real_idx][self.split_column].astype(str)
267
+ if (split_vals == split).any():
268
+ filtered.add(key)
269
+ common = filtered
270
+
271
+ return sorted(common)
272
+
273
+ def _extract_data_subset(
274
+ self,
275
+ adata: ad.AnnData,
276
+ indices: np.ndarray,
277
+ gene_idx: np.ndarray,
278
+ ) -> np.ndarray:
279
+ """Extract and align a subset of data."""
280
+ # Handle backed vs loaded data
281
+ if hasattr(adata, 'isbacked') and adata.isbacked:
282
+ # Backed mode: read only what we need
283
+ X = adata.X[indices][:, gene_idx]
284
+ else:
285
+ # Standard mode
286
+ X = adata.X[indices][:, gene_idx]
287
+
288
+ # Convert to dense if sparse
289
+ if sparse.issparse(X):
290
+ X = X.toarray()
291
+
292
+ return np.asarray(X, dtype=np.float32)
293
+
294
+ def iterate_conditions(
295
+ self,
296
+ split: Optional[str] = None,
297
+ ) -> Generator[Tuple[str, np.ndarray, np.ndarray, Dict[str, str]], None, None]:
298
+ """
299
+ Iterate over conditions, loading one condition at a time.
300
+
301
+ This loads data for one condition, yields it, then releases memory
302
+ before loading the next condition.
303
+
304
+ Parameters
305
+ ----------
306
+ split : str, optional
307
+ Filter to this split only
308
+
309
+ Yields
310
+ ------
311
+ Tuple[str, np.ndarray, np.ndarray, Dict[str, str]]
312
+ (condition_key, real_data, generated_data, condition_info)
313
+ """
314
+ if not self._is_initialized:
315
+ self.initialize()
316
+
317
+ common_conditions = self.get_common_conditions(split)
318
+
319
+ for key in common_conditions:
320
+ real_indices = self._condition_indices["real"][key]
321
+ gen_indices = self._condition_indices["generated"][key]
322
+
323
+ # Filter by split if needed
324
+ if split is not None and split != "all" and self.split_column is not None:
325
+ split_mask = self._real_obs.iloc[real_indices][self.split_column].astype(str) == split
326
+ real_indices = real_indices[split_mask.values]
327
+
328
+ if len(real_indices) < self.min_samples_per_condition:
329
+ continue
330
+
331
+ # Load data for this condition only
332
+ real_data = self._extract_data_subset(
333
+ self._real, real_indices, self._real_gene_idx
334
+ )
335
+ gen_data = self._extract_data_subset(
336
+ self._generated, gen_indices, self._gen_gene_idx
337
+ )
338
+
339
+ # Parse condition info
340
+ parts = key.split("####")
341
+ condition_info = dict(zip(self.condition_columns, parts))
342
+
343
+ yield key, real_data, gen_data, condition_info
344
+
345
+ def iterate_conditions_batched(
346
+ self,
347
+ split: Optional[str] = None,
348
+ batch_size: Optional[int] = None,
349
+ ) -> Generator[ConditionBatch, None, None]:
350
+ """
351
+ Iterate over conditions in batches for memory efficiency.
352
+
353
+ Useful when even a single condition is too large to fit in memory.
354
+
355
+ Parameters
356
+ ----------
357
+ split : str, optional
358
+ Filter to this split only
359
+ batch_size : int, optional
360
+ Override default batch size
361
+
362
+ Yields
363
+ ------
364
+ ConditionBatch
365
+ Batch of samples from a condition
366
+ """
367
+ if not self._is_initialized:
368
+ self.initialize()
369
+
370
+ batch_size = batch_size or self.batch_size
371
+ common_conditions = self.get_common_conditions(split)
372
+
373
+ for key in common_conditions:
374
+ real_indices = self._condition_indices["real"][key]
375
+ gen_indices = self._condition_indices["generated"][key]
376
+
377
+ # Filter by split
378
+ if split is not None and split != "all" and self.split_column is not None:
379
+ split_mask = self._real_obs.iloc[real_indices][self.split_column].astype(str) == split
380
+ real_indices = real_indices[split_mask.values]
381
+
382
+ if len(real_indices) < self.min_samples_per_condition:
383
+ continue
384
+
385
+ # Parse condition info
386
+ parts = key.split("####")
387
+ condition_info = dict(zip(self.condition_columns, parts))
388
+
389
+ # Calculate number of batches (use max of real/gen for alignment)
390
+ n_real = len(real_indices)
391
+ n_gen = len(gen_indices)
392
+ n_batches = max(
393
+ (n_real + batch_size - 1) // batch_size,
394
+ (n_gen + batch_size - 1) // batch_size
395
+ )
396
+
397
+ for batch_idx in range(n_batches):
398
+ start_real = batch_idx * batch_size
399
+ end_real = min(start_real + batch_size, n_real)
400
+ start_gen = batch_idx * batch_size
401
+ end_gen = min(start_gen + batch_size, n_gen)
402
+
403
+ # Handle case where one dataset is smaller
404
+ if start_real >= n_real:
405
+ # Wrap around for real data
406
+ batch_real_idx = real_indices[start_real % n_real:end_real % n_real + 1]
407
+ else:
408
+ batch_real_idx = real_indices[start_real:end_real]
409
+
410
+ if start_gen >= n_gen:
411
+ batch_gen_idx = gen_indices[start_gen % n_gen:end_gen % n_gen + 1]
412
+ else:
413
+ batch_gen_idx = gen_indices[start_gen:end_gen]
414
+
415
+ if len(batch_real_idx) == 0 or len(batch_gen_idx) == 0:
416
+ continue
417
+
418
+ real_data = self._extract_data_subset(
419
+ self._real, batch_real_idx, self._real_gene_idx
420
+ )
421
+ gen_data = self._extract_data_subset(
422
+ self._generated, batch_gen_idx, self._gen_gene_idx
423
+ )
424
+
425
+ yield ConditionBatch(
426
+ condition_key=key,
427
+ condition_info=condition_info,
428
+ real_data=real_data,
429
+ generated_data=gen_data,
430
+ batch_idx=batch_idx,
431
+ n_batches=n_batches,
432
+ is_last_batch=(batch_idx == n_batches - 1),
433
+ )
434
+
435
+ @property
436
+ def gene_names(self) -> List[str]:
437
+ """Get common gene names."""
438
+ if not self._is_initialized:
439
+ self.initialize()
440
+ return self._common_genes
441
+
442
+ @property
443
+ def n_genes(self) -> int:
444
+ """Number of common genes."""
445
+ return len(self.gene_names)
446
+
447
+ def estimate_memory_usage(self) -> Dict[str, float]:
448
+ """
449
+ Estimate memory usage in MB for different loading strategies.
450
+
451
+ Returns
452
+ -------
453
+ Dict[str, float]
454
+ Memory estimates in MB
455
+ """
456
+ if not self._is_initialized:
457
+ self.initialize()
458
+
459
+ n_real = len(self._real_obs)
460
+ n_gen = len(self._generated_obs)
461
+ n_genes = self.n_genes
462
+
463
+ # 4 bytes per float32
464
+ bytes_per_element = 4
465
+
466
+ full_real = n_real * n_genes * bytes_per_element / 1e6
467
+ full_gen = n_gen * n_genes * bytes_per_element / 1e6
468
+
469
+ # Average condition size
470
+ n_conditions = len(self.get_common_conditions())
471
+ avg_per_condition = (n_real + n_gen) / max(n_conditions, 1)
472
+ per_condition = avg_per_condition * n_genes * bytes_per_element / 1e6
473
+
474
+ per_batch = self.batch_size * 2 * n_genes * bytes_per_element / 1e6
475
+
476
+ return {
477
+ "full_load_mb": full_real + full_gen,
478
+ "per_condition_mb": per_condition,
479
+ "per_batch_mb": per_batch,
480
+ "metadata_mb": (n_real + n_gen) * 100 / 1e6, # rough obs estimate
481
+ }
482
+
483
+ def close(self):
484
+ """Close backed file handles if any."""
485
+ if self._real is not None and hasattr(self._real, 'file'):
486
+ try:
487
+ self._real.file.close()
488
+ except:
489
+ pass
490
+ if self._generated is not None and hasattr(self._generated, 'file'):
491
+ try:
492
+ self._generated.file.close()
493
+ except:
494
+ pass
495
+
496
+ def __enter__(self):
497
+ """Context manager entry."""
498
+ self.initialize()
499
+ return self
500
+
501
+ def __exit__(self, exc_type, exc_val, exc_tb):
502
+ """Context manager exit."""
503
+ self.close()
504
+
505
+ def __repr__(self) -> str:
506
+ if not self._is_initialized:
507
+ return f"LazyGeneExpressionDataLoader(not initialized, backed={self.use_backed})"
508
+
509
+ return (
510
+ f"LazyGeneExpressionDataLoader("
511
+ f"real={len(self._real_obs)}x{len(self._real_var_names)}, "
512
+ f"gen={len(self._generated_obs)}x{len(self._generated_var_names)}, "
513
+ f"common_genes={self.n_genes}, "
514
+ f"batch_size={self.batch_size})"
515
+ )
516
+
517
+
518
+ def load_data_lazy(
519
+ real_path: Union[str, Path],
520
+ generated_path: Union[str, Path],
521
+ condition_columns: List[str],
522
+ split_column: Optional[str] = None,
523
+ batch_size: int = 256,
524
+ use_backed: bool = False,
525
+ **kwargs
526
+ ) -> LazyGeneExpressionDataLoader:
527
+ """
528
+ Convenience function to create a lazy data loader.
529
+
530
+ Parameters
531
+ ----------
532
+ real_path : str or Path
533
+ Path to real data h5ad file
534
+ generated_path : str or Path
535
+ Path to generated data h5ad file
536
+ condition_columns : List[str]
537
+ Columns to match between datasets
538
+ split_column : str, optional
539
+ Column indicating train/test split
540
+ batch_size : int
541
+ Maximum samples per batch
542
+ use_backed : bool
543
+ Use memory-mapped access for very large files
544
+ **kwargs
545
+ Additional arguments for LazyGeneExpressionDataLoader
546
+
547
+ Returns
548
+ -------
549
+ LazyGeneExpressionDataLoader
550
+ Initialized lazy data loader
551
+ """
552
+ loader = LazyGeneExpressionDataLoader(
553
+ real_path=real_path,
554
+ generated_path=generated_path,
555
+ condition_columns=condition_columns,
556
+ split_column=split_column,
557
+ batch_size=batch_size,
558
+ use_backed=use_backed,
559
+ **kwargs
560
+ )
561
+ loader.initialize()
562
+ return loader