gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,912 @@
1
+ """
2
+ Unified spatial LDSC processor combining chunk production, parallel loading, and result accumulation.
3
+ """
4
+
5
+ import gc
6
+ import logging
7
+ import queue
8
+ import sys
9
+ import threading
10
+ import time
11
+ import traceback
12
+ from collections import deque
13
+ from dataclasses import dataclass
14
+ from pathlib import Path
15
+
16
+ import anndata as ad
17
+ import jax.numpy as jnp
18
+ import numpy as np
19
+ import pandas as pd
20
+ from rich.console import Console
21
+ from rich.progress import (
22
+ BarColumn,
23
+ MofNCompleteColumn,
24
+ Progress,
25
+ SpinnerColumn,
26
+ TaskProgressColumn,
27
+ TextColumn,
28
+ TimeElapsedColumn,
29
+ TimeRemainingColumn,
30
+ )
31
+ from scipy.stats import norm
32
+ from statsmodels.stats.multitest import multipletests
33
+
34
+ from gsMap.config import SpatialLDSCConfig
35
+
36
+ from .io import prepare_trait_data
37
+
38
+ logger = logging.getLogger("gsMap.spatial_ldsc_processor")
39
+
40
+
41
+ def prepare_snp_data_for_blocks(data: dict, n_blocks: int) -> dict:
42
+ """Prepare SNP-related data arrays for equal-sized blocks."""
43
+ if 'chisq' in data:
44
+ n_snps = len(data['chisq'])
45
+ elif 'N' in data:
46
+ n_snps = len(data['N'])
47
+ else:
48
+ raise ValueError("Cannot determine number of SNPs from data")
49
+
50
+ block_size = n_snps // n_blocks
51
+ n_snps_used = block_size * n_blocks
52
+ n_dropped = n_snps - n_snps_used
53
+
54
+ if n_dropped > 0:
55
+ logger.info(f"Truncating SNP data: dropping {n_dropped} SNPs "
56
+ f"({n_dropped/n_snps*100:.3f}%) for {n_blocks} blocks of size {block_size}")
57
+
58
+ truncated = {}
59
+ snp_keys = ['baseline_ld_sum', 'w_ld', 'chisq', 'N', 'snp_positions']
60
+
61
+ for key, value in data.items():
62
+ if key in snp_keys and isinstance(value, np.ndarray | jnp.ndarray):
63
+ truncated[key] = value[:n_snps_used]
64
+ elif key == 'baseline_ld':
65
+ truncated[key] = value.iloc[:n_snps_used]
66
+ else:
67
+ truncated[key] = value
68
+
69
+ truncated['block_size'] = block_size
70
+ truncated['n_blocks'] = n_blocks
71
+ truncated['n_snps_used'] = n_snps_used
72
+ truncated['n_snps_original'] = n_snps
73
+
74
+ return truncated
75
+
76
+
77
+ @dataclass
78
+ class ComponentThroughput:
79
+ """Track throughput for individual pipeline components"""
80
+ total_batches: int = 0
81
+ total_time: float = 0.0
82
+ last_batch_time: float = 0.0
83
+
84
+ def record_batch(self, elapsed_time: float):
85
+ """Record a batch completion"""
86
+ self.total_batches += 1
87
+ self.total_time += elapsed_time
88
+ self.last_batch_time = elapsed_time
89
+
90
+ @property
91
+ def average_time(self) -> float:
92
+ """Average time per batch"""
93
+ if self.total_batches > 0:
94
+ return self.total_time / self.total_batches
95
+ return 0.0
96
+
97
+ @property
98
+ def throughput(self) -> float:
99
+ """Batches per second"""
100
+ if self.average_time > 0:
101
+ return 1.0 / self.average_time
102
+ return 0.0
103
+
104
+
105
+ class ParallelLDScoreReader:
106
+ """Multi-threaded reader for fetching LD score chunks from memory-mapped marker scores"""
107
+
108
+ def __init__(
109
+ self,
110
+ processor, # Reference to SpatialLDSCProcessor for data access
111
+ num_workers: int = 4,
112
+ output_queue: queue.Queue = None
113
+ ):
114
+ """Initialize reader pool"""
115
+ self.processor = processor
116
+ self.num_workers = num_workers
117
+
118
+ # Queues for communication
119
+ self.read_queue = queue.Queue()
120
+ self.result_queue = output_queue if output_queue else queue.Queue(maxsize=num_workers * 8)
121
+
122
+ # Throughput tracking
123
+ self.throughput = ComponentThroughput()
124
+ self.throughput_lock = threading.Lock()
125
+
126
+ # Exception handling
127
+ self.exception_queue = queue.Queue()
128
+ self.has_error = threading.Event()
129
+
130
+ # Start worker threads
131
+ self.workers = []
132
+ self.stop_workers = threading.Event()
133
+ self._start_workers()
134
+
135
+ def _start_workers(self):
136
+ """Start worker threads"""
137
+ for i in range(self.num_workers):
138
+ worker = threading.Thread(
139
+ target=self._worker,
140
+ args=(i,),
141
+ daemon=True
142
+ )
143
+ worker.start()
144
+ self.workers.append(worker)
145
+ logger.info(f"Started {self.num_workers} reader threads")
146
+
147
+ def _worker(self, worker_id: int):
148
+ """Worker thread for reading LD score chunks"""
149
+ logger.debug(f"Reader worker {worker_id} started")
150
+
151
+ while not self.stop_workers.is_set():
152
+ try:
153
+ # Get chunk request
154
+ item = self.read_queue.get(timeout=1)
155
+ if item is None:
156
+ break
157
+
158
+ chunk_idx = item
159
+
160
+ # Track timing
161
+ start_time = time.time()
162
+
163
+ # Fetch the chunk using processor's method
164
+ ldscore, spot_names, abs_start, abs_end = self.processor._fetch_ldscore_chunk(chunk_idx)
165
+
166
+
167
+ # Track throughput
168
+ elapsed = time.time() - start_time
169
+ with self.throughput_lock:
170
+ self.throughput.record_batch(elapsed)
171
+
172
+ # Put result for computer
173
+ self.result_queue.put({
174
+ 'chunk_idx': chunk_idx,
175
+ 'ldscore': ldscore,
176
+ 'spot_names': spot_names,
177
+ 'abs_start': abs_start,
178
+ 'abs_end': abs_end,
179
+ 'worker_id': worker_id,
180
+ 'success': True
181
+ })
182
+ self.read_queue.task_done()
183
+
184
+ except queue.Empty:
185
+ continue
186
+ except Exception as e:
187
+ error_trace = traceback.format_exc()
188
+ logger.error(f"Reader worker {worker_id} error on chunk {chunk_idx}: {e}\nTraceback:\n{error_trace}")
189
+ self.exception_queue.put((worker_id, e, error_trace))
190
+ self.has_error.set()
191
+ self.result_queue.put({
192
+ 'chunk_idx': chunk_idx if 'chunk_idx' in locals() else -1,
193
+ 'worker_id': worker_id,
194
+ 'success': False,
195
+ 'error': str(e)
196
+ })
197
+ break
198
+
199
+ logger.debug(f"Reader worker {worker_id} stopped")
200
+
201
+ def submit_chunk(self, chunk_idx: int):
202
+ """Submit chunk for reading"""
203
+ self.read_queue.put(chunk_idx)
204
+
205
+ def get_result(self):
206
+ """Get next completed chunk"""
207
+ return self.result_queue.get()
208
+
209
+ def get_queue_sizes(self):
210
+ """Get current queue sizes for monitoring"""
211
+ return self.read_queue.qsize(), self.result_queue.qsize()
212
+
213
+ def check_errors(self):
214
+ """Check if any worker encountered an error"""
215
+ if self.has_error.is_set():
216
+ try:
217
+ worker_id, exception, error_trace = self.exception_queue.get_nowait()
218
+ raise RuntimeError(f"Reader worker {worker_id} failed: {exception}\nOriginal traceback:\n{error_trace}") from exception
219
+ except queue.Empty:
220
+ raise RuntimeError("Reader worker failed with unknown error")
221
+
222
+ def close(self):
223
+ """Clean up resources"""
224
+ self.stop_workers.set()
225
+ for _ in range(self.num_workers):
226
+ self.read_queue.put(None)
227
+ for worker in self.workers:
228
+ worker.join(timeout=5)
229
+ logger.info("Reader pool closed")
230
+
231
+
232
+ class ParallelLDScoreComputer:
233
+ """Multi-threaded computer for processing LD scores with JAX"""
234
+
235
+ def __init__(
236
+ self,
237
+ processor, # Reference to SpatialLDSCProcessor
238
+ process_chunk_jit_fn, # JIT-compiled processing function
239
+ num_workers: int = 4,
240
+ input_queue: queue.Queue = None
241
+ ):
242
+ """Initialize computer pool"""
243
+ self.processor = processor
244
+ self.process_chunk_jit_fn = process_chunk_jit_fn
245
+ self.num_workers = num_workers
246
+
247
+ # Queues for communication
248
+ self.compute_queue = input_queue if input_queue else queue.Queue(maxsize=num_workers * 2)
249
+
250
+ # Throughput tracking
251
+ self.throughput = ComponentThroughput()
252
+ self.throughput_lock = threading.Lock()
253
+
254
+ # Processing statistics
255
+ self.total_cells_processed = 0
256
+ self.total_chunks_processed = 0
257
+ self.stats_lock = threading.Lock()
258
+
259
+ # Exception handling
260
+ self.exception_queue = queue.Queue()
261
+ self.has_error = threading.Event()
262
+
263
+ # Results storage
264
+ self.results = []
265
+ self.results_lock = threading.Lock()
266
+
267
+ # Prepare static JAX arrays
268
+ self._prepare_jax_arrays()
269
+
270
+ # Start worker threads
271
+ self.workers = []
272
+ self.stop_workers = threading.Event()
273
+ self._start_workers()
274
+
275
+ def _prepare_jax_arrays(self):
276
+ """Prepare static JAX arrays from data_truncated"""
277
+ n_snps_used = self.processor.data_truncated['n_snps_used']
278
+
279
+ baseline_ann = (self.processor.data_truncated['baseline_ld'].values.astype(np.float32) *
280
+ self.processor.data_truncated['N'].reshape(-1, 1).astype(np.float32) /
281
+ self.processor.data_truncated['Nbar'])
282
+ baseline_ann = np.concatenate([baseline_ann,
283
+ np.ones((n_snps_used, 1), dtype=np.float32)], axis=1)
284
+
285
+ # Convert to JAX arrays
286
+ self.baseline_ld_sum_jax = jnp.asarray(self.processor.data_truncated['baseline_ld_sum'], dtype=jnp.float32)
287
+ self.chisq_jax = jnp.asarray(self.processor.data_truncated['chisq'], dtype=jnp.float32)
288
+ self.N_jax = jnp.asarray(self.processor.data_truncated['N'], dtype=jnp.float32)
289
+ self.baseline_ann_jax = jnp.asarray(baseline_ann, dtype=jnp.float32)
290
+ self.w_ld_jax = jnp.asarray(self.processor.data_truncated['w_ld'], dtype=jnp.float32)
291
+ self.Nbar = self.processor.data_truncated['Nbar']
292
+
293
+ del baseline_ann
294
+ gc.collect()
295
+
296
+ def _start_workers(self):
297
+ """Start compute worker threads"""
298
+ for i in range(self.num_workers):
299
+ worker = threading.Thread(
300
+ target=self._compute_worker,
301
+ args=(i,),
302
+ daemon=True
303
+ )
304
+ worker.start()
305
+ self.workers.append(worker)
306
+ logger.info(f"Started {self.num_workers} compute workers")
307
+
308
+ def _compute_worker(self, worker_id: int):
309
+ """Compute worker thread"""
310
+ logger.debug(f"Compute worker {worker_id} started")
311
+
312
+ while not self.stop_workers.is_set():
313
+ try:
314
+ # Get data from reader
315
+ item = self.compute_queue.get(timeout=1)
316
+ if item is None:
317
+ break
318
+
319
+ # Skip failed chunks
320
+ if not item.get('success', False):
321
+ logger.error(f"Skipping chunk {item.get('chunk_idx')} due to read error")
322
+ continue
323
+
324
+ # Unpack data
325
+ chunk_idx = item['chunk_idx']
326
+ ldscore = item['ldscore']
327
+ spot_names = item['spot_names']
328
+ abs_start = item['abs_start']
329
+ abs_end = item['abs_end']
330
+
331
+ # Track timing
332
+ start_time = time.time()
333
+
334
+ # Convert to JAX and process
335
+ spatial_ld_jax = jnp.asarray(ldscore, dtype=jnp.float32)
336
+
337
+ # Process with batched JIT function (no batch_size needed - processes all spots at once)
338
+ betas, ses = self.process_chunk_jit_fn(
339
+ self.processor.config.n_blocks,
340
+ spatial_ld_jax,
341
+ self.baseline_ld_sum_jax,
342
+ self.chisq_jax,
343
+ self.N_jax,
344
+ self.baseline_ann_jax,
345
+ self.w_ld_jax,
346
+ self.Nbar
347
+ )
348
+
349
+ # Ensure computation completes
350
+ betas.block_until_ready()
351
+ ses.block_until_ready()
352
+
353
+ # Convert to numpy
354
+ betas_np = np.array(betas)
355
+ ses_np = np.array(ses)
356
+
357
+ # Track throughput and statistics
358
+ elapsed = time.time() - start_time
359
+ n_cells_in_chunk = abs_end - abs_start
360
+
361
+ with self.throughput_lock:
362
+ self.throughput.record_batch(elapsed)
363
+
364
+ with self.stats_lock:
365
+ self.total_cells_processed += n_cells_in_chunk
366
+ self.total_chunks_processed += 1
367
+
368
+ # Store result
369
+ with self.results_lock:
370
+ self.processor._add_chunk_result(
371
+ chunk_idx, betas_np, ses_np, spot_names,
372
+ abs_start, abs_end
373
+ )
374
+
375
+ # Clean up
376
+ del spatial_ld_jax, betas, ses
377
+
378
+ except queue.Empty:
379
+ continue
380
+ except Exception as e:
381
+ error_trace = traceback.format_exc()
382
+ logger.error(f"Compute worker {worker_id} error: {e}\nTraceback:\n{error_trace}")
383
+ self.exception_queue.put((worker_id, e, error_trace))
384
+ self.has_error.set()
385
+ break
386
+
387
+ logger.debug(f"Compute worker {worker_id} stopped")
388
+
389
+ def get_queue_size(self):
390
+ """Get compute queue size"""
391
+ return self.compute_queue.qsize()
392
+
393
+ def get_stats(self):
394
+ """Get processing statistics"""
395
+ with self.stats_lock:
396
+ return self.total_cells_processed, self.total_chunks_processed
397
+
398
+ def check_errors(self):
399
+ """Check if any worker encountered an error"""
400
+ if self.has_error.is_set():
401
+ try:
402
+ worker_id, exception, error_trace = self.exception_queue.get_nowait()
403
+ raise RuntimeError(f"Compute worker {worker_id} failed: {exception}\nOriginal traceback:\n{error_trace}") from exception
404
+ except queue.Empty:
405
+ raise RuntimeError("Compute worker failed with unknown error")
406
+
407
+ def close(self):
408
+ """Close compute pool"""
409
+ self.stop_workers.set()
410
+ for _ in range(self.num_workers):
411
+ self.compute_queue.put(None)
412
+ for worker in self.workers:
413
+ worker.join(timeout=5)
414
+ logger.info("Compute pool closed")
415
+
416
+
417
+ class SpatialLDSCProcessor:
418
+ """
419
+ Unified processor for spatial LDSC that combines:
420
+ - ChunkProducer: Loading spatial LD chunks
421
+ - ParallelChunkLoader: Managing parallel chunk loading with adjacent fetching
422
+ - QuickModeLDScore: Handling memory-mapped marker scores
423
+ - ResultAccumulator: Validating, merging and saving results
424
+ """
425
+
426
+ def __init__(self,
427
+ config: SpatialLDSCConfig,
428
+ output_dir: Path,
429
+ marker_score_adata: ad.AnnData,
430
+ snp_gene_weight_adata: ad.AnnData,
431
+ baseline_ld: pd.DataFrame,
432
+ w_ld: pd.DataFrame,
433
+ n_loader_threads: int = 10):
434
+ """
435
+ Initialize the unified processor.
436
+
437
+ Args:
438
+ config: Configuration object
439
+ output_dir: Output directory for results
440
+ marker_score_adata: AnnData object containing marker scores in .X
441
+ snp_gene_weight_adata: AnnData object containing SNP-gene weights
442
+ baseline_ld: Baseline LD scores common to all traits
443
+ w_ld: Weights common to all traits
444
+ n_loader_threads: Number of parallel loader threads
445
+ """
446
+ self.config = config
447
+ self.output_dir = output_dir
448
+ self.marker_score_adata = marker_score_adata
449
+ self.snp_gene_weight_adata = snp_gene_weight_adata
450
+ self.n_spots = marker_score_adata.n_obs
451
+ self.baseline_ld = baseline_ld
452
+ self.w_ld = w_ld
453
+ self.n_loader_threads = n_loader_threads
454
+
455
+ # Trait specific state
456
+ self.trait_name = None
457
+ self.data_truncated = None
458
+
459
+ self.results = []
460
+ self.processed_chunks = set()
461
+ self.min_spot_start = float('inf')
462
+ self.max_spot_end = 0
463
+
464
+ logger.info(f"Detected Marker scores shape: (n_spots={self.n_spots}, n_genes={self.marker_score_adata.n_vars})")
465
+
466
+ # Find common genes and initialize indices
467
+ gene_names_from_adata = np.array(self.marker_score_adata.var_names)
468
+ self.spot_names_all = np.array(self.marker_score_adata.obs_names)
469
+
470
+ if self.snp_gene_weight_adata is None:
471
+ raise ValueError("snp_gene_weight_adata must be provided")
472
+
473
+ marker_score_genes_series = pd.Series(gene_names_from_adata)
474
+ common_genes_mask = marker_score_genes_series.isin(self.snp_gene_weight_adata.var.index)
475
+ common_genes = gene_names_from_adata[common_genes_mask]
476
+
477
+ self.marker_score_gene_indices = np.where(common_genes_mask)[0]
478
+ self.weight_gene_indices = [self.snp_gene_weight_adata.var.index.get_loc(g) for g in common_genes]
479
+
480
+ logger.info(f"Found {len(common_genes)} common genes")
481
+
482
+ # Filter by sample if specified
483
+ if self.config.sample_filter:
484
+ logger.info(f"Filtering spots by sample: {self.config.sample_filter}")
485
+ sample_info = self.marker_score_adata.obs.get('sample', self.marker_score_adata.obs.get('sample_name', None))
486
+
487
+ if sample_info is None:
488
+ raise ValueError("No 'sample' or 'sample_name' column found in obs")
489
+
490
+ sample_info = sample_info.to_numpy()
491
+ self.spot_indices = np.where(sample_info == self.config.sample_filter)[0]
492
+
493
+ # Verify spots are contiguous for efficient slicing
494
+ expected_range = list(range(self.spot_indices[0], self.spot_indices[-1] + 1))
495
+ if self.spot_indices.tolist() != expected_range:
496
+ raise ValueError("Spot indices for sample must be contiguous")
497
+
498
+ self.sample_start_offset = self.spot_indices[0]
499
+ self.spot_names_filtered = self.spot_names_all[self.spot_indices]
500
+ logger.info(f"Found {len(self.spot_indices)} spots for sample '{self.config.sample_filter}'")
501
+ else:
502
+ self.spot_indices = np.arange(self.n_spots)
503
+ self.spot_names_filtered = self.spot_names_all
504
+ self.sample_start_offset = 0
505
+
506
+ self.n_spots_filtered = len(self.spot_indices)
507
+
508
+ # Set up chunking
509
+ self.chunk_size = self.config.spots_per_chunk_quick_mode
510
+
511
+ # Handle cell indices range if specified
512
+ if self.config.cell_indices_range:
513
+ start_cell, end_cell = self.config.cell_indices_range
514
+ # Adjust for filtered spots
515
+ start_cell = max(0, start_cell)
516
+ end_cell = min(end_cell, self.n_spots_filtered)
517
+ self.chunk_starts = list(range(start_cell, end_cell, self.chunk_size))
518
+ logger.info(f"Processing cell range [{start_cell}, {end_cell})")
519
+ self.total_cells_to_process = end_cell - start_cell
520
+ else:
521
+ self.chunk_starts = list(range(0, self.n_spots_filtered, self.chunk_size))
522
+ self.total_cells_to_process = self.n_spots_filtered
523
+
524
+ self.total_chunks = len(self.chunk_starts)
525
+ logger.info(f"Total chunks to process: {self.total_chunks}")
526
+
527
+
528
+ def setup_trait(self, trait_name: str, sumstats_file: str):
529
+ """
530
+ Setup processor for a new trait.
531
+ Loads sumstats, prepares data, and initializes trait-specific components.
532
+ """
533
+ self.trait_name = trait_name
534
+ logger.info(f"Setting up processor for trait: {trait_name}")
535
+
536
+ # Prepare trait data
537
+ data, common_snps = prepare_trait_data(
538
+ self.config,
539
+ trait_name,
540
+ sumstats_file,
541
+ self.baseline_ld,
542
+ self.w_ld,
543
+ self.snp_gene_weight_adata
544
+ )
545
+
546
+ # Prepare blocks
547
+ self.data_truncated = prepare_snp_data_for_blocks(data, self.config.n_blocks)
548
+
549
+ # Extract SNP-gene weight matrix for this trait's SNPs
550
+ logger.info(f"Initializing trait-specific components for trait: {self.trait_name}")
551
+ snp_positions = self.data_truncated.get('snp_positions', None)
552
+ if snp_positions is None:
553
+ raise ValueError("snp_positions not found in data_truncated")
554
+
555
+ # Extract SNP-gene weight matrix for this trait's SNPs
556
+ self.snp_gene_weight_sparse = self.snp_gene_weight_adata.X[snp_positions, :][:, self.weight_gene_indices]
557
+
558
+ if hasattr(self.snp_gene_weight_sparse, 'tocsr'):
559
+ self.snp_gene_weight_sparse = self.snp_gene_weight_sparse.tocsr()
560
+
561
+ logger.info(f"SNP-gene weight matrix shape: {self.snp_gene_weight_sparse.shape}")
562
+
563
+ # Reset results
564
+ self.results = []
565
+ self.processed_chunks = set()
566
+ self.min_spot_start = float('inf')
567
+ self.max_spot_end = 0
568
+
569
+
570
+
571
+
572
+ def _fetch_ldscore_chunk(self, chunk_index: int) -> tuple[np.ndarray, pd.Index, int, int]:
573
+ """
574
+ Fetch LD score chunk for given index.
575
+
576
+ Returns:
577
+ Tuple of (ldscore_array, spot_names, absolute_start, absolute_end)
578
+ """
579
+ if chunk_index >= len(self.chunk_starts):
580
+ raise ValueError(f"Invalid chunk index {chunk_index}")
581
+
582
+ start = self.chunk_starts[chunk_index]
583
+ end = min(start + self.chunk_size, self.n_spots_filtered)
584
+
585
+ # Calculate absolute positions in memmap
586
+ memmap_start = self.sample_start_offset + start
587
+ memmap_end = self.sample_start_offset + end
588
+
589
+ # Load chunk from marker_score_adata
590
+ mk_score_chunk = self.marker_score_adata.X[memmap_start:memmap_end, self.marker_score_gene_indices]
591
+ mk_score_chunk = mk_score_chunk.T.astype(np.float32)
592
+
593
+ # Compute LD scores via sparse matrix multiplication
594
+ ldscore_chunk = self.snp_gene_weight_sparse @ mk_score_chunk
595
+
596
+ if hasattr(ldscore_chunk, 'toarray'):
597
+ ldscore_chunk = ldscore_chunk.toarray()
598
+
599
+ # Get spot names
600
+ spot_names = pd.Index(self.spot_names_filtered[start:end])
601
+
602
+ # Calculate absolute positions in original data
603
+ absolute_start = self.spot_indices[start] if start < len(self.spot_indices) else start
604
+ absolute_end = self.spot_indices[end - 1] + 1 if end > 0 else absolute_start
605
+
606
+ return ldscore_chunk.astype(np.float32, copy=False), spot_names, absolute_start, absolute_end
607
+
608
+
609
+ def process_all_chunks(self, process_chunk_jit_fn) -> pd.DataFrame:
610
+ """
611
+ Process all chunks using parallel reader-computer pipeline.
612
+
613
+ Args:
614
+ process_chunk_jit_fn: JIT-compiled function for processing chunks
615
+
616
+ Returns:
617
+ Merged DataFrame with all results
618
+ """
619
+ # Create the reader-computer pipeline
620
+ reader = ParallelLDScoreReader(
621
+ processor=self,
622
+ num_workers=self.config.ldsc_read_workers,
623
+ )
624
+
625
+ computer = ParallelLDScoreComputer(
626
+ processor=self,
627
+ process_chunk_jit_fn=process_chunk_jit_fn,
628
+ num_workers=self.config.ldsc_compute_workers,
629
+ input_queue=reader.result_queue # Connect reader output to computer input
630
+ )
631
+
632
+ try:
633
+ # Submit all chunks to reader
634
+ for chunk_idx in range(self.total_chunks):
635
+ reader.submit_chunk(chunk_idx)
636
+
637
+ # Build description with sample name and range info
638
+ desc_parts = [f"Processing {self.total_chunks:,} chunks ({self.total_cells_to_process:,} cells)"]
639
+
640
+ if hasattr(self.config, 'sample_filter') and self.config.sample_filter:
641
+ desc_parts.append(f"Sample: {self.config.sample_filter}")
642
+
643
+ if self.config.cell_indices_range:
644
+ start_cell, end_cell = self.config.cell_indices_range
645
+ desc_parts.append(f"Range: [{start_cell:,}-{end_cell:,})")
646
+
647
+ description = " | ".join(desc_parts)
648
+
649
+ # # Start JAX profiling if needed
650
+ # # if hasattr(self.config, 'enable_jax_profiling') and self.config.enable_jax_profiling:
651
+ # print('starting jax profiler...')
652
+ # print('starting jax profiler...')
653
+ # print('starting jax profiler...')
654
+ # print('starting jax profiler...')
655
+ # print('starting jax profiler...')
656
+ # jax.profiler.start_trace("/tmp/jax-trace-ldsc")
657
+ # Auto-detect terminal vs redirected output
658
+ # When output is redirected, Console will disable progress bar animations
659
+ is_terminal = sys.stdout.isatty()
660
+ console = Console(soft_wrap=True)
661
+
662
+ # Log interval for non-terminal mode (every 10% or at least every 60 seconds)
663
+ log_interval_pct = 10
664
+ last_log_pct = 0
665
+ last_log_time = 0
666
+
667
+ with Progress(
668
+ SpinnerColumn(),
669
+ TextColumn("[bold blue]{task.description}"),
670
+ BarColumn(),
671
+ MofNCompleteColumn(),
672
+ TaskProgressColumn(),
673
+ TextColumn("[bold green]{task.fields[speed]} cells/s"),
674
+ TextColumn("[dim]R→C: {task.fields[r_to_c_queue]}"),
675
+ TimeElapsedColumn(),
676
+ TimeRemainingColumn(),
677
+ console=console,
678
+ refresh_per_second=2,
679
+ disable=not is_terminal # Disable rich progress bar when not in terminal
680
+ ) as progress:
681
+ task = progress.add_task(
682
+ description,
683
+ total=self.total_cells_to_process,
684
+ speed="0",
685
+ r_to_c_queue="0"
686
+ )
687
+
688
+ start_time = time.time()
689
+ last_update_time = start_time
690
+ last_chunks_processed = 0
691
+
692
+ # For 10s moving average
693
+ speed_window = deque() # Stores (timestamp, n_cells_processed)
694
+
695
+ while last_chunks_processed < self.total_chunks:
696
+ # Check for errors
697
+ reader.check_errors()
698
+ computer.check_errors()
699
+
700
+ # Get current stats from computer
701
+ n_cells_processed, n_chunks_processed = computer.get_stats()
702
+
703
+ # Update progress periodically
704
+ current_time = time.time()
705
+ if current_time - last_update_time > 0.5: # Update every 0.5 seconds
706
+ # Get queue sizes
707
+ r_pending, r_to_c = reader.get_queue_sizes()
708
+
709
+ # Calculate speed in cells/s (10s moving average)
710
+ speed_window.append((current_time, n_cells_processed))
711
+
712
+ # Remove entries older than 10s
713
+ while speed_window and current_time - speed_window[0][0] > 10.0:
714
+ speed_window.popleft()
715
+
716
+ if len(speed_window) > 1:
717
+ t_diff = speed_window[-1][0] - speed_window[0][0]
718
+ c_diff = speed_window[-1][1] - speed_window[0][1]
719
+ speed_10s = c_diff / t_diff if t_diff > 0 else 0
720
+ else:
721
+ # Fallback to cumulative speed if not enough data for window
722
+ elapsed_total = current_time - start_time
723
+ speed_10s = n_cells_processed / elapsed_total if elapsed_total > 0 else 0
724
+
725
+ # Update progress bar (only effective in terminal mode)
726
+ progress.update(
727
+ task,
728
+ completed=n_cells_processed,
729
+ speed=f"{speed_10s:,.0f}",
730
+ r_to_c_queue=f"{r_to_c}"
731
+ )
732
+
733
+ # Log progress when output is redirected
734
+ if not is_terminal:
735
+ current_pct = (n_cells_processed / self.total_cells_to_process) * 100 if self.total_cells_to_process > 0 else 0
736
+ time_since_last_log = current_time - last_log_time
737
+
738
+ # Log every 10% or every 60 seconds
739
+ if current_pct >= last_log_pct + log_interval_pct or time_since_last_log >= 60:
740
+ elapsed = current_time - start_time
741
+ logger.info(
742
+ f"Progress: {n_cells_processed:,}/{self.total_cells_to_process:,} cells "
743
+ f"({current_pct:.1f}%) | {n_chunks_processed}/{self.total_chunks} chunks | "
744
+ f"Speed: {speed_10s:,.0f} cells/s | Elapsed: {elapsed:.1f}s"
745
+ )
746
+ last_log_pct = (current_pct // log_interval_pct) * log_interval_pct
747
+ last_log_time = current_time
748
+
749
+ last_update_time = current_time
750
+
751
+ # Update last processed count
752
+ if n_chunks_processed > last_chunks_processed:
753
+ last_chunks_processed = n_chunks_processed
754
+
755
+ # This would block all threads, so we avoid it
756
+ # # Periodic memory check
757
+ # if n_chunks_processed % 100 == 0:
758
+ # gc.collect()
759
+ # Small sleep to prevent busy waiting
760
+ time.sleep(0.5)
761
+
762
+ # # if hasattr(self.config, 'enable_jax_profiling') and self.config.enable_jax_profiling:
763
+ # jax.profiler.stop_trace()
764
+ # print(f'stopped jax profiler, trace saved to /tmp/jax-trace-ldsc')
765
+ # print(f'stopped jax profiler, trace saved to /tmp/jax-trace-ldsc')
766
+ # print(f'stopped jax profiler, trace saved to /tmp/jax-trace-ldsc')
767
+ # # logger.info("JAX profiling trace saved to /tmp/jax-trace-ldsc")
768
+
769
+ finally:
770
+ # Clean up resources
771
+ reader.close()
772
+ computer.close()
773
+
774
+ # Log overall speed
775
+ total_elapsed = time.time() - start_time
776
+ n_cells_total, _ = computer.get_stats()
777
+ overall_speed = n_cells_total / total_elapsed if total_elapsed > 0 else 0
778
+ logger.info(f"Processing complete. Overall speed: {overall_speed:,.0f} cells/s")
779
+
780
+ # Note: mkscore_memmap is NOT closed here to allow reuse across multiple traits.
781
+ # It will be closed by the caller after all traits are processed.
782
+
783
+ # Validate and merge results
784
+ return self._validate_merge_and_save()
785
+
786
+ def _add_chunk_result(self, chunk_idx: int, betas: np.ndarray, ses: np.ndarray,
787
+ spot_names: pd.Index, abs_start: int, abs_end: int):
788
+ """Add processed chunk result to accumulator."""
789
+ # Update coverage tracking
790
+ self.min_spot_start = min(self.min_spot_start, abs_start)
791
+ self.max_spot_end = max(self.max_spot_end, abs_end)
792
+
793
+ # Store result
794
+ self.results.append({
795
+ 'chunk_idx': chunk_idx,
796
+ 'betas': betas,
797
+ 'ses': ses,
798
+ 'spot_names': spot_names,
799
+ 'abs_start': abs_start,
800
+ 'abs_end': abs_end
801
+ })
802
+ self.processed_chunks.add(chunk_idx)
803
+
804
+ def _validate_merge_and_save(self) -> pd.DataFrame:
805
+ """
806
+ Validate completeness, merge results, and save with appropriate filename.
807
+
808
+ Returns:
809
+ Merged DataFrame with all results
810
+ """
811
+ if not self.results:
812
+ raise ValueError("No results to merge")
813
+
814
+ # Check completeness
815
+ expected_chunks = set(range(self.total_chunks))
816
+ missing_chunks = expected_chunks - self.processed_chunks
817
+
818
+ if missing_chunks:
819
+ logger.warning(f"Missing chunks: {sorted(missing_chunks)}")
820
+ logger.warning(f"Processed {len(self.processed_chunks)}/{self.total_chunks} chunks")
821
+
822
+ # Sort results by chunk index
823
+ sorted_results = sorted(self.results, key=lambda x: x['chunk_idx'])
824
+
825
+ # Merge all results
826
+ dfs = []
827
+ for result in sorted_results:
828
+ betas = result['betas'].astype(np.float64)
829
+ ses = result['ses'].astype(np.float64)
830
+
831
+ # Calculate statistics
832
+ z_scores = betas / ses
833
+ p_values = norm.sf(z_scores)
834
+ log10_p = -np.log10(np.maximum(p_values, 1e-300))
835
+
836
+ chunk_df = pd.DataFrame({
837
+ 'spot': result['spot_names'],
838
+ 'beta': result['betas'],
839
+ 'se': result['ses'],
840
+ 'z': z_scores.astype(np.float32),
841
+ 'p': p_values,
842
+ 'neg_log10_p': log10_p
843
+ })
844
+ dfs.append(chunk_df)
845
+
846
+ merged_df = pd.concat(dfs, ignore_index=True)
847
+
848
+ # Generate filename with cell range information
849
+ filename = self._generate_output_filename()
850
+ output_path = self.output_dir / filename
851
+
852
+ # Save results
853
+ logger.info(f"Saving results to {output_path}")
854
+ merged_df.to_csv(output_path, index=False, compression='gzip')
855
+
856
+ # Log statistics
857
+ self._log_statistics(merged_df, output_path)
858
+
859
+ return merged_df
860
+
861
+ def _generate_output_filename(self) -> str:
862
+ """Generate output filename including cell range information."""
863
+ base_name = f"{self.config.project_name}_{self.trait_name}"
864
+
865
+ # If we have cell indices range, include it in filename
866
+ if self.config.cell_indices_range:
867
+ start_cell, end_cell = self.config.cell_indices_range
868
+ # Adjust for actual processed range
869
+ actual_start = max(self.min_spot_start, start_cell)
870
+ actual_end = min(self.max_spot_end, end_cell)
871
+ return f"{base_name}_cells_{actual_start}_{actual_end}.csv.gz"
872
+
873
+ # Check if we have complete coverage
874
+ if self.min_spot_start == 0 and self.max_spot_end == self.n_spots:
875
+ return f"{base_name}.csv.gz"
876
+
877
+ if self.config.sample_filter:
878
+ return f"{base_name}_{self.config.sample_filter}_start{self.min_spot_start}_end{self.max_spot_end}_total{self.n_spots}.csv.gz"
879
+
880
+ # Partial coverage without explicit range
881
+ return f"{base_name}_start{self.min_spot_start}_end{self.max_spot_end}_total{self.n_spots}.csv.gz"
882
+
883
+ def _log_statistics(self, df: pd.DataFrame, output_path: Path):
884
+ """Log statistical summary of results."""
885
+ n_spots = len(df)
886
+ bonferroni_threshold = 0.05 / n_spots
887
+ n_bonferroni_sig = (df['p'] < bonferroni_threshold).sum()
888
+
889
+ # FDR correction
890
+ reject, _, _, _ = multipletests(
891
+ df['p'], alpha=0.001, method='fdr_bh'
892
+ )
893
+ n_fdr_sig = reject.sum()
894
+
895
+ logger.info("=" * 70)
896
+ logger.info("STATISTICAL SUMMARY")
897
+ logger.info("=" * 70)
898
+ logger.info(f"Total spots: {n_spots:,}")
899
+ logger.info(f"Cell range processed: [{self.min_spot_start}, {self.max_spot_end})")
900
+ logger.info(f"Max -log10(p): {df['neg_log10_p'].max():.2f}")
901
+ logger.info("-" * 70)
902
+ logger.info(f"Nominally significant (p < 0.05): {(df['p'] < 0.05).sum():,}")
903
+ logger.info(f"Bonferroni threshold: {bonferroni_threshold:.2e}")
904
+ logger.info(f"Bonferroni significant: {n_bonferroni_sig:,}")
905
+ logger.info(f"FDR significant (alpha=0.001): {n_fdr_sig:,}")
906
+ logger.info("=" * 70)
907
+ logger.info(f"Results saved to: {output_path}")
908
+
909
+ # Warn if incomplete
910
+ if len(self.processed_chunks) < self.total_chunks:
911
+ logger.warning(f"WARNING: Only processed {len(self.processed_chunks)}/{self.total_chunks} chunks")
912
+