gsMap3D 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/__init__.py +13 -0
- gsMap/__main__.py +4 -0
- gsMap/cauchy_combination_test.py +342 -0
- gsMap/cli.py +355 -0
- gsMap/config/__init__.py +72 -0
- gsMap/config/base.py +296 -0
- gsMap/config/cauchy_config.py +79 -0
- gsMap/config/dataclasses.py +235 -0
- gsMap/config/decorators.py +302 -0
- gsMap/config/find_latent_config.py +276 -0
- gsMap/config/format_sumstats_config.py +54 -0
- gsMap/config/latent2gene_config.py +461 -0
- gsMap/config/ldscore_config.py +261 -0
- gsMap/config/quick_mode_config.py +242 -0
- gsMap/config/report_config.py +81 -0
- gsMap/config/spatial_ldsc_config.py +334 -0
- gsMap/config/utils.py +286 -0
- gsMap/find_latent/__init__.py +3 -0
- gsMap/find_latent/find_latent_representation.py +312 -0
- gsMap/find_latent/gnn/distribution.py +498 -0
- gsMap/find_latent/gnn/encoder_decoder.py +186 -0
- gsMap/find_latent/gnn/gcn.py +85 -0
- gsMap/find_latent/gnn/gene_former.py +164 -0
- gsMap/find_latent/gnn/loss.py +18 -0
- gsMap/find_latent/gnn/st_model.py +125 -0
- gsMap/find_latent/gnn/train_step.py +177 -0
- gsMap/find_latent/st_process.py +781 -0
- gsMap/format_sumstats.py +446 -0
- gsMap/generate_ldscore.py +1018 -0
- gsMap/latent2gene/__init__.py +18 -0
- gsMap/latent2gene/connectivity.py +781 -0
- gsMap/latent2gene/entry_point.py +141 -0
- gsMap/latent2gene/marker_scores.py +1265 -0
- gsMap/latent2gene/memmap_io.py +766 -0
- gsMap/latent2gene/rank_calculator.py +590 -0
- gsMap/latent2gene/row_ordering.py +182 -0
- gsMap/latent2gene/row_ordering_jax.py +159 -0
- gsMap/ldscore/__init__.py +1 -0
- gsMap/ldscore/batch_construction.py +163 -0
- gsMap/ldscore/compute.py +126 -0
- gsMap/ldscore/constants.py +70 -0
- gsMap/ldscore/io.py +262 -0
- gsMap/ldscore/mapping.py +262 -0
- gsMap/ldscore/pipeline.py +615 -0
- gsMap/pipeline/quick_mode.py +134 -0
- gsMap/report/__init__.py +2 -0
- gsMap/report/diagnosis.py +375 -0
- gsMap/report/report.py +100 -0
- gsMap/report/report_data.py +1832 -0
- gsMap/report/static/js_lib/alpine.min.js +5 -0
- gsMap/report/static/js_lib/tailwindcss.js +83 -0
- gsMap/report/static/template.html +2242 -0
- gsMap/report/three_d_combine.py +312 -0
- gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
- gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
- gsMap/report/three_d_plot/three_d_plots.py +425 -0
- gsMap/report/visualize.py +1409 -0
- gsMap/setup.py +5 -0
- gsMap/spatial_ldsc/__init__.py +0 -0
- gsMap/spatial_ldsc/io.py +656 -0
- gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
- gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
- gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +610 -0
- gsMap/utils/jackknife.py +518 -0
- gsMap/utils/manhattan_plot.py +643 -0
- gsMap/utils/regression_read.py +177 -0
- gsMap/utils/torch_utils.py +23 -0
- gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
- gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
- gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
- gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
- gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,912 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unified spatial LDSC processor combining chunk production, parallel loading, and result accumulation.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import gc
|
|
6
|
+
import logging
|
|
7
|
+
import queue
|
|
8
|
+
import sys
|
|
9
|
+
import threading
|
|
10
|
+
import time
|
|
11
|
+
import traceback
|
|
12
|
+
from collections import deque
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
|
|
16
|
+
import anndata as ad
|
|
17
|
+
import jax.numpy as jnp
|
|
18
|
+
import numpy as np
|
|
19
|
+
import pandas as pd
|
|
20
|
+
from rich.console import Console
|
|
21
|
+
from rich.progress import (
|
|
22
|
+
BarColumn,
|
|
23
|
+
MofNCompleteColumn,
|
|
24
|
+
Progress,
|
|
25
|
+
SpinnerColumn,
|
|
26
|
+
TaskProgressColumn,
|
|
27
|
+
TextColumn,
|
|
28
|
+
TimeElapsedColumn,
|
|
29
|
+
TimeRemainingColumn,
|
|
30
|
+
)
|
|
31
|
+
from scipy.stats import norm
|
|
32
|
+
from statsmodels.stats.multitest import multipletests
|
|
33
|
+
|
|
34
|
+
from gsMap.config import SpatialLDSCConfig
|
|
35
|
+
|
|
36
|
+
from .io import prepare_trait_data
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger("gsMap.spatial_ldsc_processor")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def prepare_snp_data_for_blocks(data: dict, n_blocks: int) -> dict:
|
|
42
|
+
"""Prepare SNP-related data arrays for equal-sized blocks."""
|
|
43
|
+
if 'chisq' in data:
|
|
44
|
+
n_snps = len(data['chisq'])
|
|
45
|
+
elif 'N' in data:
|
|
46
|
+
n_snps = len(data['N'])
|
|
47
|
+
else:
|
|
48
|
+
raise ValueError("Cannot determine number of SNPs from data")
|
|
49
|
+
|
|
50
|
+
block_size = n_snps // n_blocks
|
|
51
|
+
n_snps_used = block_size * n_blocks
|
|
52
|
+
n_dropped = n_snps - n_snps_used
|
|
53
|
+
|
|
54
|
+
if n_dropped > 0:
|
|
55
|
+
logger.info(f"Truncating SNP data: dropping {n_dropped} SNPs "
|
|
56
|
+
f"({n_dropped/n_snps*100:.3f}%) for {n_blocks} blocks of size {block_size}")
|
|
57
|
+
|
|
58
|
+
truncated = {}
|
|
59
|
+
snp_keys = ['baseline_ld_sum', 'w_ld', 'chisq', 'N', 'snp_positions']
|
|
60
|
+
|
|
61
|
+
for key, value in data.items():
|
|
62
|
+
if key in snp_keys and isinstance(value, np.ndarray | jnp.ndarray):
|
|
63
|
+
truncated[key] = value[:n_snps_used]
|
|
64
|
+
elif key == 'baseline_ld':
|
|
65
|
+
truncated[key] = value.iloc[:n_snps_used]
|
|
66
|
+
else:
|
|
67
|
+
truncated[key] = value
|
|
68
|
+
|
|
69
|
+
truncated['block_size'] = block_size
|
|
70
|
+
truncated['n_blocks'] = n_blocks
|
|
71
|
+
truncated['n_snps_used'] = n_snps_used
|
|
72
|
+
truncated['n_snps_original'] = n_snps
|
|
73
|
+
|
|
74
|
+
return truncated
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass
|
|
78
|
+
class ComponentThroughput:
|
|
79
|
+
"""Track throughput for individual pipeline components"""
|
|
80
|
+
total_batches: int = 0
|
|
81
|
+
total_time: float = 0.0
|
|
82
|
+
last_batch_time: float = 0.0
|
|
83
|
+
|
|
84
|
+
def record_batch(self, elapsed_time: float):
|
|
85
|
+
"""Record a batch completion"""
|
|
86
|
+
self.total_batches += 1
|
|
87
|
+
self.total_time += elapsed_time
|
|
88
|
+
self.last_batch_time = elapsed_time
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def average_time(self) -> float:
|
|
92
|
+
"""Average time per batch"""
|
|
93
|
+
if self.total_batches > 0:
|
|
94
|
+
return self.total_time / self.total_batches
|
|
95
|
+
return 0.0
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def throughput(self) -> float:
|
|
99
|
+
"""Batches per second"""
|
|
100
|
+
if self.average_time > 0:
|
|
101
|
+
return 1.0 / self.average_time
|
|
102
|
+
return 0.0
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class ParallelLDScoreReader:
|
|
106
|
+
"""Multi-threaded reader for fetching LD score chunks from memory-mapped marker scores"""
|
|
107
|
+
|
|
108
|
+
def __init__(
|
|
109
|
+
self,
|
|
110
|
+
processor, # Reference to SpatialLDSCProcessor for data access
|
|
111
|
+
num_workers: int = 4,
|
|
112
|
+
output_queue: queue.Queue = None
|
|
113
|
+
):
|
|
114
|
+
"""Initialize reader pool"""
|
|
115
|
+
self.processor = processor
|
|
116
|
+
self.num_workers = num_workers
|
|
117
|
+
|
|
118
|
+
# Queues for communication
|
|
119
|
+
self.read_queue = queue.Queue()
|
|
120
|
+
self.result_queue = output_queue if output_queue else queue.Queue(maxsize=num_workers * 8)
|
|
121
|
+
|
|
122
|
+
# Throughput tracking
|
|
123
|
+
self.throughput = ComponentThroughput()
|
|
124
|
+
self.throughput_lock = threading.Lock()
|
|
125
|
+
|
|
126
|
+
# Exception handling
|
|
127
|
+
self.exception_queue = queue.Queue()
|
|
128
|
+
self.has_error = threading.Event()
|
|
129
|
+
|
|
130
|
+
# Start worker threads
|
|
131
|
+
self.workers = []
|
|
132
|
+
self.stop_workers = threading.Event()
|
|
133
|
+
self._start_workers()
|
|
134
|
+
|
|
135
|
+
def _start_workers(self):
|
|
136
|
+
"""Start worker threads"""
|
|
137
|
+
for i in range(self.num_workers):
|
|
138
|
+
worker = threading.Thread(
|
|
139
|
+
target=self._worker,
|
|
140
|
+
args=(i,),
|
|
141
|
+
daemon=True
|
|
142
|
+
)
|
|
143
|
+
worker.start()
|
|
144
|
+
self.workers.append(worker)
|
|
145
|
+
logger.info(f"Started {self.num_workers} reader threads")
|
|
146
|
+
|
|
147
|
+
def _worker(self, worker_id: int):
|
|
148
|
+
"""Worker thread for reading LD score chunks"""
|
|
149
|
+
logger.debug(f"Reader worker {worker_id} started")
|
|
150
|
+
|
|
151
|
+
while not self.stop_workers.is_set():
|
|
152
|
+
try:
|
|
153
|
+
# Get chunk request
|
|
154
|
+
item = self.read_queue.get(timeout=1)
|
|
155
|
+
if item is None:
|
|
156
|
+
break
|
|
157
|
+
|
|
158
|
+
chunk_idx = item
|
|
159
|
+
|
|
160
|
+
# Track timing
|
|
161
|
+
start_time = time.time()
|
|
162
|
+
|
|
163
|
+
# Fetch the chunk using processor's method
|
|
164
|
+
ldscore, spot_names, abs_start, abs_end = self.processor._fetch_ldscore_chunk(chunk_idx)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
# Track throughput
|
|
168
|
+
elapsed = time.time() - start_time
|
|
169
|
+
with self.throughput_lock:
|
|
170
|
+
self.throughput.record_batch(elapsed)
|
|
171
|
+
|
|
172
|
+
# Put result for computer
|
|
173
|
+
self.result_queue.put({
|
|
174
|
+
'chunk_idx': chunk_idx,
|
|
175
|
+
'ldscore': ldscore,
|
|
176
|
+
'spot_names': spot_names,
|
|
177
|
+
'abs_start': abs_start,
|
|
178
|
+
'abs_end': abs_end,
|
|
179
|
+
'worker_id': worker_id,
|
|
180
|
+
'success': True
|
|
181
|
+
})
|
|
182
|
+
self.read_queue.task_done()
|
|
183
|
+
|
|
184
|
+
except queue.Empty:
|
|
185
|
+
continue
|
|
186
|
+
except Exception as e:
|
|
187
|
+
error_trace = traceback.format_exc()
|
|
188
|
+
logger.error(f"Reader worker {worker_id} error on chunk {chunk_idx}: {e}\nTraceback:\n{error_trace}")
|
|
189
|
+
self.exception_queue.put((worker_id, e, error_trace))
|
|
190
|
+
self.has_error.set()
|
|
191
|
+
self.result_queue.put({
|
|
192
|
+
'chunk_idx': chunk_idx if 'chunk_idx' in locals() else -1,
|
|
193
|
+
'worker_id': worker_id,
|
|
194
|
+
'success': False,
|
|
195
|
+
'error': str(e)
|
|
196
|
+
})
|
|
197
|
+
break
|
|
198
|
+
|
|
199
|
+
logger.debug(f"Reader worker {worker_id} stopped")
|
|
200
|
+
|
|
201
|
+
def submit_chunk(self, chunk_idx: int):
|
|
202
|
+
"""Submit chunk for reading"""
|
|
203
|
+
self.read_queue.put(chunk_idx)
|
|
204
|
+
|
|
205
|
+
def get_result(self):
|
|
206
|
+
"""Get next completed chunk"""
|
|
207
|
+
return self.result_queue.get()
|
|
208
|
+
|
|
209
|
+
def get_queue_sizes(self):
|
|
210
|
+
"""Get current queue sizes for monitoring"""
|
|
211
|
+
return self.read_queue.qsize(), self.result_queue.qsize()
|
|
212
|
+
|
|
213
|
+
def check_errors(self):
|
|
214
|
+
"""Check if any worker encountered an error"""
|
|
215
|
+
if self.has_error.is_set():
|
|
216
|
+
try:
|
|
217
|
+
worker_id, exception, error_trace = self.exception_queue.get_nowait()
|
|
218
|
+
raise RuntimeError(f"Reader worker {worker_id} failed: {exception}\nOriginal traceback:\n{error_trace}") from exception
|
|
219
|
+
except queue.Empty:
|
|
220
|
+
raise RuntimeError("Reader worker failed with unknown error")
|
|
221
|
+
|
|
222
|
+
def close(self):
|
|
223
|
+
"""Clean up resources"""
|
|
224
|
+
self.stop_workers.set()
|
|
225
|
+
for _ in range(self.num_workers):
|
|
226
|
+
self.read_queue.put(None)
|
|
227
|
+
for worker in self.workers:
|
|
228
|
+
worker.join(timeout=5)
|
|
229
|
+
logger.info("Reader pool closed")
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class ParallelLDScoreComputer:
|
|
233
|
+
"""Multi-threaded computer for processing LD scores with JAX"""
|
|
234
|
+
|
|
235
|
+
def __init__(
|
|
236
|
+
self,
|
|
237
|
+
processor, # Reference to SpatialLDSCProcessor
|
|
238
|
+
process_chunk_jit_fn, # JIT-compiled processing function
|
|
239
|
+
num_workers: int = 4,
|
|
240
|
+
input_queue: queue.Queue = None
|
|
241
|
+
):
|
|
242
|
+
"""Initialize computer pool"""
|
|
243
|
+
self.processor = processor
|
|
244
|
+
self.process_chunk_jit_fn = process_chunk_jit_fn
|
|
245
|
+
self.num_workers = num_workers
|
|
246
|
+
|
|
247
|
+
# Queues for communication
|
|
248
|
+
self.compute_queue = input_queue if input_queue else queue.Queue(maxsize=num_workers * 2)
|
|
249
|
+
|
|
250
|
+
# Throughput tracking
|
|
251
|
+
self.throughput = ComponentThroughput()
|
|
252
|
+
self.throughput_lock = threading.Lock()
|
|
253
|
+
|
|
254
|
+
# Processing statistics
|
|
255
|
+
self.total_cells_processed = 0
|
|
256
|
+
self.total_chunks_processed = 0
|
|
257
|
+
self.stats_lock = threading.Lock()
|
|
258
|
+
|
|
259
|
+
# Exception handling
|
|
260
|
+
self.exception_queue = queue.Queue()
|
|
261
|
+
self.has_error = threading.Event()
|
|
262
|
+
|
|
263
|
+
# Results storage
|
|
264
|
+
self.results = []
|
|
265
|
+
self.results_lock = threading.Lock()
|
|
266
|
+
|
|
267
|
+
# Prepare static JAX arrays
|
|
268
|
+
self._prepare_jax_arrays()
|
|
269
|
+
|
|
270
|
+
# Start worker threads
|
|
271
|
+
self.workers = []
|
|
272
|
+
self.stop_workers = threading.Event()
|
|
273
|
+
self._start_workers()
|
|
274
|
+
|
|
275
|
+
def _prepare_jax_arrays(self):
|
|
276
|
+
"""Prepare static JAX arrays from data_truncated"""
|
|
277
|
+
n_snps_used = self.processor.data_truncated['n_snps_used']
|
|
278
|
+
|
|
279
|
+
baseline_ann = (self.processor.data_truncated['baseline_ld'].values.astype(np.float32) *
|
|
280
|
+
self.processor.data_truncated['N'].reshape(-1, 1).astype(np.float32) /
|
|
281
|
+
self.processor.data_truncated['Nbar'])
|
|
282
|
+
baseline_ann = np.concatenate([baseline_ann,
|
|
283
|
+
np.ones((n_snps_used, 1), dtype=np.float32)], axis=1)
|
|
284
|
+
|
|
285
|
+
# Convert to JAX arrays
|
|
286
|
+
self.baseline_ld_sum_jax = jnp.asarray(self.processor.data_truncated['baseline_ld_sum'], dtype=jnp.float32)
|
|
287
|
+
self.chisq_jax = jnp.asarray(self.processor.data_truncated['chisq'], dtype=jnp.float32)
|
|
288
|
+
self.N_jax = jnp.asarray(self.processor.data_truncated['N'], dtype=jnp.float32)
|
|
289
|
+
self.baseline_ann_jax = jnp.asarray(baseline_ann, dtype=jnp.float32)
|
|
290
|
+
self.w_ld_jax = jnp.asarray(self.processor.data_truncated['w_ld'], dtype=jnp.float32)
|
|
291
|
+
self.Nbar = self.processor.data_truncated['Nbar']
|
|
292
|
+
|
|
293
|
+
del baseline_ann
|
|
294
|
+
gc.collect()
|
|
295
|
+
|
|
296
|
+
def _start_workers(self):
|
|
297
|
+
"""Start compute worker threads"""
|
|
298
|
+
for i in range(self.num_workers):
|
|
299
|
+
worker = threading.Thread(
|
|
300
|
+
target=self._compute_worker,
|
|
301
|
+
args=(i,),
|
|
302
|
+
daemon=True
|
|
303
|
+
)
|
|
304
|
+
worker.start()
|
|
305
|
+
self.workers.append(worker)
|
|
306
|
+
logger.info(f"Started {self.num_workers} compute workers")
|
|
307
|
+
|
|
308
|
+
def _compute_worker(self, worker_id: int):
|
|
309
|
+
"""Compute worker thread"""
|
|
310
|
+
logger.debug(f"Compute worker {worker_id} started")
|
|
311
|
+
|
|
312
|
+
while not self.stop_workers.is_set():
|
|
313
|
+
try:
|
|
314
|
+
# Get data from reader
|
|
315
|
+
item = self.compute_queue.get(timeout=1)
|
|
316
|
+
if item is None:
|
|
317
|
+
break
|
|
318
|
+
|
|
319
|
+
# Skip failed chunks
|
|
320
|
+
if not item.get('success', False):
|
|
321
|
+
logger.error(f"Skipping chunk {item.get('chunk_idx')} due to read error")
|
|
322
|
+
continue
|
|
323
|
+
|
|
324
|
+
# Unpack data
|
|
325
|
+
chunk_idx = item['chunk_idx']
|
|
326
|
+
ldscore = item['ldscore']
|
|
327
|
+
spot_names = item['spot_names']
|
|
328
|
+
abs_start = item['abs_start']
|
|
329
|
+
abs_end = item['abs_end']
|
|
330
|
+
|
|
331
|
+
# Track timing
|
|
332
|
+
start_time = time.time()
|
|
333
|
+
|
|
334
|
+
# Convert to JAX and process
|
|
335
|
+
spatial_ld_jax = jnp.asarray(ldscore, dtype=jnp.float32)
|
|
336
|
+
|
|
337
|
+
# Process with batched JIT function (no batch_size needed - processes all spots at once)
|
|
338
|
+
betas, ses = self.process_chunk_jit_fn(
|
|
339
|
+
self.processor.config.n_blocks,
|
|
340
|
+
spatial_ld_jax,
|
|
341
|
+
self.baseline_ld_sum_jax,
|
|
342
|
+
self.chisq_jax,
|
|
343
|
+
self.N_jax,
|
|
344
|
+
self.baseline_ann_jax,
|
|
345
|
+
self.w_ld_jax,
|
|
346
|
+
self.Nbar
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
# Ensure computation completes
|
|
350
|
+
betas.block_until_ready()
|
|
351
|
+
ses.block_until_ready()
|
|
352
|
+
|
|
353
|
+
# Convert to numpy
|
|
354
|
+
betas_np = np.array(betas)
|
|
355
|
+
ses_np = np.array(ses)
|
|
356
|
+
|
|
357
|
+
# Track throughput and statistics
|
|
358
|
+
elapsed = time.time() - start_time
|
|
359
|
+
n_cells_in_chunk = abs_end - abs_start
|
|
360
|
+
|
|
361
|
+
with self.throughput_lock:
|
|
362
|
+
self.throughput.record_batch(elapsed)
|
|
363
|
+
|
|
364
|
+
with self.stats_lock:
|
|
365
|
+
self.total_cells_processed += n_cells_in_chunk
|
|
366
|
+
self.total_chunks_processed += 1
|
|
367
|
+
|
|
368
|
+
# Store result
|
|
369
|
+
with self.results_lock:
|
|
370
|
+
self.processor._add_chunk_result(
|
|
371
|
+
chunk_idx, betas_np, ses_np, spot_names,
|
|
372
|
+
abs_start, abs_end
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
# Clean up
|
|
376
|
+
del spatial_ld_jax, betas, ses
|
|
377
|
+
|
|
378
|
+
except queue.Empty:
|
|
379
|
+
continue
|
|
380
|
+
except Exception as e:
|
|
381
|
+
error_trace = traceback.format_exc()
|
|
382
|
+
logger.error(f"Compute worker {worker_id} error: {e}\nTraceback:\n{error_trace}")
|
|
383
|
+
self.exception_queue.put((worker_id, e, error_trace))
|
|
384
|
+
self.has_error.set()
|
|
385
|
+
break
|
|
386
|
+
|
|
387
|
+
logger.debug(f"Compute worker {worker_id} stopped")
|
|
388
|
+
|
|
389
|
+
def get_queue_size(self):
|
|
390
|
+
"""Get compute queue size"""
|
|
391
|
+
return self.compute_queue.qsize()
|
|
392
|
+
|
|
393
|
+
def get_stats(self):
|
|
394
|
+
"""Get processing statistics"""
|
|
395
|
+
with self.stats_lock:
|
|
396
|
+
return self.total_cells_processed, self.total_chunks_processed
|
|
397
|
+
|
|
398
|
+
def check_errors(self):
|
|
399
|
+
"""Check if any worker encountered an error"""
|
|
400
|
+
if self.has_error.is_set():
|
|
401
|
+
try:
|
|
402
|
+
worker_id, exception, error_trace = self.exception_queue.get_nowait()
|
|
403
|
+
raise RuntimeError(f"Compute worker {worker_id} failed: {exception}\nOriginal traceback:\n{error_trace}") from exception
|
|
404
|
+
except queue.Empty:
|
|
405
|
+
raise RuntimeError("Compute worker failed with unknown error")
|
|
406
|
+
|
|
407
|
+
def close(self):
|
|
408
|
+
"""Close compute pool"""
|
|
409
|
+
self.stop_workers.set()
|
|
410
|
+
for _ in range(self.num_workers):
|
|
411
|
+
self.compute_queue.put(None)
|
|
412
|
+
for worker in self.workers:
|
|
413
|
+
worker.join(timeout=5)
|
|
414
|
+
logger.info("Compute pool closed")
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
class SpatialLDSCProcessor:
|
|
418
|
+
"""
|
|
419
|
+
Unified processor for spatial LDSC that combines:
|
|
420
|
+
- ChunkProducer: Loading spatial LD chunks
|
|
421
|
+
- ParallelChunkLoader: Managing parallel chunk loading with adjacent fetching
|
|
422
|
+
- QuickModeLDScore: Handling memory-mapped marker scores
|
|
423
|
+
- ResultAccumulator: Validating, merging and saving results
|
|
424
|
+
"""
|
|
425
|
+
|
|
426
|
+
def __init__(self,
|
|
427
|
+
config: SpatialLDSCConfig,
|
|
428
|
+
output_dir: Path,
|
|
429
|
+
marker_score_adata: ad.AnnData,
|
|
430
|
+
snp_gene_weight_adata: ad.AnnData,
|
|
431
|
+
baseline_ld: pd.DataFrame,
|
|
432
|
+
w_ld: pd.DataFrame,
|
|
433
|
+
n_loader_threads: int = 10):
|
|
434
|
+
"""
|
|
435
|
+
Initialize the unified processor.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
config: Configuration object
|
|
439
|
+
output_dir: Output directory for results
|
|
440
|
+
marker_score_adata: AnnData object containing marker scores in .X
|
|
441
|
+
snp_gene_weight_adata: AnnData object containing SNP-gene weights
|
|
442
|
+
baseline_ld: Baseline LD scores common to all traits
|
|
443
|
+
w_ld: Weights common to all traits
|
|
444
|
+
n_loader_threads: Number of parallel loader threads
|
|
445
|
+
"""
|
|
446
|
+
self.config = config
|
|
447
|
+
self.output_dir = output_dir
|
|
448
|
+
self.marker_score_adata = marker_score_adata
|
|
449
|
+
self.snp_gene_weight_adata = snp_gene_weight_adata
|
|
450
|
+
self.n_spots = marker_score_adata.n_obs
|
|
451
|
+
self.baseline_ld = baseline_ld
|
|
452
|
+
self.w_ld = w_ld
|
|
453
|
+
self.n_loader_threads = n_loader_threads
|
|
454
|
+
|
|
455
|
+
# Trait specific state
|
|
456
|
+
self.trait_name = None
|
|
457
|
+
self.data_truncated = None
|
|
458
|
+
|
|
459
|
+
self.results = []
|
|
460
|
+
self.processed_chunks = set()
|
|
461
|
+
self.min_spot_start = float('inf')
|
|
462
|
+
self.max_spot_end = 0
|
|
463
|
+
|
|
464
|
+
logger.info(f"Detected Marker scores shape: (n_spots={self.n_spots}, n_genes={self.marker_score_adata.n_vars})")
|
|
465
|
+
|
|
466
|
+
# Find common genes and initialize indices
|
|
467
|
+
gene_names_from_adata = np.array(self.marker_score_adata.var_names)
|
|
468
|
+
self.spot_names_all = np.array(self.marker_score_adata.obs_names)
|
|
469
|
+
|
|
470
|
+
if self.snp_gene_weight_adata is None:
|
|
471
|
+
raise ValueError("snp_gene_weight_adata must be provided")
|
|
472
|
+
|
|
473
|
+
marker_score_genes_series = pd.Series(gene_names_from_adata)
|
|
474
|
+
common_genes_mask = marker_score_genes_series.isin(self.snp_gene_weight_adata.var.index)
|
|
475
|
+
common_genes = gene_names_from_adata[common_genes_mask]
|
|
476
|
+
|
|
477
|
+
self.marker_score_gene_indices = np.where(common_genes_mask)[0]
|
|
478
|
+
self.weight_gene_indices = [self.snp_gene_weight_adata.var.index.get_loc(g) for g in common_genes]
|
|
479
|
+
|
|
480
|
+
logger.info(f"Found {len(common_genes)} common genes")
|
|
481
|
+
|
|
482
|
+
# Filter by sample if specified
|
|
483
|
+
if self.config.sample_filter:
|
|
484
|
+
logger.info(f"Filtering spots by sample: {self.config.sample_filter}")
|
|
485
|
+
sample_info = self.marker_score_adata.obs.get('sample', self.marker_score_adata.obs.get('sample_name', None))
|
|
486
|
+
|
|
487
|
+
if sample_info is None:
|
|
488
|
+
raise ValueError("No 'sample' or 'sample_name' column found in obs")
|
|
489
|
+
|
|
490
|
+
sample_info = sample_info.to_numpy()
|
|
491
|
+
self.spot_indices = np.where(sample_info == self.config.sample_filter)[0]
|
|
492
|
+
|
|
493
|
+
# Verify spots are contiguous for efficient slicing
|
|
494
|
+
expected_range = list(range(self.spot_indices[0], self.spot_indices[-1] + 1))
|
|
495
|
+
if self.spot_indices.tolist() != expected_range:
|
|
496
|
+
raise ValueError("Spot indices for sample must be contiguous")
|
|
497
|
+
|
|
498
|
+
self.sample_start_offset = self.spot_indices[0]
|
|
499
|
+
self.spot_names_filtered = self.spot_names_all[self.spot_indices]
|
|
500
|
+
logger.info(f"Found {len(self.spot_indices)} spots for sample '{self.config.sample_filter}'")
|
|
501
|
+
else:
|
|
502
|
+
self.spot_indices = np.arange(self.n_spots)
|
|
503
|
+
self.spot_names_filtered = self.spot_names_all
|
|
504
|
+
self.sample_start_offset = 0
|
|
505
|
+
|
|
506
|
+
self.n_spots_filtered = len(self.spot_indices)
|
|
507
|
+
|
|
508
|
+
# Set up chunking
|
|
509
|
+
self.chunk_size = self.config.spots_per_chunk_quick_mode
|
|
510
|
+
|
|
511
|
+
# Handle cell indices range if specified
|
|
512
|
+
if self.config.cell_indices_range:
|
|
513
|
+
start_cell, end_cell = self.config.cell_indices_range
|
|
514
|
+
# Adjust for filtered spots
|
|
515
|
+
start_cell = max(0, start_cell)
|
|
516
|
+
end_cell = min(end_cell, self.n_spots_filtered)
|
|
517
|
+
self.chunk_starts = list(range(start_cell, end_cell, self.chunk_size))
|
|
518
|
+
logger.info(f"Processing cell range [{start_cell}, {end_cell})")
|
|
519
|
+
self.total_cells_to_process = end_cell - start_cell
|
|
520
|
+
else:
|
|
521
|
+
self.chunk_starts = list(range(0, self.n_spots_filtered, self.chunk_size))
|
|
522
|
+
self.total_cells_to_process = self.n_spots_filtered
|
|
523
|
+
|
|
524
|
+
self.total_chunks = len(self.chunk_starts)
|
|
525
|
+
logger.info(f"Total chunks to process: {self.total_chunks}")
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def setup_trait(self, trait_name: str, sumstats_file: str):
|
|
529
|
+
"""
|
|
530
|
+
Setup processor for a new trait.
|
|
531
|
+
Loads sumstats, prepares data, and initializes trait-specific components.
|
|
532
|
+
"""
|
|
533
|
+
self.trait_name = trait_name
|
|
534
|
+
logger.info(f"Setting up processor for trait: {trait_name}")
|
|
535
|
+
|
|
536
|
+
# Prepare trait data
|
|
537
|
+
data, common_snps = prepare_trait_data(
|
|
538
|
+
self.config,
|
|
539
|
+
trait_name,
|
|
540
|
+
sumstats_file,
|
|
541
|
+
self.baseline_ld,
|
|
542
|
+
self.w_ld,
|
|
543
|
+
self.snp_gene_weight_adata
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
# Prepare blocks
|
|
547
|
+
self.data_truncated = prepare_snp_data_for_blocks(data, self.config.n_blocks)
|
|
548
|
+
|
|
549
|
+
# Extract SNP-gene weight matrix for this trait's SNPs
|
|
550
|
+
logger.info(f"Initializing trait-specific components for trait: {self.trait_name}")
|
|
551
|
+
snp_positions = self.data_truncated.get('snp_positions', None)
|
|
552
|
+
if snp_positions is None:
|
|
553
|
+
raise ValueError("snp_positions not found in data_truncated")
|
|
554
|
+
|
|
555
|
+
# Extract SNP-gene weight matrix for this trait's SNPs
|
|
556
|
+
self.snp_gene_weight_sparse = self.snp_gene_weight_adata.X[snp_positions, :][:, self.weight_gene_indices]
|
|
557
|
+
|
|
558
|
+
if hasattr(self.snp_gene_weight_sparse, 'tocsr'):
|
|
559
|
+
self.snp_gene_weight_sparse = self.snp_gene_weight_sparse.tocsr()
|
|
560
|
+
|
|
561
|
+
logger.info(f"SNP-gene weight matrix shape: {self.snp_gene_weight_sparse.shape}")
|
|
562
|
+
|
|
563
|
+
# Reset results
|
|
564
|
+
self.results = []
|
|
565
|
+
self.processed_chunks = set()
|
|
566
|
+
self.min_spot_start = float('inf')
|
|
567
|
+
self.max_spot_end = 0
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
def _fetch_ldscore_chunk(self, chunk_index: int) -> tuple[np.ndarray, pd.Index, int, int]:
|
|
573
|
+
"""
|
|
574
|
+
Fetch LD score chunk for given index.
|
|
575
|
+
|
|
576
|
+
Returns:
|
|
577
|
+
Tuple of (ldscore_array, spot_names, absolute_start, absolute_end)
|
|
578
|
+
"""
|
|
579
|
+
if chunk_index >= len(self.chunk_starts):
|
|
580
|
+
raise ValueError(f"Invalid chunk index {chunk_index}")
|
|
581
|
+
|
|
582
|
+
start = self.chunk_starts[chunk_index]
|
|
583
|
+
end = min(start + self.chunk_size, self.n_spots_filtered)
|
|
584
|
+
|
|
585
|
+
# Calculate absolute positions in memmap
|
|
586
|
+
memmap_start = self.sample_start_offset + start
|
|
587
|
+
memmap_end = self.sample_start_offset + end
|
|
588
|
+
|
|
589
|
+
# Load chunk from marker_score_adata
|
|
590
|
+
mk_score_chunk = self.marker_score_adata.X[memmap_start:memmap_end, self.marker_score_gene_indices]
|
|
591
|
+
mk_score_chunk = mk_score_chunk.T.astype(np.float32)
|
|
592
|
+
|
|
593
|
+
# Compute LD scores via sparse matrix multiplication
|
|
594
|
+
ldscore_chunk = self.snp_gene_weight_sparse @ mk_score_chunk
|
|
595
|
+
|
|
596
|
+
if hasattr(ldscore_chunk, 'toarray'):
|
|
597
|
+
ldscore_chunk = ldscore_chunk.toarray()
|
|
598
|
+
|
|
599
|
+
# Get spot names
|
|
600
|
+
spot_names = pd.Index(self.spot_names_filtered[start:end])
|
|
601
|
+
|
|
602
|
+
# Calculate absolute positions in original data
|
|
603
|
+
absolute_start = self.spot_indices[start] if start < len(self.spot_indices) else start
|
|
604
|
+
absolute_end = self.spot_indices[end - 1] + 1 if end > 0 else absolute_start
|
|
605
|
+
|
|
606
|
+
return ldscore_chunk.astype(np.float32, copy=False), spot_names, absolute_start, absolute_end
|
|
607
|
+
|
|
608
|
+
|
|
609
|
+
def process_all_chunks(self, process_chunk_jit_fn) -> pd.DataFrame:
|
|
610
|
+
"""
|
|
611
|
+
Process all chunks using parallel reader-computer pipeline.
|
|
612
|
+
|
|
613
|
+
Args:
|
|
614
|
+
process_chunk_jit_fn: JIT-compiled function for processing chunks
|
|
615
|
+
|
|
616
|
+
Returns:
|
|
617
|
+
Merged DataFrame with all results
|
|
618
|
+
"""
|
|
619
|
+
# Create the reader-computer pipeline
|
|
620
|
+
reader = ParallelLDScoreReader(
|
|
621
|
+
processor=self,
|
|
622
|
+
num_workers=self.config.ldsc_read_workers,
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
computer = ParallelLDScoreComputer(
|
|
626
|
+
processor=self,
|
|
627
|
+
process_chunk_jit_fn=process_chunk_jit_fn,
|
|
628
|
+
num_workers=self.config.ldsc_compute_workers,
|
|
629
|
+
input_queue=reader.result_queue # Connect reader output to computer input
|
|
630
|
+
)
|
|
631
|
+
|
|
632
|
+
try:
|
|
633
|
+
# Submit all chunks to reader
|
|
634
|
+
for chunk_idx in range(self.total_chunks):
|
|
635
|
+
reader.submit_chunk(chunk_idx)
|
|
636
|
+
|
|
637
|
+
# Build description with sample name and range info
|
|
638
|
+
desc_parts = [f"Processing {self.total_chunks:,} chunks ({self.total_cells_to_process:,} cells)"]
|
|
639
|
+
|
|
640
|
+
if hasattr(self.config, 'sample_filter') and self.config.sample_filter:
|
|
641
|
+
desc_parts.append(f"Sample: {self.config.sample_filter}")
|
|
642
|
+
|
|
643
|
+
if self.config.cell_indices_range:
|
|
644
|
+
start_cell, end_cell = self.config.cell_indices_range
|
|
645
|
+
desc_parts.append(f"Range: [{start_cell:,}-{end_cell:,})")
|
|
646
|
+
|
|
647
|
+
description = " | ".join(desc_parts)
|
|
648
|
+
|
|
649
|
+
# # Start JAX profiling if needed
|
|
650
|
+
# # if hasattr(self.config, 'enable_jax_profiling') and self.config.enable_jax_profiling:
|
|
651
|
+
# print('starting jax profiler...')
|
|
652
|
+
# print('starting jax profiler...')
|
|
653
|
+
# print('starting jax profiler...')
|
|
654
|
+
# print('starting jax profiler...')
|
|
655
|
+
# print('starting jax profiler...')
|
|
656
|
+
# jax.profiler.start_trace("/tmp/jax-trace-ldsc")
|
|
657
|
+
# Auto-detect terminal vs redirected output
|
|
658
|
+
# When output is redirected, Console will disable progress bar animations
|
|
659
|
+
is_terminal = sys.stdout.isatty()
|
|
660
|
+
console = Console(soft_wrap=True)
|
|
661
|
+
|
|
662
|
+
# Log interval for non-terminal mode (every 10% or at least every 60 seconds)
|
|
663
|
+
log_interval_pct = 10
|
|
664
|
+
last_log_pct = 0
|
|
665
|
+
last_log_time = 0
|
|
666
|
+
|
|
667
|
+
with Progress(
|
|
668
|
+
SpinnerColumn(),
|
|
669
|
+
TextColumn("[bold blue]{task.description}"),
|
|
670
|
+
BarColumn(),
|
|
671
|
+
MofNCompleteColumn(),
|
|
672
|
+
TaskProgressColumn(),
|
|
673
|
+
TextColumn("[bold green]{task.fields[speed]} cells/s"),
|
|
674
|
+
TextColumn("[dim]R→C: {task.fields[r_to_c_queue]}"),
|
|
675
|
+
TimeElapsedColumn(),
|
|
676
|
+
TimeRemainingColumn(),
|
|
677
|
+
console=console,
|
|
678
|
+
refresh_per_second=2,
|
|
679
|
+
disable=not is_terminal # Disable rich progress bar when not in terminal
|
|
680
|
+
) as progress:
|
|
681
|
+
task = progress.add_task(
|
|
682
|
+
description,
|
|
683
|
+
total=self.total_cells_to_process,
|
|
684
|
+
speed="0",
|
|
685
|
+
r_to_c_queue="0"
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
start_time = time.time()
|
|
689
|
+
last_update_time = start_time
|
|
690
|
+
last_chunks_processed = 0
|
|
691
|
+
|
|
692
|
+
# For 10s moving average
|
|
693
|
+
speed_window = deque() # Stores (timestamp, n_cells_processed)
|
|
694
|
+
|
|
695
|
+
while last_chunks_processed < self.total_chunks:
|
|
696
|
+
# Check for errors
|
|
697
|
+
reader.check_errors()
|
|
698
|
+
computer.check_errors()
|
|
699
|
+
|
|
700
|
+
# Get current stats from computer
|
|
701
|
+
n_cells_processed, n_chunks_processed = computer.get_stats()
|
|
702
|
+
|
|
703
|
+
# Update progress periodically
|
|
704
|
+
current_time = time.time()
|
|
705
|
+
if current_time - last_update_time > 0.5: # Update every 0.5 seconds
|
|
706
|
+
# Get queue sizes
|
|
707
|
+
r_pending, r_to_c = reader.get_queue_sizes()
|
|
708
|
+
|
|
709
|
+
# Calculate speed in cells/s (10s moving average)
|
|
710
|
+
speed_window.append((current_time, n_cells_processed))
|
|
711
|
+
|
|
712
|
+
# Remove entries older than 10s
|
|
713
|
+
while speed_window and current_time - speed_window[0][0] > 10.0:
|
|
714
|
+
speed_window.popleft()
|
|
715
|
+
|
|
716
|
+
if len(speed_window) > 1:
|
|
717
|
+
t_diff = speed_window[-1][0] - speed_window[0][0]
|
|
718
|
+
c_diff = speed_window[-1][1] - speed_window[0][1]
|
|
719
|
+
speed_10s = c_diff / t_diff if t_diff > 0 else 0
|
|
720
|
+
else:
|
|
721
|
+
# Fallback to cumulative speed if not enough data for window
|
|
722
|
+
elapsed_total = current_time - start_time
|
|
723
|
+
speed_10s = n_cells_processed / elapsed_total if elapsed_total > 0 else 0
|
|
724
|
+
|
|
725
|
+
# Update progress bar (only effective in terminal mode)
|
|
726
|
+
progress.update(
|
|
727
|
+
task,
|
|
728
|
+
completed=n_cells_processed,
|
|
729
|
+
speed=f"{speed_10s:,.0f}",
|
|
730
|
+
r_to_c_queue=f"{r_to_c}"
|
|
731
|
+
)
|
|
732
|
+
|
|
733
|
+
# Log progress when output is redirected
|
|
734
|
+
if not is_terminal:
|
|
735
|
+
current_pct = (n_cells_processed / self.total_cells_to_process) * 100 if self.total_cells_to_process > 0 else 0
|
|
736
|
+
time_since_last_log = current_time - last_log_time
|
|
737
|
+
|
|
738
|
+
# Log every 10% or every 60 seconds
|
|
739
|
+
if current_pct >= last_log_pct + log_interval_pct or time_since_last_log >= 60:
|
|
740
|
+
elapsed = current_time - start_time
|
|
741
|
+
logger.info(
|
|
742
|
+
f"Progress: {n_cells_processed:,}/{self.total_cells_to_process:,} cells "
|
|
743
|
+
f"({current_pct:.1f}%) | {n_chunks_processed}/{self.total_chunks} chunks | "
|
|
744
|
+
f"Speed: {speed_10s:,.0f} cells/s | Elapsed: {elapsed:.1f}s"
|
|
745
|
+
)
|
|
746
|
+
last_log_pct = (current_pct // log_interval_pct) * log_interval_pct
|
|
747
|
+
last_log_time = current_time
|
|
748
|
+
|
|
749
|
+
last_update_time = current_time
|
|
750
|
+
|
|
751
|
+
# Update last processed count
|
|
752
|
+
if n_chunks_processed > last_chunks_processed:
|
|
753
|
+
last_chunks_processed = n_chunks_processed
|
|
754
|
+
|
|
755
|
+
# This would block all threads, so we avoid it
|
|
756
|
+
# # Periodic memory check
|
|
757
|
+
# if n_chunks_processed % 100 == 0:
|
|
758
|
+
# gc.collect()
|
|
759
|
+
# Small sleep to prevent busy waiting
|
|
760
|
+
time.sleep(0.5)
|
|
761
|
+
|
|
762
|
+
# # if hasattr(self.config, 'enable_jax_profiling') and self.config.enable_jax_profiling:
|
|
763
|
+
# jax.profiler.stop_trace()
|
|
764
|
+
# print(f'stopped jax profiler, trace saved to /tmp/jax-trace-ldsc')
|
|
765
|
+
# print(f'stopped jax profiler, trace saved to /tmp/jax-trace-ldsc')
|
|
766
|
+
# print(f'stopped jax profiler, trace saved to /tmp/jax-trace-ldsc')
|
|
767
|
+
# # logger.info("JAX profiling trace saved to /tmp/jax-trace-ldsc")
|
|
768
|
+
|
|
769
|
+
finally:
|
|
770
|
+
# Clean up resources
|
|
771
|
+
reader.close()
|
|
772
|
+
computer.close()
|
|
773
|
+
|
|
774
|
+
# Log overall speed
|
|
775
|
+
total_elapsed = time.time() - start_time
|
|
776
|
+
n_cells_total, _ = computer.get_stats()
|
|
777
|
+
overall_speed = n_cells_total / total_elapsed if total_elapsed > 0 else 0
|
|
778
|
+
logger.info(f"Processing complete. Overall speed: {overall_speed:,.0f} cells/s")
|
|
779
|
+
|
|
780
|
+
# Note: mkscore_memmap is NOT closed here to allow reuse across multiple traits.
|
|
781
|
+
# It will be closed by the caller after all traits are processed.
|
|
782
|
+
|
|
783
|
+
# Validate and merge results
|
|
784
|
+
return self._validate_merge_and_save()
|
|
785
|
+
|
|
786
|
+
def _add_chunk_result(self, chunk_idx: int, betas: np.ndarray, ses: np.ndarray,
|
|
787
|
+
spot_names: pd.Index, abs_start: int, abs_end: int):
|
|
788
|
+
"""Add processed chunk result to accumulator."""
|
|
789
|
+
# Update coverage tracking
|
|
790
|
+
self.min_spot_start = min(self.min_spot_start, abs_start)
|
|
791
|
+
self.max_spot_end = max(self.max_spot_end, abs_end)
|
|
792
|
+
|
|
793
|
+
# Store result
|
|
794
|
+
self.results.append({
|
|
795
|
+
'chunk_idx': chunk_idx,
|
|
796
|
+
'betas': betas,
|
|
797
|
+
'ses': ses,
|
|
798
|
+
'spot_names': spot_names,
|
|
799
|
+
'abs_start': abs_start,
|
|
800
|
+
'abs_end': abs_end
|
|
801
|
+
})
|
|
802
|
+
self.processed_chunks.add(chunk_idx)
|
|
803
|
+
|
|
804
|
+
def _validate_merge_and_save(self) -> pd.DataFrame:
|
|
805
|
+
"""
|
|
806
|
+
Validate completeness, merge results, and save with appropriate filename.
|
|
807
|
+
|
|
808
|
+
Returns:
|
|
809
|
+
Merged DataFrame with all results
|
|
810
|
+
"""
|
|
811
|
+
if not self.results:
|
|
812
|
+
raise ValueError("No results to merge")
|
|
813
|
+
|
|
814
|
+
# Check completeness
|
|
815
|
+
expected_chunks = set(range(self.total_chunks))
|
|
816
|
+
missing_chunks = expected_chunks - self.processed_chunks
|
|
817
|
+
|
|
818
|
+
if missing_chunks:
|
|
819
|
+
logger.warning(f"Missing chunks: {sorted(missing_chunks)}")
|
|
820
|
+
logger.warning(f"Processed {len(self.processed_chunks)}/{self.total_chunks} chunks")
|
|
821
|
+
|
|
822
|
+
# Sort results by chunk index
|
|
823
|
+
sorted_results = sorted(self.results, key=lambda x: x['chunk_idx'])
|
|
824
|
+
|
|
825
|
+
# Merge all results
|
|
826
|
+
dfs = []
|
|
827
|
+
for result in sorted_results:
|
|
828
|
+
betas = result['betas'].astype(np.float64)
|
|
829
|
+
ses = result['ses'].astype(np.float64)
|
|
830
|
+
|
|
831
|
+
# Calculate statistics
|
|
832
|
+
z_scores = betas / ses
|
|
833
|
+
p_values = norm.sf(z_scores)
|
|
834
|
+
log10_p = -np.log10(np.maximum(p_values, 1e-300))
|
|
835
|
+
|
|
836
|
+
chunk_df = pd.DataFrame({
|
|
837
|
+
'spot': result['spot_names'],
|
|
838
|
+
'beta': result['betas'],
|
|
839
|
+
'se': result['ses'],
|
|
840
|
+
'z': z_scores.astype(np.float32),
|
|
841
|
+
'p': p_values,
|
|
842
|
+
'neg_log10_p': log10_p
|
|
843
|
+
})
|
|
844
|
+
dfs.append(chunk_df)
|
|
845
|
+
|
|
846
|
+
merged_df = pd.concat(dfs, ignore_index=True)
|
|
847
|
+
|
|
848
|
+
# Generate filename with cell range information
|
|
849
|
+
filename = self._generate_output_filename()
|
|
850
|
+
output_path = self.output_dir / filename
|
|
851
|
+
|
|
852
|
+
# Save results
|
|
853
|
+
logger.info(f"Saving results to {output_path}")
|
|
854
|
+
merged_df.to_csv(output_path, index=False, compression='gzip')
|
|
855
|
+
|
|
856
|
+
# Log statistics
|
|
857
|
+
self._log_statistics(merged_df, output_path)
|
|
858
|
+
|
|
859
|
+
return merged_df
|
|
860
|
+
|
|
861
|
+
def _generate_output_filename(self) -> str:
|
|
862
|
+
"""Generate output filename including cell range information."""
|
|
863
|
+
base_name = f"{self.config.project_name}_{self.trait_name}"
|
|
864
|
+
|
|
865
|
+
# If we have cell indices range, include it in filename
|
|
866
|
+
if self.config.cell_indices_range:
|
|
867
|
+
start_cell, end_cell = self.config.cell_indices_range
|
|
868
|
+
# Adjust for actual processed range
|
|
869
|
+
actual_start = max(self.min_spot_start, start_cell)
|
|
870
|
+
actual_end = min(self.max_spot_end, end_cell)
|
|
871
|
+
return f"{base_name}_cells_{actual_start}_{actual_end}.csv.gz"
|
|
872
|
+
|
|
873
|
+
# Check if we have complete coverage
|
|
874
|
+
if self.min_spot_start == 0 and self.max_spot_end == self.n_spots:
|
|
875
|
+
return f"{base_name}.csv.gz"
|
|
876
|
+
|
|
877
|
+
if self.config.sample_filter:
|
|
878
|
+
return f"{base_name}_{self.config.sample_filter}_start{self.min_spot_start}_end{self.max_spot_end}_total{self.n_spots}.csv.gz"
|
|
879
|
+
|
|
880
|
+
# Partial coverage without explicit range
|
|
881
|
+
return f"{base_name}_start{self.min_spot_start}_end{self.max_spot_end}_total{self.n_spots}.csv.gz"
|
|
882
|
+
|
|
883
|
+
def _log_statistics(self, df: pd.DataFrame, output_path: Path):
|
|
884
|
+
"""Log statistical summary of results."""
|
|
885
|
+
n_spots = len(df)
|
|
886
|
+
bonferroni_threshold = 0.05 / n_spots
|
|
887
|
+
n_bonferroni_sig = (df['p'] < bonferroni_threshold).sum()
|
|
888
|
+
|
|
889
|
+
# FDR correction
|
|
890
|
+
reject, _, _, _ = multipletests(
|
|
891
|
+
df['p'], alpha=0.001, method='fdr_bh'
|
|
892
|
+
)
|
|
893
|
+
n_fdr_sig = reject.sum()
|
|
894
|
+
|
|
895
|
+
logger.info("=" * 70)
|
|
896
|
+
logger.info("STATISTICAL SUMMARY")
|
|
897
|
+
logger.info("=" * 70)
|
|
898
|
+
logger.info(f"Total spots: {n_spots:,}")
|
|
899
|
+
logger.info(f"Cell range processed: [{self.min_spot_start}, {self.max_spot_end})")
|
|
900
|
+
logger.info(f"Max -log10(p): {df['neg_log10_p'].max():.2f}")
|
|
901
|
+
logger.info("-" * 70)
|
|
902
|
+
logger.info(f"Nominally significant (p < 0.05): {(df['p'] < 0.05).sum():,}")
|
|
903
|
+
logger.info(f"Bonferroni threshold: {bonferroni_threshold:.2e}")
|
|
904
|
+
logger.info(f"Bonferroni significant: {n_bonferroni_sig:,}")
|
|
905
|
+
logger.info(f"FDR significant (alpha=0.001): {n_fdr_sig:,}")
|
|
906
|
+
logger.info("=" * 70)
|
|
907
|
+
logger.info(f"Results saved to: {output_path}")
|
|
908
|
+
|
|
909
|
+
# Warn if incomplete
|
|
910
|
+
if len(self.processed_chunks) < self.total_chunks:
|
|
911
|
+
logger.warning(f"WARNING: Only processed {len(self.processed_chunks)}/{self.total_chunks} chunks")
|
|
912
|
+
|