gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
gsMap/setup.py ADDED
@@ -0,0 +1,5 @@
1
+ #!/usr/bin/env python
2
+ import setuptools
3
+
4
+ if __name__ == "__main__":
5
+ setuptools.setup(name="gsMap")
File without changes
@@ -0,0 +1,656 @@
1
+ import glob
2
+ import logging
3
+ import os
4
+ from pathlib import Path
5
+
6
+ import anndata as ad
7
+ import numpy as np
8
+ import pandas as pd
9
+ import psutil
10
+ import pyarrow.feather as feather
11
+ from statsmodels.stats.multitest import multipletests
12
+
13
+ from gsMap.config import SpatialLDSCConfig
14
+ from gsMap.latent2gene.memmap_io import MemMapDense
15
+
16
+ logger = logging.getLogger("gsMap.spatial_ldsc.io")
17
+
18
+
19
+ def _read_sumstats(fh, alleles=False, dropna=False):
20
+ """Parse GWAS summary statistics."""
21
+ fh = str(fh)
22
+ logger.info(f"Reading summary statistics from {fh} ...")
23
+
24
+ # Determine compression type
25
+ compression = None
26
+ if fh.endswith("gz"):
27
+ compression = "gzip"
28
+ elif fh.endswith("bz2"):
29
+ compression = "bz2"
30
+
31
+ # Define columns and dtypes
32
+ dtype_dict = {"SNP": str, "Z": float, "N": float, "A1": str, "A2": str}
33
+ usecols = ["SNP", "Z", "N"]
34
+ if alleles:
35
+ usecols += ["A1", "A2"]
36
+
37
+ # Read the file
38
+ try:
39
+ sumstats = pd.read_csv(
40
+ fh,
41
+ sep=r"\s+",
42
+ na_values=".",
43
+ usecols=usecols,
44
+ dtype=dtype_dict,
45
+ compression=compression,
46
+ )
47
+ except (AttributeError, ValueError) as e:
48
+ logger.error(f"Failed to parse sumstats file: {str(e.args)}")
49
+ raise ValueError("Improperly formatted sumstats file: " + str(e.args)) from e
50
+
51
+ # Drop NA values if specified
52
+ if dropna:
53
+ sumstats = sumstats.dropna(how="any")
54
+
55
+ logger.info(f"Read summary statistics for {len(sumstats)} SNPs.")
56
+
57
+ # Drop duplicates
58
+ m = len(sumstats)
59
+ sumstats = sumstats.drop_duplicates(subset="SNP")
60
+ if m > len(sumstats):
61
+ logger.info(f"Dropped {m - len(sumstats)} SNPs with duplicated rs numbers.")
62
+
63
+ return sumstats
64
+
65
+
66
+ def _read_chr_files(base_path, suffix, expected_count=22):
67
+ """Read chromosome files using glob pattern matching."""
68
+ # Create the pattern to search for files
69
+ file_pattern = f"{base_path}[1-9]*{suffix}*"
70
+
71
+ # Find all matching files
72
+ all_files = glob.glob(file_pattern)
73
+
74
+ # Extract chromosome numbers
75
+ chr_files = []
76
+ for file in all_files:
77
+ try:
78
+ # Extract the chromosome number from filename
79
+ file_name = os.path.basename(file)
80
+ base_name = os.path.basename(base_path)
81
+ chr_part = file_name.replace(base_name, "").split(suffix)[0]
82
+ chr_num = int(chr_part)
83
+ if 1 <= chr_num <= expected_count:
84
+ chr_files.append((chr_num, file))
85
+ except (ValueError, IndexError):
86
+ continue
87
+
88
+ # Check if we have the expected number of chromosome files
89
+ if len(chr_files) != expected_count:
90
+ logger.warning(
91
+ f"❗ SEVERE WARNING ❗ Expected {expected_count} chromosome files, but found {len(chr_files)}! "
92
+ f"⚠️ For human GWAS data, all 22 autosomes must be present. Please verify your input files."
93
+ )
94
+
95
+ # Sort by chromosome number and return file paths
96
+ chr_files.sort()
97
+ return [file for _, file in chr_files]
98
+
99
+
100
+ def _read_file(file_path):
101
+ """Read a file based on its format/extension."""
102
+ file_path = str(file_path)
103
+ try:
104
+ if file_path.endswith(".feather"):
105
+ return pd.read_feather(file_path)
106
+ elif file_path.endswith(".parquet"):
107
+ return pd.read_parquet(file_path)
108
+ elif file_path.endswith(".gz"):
109
+ return pd.read_csv(file_path, compression="gzip", sep="\t")
110
+ elif file_path.endswith(".bz2"):
111
+ return pd.read_csv(file_path, compression="bz2", sep="\t")
112
+ else:
113
+ return pd.read_csv(file_path, sep="\t")
114
+ except Exception as e:
115
+ logger.error(f"Failed to read file {file_path}: {str(e)}")
116
+ raise
117
+
118
+
119
+ def _read_ref_ld_v2(ld_file):
120
+ """Read reference LD scores for all chromosomes."""
121
+ suffix = ".l2.ldscore"
122
+ logger.info(f"Reading LD score annotations from {ld_file}[1-22]{suffix}...")
123
+
124
+ # Get the chromosome files
125
+ chr_files = _read_chr_files(ld_file, suffix)
126
+
127
+ # Read and concatenate all files
128
+ df_list = [_read_file(file) for file in chr_files]
129
+
130
+ if not df_list:
131
+ logger.error(f"No LD score files found matching pattern: {ld_file}*{suffix}*")
132
+ raise FileNotFoundError(f"No LD score files found matching pattern: {ld_file}*{suffix}*")
133
+
134
+ ref_ld = pd.concat(df_list, axis=0)
135
+ logger.info(f"Loaded {len(ref_ld)} SNPs from LD score files")
136
+
137
+ # Set SNP as index
138
+ if "index" in ref_ld.columns:
139
+ ref_ld.rename(columns={"index": "SNP"}, inplace=True)
140
+ if "SNP" in ref_ld.columns:
141
+ ref_ld.set_index("SNP", inplace=True)
142
+
143
+ return ref_ld
144
+
145
+
146
+ def _read_w_ld(w_ld_dir):
147
+ """Read LD weights for all chromosomes."""
148
+ suffix = ".l2.ldscore"
149
+ w_file_pattern = str(Path(w_ld_dir) / "weights.")
150
+ logger.info(f"Reading LD score annotations from {w_file_pattern}[1-22]{suffix}...")
151
+
152
+ chr_files = _read_chr_files(w_file_pattern, suffix)
153
+
154
+ if not chr_files:
155
+ logger.error(f"No LD score files found matching pattern: {w_file_pattern}*{suffix}* inside {w_ld_dir}")
156
+ raise FileNotFoundError(f"No LD score files found matching pattern: {w_file_pattern}*{suffix}* inside {w_ld_dir}")
157
+
158
+ # Read and process each file
159
+ w_array = []
160
+ for file in chr_files:
161
+ x = _read_file(file)
162
+
163
+ # Sort if possible
164
+ if "CHR" in x.columns and "BP" in x.columns:
165
+ x = x.sort_values(by=["CHR", "BP"])
166
+
167
+ # Drop unnecessary columns
168
+ columns_to_drop = ["MAF", "CM", "Gene", "TSS", "CHR", "BP"]
169
+ columns_to_drop = [col for col in columns_to_drop if col in x.columns]
170
+ if columns_to_drop:
171
+ x = x.drop(columns=columns_to_drop, axis=1)
172
+
173
+ w_array.append(x)
174
+
175
+ # Concatenate and set column names
176
+ w_ld = pd.concat(w_array, axis=0)
177
+ logger.info(f"Loaded {len(w_ld)} SNPs from LD weight files")
178
+
179
+ # Set column names
180
+ w_ld.columns = (
181
+ ["SNP", "LD_weights"] + list(w_ld.columns[2:])
182
+ if len(w_ld.columns) > 2
183
+ else ["SNP", "LD_weights"]
184
+ )
185
+
186
+ return w_ld
187
+
188
+
189
+ # ============================================================================
190
+ # Memory monitoring
191
+ # ============================================================================
192
+
193
+ def log_memory_usage(message=""):
194
+ """Log current memory usage."""
195
+ try:
196
+ process = psutil.Process()
197
+ mem_info = process.memory_info()
198
+ rss_gb = mem_info.rss / 1024**3
199
+ logger.debug(f"Memory usage {message}: {rss_gb:.2f} GB")
200
+ return rss_gb
201
+ except:
202
+ return 0.0
203
+
204
+
205
+ # ============================================================================
206
+ # Data loading and preparation
207
+ # ============================================================================
208
+
209
+ def load_common_resources(config: SpatialLDSCConfig) -> tuple[pd.DataFrame, pd.DataFrame, ad.AnnData]:
210
+ """
211
+ Load resources common to all traits (weights, baseline, SNP-gene matrix).
212
+ Returns (baseline_ld, w_ld, snp_gene_weight_adata)
213
+ baseline_ld and w_ld are guaranteed to have the same index (intersection of available SNPs).
214
+ """
215
+ logger.info("Loading common resources...")
216
+ log_memory_usage("before loading common resources")
217
+
218
+ # 1. Load weights
219
+ w_ld = _read_w_ld(config.w_ld_dir)
220
+ w_ld.set_index("SNP", inplace=True)
221
+
222
+ # 2. Load SNP-gene weight matrix
223
+ logger.info(f"Loading SNP-gene weight matrix from {config.snp_gene_weight_adata_path}...")
224
+ snp_gene_weight_adata = ad.read_h5ad(config.snp_gene_weight_adata_path)
225
+
226
+ # 3. Construct baseline LD from snp_gene_weight_adata
227
+ X = snp_gene_weight_adata.X
228
+
229
+ # Compute base annotations
230
+ all_gene = X[:, :-1].sum(axis=1)
231
+ base = all_gene + X[:, -1]
232
+
233
+ baseline_ld = pd.DataFrame(
234
+ np.column_stack((base, all_gene)),
235
+ columns=["base", "all_gene"],
236
+ index=snp_gene_weight_adata.obs_names
237
+ )
238
+ baseline_ld.index.name = "SNP"
239
+ logger.info(f"Constructed baseline LD from SNP-gene weights. Shape: {baseline_ld.shape}")
240
+
241
+ # 4. Find common SNPs between baseline and weights
242
+ common_snps = baseline_ld.index.intersection(w_ld.index)
243
+
244
+ # 5. Load additional baselines and update common SNPs
245
+ if config.additional_baseline_h5ad_path_list:
246
+ logger.info(f"Loading {len(config.additional_baseline_h5ad_path_list)} additional baseline annotations...")
247
+
248
+ # We need to process additional baselines carefully to maintain the dataframe structure
249
+ # First, ensure we only work with currently common SNPs
250
+ baseline_ld = baseline_ld.loc[common_snps]
251
+
252
+ for i, h5ad_path in enumerate(config.additional_baseline_h5ad_path_list):
253
+ logger.info(f"Loading additional baseline {i+1}: {h5ad_path}")
254
+ add_adata = ad.read_h5ad(h5ad_path)
255
+
256
+ # Intersect with current common SNPs
257
+ common_in_add = common_snps.intersection(add_adata.obs_names)
258
+
259
+ if len(common_in_add) < len(common_snps):
260
+ logger.warning(f"Additional baseline {h5ad_path} only has {len(common_in_add)}/{len(common_snps)} common SNPs. Intersecting...")
261
+ common_snps = common_in_add
262
+ baseline_ld = baseline_ld.loc[common_snps]
263
+
264
+ # Extract data
265
+ add_X = add_adata[common_snps].X
266
+ if hasattr(add_X, "toarray"):
267
+ add_X = add_X.toarray()
268
+
269
+ add_df = pd.DataFrame(
270
+ add_X,
271
+ index=common_snps,
272
+ columns=add_adata.var_names
273
+ )
274
+
275
+ # Concatenate
276
+ baseline_ld = pd.concat([baseline_ld, add_df], axis=1)
277
+
278
+ # Final subsetting
279
+ baseline_ld = baseline_ld.loc[common_snps]
280
+ w_ld = w_ld.loc[common_snps]
281
+
282
+ log_memory_usage("after loading common resources")
283
+ return baseline_ld, w_ld, snp_gene_weight_adata
284
+
285
+
286
+ def prepare_trait_data(config: SpatialLDSCConfig,
287
+ trait_name: str,
288
+ sumstats_file: str,
289
+ baseline_ld: pd.DataFrame,
290
+ w_ld: pd.DataFrame,
291
+ snp_gene_weight_adata: ad.AnnData) -> tuple[dict, pd.Index]:
292
+ """
293
+ Prepare data for a specific trait using pre-loaded common resources.
294
+ """
295
+ logger.info(f"Preparing data for {trait_name}...")
296
+
297
+ # Load and process summary statistics
298
+ sumstats = _read_sumstats(fh=sumstats_file, alleles=False, dropna=False)
299
+ sumstats.set_index("SNP", inplace=True)
300
+ sumstats = sumstats.astype(np.float32)
301
+
302
+ # Filter by chi-squared
303
+ chisq_max = config.chisq_max
304
+ if chisq_max is None:
305
+ chisq_max = max(0.001 * sumstats.N.max(), 80)
306
+ sumstats["chisq"] = sumstats.Z ** 2
307
+
308
+ # Calculate genomic control lambda (λGC) before filtering
309
+ lambda_gc = np.median(sumstats.chisq) / 0.4559364
310
+ logger.info(f"Lambda GC (genomic control λ): {lambda_gc:.4f}")
311
+
312
+ sumstats = sumstats[sumstats.chisq < chisq_max]
313
+ logger.info(f"Filtered to {len(sumstats)} SNPs with chi^2 < {chisq_max}")
314
+
315
+ # Intersect trait sumstats with common resources
316
+ common_snps = baseline_ld.index.intersection(sumstats.index)
317
+ logger.info(f"Common SNPs: {len(common_snps)}")
318
+
319
+ if len(common_snps) < 200000:
320
+ logger.warning(f"WARNING: Only {len(common_snps)} common SNPs")
321
+
322
+ # Get SNP positions relative to the original snp_gene_weight_adata
323
+ # This is crucial for QuickMode to know which rows of the weight matrix to pick
324
+ snp_positions = snp_gene_weight_adata.obs_names.get_indexer(common_snps)
325
+
326
+ # Subset data
327
+ trait_baseline_ld = baseline_ld.loc[common_snps]
328
+ trait_w_ld = w_ld.loc[common_snps]
329
+ trait_sumstats = sumstats.loc[common_snps]
330
+
331
+ # Prepare data dictionary
332
+ data = {
333
+ 'baseline_ld': trait_baseline_ld,
334
+ 'baseline_ld_sum': trait_baseline_ld.sum(axis=1).values.astype(np.float32),
335
+ 'w_ld': trait_w_ld.LD_weights.values.astype(np.float32),
336
+ 'sumstats': trait_sumstats,
337
+ 'chisq': trait_sumstats.chisq.values.astype(np.float32),
338
+ 'N': trait_sumstats.N.values.astype(np.float32),
339
+ 'Nbar': np.float32(trait_sumstats.N.mean()),
340
+ 'snp_positions': snp_positions
341
+ }
342
+
343
+ return data, common_snps
344
+
345
+ def load_marker_scores_memmap_format(config: SpatialLDSCConfig) -> ad.AnnData:
346
+ """
347
+ Load marker scores memmap and wrap it in an AnnData object with metadata
348
+ from a reference h5ad file using configuration.
349
+
350
+ Args:
351
+ config: SpatialLDSCConfig containing paths and settings
352
+
353
+ Returns:
354
+ AnnData object with X backed by the memory map
355
+ """
356
+ memmap_path = Path(config.marker_scores_memmap_path)
357
+ metadata_path = Path(config.concatenated_latent_adata_path)
358
+ tmp_dir = config.memmap_tmp_dir
359
+ mode = 'r' # Read-only mode for loading
360
+
361
+ if not metadata_path.exists():
362
+ raise FileNotFoundError(f"Metadata file not found: {metadata_path}")
363
+
364
+ # check complete
365
+ is_complete, _ = MemMapDense.check_complete(memmap_path)
366
+ if not is_complete:
367
+ raise ValueError(f"Marker score at {memmap_path} is incomplete or corrupted. Please recompute.")
368
+
369
+ # Load metadata source in backed mode
370
+ logger.info(f"Loading metadata from {metadata_path}")
371
+ src_adata = ad.read_h5ad(metadata_path, backed='r')
372
+
373
+ # Determine shape from metadata
374
+ shape = (src_adata.n_obs, src_adata.n_vars)
375
+
376
+ # Initialize MemMapDense
377
+ mm = MemMapDense(
378
+ memmap_path,
379
+ shape=shape,
380
+ mode=mode,
381
+ tmp_dir=tmp_dir
382
+ )
383
+
384
+ logger.info("Constructing AnnData wrapper...")
385
+ adata = ad.AnnData(
386
+ X=mm.memmap,
387
+ obs=src_adata.obs.copy(),
388
+ var=src_adata.var.copy(),
389
+ uns=src_adata.uns.copy(),
390
+ obsm=src_adata.obsm.copy(),
391
+ varm=src_adata.varm.copy()
392
+ )
393
+ # Close metadata source to release file handle
394
+ if src_adata.isbacked:
395
+ src_adata.file.close()
396
+
397
+ # Attach the manager to allow access to MemMapDense methods
398
+ adata.uns['memmap_manager'] = mm
399
+
400
+ return adata
401
+
402
+ def generate_expected_output_filename(config: SpatialLDSCConfig, trait_name: str) -> str | None:
403
+
404
+ base_name = f"{config.project_name}_{trait_name}"
405
+
406
+ # If we have cell indices range, include it in filename
407
+ if config.cell_indices_range:
408
+ start_cell, end_cell = config.cell_indices_range
409
+ return f"{base_name}_cells_{start_cell}_{end_cell}.csv.gz"
410
+
411
+ # If sample filter is set, filename will include sample info
412
+ # but we can't predict exact start/end without loading data
413
+ # For now, just check the simple complete case
414
+ # If using sample filter, we might not be able to easily predict output name
415
+ # without knowing the sample filtering results.
416
+ # But usually we don't skip in that case or logic is handled by caller.
417
+ if config.sample_filter:
418
+ return None
419
+
420
+ # Default case: complete coverage
421
+ return f"{base_name}.csv.gz"
422
+
423
+
424
+ def log_existing_result_statistics(result_path: Path, trait_name: str):
425
+
426
+ try:
427
+ # Read the existing result
428
+ logger.info(f"Reading existing result from: {result_path}")
429
+ df = pd.read_csv(result_path, compression='gzip')
430
+
431
+ n_spots = len(df)
432
+ bonferroni_threshold = 0.05 / n_spots
433
+ n_bonferroni_sig = (df['p'] < bonferroni_threshold).sum()
434
+
435
+ # FDR correction
436
+ reject, _, _, _ = multipletests(
437
+ df['p'], alpha=0.001, method='fdr_bh'
438
+ )
439
+ n_fdr_sig = reject.sum()
440
+
441
+ logger.info("=" * 70)
442
+ logger.info(f"EXISTING RESULT SUMMARY - {trait_name}")
443
+ logger.info("=" * 70)
444
+ logger.info(f"Total spots: {n_spots:,}")
445
+ logger.info(f"Max -log10(p): {df['neg_log10_p'].max():.2f}")
446
+ logger.info("-" * 70)
447
+ logger.info(f"Nominally significant (p < 0.05): {(df['p'] < 0.05).sum():,}")
448
+ logger.info(f"Bonferroni threshold: {bonferroni_threshold:.2e}")
449
+ logger.info(f"Bonferroni significant: {n_bonferroni_sig:,}")
450
+ logger.info(f"FDR significant (alpha=0.001): {n_fdr_sig:,}")
451
+ logger.info("=" * 70)
452
+
453
+ except Exception as e:
454
+ logger.warning(f"Could not read existing result statistics: {e}")
455
+
456
+
457
+
458
+
459
+ class LazyFeatherX:
460
+ """
461
+ A proxy for the 'X' matrix that slices directly from the Feather file
462
+ via memory mapping without loading the full data.
463
+ """
464
+
465
+ def __init__(self, arrow_table, feature_names, transpose=False):
466
+ self.table = arrow_table
467
+ self.feature_names = feature_names
468
+ self.transpose = transpose
469
+ if not transpose:
470
+ # Standard AnnData shape: (n_obs, n_vars)
471
+ self.shape = (self.table.num_rows, len(self.feature_names))
472
+ else:
473
+ # Transposed shape: (n_obs, n_vars) where n_obs = num_features, n_vars = num_rows
474
+ self.shape = (len(self.feature_names), self.table.num_rows)
475
+
476
+ def __getitem__(self, key):
477
+ """
478
+ Handles slicing: adata.X[start:end], adata.X[start:end, :], etc.
479
+ """
480
+ # Normalize the key to always be a tuple of (row_slice, col_slice)
481
+ if not isinstance(key, tuple):
482
+ row_key = key
483
+ col_key = slice(None) # Select all columns
484
+ else:
485
+ row_key, col_key = key
486
+
487
+ if not self.transpose:
488
+ # --- 1. Handle Row Slicing ---
489
+ if isinstance(row_key, slice):
490
+ start = row_key.start or 0
491
+ stop = row_key.stop or self.shape[0]
492
+ step = row_key.step or 1
493
+
494
+ # Calculate length based on step (simplified for step=1)
495
+ # For complex steps, we might need more logic, but basic slicing:
496
+ if step == 1:
497
+ length = stop - start
498
+ sliced_table = self.table.slice(offset=start, length=length)
499
+ else:
500
+ # Fallback for stepped slicing: Read range, then step in pandas
501
+ # (Slightly less efficient but works)
502
+ length = stop - start
503
+ sliced_table = self.table.slice(offset=start, length=length)
504
+ # We will handle the stepping after conversion
505
+
506
+ elif isinstance(row_key, int):
507
+ # Single row request
508
+ sliced_table = self.table.slice(offset=row_key, length=1)
509
+ else:
510
+ raise NotImplementedError("Only slice objects (start:stop) or integers are supported for rows.")
511
+
512
+ # --- 2. Handle Column Slicing ---
513
+ final_cols = self.feature_names
514
+
515
+ # Helper to map integer indices to column names
516
+ def get_col_names_by_indices(indices):
517
+ return [self.feature_names[i] for i in indices]
518
+
519
+ if isinstance(col_key, slice):
520
+ final_cols = self.feature_names[col_key]
521
+ sliced_table = sliced_table.select(final_cols)
522
+
523
+ elif isinstance(col_key, list | np.ndarray):
524
+ # Check if it's integers or strings
525
+ if len(col_key) > 0:
526
+ if isinstance(col_key[0], int | np.integer):
527
+ final_cols = get_col_names_by_indices(col_key)
528
+ else:
529
+ # Assume strings
530
+ final_cols = col_key
531
+ sliced_table = sliced_table.select(final_cols)
532
+
533
+ elif isinstance(col_key, int):
534
+ final_cols = [self.feature_names[col_key]]
535
+ sliced_table = sliced_table.select(final_cols)
536
+
537
+ # --- 3. Materialize to NumPy ---
538
+ # FIX: We convert to Pandas first, because pyarrow.Table
539
+ # doesn't have a direct to_numpy() for 2D structures.
540
+ df = sliced_table.to_pandas()
541
+
542
+ # If we had a row step > 1, apply it now on the small DataFrame
543
+ if isinstance(row_key, slice) and row_key.step and row_key.step != 1:
544
+ df = df.iloc[::row_key.step]
545
+
546
+ return df.to_numpy()
547
+
548
+ else:
549
+ # Transposed mode:
550
+ # row_key selects from feature_names (which were columns in feather)
551
+ # col_key selects from table rows
552
+
553
+ # --- 1. Determine which columns to read (Observations) ---
554
+ if isinstance(row_key, slice):
555
+ selected_cols = self.feature_names[row_key]
556
+ elif isinstance(row_key, list | np.ndarray):
557
+ if len(row_key) > 0 and isinstance(row_key[0], int | np.integer):
558
+ selected_cols = [self.feature_names[i] for i in row_key]
559
+ else:
560
+ selected_cols = row_key
561
+ elif isinstance(row_key, int):
562
+ selected_cols = [self.feature_names[row_key]]
563
+ else:
564
+ raise NotImplementedError("Row key type not supported for transposed FeatherAnnData.")
565
+
566
+ # --- 2. Determine which rows to read (Variables) ---
567
+ if isinstance(col_key, slice):
568
+ start = col_key.start or 0
569
+ stop = col_key.stop or self.table.num_rows
570
+ length = stop - start
571
+ sliced_table = self.table.slice(offset=start, length=length)
572
+ # Select columns and convert
573
+ df = sliced_table.select(selected_cols).to_pandas()
574
+ # Apply step for rows if needed
575
+ if col_key.step and col_key.step != 1:
576
+ df = df.iloc[::col_key.step]
577
+ elif isinstance(col_key, int):
578
+ sliced_table = self.table.slice(offset=col_key, length=1)
579
+ df = sliced_table.select(selected_cols).to_pandas()
580
+ elif isinstance(col_key, list | np.ndarray):
581
+ # Use take for arbitrary row selection
582
+ # Note: col_key should be integer indices
583
+ sliced_table = self.table.take(col_key)
584
+ df = sliced_table.select(selected_cols).to_pandas()
585
+ else:
586
+ raise NotImplementedError("Column key type not supported for transposed FeatherAnnData.")
587
+
588
+ # The dataframe 'df' has table rows as index and selected_cols as columns.
589
+ # In transposed mode, we want selected_cols as rows and table rows as columns.
590
+ return df.to_numpy().T
591
+
592
+
593
+ class FeatherAnnData:
594
+ """
595
+ A minimal AnnData-like class backed by a Feather file.
596
+ Mimics the behavior of anndata.AnnData without loading X into memory.
597
+ """
598
+
599
+ def __init__(self, file_path, index_col=None, transpose=False):
600
+ # 1. Open with memory mapping (Zero RAM usage for data)
601
+ self._table = feather.read_table(file_path, memory_map=True)
602
+
603
+ # 2. Setup Index (Obs Names) and Columns (Var Names)
604
+ all_cols = self._table.column_names
605
+
606
+ if transpose:
607
+ if index_col:
608
+ # Genes are rows in file, but we want them as vars
609
+ self.var_names = self._table.column(index_col).to_pylist()
610
+ # Cells are columns in file, we want them as obs
611
+ self.obs_names = [c for c in all_cols if c != index_col]
612
+ else:
613
+ self.obs_names = all_cols
614
+ self.var_names = [str(i) for i in range(self._table.num_rows)]
615
+ else:
616
+ if index_col:
617
+ # Load the index column specifically
618
+ self.obs_names = self._table.column(index_col).to_pylist()
619
+ # The variables (genes) are all columns MINUS the index column
620
+ self.var_names = [c for c in all_cols if c != index_col]
621
+ else:
622
+ # Fallback: Assume all columns are genes
623
+ self.obs_names = [str(i) for i in range(self._table.num_rows)]
624
+ self.var_names = all_cols
625
+
626
+ # 3. Setup Metadata DataFrames
627
+ self.obs = pd.DataFrame(index=self.obs_names)
628
+ self.var = pd.DataFrame(index=self.var_names)
629
+
630
+ # 4. Setup Attributes (n_obs, n_vars)
631
+ self.n_obs = len(self.obs_names)
632
+ self.n_vars = len(self.var_names)
633
+
634
+ # 5. Setup the Lazy X
635
+ if transpose:
636
+ # When transposed, 'feature_names' in LazyFeatherX refers to the obs columns
637
+ self.X = LazyFeatherX(self._table, self.obs_names, transpose=True)
638
+ else:
639
+ # Standard: 'feature_names' refers to the var columns
640
+ self.X = LazyFeatherX(self._table, self.var_names, transpose=False)
641
+
642
+ self.uns = {}
643
+
644
+ # 6. Setup Shape
645
+ self.shape = (self.n_obs, self.n_vars)
646
+
647
+ def __repr__(self):
648
+ return (f"FeatherAnnData object with n_obs × n_vars = {self.n_obs} × {self.n_vars}\n"
649
+ f" obs: {list(self.obs.columns)}\n"
650
+ f" var: {list(self.var.columns)}\n"
651
+ f" uns: (Empty)\n"
652
+ f" obsm: (Empty)\n"
653
+ f" varm: (Empty)\n"
654
+ f" layers: (Empty)\n"
655
+ f" Backing: PyArrow Memory Mapping (Read-Only)")
656
+