gsMap3D 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/__init__.py +13 -0
- gsMap/__main__.py +4 -0
- gsMap/cauchy_combination_test.py +342 -0
- gsMap/cli.py +355 -0
- gsMap/config/__init__.py +72 -0
- gsMap/config/base.py +296 -0
- gsMap/config/cauchy_config.py +79 -0
- gsMap/config/dataclasses.py +235 -0
- gsMap/config/decorators.py +302 -0
- gsMap/config/find_latent_config.py +276 -0
- gsMap/config/format_sumstats_config.py +54 -0
- gsMap/config/latent2gene_config.py +461 -0
- gsMap/config/ldscore_config.py +261 -0
- gsMap/config/quick_mode_config.py +242 -0
- gsMap/config/report_config.py +81 -0
- gsMap/config/spatial_ldsc_config.py +334 -0
- gsMap/config/utils.py +286 -0
- gsMap/find_latent/__init__.py +3 -0
- gsMap/find_latent/find_latent_representation.py +312 -0
- gsMap/find_latent/gnn/distribution.py +498 -0
- gsMap/find_latent/gnn/encoder_decoder.py +186 -0
- gsMap/find_latent/gnn/gcn.py +85 -0
- gsMap/find_latent/gnn/gene_former.py +164 -0
- gsMap/find_latent/gnn/loss.py +18 -0
- gsMap/find_latent/gnn/st_model.py +125 -0
- gsMap/find_latent/gnn/train_step.py +177 -0
- gsMap/find_latent/st_process.py +781 -0
- gsMap/format_sumstats.py +446 -0
- gsMap/generate_ldscore.py +1018 -0
- gsMap/latent2gene/__init__.py +18 -0
- gsMap/latent2gene/connectivity.py +781 -0
- gsMap/latent2gene/entry_point.py +141 -0
- gsMap/latent2gene/marker_scores.py +1265 -0
- gsMap/latent2gene/memmap_io.py +766 -0
- gsMap/latent2gene/rank_calculator.py +590 -0
- gsMap/latent2gene/row_ordering.py +182 -0
- gsMap/latent2gene/row_ordering_jax.py +159 -0
- gsMap/ldscore/__init__.py +1 -0
- gsMap/ldscore/batch_construction.py +163 -0
- gsMap/ldscore/compute.py +126 -0
- gsMap/ldscore/constants.py +70 -0
- gsMap/ldscore/io.py +262 -0
- gsMap/ldscore/mapping.py +262 -0
- gsMap/ldscore/pipeline.py +615 -0
- gsMap/pipeline/quick_mode.py +134 -0
- gsMap/report/__init__.py +2 -0
- gsMap/report/diagnosis.py +375 -0
- gsMap/report/report.py +100 -0
- gsMap/report/report_data.py +1832 -0
- gsMap/report/static/js_lib/alpine.min.js +5 -0
- gsMap/report/static/js_lib/tailwindcss.js +83 -0
- gsMap/report/static/template.html +2242 -0
- gsMap/report/three_d_combine.py +312 -0
- gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
- gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
- gsMap/report/three_d_plot/three_d_plots.py +425 -0
- gsMap/report/visualize.py +1409 -0
- gsMap/setup.py +5 -0
- gsMap/spatial_ldsc/__init__.py +0 -0
- gsMap/spatial_ldsc/io.py +656 -0
- gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
- gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
- gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +610 -0
- gsMap/utils/jackknife.py +518 -0
- gsMap/utils/manhattan_plot.py +643 -0
- gsMap/utils/regression_read.py +177 -0
- gsMap/utils/torch_utils.py +23 -0
- gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
- gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
- gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
- gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
- gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1265 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Marker score calculation using homogeneous neighbors
|
|
3
|
+
Implements weighted geometric mean calculation in log space with JAX acceleration
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import gc
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import queue
|
|
10
|
+
import threading
|
|
11
|
+
import time
|
|
12
|
+
import traceback
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from functools import partial
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
|
|
17
|
+
import anndata as ad
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import numpy as np
|
|
20
|
+
import pandas as pd
|
|
21
|
+
import scanpy as sc
|
|
22
|
+
from jax import jit
|
|
23
|
+
|
|
24
|
+
# Progress bar imports
|
|
25
|
+
from rich.console import Console, Group
|
|
26
|
+
from rich.panel import Panel
|
|
27
|
+
from rich.progress import (
|
|
28
|
+
BarColumn,
|
|
29
|
+
MofNCompleteColumn,
|
|
30
|
+
Progress,
|
|
31
|
+
SpinnerColumn,
|
|
32
|
+
TaskProgressColumn,
|
|
33
|
+
TextColumn,
|
|
34
|
+
TimeElapsedColumn,
|
|
35
|
+
TimeRemainingColumn,
|
|
36
|
+
)
|
|
37
|
+
from rich.table import Table
|
|
38
|
+
|
|
39
|
+
from gsMap.config import DatasetType, MarkerScoreCrossSliceStrategy
|
|
40
|
+
|
|
41
|
+
from .connectivity import ConnectivityMatrixBuilder
|
|
42
|
+
from .memmap_io import MemMapDense
|
|
43
|
+
from .row_ordering_jax import optimize_row_order_jax
|
|
44
|
+
|
|
45
|
+
logger = logging.getLogger(__name__)
|
|
46
|
+
|
|
47
|
+
from .memmap_io import ComponentThroughput, ParallelMarkerScoreWriter, ParallelRankReader
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class ParallelMarkerScoreComputer:
|
|
51
|
+
"""Multi-threaded computer pool for marker score calculation (reusable across cell types)"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
global_log_gmean: np.ndarray,
|
|
56
|
+
global_expr_frac: np.ndarray,
|
|
57
|
+
homogeneous_neighbors: int,
|
|
58
|
+
num_workers: int = 4,
|
|
59
|
+
input_queue: queue.Queue = None,
|
|
60
|
+
output_queue: queue.Queue = None,
|
|
61
|
+
cross_slice_strategy: str = None,
|
|
62
|
+
n_slices: int = 1,
|
|
63
|
+
num_homogeneous_per_slice: int = None,
|
|
64
|
+
no_expression_fraction: bool = False
|
|
65
|
+
):
|
|
66
|
+
"""
|
|
67
|
+
Initialize computer pool
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
global_log_gmean: Global log geometric mean
|
|
71
|
+
global_expr_frac: Global expression fraction
|
|
72
|
+
homogeneous_neighbors: Number of homogeneous neighbors
|
|
73
|
+
num_workers: Number of compute workers
|
|
74
|
+
input_queue: Optional input queue (from reader)
|
|
75
|
+
output_queue: Optional output queue (to writer)
|
|
76
|
+
cross_slice_strategy: Strategy for 3D ('per_slice_pool' or 'hierarchical_pool')
|
|
77
|
+
n_slices: Number of slices for 3D data
|
|
78
|
+
num_homogeneous_per_slice: Neighbors per slice for 3D strategies
|
|
79
|
+
no_expression_fraction: Skip expression fraction filtering if True
|
|
80
|
+
"""
|
|
81
|
+
self.num_workers = num_workers
|
|
82
|
+
self.homogeneous_neighbors = homogeneous_neighbors
|
|
83
|
+
self.cross_slice_strategy = cross_slice_strategy
|
|
84
|
+
self.n_slices = n_slices
|
|
85
|
+
self.num_homogeneous_per_slice = num_homogeneous_per_slice or homogeneous_neighbors
|
|
86
|
+
self.no_expression_fraction = no_expression_fraction
|
|
87
|
+
|
|
88
|
+
# Store global statistics as JAX arrays
|
|
89
|
+
self.global_log_gmean = jnp.array(global_log_gmean)
|
|
90
|
+
self.global_expr_frac = jnp.array(global_expr_frac)
|
|
91
|
+
|
|
92
|
+
# Queues for communication
|
|
93
|
+
self.compute_queue = input_queue if input_queue else queue.Queue(maxsize=num_workers * 2)
|
|
94
|
+
self.result_queue = output_queue if output_queue else queue.Queue(maxsize=num_workers * 2)
|
|
95
|
+
|
|
96
|
+
# Throughput tracking
|
|
97
|
+
self.throughput = ComponentThroughput()
|
|
98
|
+
self.throughput_lock = threading.Lock()
|
|
99
|
+
|
|
100
|
+
# Current processing context
|
|
101
|
+
self.neighbor_weights = None
|
|
102
|
+
self.cell_indices_sorted = None
|
|
103
|
+
self.batch_size = None
|
|
104
|
+
self.n_cells = None
|
|
105
|
+
|
|
106
|
+
# Exception handling
|
|
107
|
+
self.exception_queue = queue.Queue()
|
|
108
|
+
self.has_error = threading.Event()
|
|
109
|
+
|
|
110
|
+
# Start worker threads
|
|
111
|
+
self.workers = []
|
|
112
|
+
self.stop_workers = threading.Event()
|
|
113
|
+
self.active_cell_type = None # Track current cell type being processed
|
|
114
|
+
self._start_workers()
|
|
115
|
+
|
|
116
|
+
def _start_workers(self):
|
|
117
|
+
"""Start compute worker threads"""
|
|
118
|
+
for i in range(self.num_workers):
|
|
119
|
+
worker = threading.Thread(
|
|
120
|
+
target=self._compute_worker,
|
|
121
|
+
args=(i,),
|
|
122
|
+
daemon=True
|
|
123
|
+
)
|
|
124
|
+
worker.start()
|
|
125
|
+
self.workers.append(worker)
|
|
126
|
+
logger.info(f"Started {self.num_workers} compute workers")
|
|
127
|
+
|
|
128
|
+
def _compute_worker(self, worker_id: int):
|
|
129
|
+
"""Compute worker thread"""
|
|
130
|
+
logger.info(f"Compute worker {worker_id} started")
|
|
131
|
+
|
|
132
|
+
while not self.stop_workers.is_set():
|
|
133
|
+
try:
|
|
134
|
+
# Get data from reader
|
|
135
|
+
item = self.compute_queue.get(timeout=1)
|
|
136
|
+
if item is None:
|
|
137
|
+
break
|
|
138
|
+
|
|
139
|
+
# Unpack data from reader
|
|
140
|
+
batch_idx, rank_data, rank_indices, original_shape, batch_metadata = item
|
|
141
|
+
|
|
142
|
+
# Track timing
|
|
143
|
+
start_time = time.time()
|
|
144
|
+
|
|
145
|
+
# Extract batch parameters from metadata
|
|
146
|
+
# These should always be provided, but check for safety
|
|
147
|
+
if self.batch_size is None or self.neighbor_weights is None:
|
|
148
|
+
logger.error(f"Compute worker {worker_id}: batch context not set")
|
|
149
|
+
raise RuntimeError("Batch context must be set before processing")
|
|
150
|
+
|
|
151
|
+
batch_start = batch_metadata['batch_start']
|
|
152
|
+
batch_end = batch_metadata['batch_end']
|
|
153
|
+
actual_batch_size = batch_end - batch_start
|
|
154
|
+
|
|
155
|
+
# Verify shape
|
|
156
|
+
assert original_shape == (actual_batch_size, self.homogeneous_neighbors * self.n_slices), \
|
|
157
|
+
f"Unexpected rank data shape: {original_shape}, expected ({actual_batch_size}, {self.homogeneous_neighbors * self.n_slices})"
|
|
158
|
+
|
|
159
|
+
# Get batch-specific data
|
|
160
|
+
batch_weights = self.neighbor_weights[batch_start:batch_end]
|
|
161
|
+
batch_cell_indices = self.cell_indices_sorted[batch_start:batch_end]
|
|
162
|
+
|
|
163
|
+
# Convert to JAX for efficient computation
|
|
164
|
+
rank_data_jax = jnp.array(rank_data)
|
|
165
|
+
rank_indices_jax = jnp.array(rank_indices)
|
|
166
|
+
|
|
167
|
+
# Use JAX fancy indexing
|
|
168
|
+
batch_ranks = rank_data_jax[rank_indices_jax]
|
|
169
|
+
|
|
170
|
+
# Compute marker scores using appropriate strategy
|
|
171
|
+
if self.cross_slice_strategy == 'hierarchical_pool':
|
|
172
|
+
# Use hierarchical pooling (per-slice marker score average) for 3D data
|
|
173
|
+
marker_scores = compute_marker_scores_3d_hierarchical_pool_jax(
|
|
174
|
+
batch_ranks,
|
|
175
|
+
batch_weights,
|
|
176
|
+
actual_batch_size,
|
|
177
|
+
self.n_slices,
|
|
178
|
+
self.num_homogeneous_per_slice,
|
|
179
|
+
self.global_log_gmean,
|
|
180
|
+
self.global_expr_frac,
|
|
181
|
+
self.no_expression_fraction
|
|
182
|
+
)
|
|
183
|
+
else:
|
|
184
|
+
# Use standard computation (includes mean pooling via weights)
|
|
185
|
+
marker_scores = compute_marker_scores_jax(
|
|
186
|
+
batch_ranks,
|
|
187
|
+
batch_weights,
|
|
188
|
+
actual_batch_size,
|
|
189
|
+
self.homogeneous_neighbors * self.n_slices,
|
|
190
|
+
self.global_log_gmean,
|
|
191
|
+
self.global_expr_frac,
|
|
192
|
+
self.no_expression_fraction
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Convert back to numpy as float16 for memory efficiency
|
|
196
|
+
marker_scores_np = np.array(marker_scores, dtype=np.float16)
|
|
197
|
+
|
|
198
|
+
# Track throughput
|
|
199
|
+
elapsed = time.time() - start_time
|
|
200
|
+
with self.throughput_lock:
|
|
201
|
+
self.throughput.record_batch(elapsed)
|
|
202
|
+
|
|
203
|
+
# Put result directly to writer queue
|
|
204
|
+
self.result_queue.put((batch_idx, marker_scores_np, batch_cell_indices))
|
|
205
|
+
self.compute_queue.task_done()
|
|
206
|
+
|
|
207
|
+
except queue.Empty:
|
|
208
|
+
continue
|
|
209
|
+
except Exception as e:
|
|
210
|
+
error_trace = traceback.format_exc()
|
|
211
|
+
logger.error(f"Compute worker {worker_id} error: {e}\nTraceback:\n{error_trace}")
|
|
212
|
+
self.exception_queue.put((worker_id, e, error_trace))
|
|
213
|
+
self.has_error.set()
|
|
214
|
+
self.stop_workers.set() # Signal all workers to stop
|
|
215
|
+
break
|
|
216
|
+
|
|
217
|
+
logger.info(f"Compute worker {worker_id} stopped")
|
|
218
|
+
|
|
219
|
+
def set_batch_context(self, neighbor_weights: np.ndarray, cell_indices_sorted: np.ndarray,
|
|
220
|
+
batch_size: int, n_cells: int):
|
|
221
|
+
"""Set context for processing batches of current cell type"""
|
|
222
|
+
self.neighbor_weights = jnp.asarray(neighbor_weights) # transfer to jax array only once
|
|
223
|
+
self.cell_indices_sorted = cell_indices_sorted
|
|
224
|
+
self.batch_size = batch_size
|
|
225
|
+
self.n_cells = n_cells
|
|
226
|
+
|
|
227
|
+
def reset_for_cell_type(self, cell_type: str):
|
|
228
|
+
"""Reset for processing a new cell type"""
|
|
229
|
+
self.active_cell_type = cell_type
|
|
230
|
+
self.neighbor_weights = None
|
|
231
|
+
self.cell_indices_sorted = None
|
|
232
|
+
self.batch_size = None
|
|
233
|
+
self.n_cells = None
|
|
234
|
+
with self.throughput_lock:
|
|
235
|
+
self.throughput = ComponentThroughput()
|
|
236
|
+
logger.debug(f"Reset computer throughput for {cell_type}")
|
|
237
|
+
|
|
238
|
+
def get_queue_sizes(self):
|
|
239
|
+
"""Get current queue sizes for progress tracking"""
|
|
240
|
+
return self.compute_queue.qsize(), self.result_queue.qsize()
|
|
241
|
+
|
|
242
|
+
def check_errors(self):
|
|
243
|
+
"""Check if any worker encountered an error"""
|
|
244
|
+
if self.has_error.is_set():
|
|
245
|
+
try:
|
|
246
|
+
worker_id, exception, error_trace = self.exception_queue.get_nowait()
|
|
247
|
+
raise RuntimeError(f"Compute worker {worker_id} failed: {exception}\nOriginal traceback:\n{error_trace}") from exception
|
|
248
|
+
except queue.Empty:
|
|
249
|
+
raise RuntimeError("Compute worker failed with unknown error")
|
|
250
|
+
|
|
251
|
+
def close(self):
|
|
252
|
+
"""Close compute pool"""
|
|
253
|
+
self.stop_workers.set()
|
|
254
|
+
for _ in range(self.num_workers):
|
|
255
|
+
self.compute_queue.put(None)
|
|
256
|
+
for worker in self.workers:
|
|
257
|
+
worker.join(timeout=5)
|
|
258
|
+
logger.info("Compute pool closed")
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
@dataclass
|
|
265
|
+
class PipelineStats:
|
|
266
|
+
"""Statistics for pipeline monitoring"""
|
|
267
|
+
total_batches: int
|
|
268
|
+
completed_reads: int = 0
|
|
269
|
+
completed_computes: int = 0
|
|
270
|
+
completed_writes: int = 0
|
|
271
|
+
pending_compute: int = 0
|
|
272
|
+
pending_write: int = 0
|
|
273
|
+
pending_read: int = 0
|
|
274
|
+
start_time: float = 0
|
|
275
|
+
|
|
276
|
+
# Component throughput tracking
|
|
277
|
+
reader_throughput: ComponentThroughput = None
|
|
278
|
+
computer_throughput: ComponentThroughput = None
|
|
279
|
+
writer_throughput: ComponentThroughput = None
|
|
280
|
+
|
|
281
|
+
def __post_init__(self):
|
|
282
|
+
self.start_time = time.time()
|
|
283
|
+
self.reader_throughput = ComponentThroughput()
|
|
284
|
+
self.computer_throughput = ComponentThroughput()
|
|
285
|
+
self.writer_throughput = ComponentThroughput()
|
|
286
|
+
|
|
287
|
+
@property
|
|
288
|
+
def elapsed_time(self):
|
|
289
|
+
return time.time() - self.start_time
|
|
290
|
+
|
|
291
|
+
@property
|
|
292
|
+
def throughput(self):
|
|
293
|
+
if self.elapsed_time > 0:
|
|
294
|
+
return self.completed_writes / self.elapsed_time
|
|
295
|
+
return 0
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class MarkerScoreMessageQueue:
|
|
299
|
+
"""Streamlined pipeline for marker score calculation (reusable across cell types)"""
|
|
300
|
+
|
|
301
|
+
def __init__(
|
|
302
|
+
self,
|
|
303
|
+
reader: ParallelRankReader,
|
|
304
|
+
computer: ParallelMarkerScoreComputer,
|
|
305
|
+
writer: ParallelMarkerScoreWriter,
|
|
306
|
+
batch_size: int,
|
|
307
|
+
):
|
|
308
|
+
"""
|
|
309
|
+
Initialize pipeline with shared pools
|
|
310
|
+
|
|
311
|
+
Args:
|
|
312
|
+
reader: Rank reader pool
|
|
313
|
+
computer: Marker score computer pool
|
|
314
|
+
writer: Marker score writer pool
|
|
315
|
+
batch_size: Size of each batch
|
|
316
|
+
"""
|
|
317
|
+
self.reader = reader
|
|
318
|
+
self.computer = computer
|
|
319
|
+
self.writer = writer
|
|
320
|
+
self.batch_size = batch_size
|
|
321
|
+
|
|
322
|
+
# Cell type specific parameters (set via reset_for_cell_type)
|
|
323
|
+
self.n_batches = None
|
|
324
|
+
self.n_cells = None
|
|
325
|
+
self.active_cell_type = None
|
|
326
|
+
self.stats = None
|
|
327
|
+
|
|
328
|
+
def reset_for_cell_type(self, cell_type: str, n_cells: int):
|
|
329
|
+
"""Reset the queue for processing a new cell type
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
cell_type: Name of the cell type
|
|
333
|
+
n_cells: Total number of cells for this cell type
|
|
334
|
+
"""
|
|
335
|
+
# Reset all components for new cell type
|
|
336
|
+
self.reader.reset_for_cell_type(cell_type)
|
|
337
|
+
self.computer.reset_for_cell_type(cell_type)
|
|
338
|
+
self.writer.reset_for_cell_type(cell_type)
|
|
339
|
+
|
|
340
|
+
self.active_cell_type = cell_type
|
|
341
|
+
self.n_cells = n_cells
|
|
342
|
+
self.n_batches = (n_cells + self.batch_size - 1) // self.batch_size
|
|
343
|
+
self.stats = PipelineStats(total_batches=self.n_batches)
|
|
344
|
+
logger.debug(f"Reset MarkerScoreMessageQueue for {cell_type}: {self.n_batches} batches, {n_cells} cells")
|
|
345
|
+
|
|
346
|
+
def _submit_batches(
|
|
347
|
+
self,
|
|
348
|
+
neighbor_indices: np.ndarray,
|
|
349
|
+
neighbor_weights: np.ndarray,
|
|
350
|
+
cell_indices_sorted: np.ndarray
|
|
351
|
+
):
|
|
352
|
+
"""Submit all batches to the reader"""
|
|
353
|
+
# Set context for computer
|
|
354
|
+
self.computer.set_batch_context(
|
|
355
|
+
neighbor_weights=neighbor_weights,
|
|
356
|
+
cell_indices_sorted=cell_indices_sorted,
|
|
357
|
+
batch_size=self.batch_size,
|
|
358
|
+
n_cells=self.n_cells
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
# Submit all batches with metadata
|
|
362
|
+
for batch_idx in range(self.n_batches):
|
|
363
|
+
batch_start = batch_idx * self.batch_size
|
|
364
|
+
batch_end = min(batch_start + self.batch_size, self.n_cells)
|
|
365
|
+
batch_neighbors = neighbor_indices[batch_start:batch_end]
|
|
366
|
+
|
|
367
|
+
# Include metadata for computer
|
|
368
|
+
batch_metadata = {
|
|
369
|
+
'batch_start': batch_start,
|
|
370
|
+
'batch_end': batch_end,
|
|
371
|
+
'batch_idx': batch_idx
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
self.reader.submit_batch(batch_idx, batch_neighbors, batch_metadata)
|
|
375
|
+
|
|
376
|
+
def _update_stats(self):
|
|
377
|
+
"""Update pipeline statistics"""
|
|
378
|
+
# Update completed counts
|
|
379
|
+
self.stats.completed_writes = self.writer.get_completed_count()
|
|
380
|
+
|
|
381
|
+
# Update queue sizes
|
|
382
|
+
read_pending, read_ready = self.reader.get_queue_sizes()
|
|
383
|
+
self.stats.pending_read = read_pending
|
|
384
|
+
|
|
385
|
+
compute_pending, compute_ready = self.computer.get_queue_sizes()
|
|
386
|
+
self.stats.pending_compute = compute_pending
|
|
387
|
+
|
|
388
|
+
self.stats.pending_write = self.writer.get_queue_size()
|
|
389
|
+
|
|
390
|
+
# Update throughput stats
|
|
391
|
+
self.stats.reader_throughput = self.reader.throughput
|
|
392
|
+
self.stats.computer_throughput = self.computer.throughput
|
|
393
|
+
self.stats.writer_throughput = self.writer.throughput
|
|
394
|
+
|
|
395
|
+
# Estimate completed reads and computes based on queue states
|
|
396
|
+
self.stats.completed_reads = self.n_batches - read_pending
|
|
397
|
+
self.stats.completed_computes = self.stats.completed_writes + self.stats.pending_write
|
|
398
|
+
|
|
399
|
+
def start(
|
|
400
|
+
self,
|
|
401
|
+
neighbor_indices: np.ndarray,
|
|
402
|
+
neighbor_weights: np.ndarray,
|
|
403
|
+
cell_indices_sorted: np.ndarray,
|
|
404
|
+
enable_profiling: bool = False
|
|
405
|
+
):
|
|
406
|
+
"""Run pipeline with rich progress display
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
neighbor_indices: Neighbor indices for each cell
|
|
410
|
+
neighbor_weights: Weights for each neighbor
|
|
411
|
+
cell_indices_sorted: Sorted cell indices
|
|
412
|
+
enable_profiling: Whether to enable profiling
|
|
413
|
+
"""
|
|
414
|
+
|
|
415
|
+
# Ensure the queue has been reset for this cell type
|
|
416
|
+
if self.active_cell_type is None or self.stats is None:
|
|
417
|
+
raise RuntimeError("MarkerScoreMessageQueue must be reset before starting. Call reset_for_cell_type first.")
|
|
418
|
+
|
|
419
|
+
# Optional profiling
|
|
420
|
+
tracer = None
|
|
421
|
+
if enable_profiling:
|
|
422
|
+
import viztracer
|
|
423
|
+
tracer = viztracer.VizTracer(
|
|
424
|
+
output_file=f"marker_score_{self.active_cell_type}_{self.n_batches}.json",
|
|
425
|
+
max_stack_depth=10,
|
|
426
|
+
)
|
|
427
|
+
tracer.start()
|
|
428
|
+
|
|
429
|
+
console = Console()
|
|
430
|
+
|
|
431
|
+
# Define queue color mapping based on queue size
|
|
432
|
+
def get_queue_color(size: int) -> str:
|
|
433
|
+
"""Get color based on queue size"""
|
|
434
|
+
if size == 0:
|
|
435
|
+
return "dim"
|
|
436
|
+
elif size < 5:
|
|
437
|
+
return "green"
|
|
438
|
+
elif size < 10:
|
|
439
|
+
return "yellow"
|
|
440
|
+
elif size < 20:
|
|
441
|
+
return "bright_yellow"
|
|
442
|
+
else:
|
|
443
|
+
return "red"
|
|
444
|
+
|
|
445
|
+
try:
|
|
446
|
+
# Create progress bars
|
|
447
|
+
with Progress(
|
|
448
|
+
SpinnerColumn(),
|
|
449
|
+
TextColumn("[bold blue]{task.description}"),
|
|
450
|
+
BarColumn(),
|
|
451
|
+
MofNCompleteColumn(),
|
|
452
|
+
TaskProgressColumn(),
|
|
453
|
+
TimeElapsedColumn(),
|
|
454
|
+
TimeRemainingColumn(),
|
|
455
|
+
console=console,
|
|
456
|
+
refresh_per_second=10
|
|
457
|
+
) as progress:
|
|
458
|
+
|
|
459
|
+
# Add single task for pipeline
|
|
460
|
+
pipeline_task = progress.add_task(
|
|
461
|
+
f"[bold]{self.active_cell_type}[/bold]", total=self.n_batches
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
# Submit all batches to start the pipeline
|
|
465
|
+
self._submit_batches(neighbor_indices, neighbor_weights, cell_indices_sorted)
|
|
466
|
+
|
|
467
|
+
# Monitor progress
|
|
468
|
+
while self.stats.completed_writes < self.n_batches:
|
|
469
|
+
# Check for errors in any component
|
|
470
|
+
self.reader.check_errors()
|
|
471
|
+
self.computer.check_errors()
|
|
472
|
+
self.writer.check_errors()
|
|
473
|
+
|
|
474
|
+
self._update_stats()
|
|
475
|
+
|
|
476
|
+
# Update progress based on completed writes (final stage)
|
|
477
|
+
progress.update(pipeline_task, completed=self.stats.completed_writes)
|
|
478
|
+
|
|
479
|
+
# Color code queue sizes based on fullness
|
|
480
|
+
compute_color = get_queue_color(self.stats.pending_compute)
|
|
481
|
+
write_color = get_queue_color(self.stats.pending_write)
|
|
482
|
+
|
|
483
|
+
# Update description with queue information
|
|
484
|
+
progress.update(
|
|
485
|
+
pipeline_task,
|
|
486
|
+
description=(
|
|
487
|
+
f"[bold]{self.active_cell_type}[/bold] | "
|
|
488
|
+
f"Queues: [{compute_color}]R→C:{self.stats.pending_compute}[/{compute_color}] "
|
|
489
|
+
f"[{write_color}]C→W:{self.stats.pending_write}[/{write_color}]"
|
|
490
|
+
)
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
time.sleep(0.1)
|
|
494
|
+
|
|
495
|
+
# Final update
|
|
496
|
+
progress.update(pipeline_task, completed=self.n_batches)
|
|
497
|
+
|
|
498
|
+
except Exception as e:
|
|
499
|
+
logger.error(f"Pipeline failed for {self.active_cell_type}: {e}")
|
|
500
|
+
# Stop all workers to prevent hanging
|
|
501
|
+
self.stop()
|
|
502
|
+
raise
|
|
503
|
+
|
|
504
|
+
finally:
|
|
505
|
+
# Stop profiling if enabled
|
|
506
|
+
if tracer is not None:
|
|
507
|
+
tracer.stop()
|
|
508
|
+
tracer.save()
|
|
509
|
+
logger.info(f"Profiling saved to marker_score_{self.active_cell_type}_{self.n_batches}.json")
|
|
510
|
+
|
|
511
|
+
# No threads to wait for - processing happens in component worker threads
|
|
512
|
+
|
|
513
|
+
# Print summary with component throughputs
|
|
514
|
+
# Calculate effective pipeline throughput (limited by bottleneck)
|
|
515
|
+
min(
|
|
516
|
+
self.stats.reader_throughput.throughput * self.reader.num_workers if self.stats.reader_throughput.throughput > 0 else float('inf'),
|
|
517
|
+
self.stats.computer_throughput.throughput * self.computer.num_workers if self.stats.computer_throughput.throughput > 0 else float('inf'),
|
|
518
|
+
self.stats.writer_throughput.throughput * self.writer.num_workers if self.stats.writer_throughput.throughput > 0 else float('inf')
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
# Calculate cells per second for each component and pipeline
|
|
522
|
+
pipeline_cells_per_sec = self.stats.throughput * self.batch_size if self.stats.throughput > 0 else 0
|
|
523
|
+
reader_cells_per_sec = self.stats.reader_throughput.throughput * self.batch_size * self.reader.num_workers if self.stats.reader_throughput.throughput > 0 else 0
|
|
524
|
+
computer_cells_per_sec = self.stats.computer_throughput.throughput * self.batch_size * self.computer.num_workers if self.stats.computer_throughput.throughput > 0 else 0
|
|
525
|
+
writer_cells_per_sec = self.stats.writer_throughput.throughput * self.batch_size * self.writer.num_workers if self.stats.writer_throughput.throughput > 0 else 0
|
|
526
|
+
|
|
527
|
+
# Create summary table
|
|
528
|
+
summary_table = Table(title=f"[bold green]✓ Completed {self.active_cell_type}[/bold green]", box=None)
|
|
529
|
+
summary_table.add_column("Property", style="dim")
|
|
530
|
+
summary_table.add_column("Value", style="green", justify="right")
|
|
531
|
+
|
|
532
|
+
summary_table.add_row("Total Batches", str(self.n_batches))
|
|
533
|
+
summary_table.add_row("Time Elapsed", f"{self.stats.elapsed_time:.2f}s")
|
|
534
|
+
summary_table.add_row("Pipeline Throughput", f"{self.stats.throughput:.2f} batches/s ({pipeline_cells_per_sec:.0f} cells/s)")
|
|
535
|
+
|
|
536
|
+
perf_table = Table(title="[bold]Component Performance (per worker)[/bold]", show_header=True, header_style="bold blue")
|
|
537
|
+
perf_table.add_column("Component")
|
|
538
|
+
perf_table.add_column("Throughput", justify="right")
|
|
539
|
+
perf_table.add_column("Workers", justify="right")
|
|
540
|
+
perf_table.add_column("Total Throughput", justify="right", style="green")
|
|
541
|
+
|
|
542
|
+
perf_table.add_row("Reader", f"{self.stats.reader_throughput.throughput:.2f} b/s", str(self.reader.num_workers), f"{reader_cells_per_sec:.0f} c/s")
|
|
543
|
+
perf_table.add_row("Computer", f"{self.stats.computer_throughput.throughput:.2f} b/s", str(self.computer.num_workers), f"{computer_cells_per_sec:.0f} c/s")
|
|
544
|
+
perf_table.add_row("Writer", f"{self.stats.writer_throughput.throughput:.2f} b/s", str(self.writer.num_workers), f"{writer_cells_per_sec:.0f} c/s")
|
|
545
|
+
|
|
546
|
+
console.print(Panel(
|
|
547
|
+
Group(summary_table, perf_table),
|
|
548
|
+
title="Pipeline Summary",
|
|
549
|
+
border_style="green"
|
|
550
|
+
))
|
|
551
|
+
|
|
552
|
+
# Final garbage collection
|
|
553
|
+
import gc
|
|
554
|
+
gc.collect()
|
|
555
|
+
|
|
556
|
+
logger.info(f"✓ Completed processing {self.active_cell_type}")
|
|
557
|
+
|
|
558
|
+
def stop(self):
|
|
559
|
+
"""Stop all workers in the pipeline components"""
|
|
560
|
+
logger.info("Stopping MarkerScoreMessageQueue workers...")
|
|
561
|
+
self.reader.stop_workers.set()
|
|
562
|
+
self.computer.stop_workers.set()
|
|
563
|
+
self.writer.stop_workers.set()
|
|
564
|
+
logger.info("MarkerScoreMessageQueue workers stopped")
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
@partial(jit, static_argnums=(2, 3, 6))
|
|
569
|
+
def compute_marker_scores_jax(
|
|
570
|
+
log_ranks: jnp.ndarray, # (B*N) × G matrix
|
|
571
|
+
weights: jnp.ndarray, # B × N weight matrix
|
|
572
|
+
batch_size: int,
|
|
573
|
+
num_neighbors: int,
|
|
574
|
+
global_log_gmean: jnp.ndarray, # G-dimensional vector
|
|
575
|
+
global_expr_frac: jnp.ndarray, # G-dimensional vector
|
|
576
|
+
no_expression_fraction: bool = False # Skip expression fraction filtering if True
|
|
577
|
+
) -> jnp.ndarray:
|
|
578
|
+
"""
|
|
579
|
+
JAX-accelerated marker score computation
|
|
580
|
+
|
|
581
|
+
Returns:
|
|
582
|
+
B × G marker scores
|
|
583
|
+
"""
|
|
584
|
+
n_genes = log_ranks.shape[1]
|
|
585
|
+
|
|
586
|
+
# Reshape to batch format
|
|
587
|
+
log_ranks_3d = log_ranks.reshape(batch_size, num_neighbors, n_genes)
|
|
588
|
+
|
|
589
|
+
# Compute weighted geometric mean in log space
|
|
590
|
+
weights = weights / weights.sum(axis=1, keepdims=True) # Normalize weights
|
|
591
|
+
weighted_log_mean = jnp.einsum('bn,bng->bg', weights, log_ranks_3d)
|
|
592
|
+
|
|
593
|
+
# Calculate marker score
|
|
594
|
+
marker_score = jnp.exp(weighted_log_mean - global_log_gmean)
|
|
595
|
+
marker_score = jnp.where(marker_score < 1.0, 0.0, marker_score)
|
|
596
|
+
|
|
597
|
+
# Apply expression fraction filter (only if not disabled)
|
|
598
|
+
if not no_expression_fraction:
|
|
599
|
+
# Compute expression fraction (mean of is_expressed across neighbors)
|
|
600
|
+
# Treat min log rank as non-expressed
|
|
601
|
+
is_expressed = (log_ranks_3d != log_ranks_3d.min(axis=-1, keepdims=True))
|
|
602
|
+
|
|
603
|
+
# Create mask for valid neighbors (where weights > 0)
|
|
604
|
+
valid_mask = weights > 0 # Shape: (batch_size, num_neighbors)
|
|
605
|
+
|
|
606
|
+
# Apply mask and compute mean only for valid neighbors
|
|
607
|
+
is_expressed_masked = jnp.where(valid_mask[:, :, None], is_expressed, 0)
|
|
608
|
+
valid_counts = valid_mask.sum(axis=1, keepdims=True) # Count of valid neighbors per cell
|
|
609
|
+
|
|
610
|
+
# Compute mean only over valid neighbors (avoid division by zero)
|
|
611
|
+
expr_frac = jnp.where(
|
|
612
|
+
valid_counts > 0,
|
|
613
|
+
is_expressed_masked.astype(jnp.float16).sum(axis=1) / valid_counts,
|
|
614
|
+
0.0
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
frac_mask = expr_frac > global_expr_frac
|
|
618
|
+
marker_score = jnp.where(frac_mask, marker_score, 0.0)
|
|
619
|
+
|
|
620
|
+
marker_score = jnp.exp(marker_score ** 1.5) - 1.0
|
|
621
|
+
|
|
622
|
+
# Return as float16 for memory efficiency
|
|
623
|
+
return marker_score.astype(jnp.float16)
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
@partial(jit, static_argnums=(2, 3, 4, 7))
|
|
627
|
+
def compute_marker_scores_3d_hierarchical_pool_jax(
|
|
628
|
+
log_ranks: jnp.ndarray, # (B*N) × G matrix where N = n_slices * num_homogeneous_per_slice
|
|
629
|
+
weights: jnp.ndarray, # B × N weight matrix
|
|
630
|
+
batch_size: int,
|
|
631
|
+
n_slices: int,
|
|
632
|
+
num_homogeneous_per_slice: int,
|
|
633
|
+
global_log_gmean: jnp.ndarray, # G-dimensional vector
|
|
634
|
+
global_expr_frac: jnp.ndarray, # G-dimensional vector
|
|
635
|
+
no_expression_fraction: bool = False # Skip expression fraction filtering if True
|
|
636
|
+
) -> jnp.ndarray:
|
|
637
|
+
"""
|
|
638
|
+
JAX-accelerated marker score computation with hierarchical pooling for 3D spatial data.
|
|
639
|
+
Computes marker scores independently for each slice and takes the average.
|
|
640
|
+
|
|
641
|
+
Args:
|
|
642
|
+
log_ranks: Flattened log ranks (batch_size * total_neighbors, n_genes)
|
|
643
|
+
weights: Flattened weights (batch_size, total_neighbors)
|
|
644
|
+
batch_size: Number of cells in batch
|
|
645
|
+
n_slices: Number of slices (1 + 2 * n_adjacent_slices)
|
|
646
|
+
num_homogeneous_per_slice: Number of homogeneous neighbors per slice
|
|
647
|
+
global_log_gmean: Global log geometric mean
|
|
648
|
+
global_expr_frac: Global expression fraction
|
|
649
|
+
|
|
650
|
+
Returns:
|
|
651
|
+
(batch_size, n_genes) marker scores using average cross slices
|
|
652
|
+
"""
|
|
653
|
+
n_genes = log_ranks.shape[1]
|
|
654
|
+
n_slices * num_homogeneous_per_slice
|
|
655
|
+
|
|
656
|
+
# Reshape to separate slices: (batch_size, n_slices, num_homogeneous_per_slice, n_genes)
|
|
657
|
+
log_ranks_4d = log_ranks.reshape(batch_size, n_slices, num_homogeneous_per_slice, n_genes)
|
|
658
|
+
|
|
659
|
+
# Reshape weights: (batch_size, n_slices, num_homogeneous_per_slice)
|
|
660
|
+
weights_3d = weights.reshape(batch_size, n_slices, num_homogeneous_per_slice)
|
|
661
|
+
|
|
662
|
+
# Normalize weights within each slice (sum to 1 along num_homogeneous_per_slice axis)
|
|
663
|
+
weights_sum = weights_3d.sum(axis=2, keepdims=True) # Shape: (batch_size, n_slices, 1)
|
|
664
|
+
weights_normalized = weights_3d / jnp.where(weights_sum > 0, weights_sum, 1.0)
|
|
665
|
+
|
|
666
|
+
# Compute weighted geometric mean in log space for each slice
|
|
667
|
+
# Result: (batch_size, n_slices, n_genes)
|
|
668
|
+
weighted_log_mean = jnp.einsum('bsn,bsng->bsg', weights_normalized, log_ranks_4d)
|
|
669
|
+
|
|
670
|
+
# Calculate marker score for each slice
|
|
671
|
+
marker_score_per_slice = jnp.exp(weighted_log_mean - global_log_gmean[None, None, :])
|
|
672
|
+
marker_score_per_slice = jnp.where(marker_score_per_slice < 1.0, 0.0, marker_score_per_slice)
|
|
673
|
+
|
|
674
|
+
# Apply expression fraction filter for each slice (only if not disabled)
|
|
675
|
+
if not no_expression_fraction:
|
|
676
|
+
# Compute expression fraction for each slice
|
|
677
|
+
# Treat min log rank as non-expressed
|
|
678
|
+
min_log_rank = log_ranks_4d.min(axis=-1, keepdims=True)
|
|
679
|
+
is_expressed = (log_ranks_4d != min_log_rank)
|
|
680
|
+
|
|
681
|
+
# Create mask for valid neighbors within each slice (where weights > 0)
|
|
682
|
+
valid_mask = weights_3d > 0 # Shape: (batch_size, n_slices, num_homogeneous_per_slice)
|
|
683
|
+
|
|
684
|
+
# Apply mask and compute mean only for valid neighbors within each slice
|
|
685
|
+
is_expressed_masked = jnp.where(valid_mask[:, :, :, None], is_expressed, 0)
|
|
686
|
+
valid_counts = valid_mask.sum(axis=2, keepdims=True) # Count of valid neighbors per slice
|
|
687
|
+
|
|
688
|
+
# Compute mean only over valid neighbors (avoid division by zero)
|
|
689
|
+
# Result: (batch_size, n_slices, n_genes)
|
|
690
|
+
# valid_counts has shape (batch_size, n_slices, 1), need to broadcast properly
|
|
691
|
+
expr_frac = jnp.where(
|
|
692
|
+
valid_counts > 0,
|
|
693
|
+
is_expressed_masked.astype(jnp.float16).sum(axis=2) / valid_counts,
|
|
694
|
+
0.0
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
frac_mask = expr_frac > global_expr_frac[None, None, :]
|
|
698
|
+
marker_score_per_slice = jnp.where(frac_mask, marker_score_per_slice, 0.0)
|
|
699
|
+
|
|
700
|
+
marker_score_per_slice = jnp.exp(marker_score_per_slice ** 1.5) - 1.0
|
|
701
|
+
|
|
702
|
+
# Calculate median across slices instead of max pooling to reduce noise
|
|
703
|
+
# Result: (batch_size, n_genes)
|
|
704
|
+
|
|
705
|
+
# Get the invalid slice mask before calculating median.
|
|
706
|
+
# These slices are invalid could be:
|
|
707
|
+
# 1. no neighbors from this slice for the outermost slices
|
|
708
|
+
# 2. the adjacent slices that do not have enough high quality cells which has been filtered out
|
|
709
|
+
invalid_slice_mask = (weights_sum == 0).squeeze(-1) # Shape: (batch_size, n_slices)
|
|
710
|
+
|
|
711
|
+
# Set invalid slice scores to NaN, then calculate the median
|
|
712
|
+
# NaN values will be ignored by jnp.nanmedian
|
|
713
|
+
marker_score_per_slice = jnp.where(
|
|
714
|
+
invalid_slice_mask[:, :, None], # Broadcast to (batch_size, n_slices, n_genes)
|
|
715
|
+
jnp.nan,
|
|
716
|
+
marker_score_per_slice
|
|
717
|
+
)
|
|
718
|
+
|
|
719
|
+
# Calculate median across slices (axis=1), ignoring NaN values
|
|
720
|
+
marker_score = jnp.nanmean(marker_score_per_slice, axis=1)
|
|
721
|
+
# marker_score = jnp.nanmedian(marker_score_per_slice, axis=1)
|
|
722
|
+
|
|
723
|
+
# Handle cases where all slices are invalid (all NaN) - set to 0
|
|
724
|
+
marker_score = jnp.where(jnp.isnan(marker_score), 0.0, marker_score)
|
|
725
|
+
|
|
726
|
+
# Return as float16 for memory efficiency
|
|
727
|
+
return marker_score.astype(jnp.float16)
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
class MarkerScoreCalculator:
|
|
731
|
+
"""Main class for calculating marker scores"""
|
|
732
|
+
|
|
733
|
+
def __init__(self, config):
|
|
734
|
+
"""
|
|
735
|
+
Initialize with configuration
|
|
736
|
+
|
|
737
|
+
Args:
|
|
738
|
+
config: LatentToGeneConfig object
|
|
739
|
+
"""
|
|
740
|
+
self.config = config
|
|
741
|
+
self.connectivity_builder = ConnectivityMatrixBuilder(config)
|
|
742
|
+
|
|
743
|
+
def load_global_stats(self, mean_frac_path: str) -> tuple[np.ndarray, np.ndarray]:
|
|
744
|
+
"""Load pre-calculated global geometric mean and expression fraction from parquet"""
|
|
745
|
+
|
|
746
|
+
logger.info("Loading global statistics from parquet...")
|
|
747
|
+
parquet_path = Path(mean_frac_path)
|
|
748
|
+
|
|
749
|
+
if not parquet_path.exists():
|
|
750
|
+
raise FileNotFoundError(f"Global stats file not found: {parquet_path}")
|
|
751
|
+
|
|
752
|
+
# Load the dataframe
|
|
753
|
+
mean_frac_df = pd.read_parquet(parquet_path)
|
|
754
|
+
|
|
755
|
+
# Extract global log geometric mean and expression fraction
|
|
756
|
+
global_log_gmean = mean_frac_df['G_Mean'].values.astype(np.float32)
|
|
757
|
+
global_expr_frac = mean_frac_df['frac'].values.astype(np.float32)
|
|
758
|
+
|
|
759
|
+
logger.info(f"Loaded global stats for {len(global_log_gmean)} genes")
|
|
760
|
+
|
|
761
|
+
return global_log_gmean, global_expr_frac
|
|
762
|
+
|
|
763
|
+
def _display_input_summary(self, adata, cell_types, n_cells, n_genes):
|
|
764
|
+
"""Display summary of input data and cell types"""
|
|
765
|
+
table = Table(title="[bold cyan]Marker Score Input Summary[/bold cyan]", show_header=True, header_style="bold magenta")
|
|
766
|
+
table.add_column("Property", style="dim")
|
|
767
|
+
table.add_column("Value", style="green")
|
|
768
|
+
|
|
769
|
+
table.add_row("Total Cells", str(n_cells))
|
|
770
|
+
table.add_row("Total Genes", str(n_genes))
|
|
771
|
+
table.add_row("Cell Types", str(len(cell_types)))
|
|
772
|
+
table.add_row("Dataset Type", str(self.config.dataset_type.value))
|
|
773
|
+
table.add_row("Annotation Key", str(self.config.annotation))
|
|
774
|
+
|
|
775
|
+
self.console.print(table)
|
|
776
|
+
|
|
777
|
+
# Display cell type breakdown
|
|
778
|
+
ct_table = Table(title="[bold cyan]Cell Type Breakdown[/bold cyan]", show_header=True, header_style="bold blue")
|
|
779
|
+
ct_table.add_column("Cell Type", style="dim")
|
|
780
|
+
ct_table.add_column("Count", justify="right")
|
|
781
|
+
|
|
782
|
+
ct_counts = adata.obs[self.config.annotation].value_counts()
|
|
783
|
+
for ct in cell_types:
|
|
784
|
+
count = ct_counts.get(ct, 0)
|
|
785
|
+
style = "green" if count >= self.config.min_cells_per_type else "red"
|
|
786
|
+
ct_table.add_row(ct, f"[{style}]{count}[/{style}]")
|
|
787
|
+
|
|
788
|
+
self.console.print(ct_table)
|
|
789
|
+
|
|
790
|
+
def _load_input_data(
|
|
791
|
+
self,
|
|
792
|
+
adata_path: str,
|
|
793
|
+
rank_memmap_path: str,
|
|
794
|
+
mean_frac_path: str
|
|
795
|
+
) -> tuple[ad.AnnData, MemMapDense, np.ndarray, np.ndarray, int, int, np.ndarray]:
|
|
796
|
+
"""Load input data: AnnData, rank memory map, global statistics, and high quality mask
|
|
797
|
+
|
|
798
|
+
Returns:
|
|
799
|
+
Tuple of (adata, rank_memmap, global_log_gmean, global_expr_frac, n_cells, n_genes, high_quality_mask)
|
|
800
|
+
"""
|
|
801
|
+
# Load concatenated AnnData
|
|
802
|
+
logger.info(f"Loading concatenated AnnData from {adata_path}")
|
|
803
|
+
if not Path(adata_path).exists():
|
|
804
|
+
raise FileNotFoundError(f"Concatenated AnnData not found: {adata_path}")
|
|
805
|
+
adata = sc.read_h5ad(adata_path)
|
|
806
|
+
|
|
807
|
+
# Load pre-calculated global statistics
|
|
808
|
+
global_log_gmean, global_expr_frac = self.load_global_stats(mean_frac_path)
|
|
809
|
+
|
|
810
|
+
# Open rank memory map and get dimensions
|
|
811
|
+
rank_memmap_path = Path(rank_memmap_path)
|
|
812
|
+
meta_path = rank_memmap_path.with_suffix('.meta.json')
|
|
813
|
+
with open(meta_path) as f:
|
|
814
|
+
meta = json.load(f)
|
|
815
|
+
|
|
816
|
+
rank_memmap = MemMapDense(
|
|
817
|
+
path=rank_memmap_path,
|
|
818
|
+
shape=tuple(meta['shape']),
|
|
819
|
+
dtype=np.dtype(meta['dtype']),
|
|
820
|
+
mode='r',
|
|
821
|
+
tmp_dir=self.config.memmap_tmp_dir
|
|
822
|
+
)
|
|
823
|
+
|
|
824
|
+
logger.info(f"Opened rank memory map from {rank_memmap_path}")
|
|
825
|
+
n_cells = adata.n_obs
|
|
826
|
+
n_cells_rank = rank_memmap.shape[0]
|
|
827
|
+
n_genes = rank_memmap.shape[1]
|
|
828
|
+
|
|
829
|
+
logger.info(f"AnnData dimensions: {n_cells} cells × {adata.n_vars} genes")
|
|
830
|
+
logger.info(f"Rank MemMap dimensions: {n_cells_rank} cells × {n_genes} genes")
|
|
831
|
+
|
|
832
|
+
# Cells should match exactly since filtering is done before rank memmap creation
|
|
833
|
+
assert n_cells == n_cells_rank, \
|
|
834
|
+
f"Cell count mismatch: AnnData has {n_cells} cells, Rank MemMap has {n_cells_rank} cells. " \
|
|
835
|
+
f"This indicates the filtering was not applied consistently during rank calculation."
|
|
836
|
+
|
|
837
|
+
# Load high quality mask based on configuration
|
|
838
|
+
if self.config.high_quality_neighbor_filter:
|
|
839
|
+
if 'High_quality' not in adata.obs.columns:
|
|
840
|
+
raise ValueError("High_quality column not found in AnnData obs. Please ensure QC was applied during find_latent_representation step.")
|
|
841
|
+
high_quality_mask = adata.obs['High_quality'].values.astype(bool)
|
|
842
|
+
logger.info(f"Loaded high quality mask: {high_quality_mask.sum()}/{len(high_quality_mask)} cells marked as high quality")
|
|
843
|
+
else:
|
|
844
|
+
# Create all-True mask when high quality filtering is disabled
|
|
845
|
+
high_quality_mask = np.ones(n_cells, dtype=bool)
|
|
846
|
+
logger.info("High quality filtering disabled - using all cells")
|
|
847
|
+
|
|
848
|
+
return adata, rank_memmap, global_log_gmean, global_expr_frac, n_cells, n_genes, high_quality_mask
|
|
849
|
+
|
|
850
|
+
def _prepare_embeddings(
|
|
851
|
+
self,
|
|
852
|
+
adata: ad.AnnData
|
|
853
|
+
) -> tuple[np.ndarray | None, np.ndarray, np.ndarray, np.ndarray | None]:
|
|
854
|
+
"""Prepare and normalize embeddings based on dataset type
|
|
855
|
+
|
|
856
|
+
Returns:
|
|
857
|
+
Tuple of (coords, emb_niche, emb_indv, slice_ids)
|
|
858
|
+
"""
|
|
859
|
+
logger.info("Loading shared data structures...")
|
|
860
|
+
|
|
861
|
+
coords = None
|
|
862
|
+
emb_niche = None
|
|
863
|
+
slice_ids = None
|
|
864
|
+
|
|
865
|
+
if self.config.dataset_type in ['spatial2D', 'spatial3D']:
|
|
866
|
+
# Load spatial coordinates for spatial datasets
|
|
867
|
+
coords = adata.obsm[self.config.spatial_key]
|
|
868
|
+
|
|
869
|
+
# Load slice IDs if provided (for both spatial2D and spatial3D)
|
|
870
|
+
assert 'slice_id' in adata.obs.columns
|
|
871
|
+
slice_ids = adata.obs['slice_id'].values.astype(np.int32)
|
|
872
|
+
|
|
873
|
+
# Try to load niche embeddings if they exist
|
|
874
|
+
if self.config.latent_representation_niche in adata.obsm:
|
|
875
|
+
emb_niche = adata.obsm[self.config.latent_representation_niche]
|
|
876
|
+
|
|
877
|
+
# Load cell embeddings for all dataset types
|
|
878
|
+
emb_indv = adata.obsm[self.config.latent_representation_cell].astype(np.float16)
|
|
879
|
+
|
|
880
|
+
# --- ELEGANT FIX START ---
|
|
881
|
+
# If emb_niche is missing (scRNA-seq or spatial without niche),
|
|
882
|
+
# create a dummy (N, 1) array of ones.
|
|
883
|
+
# Normalizing a vector of 1s results in 1.0, so cosine similarity will be 1.0 everywhere.
|
|
884
|
+
if emb_niche is None:
|
|
885
|
+
logger.info("No niche embeddings found. Using dummy embeddings (all ones).")
|
|
886
|
+
emb_niche = np.ones((emb_indv.shape[0], 1), dtype=np.float32)
|
|
887
|
+
# --- ELEGANT FIX END ---
|
|
888
|
+
|
|
889
|
+
# Normalize embeddings
|
|
890
|
+
logger.info("Normalizing embeddings...")
|
|
891
|
+
|
|
892
|
+
# L2 normalize niche embeddings (always exists now)
|
|
893
|
+
emb_niche_norm = np.linalg.norm(emb_niche, axis=1, keepdims=True)
|
|
894
|
+
emb_niche = emb_niche / (emb_niche_norm + 1e-8)
|
|
895
|
+
|
|
896
|
+
# L2 normalize individual embeddings
|
|
897
|
+
emb_indv_norm = np.linalg.norm(emb_indv, axis=1, keepdims=True)
|
|
898
|
+
emb_indv = emb_indv / (emb_indv_norm + 1e-8)
|
|
899
|
+
|
|
900
|
+
return coords, emb_niche, emb_indv, slice_ids
|
|
901
|
+
|
|
902
|
+
def _get_cell_types(self, adata: ad.AnnData) -> np.ndarray:
|
|
903
|
+
"""Get cell types from annotation key
|
|
904
|
+
|
|
905
|
+
Returns:
|
|
906
|
+
Array of unique cell types
|
|
907
|
+
"""
|
|
908
|
+
annotation_key = self.config.annotation
|
|
909
|
+
|
|
910
|
+
if annotation_key is not None:
|
|
911
|
+
assert annotation_key in adata.obs.columns, f"Annotation key '{annotation_key}' not found in adata.obs"
|
|
912
|
+
# Get unique cell types, excluding NaN values
|
|
913
|
+
cell_types = adata.obs[annotation_key].dropna().unique()
|
|
914
|
+
|
|
915
|
+
# Check if there are any NaN values and handle them
|
|
916
|
+
nan_count = adata.obs[annotation_key].isna().sum()
|
|
917
|
+
if nan_count > 0:
|
|
918
|
+
logger.warning(f"Found {nan_count} cells with NaN annotation in '{annotation_key}', these will be skipped")
|
|
919
|
+
else:
|
|
920
|
+
logger.warning(f"Annotation {annotation_key} not found, processing all cells as one type")
|
|
921
|
+
cell_types = ["all"]
|
|
922
|
+
adata.obs[annotation_key] = "all"
|
|
923
|
+
|
|
924
|
+
logger.info(f"Processing {len(cell_types)} cell types")
|
|
925
|
+
return cell_types
|
|
926
|
+
|
|
927
|
+
def _initialize_pipeline(
|
|
928
|
+
self,
|
|
929
|
+
rank_memmap: MemMapDense,
|
|
930
|
+
output_memmap: MemMapDense,
|
|
931
|
+
global_log_gmean: np.ndarray,
|
|
932
|
+
global_expr_frac: np.ndarray
|
|
933
|
+
):
|
|
934
|
+
"""Initialize the processing pipeline with shared pools and queues"""
|
|
935
|
+
logger.info("Initializing shared processing pools with direct queue connections...")
|
|
936
|
+
|
|
937
|
+
# Create shared queues to connect components with configured sizes
|
|
938
|
+
reader_to_computer_queue = queue.Queue(maxsize=self.config.mkscore_compute_workers * self.config.compute_input_queue_size)
|
|
939
|
+
computer_to_writer_queue = queue.Queue(maxsize=self.config.writer_queue_size)
|
|
940
|
+
|
|
941
|
+
self.reader = ParallelRankReader(
|
|
942
|
+
rank_memmap,
|
|
943
|
+
num_workers=self.config.rank_read_workers,
|
|
944
|
+
output_queue=reader_to_computer_queue # Direct connection to computer
|
|
945
|
+
)
|
|
946
|
+
|
|
947
|
+
# Determine 3D strategy parameters
|
|
948
|
+
cross_slice_strategy = None
|
|
949
|
+
n_slices = 1
|
|
950
|
+
num_homogeneous_per_slice = self.config.homogeneous_neighbors
|
|
951
|
+
|
|
952
|
+
if (self.config.dataset_type == DatasetType.SPATIAL_3D and
|
|
953
|
+
self.config.cross_slice_marker_score_strategy in [
|
|
954
|
+
MarkerScoreCrossSliceStrategy.PER_SLICE_POOL,
|
|
955
|
+
MarkerScoreCrossSliceStrategy.HIERARCHICAL_POOL
|
|
956
|
+
]):
|
|
957
|
+
|
|
958
|
+
cross_slice_strategy = self.config.cross_slice_marker_score_strategy
|
|
959
|
+
n_slices = 1 + 2 * self.config.n_adjacent_slices
|
|
960
|
+
# For pooling strategies, num_homogeneous is per slice
|
|
961
|
+
num_homogeneous_per_slice = self.config.homogeneous_neighbors
|
|
962
|
+
|
|
963
|
+
self.computer = ParallelMarkerScoreComputer(
|
|
964
|
+
global_log_gmean,
|
|
965
|
+
global_expr_frac,
|
|
966
|
+
self.config.homogeneous_neighbors,
|
|
967
|
+
num_workers=self.config.mkscore_compute_workers,
|
|
968
|
+
input_queue=reader_to_computer_queue, # Input from reader
|
|
969
|
+
output_queue=computer_to_writer_queue, # Output to writer
|
|
970
|
+
cross_slice_strategy=cross_slice_strategy,
|
|
971
|
+
n_slices=n_slices,
|
|
972
|
+
num_homogeneous_per_slice=num_homogeneous_per_slice,
|
|
973
|
+
no_expression_fraction=self.config.no_expression_fraction
|
|
974
|
+
)
|
|
975
|
+
|
|
976
|
+
self.writer = ParallelMarkerScoreWriter(
|
|
977
|
+
output_memmap,
|
|
978
|
+
num_workers=self.config.mkscore_write_workers,
|
|
979
|
+
input_queue=computer_to_writer_queue # Input from computer
|
|
980
|
+
)
|
|
981
|
+
|
|
982
|
+
logger.info(f"Processing pools initialized: {self.config.rank_read_workers} readers, "
|
|
983
|
+
f"{self.config.mkscore_compute_workers} computers, "
|
|
984
|
+
f"{self.config.mkscore_write_workers} writers")
|
|
985
|
+
|
|
986
|
+
self.marker_score_queue = MarkerScoreMessageQueue(
|
|
987
|
+
reader=self.reader,
|
|
988
|
+
computer=self.computer,
|
|
989
|
+
writer=self.writer,
|
|
990
|
+
batch_size=self.config.mkscore_batch_size
|
|
991
|
+
)
|
|
992
|
+
|
|
993
|
+
def _find_homogeneous_spots(
|
|
994
|
+
self,
|
|
995
|
+
adata: ad.AnnData,
|
|
996
|
+
cell_type: str,
|
|
997
|
+
annotation_key: str,
|
|
998
|
+
coords: np.ndarray | None,
|
|
999
|
+
emb_niche: np.ndarray,
|
|
1000
|
+
emb_indv: np.ndarray,
|
|
1001
|
+
slice_ids: np.ndarray | None,
|
|
1002
|
+
rank_shape: tuple[int, int],
|
|
1003
|
+
high_quality_mask: np.ndarray
|
|
1004
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray | None, np.ndarray, int] | None:
|
|
1005
|
+
"""
|
|
1006
|
+
Prepare batch data for a cell type
|
|
1007
|
+
|
|
1008
|
+
Returns:
|
|
1009
|
+
Tuple of (neighbor_indices, cell_sims, niche_sims, cell_indices_sorted, n_cells)
|
|
1010
|
+
or None if cell type should be skipped
|
|
1011
|
+
"""
|
|
1012
|
+
# Get cells of this type, excluding NaN values
|
|
1013
|
+
cell_mask = (adata.obs[annotation_key] == cell_type) & (adata.obs[annotation_key].notna())
|
|
1014
|
+
cell_indices = np.where(cell_mask)[0]
|
|
1015
|
+
n_cells = len(cell_indices)
|
|
1016
|
+
|
|
1017
|
+
# Check minimum cells
|
|
1018
|
+
min_cells = self.config.min_cells_per_type
|
|
1019
|
+
if n_cells < min_cells:
|
|
1020
|
+
logger.warning(f"Skipping {cell_type}: only {n_cells} cells (min: {min_cells})")
|
|
1021
|
+
|
|
1022
|
+
logger.info(f"Processing {cell_type}: {n_cells} cells")
|
|
1023
|
+
|
|
1024
|
+
# Build connectivity matrix
|
|
1025
|
+
logger.info("Building connectivity matrix...")
|
|
1026
|
+
neighbor_indices, cell_sims, niche_sims = self.connectivity_builder.build_connectivity_matrix(
|
|
1027
|
+
coords=coords,
|
|
1028
|
+
emb_niche=emb_niche,
|
|
1029
|
+
emb_indv=emb_indv,
|
|
1030
|
+
cell_mask=cell_mask,
|
|
1031
|
+
high_quality_mask=high_quality_mask,
|
|
1032
|
+
slice_ids=slice_ids,
|
|
1033
|
+
k_central=self.config.spatial_neighbors,
|
|
1034
|
+
k_adjacent=self.config.adjacent_slice_spatial_neighbors,
|
|
1035
|
+
n_adjacent_slices=self.config.n_adjacent_slices
|
|
1036
|
+
)
|
|
1037
|
+
gc.collect()
|
|
1038
|
+
|
|
1039
|
+
# Validate neighbor indices
|
|
1040
|
+
max_valid_idx = rank_shape[0] - 1
|
|
1041
|
+
assert neighbor_indices.max() <= max_valid_idx, \
|
|
1042
|
+
f"Neighbor indices exceed bounds (max: {neighbor_indices.max()}, limit: {max_valid_idx})"
|
|
1043
|
+
|
|
1044
|
+
# Optimize row order using JAX implementation
|
|
1045
|
+
logger.info("Optimizing row order for cache efficiency...")
|
|
1046
|
+
row_order = optimize_row_order_jax(
|
|
1047
|
+
neighbor_indices=neighbor_indices[:,:self.config.homogeneous_neighbors],
|
|
1048
|
+
cell_indices=cell_indices,
|
|
1049
|
+
neighbor_weights=cell_sims[:,:self.config.homogeneous_neighbors],
|
|
1050
|
+
)
|
|
1051
|
+
|
|
1052
|
+
neighbor_indices = neighbor_indices[row_order]
|
|
1053
|
+
cell_sims = cell_sims[row_order]
|
|
1054
|
+
if niche_sims is not None:
|
|
1055
|
+
niche_sims = niche_sims[row_order]
|
|
1056
|
+
cell_indices_sorted = cell_indices[row_order]
|
|
1057
|
+
|
|
1058
|
+
# Save homogeneous neighbor data to adata
|
|
1059
|
+
has_real_niche_embedding = self.config.latent_representation_niche is not None
|
|
1060
|
+
self._save_homogeneous_data_to_adata(
|
|
1061
|
+
adata=adata,
|
|
1062
|
+
neighbor_indices=neighbor_indices,
|
|
1063
|
+
cell_sims=cell_sims,
|
|
1064
|
+
niche_sims=niche_sims,
|
|
1065
|
+
cell_indices_sorted=cell_indices_sorted,
|
|
1066
|
+
has_real_niche_embedding=has_real_niche_embedding
|
|
1067
|
+
)
|
|
1068
|
+
|
|
1069
|
+
# warning for cells not find homo neighbors
|
|
1070
|
+
homo_neighbor_count = np.count_nonzero(cell_sims > 0, axis=1)
|
|
1071
|
+
zero_homo_neighbor_mask = (homo_neighbor_count <= 5)
|
|
1072
|
+
if np.any(zero_homo_neighbor_mask):
|
|
1073
|
+
logger.warning(f"Cell type {cell_type}: {zero_homo_neighbor_mask.sum()} cells can't find enough homogeneous neighbors")
|
|
1074
|
+
|
|
1075
|
+
return neighbor_indices, cell_sims, niche_sims, cell_indices_sorted, n_cells
|
|
1076
|
+
|
|
1077
|
+
def _save_homogeneous_data_to_adata(
|
|
1078
|
+
self,
|
|
1079
|
+
adata: ad.AnnData,
|
|
1080
|
+
neighbor_indices: np.ndarray,
|
|
1081
|
+
cell_sims: np.ndarray,
|
|
1082
|
+
niche_sims: np.ndarray | None,
|
|
1083
|
+
cell_indices_sorted: np.ndarray,
|
|
1084
|
+
has_real_niche_embedding: bool
|
|
1085
|
+
):
|
|
1086
|
+
"""
|
|
1087
|
+
Save homogeneous neighbor data to adata obsm and obs.
|
|
1088
|
+
|
|
1089
|
+
Args:
|
|
1090
|
+
adata: AnnData object to save data to
|
|
1091
|
+
neighbor_indices: (n_cells, k) array of neighbor indices
|
|
1092
|
+
cell_sims: (n_cells, k) array of cell similarity scores
|
|
1093
|
+
niche_sims: (n_cells, k) array of niche similarity scores or None
|
|
1094
|
+
cell_indices_sorted: (n_cells,) array of cell indices in sorted order
|
|
1095
|
+
has_real_niche_embedding: Whether real niche embedding was provided
|
|
1096
|
+
"""
|
|
1097
|
+
# Initialize obsm matrices if they don't exist
|
|
1098
|
+
if 'gsMap_homo_indices' not in adata.obsm.keys() or (adata.obsm['gsMap_homo_indices'].shape[1] != neighbor_indices.shape[1]):
|
|
1099
|
+
adata.obsm['gsMap_homo_indices'] = np.zeros((adata.n_obs, neighbor_indices.shape[1]), dtype=neighbor_indices.dtype)
|
|
1100
|
+
if 'gsMap_homo_cell_sims' not in adata.obsm.keys() or (adata.obsm['gsMap_homo_cell_sims'].shape[1] != neighbor_indices.shape[1]):
|
|
1101
|
+
adata.obsm['gsMap_homo_cell_sims'] = np.zeros((adata.n_obs, cell_sims.shape[1]), dtype=cell_sims.dtype)
|
|
1102
|
+
|
|
1103
|
+
# Store the neighbor indices and cell_sims for this cell type
|
|
1104
|
+
adata.obsm['gsMap_homo_indices'][cell_indices_sorted] = neighbor_indices
|
|
1105
|
+
adata.obsm['gsMap_homo_cell_sims'][cell_indices_sorted] = cell_sims
|
|
1106
|
+
|
|
1107
|
+
# Only store niche_sims if niche embedding was provided (not dummy)
|
|
1108
|
+
if has_real_niche_embedding and niche_sims is not None:
|
|
1109
|
+
if 'gsMap_homo_niche_sims' not in adata.obsm.keys() or (adata.obsm['gsMap_homo_niche_sims'].shape[1] != neighbor_indices.shape[1]):
|
|
1110
|
+
adata.obsm['gsMap_homo_niche_sims'] = np.zeros((adata.n_obs, niche_sims.shape[1]), dtype=niche_sims.dtype)
|
|
1111
|
+
adata.obsm['gsMap_homo_niche_sims'][cell_indices_sorted] = niche_sims
|
|
1112
|
+
|
|
1113
|
+
# Initialize obs columns if they don't exist
|
|
1114
|
+
if 'gsMap_homo_neighbor_count' not in adata.obs.columns:
|
|
1115
|
+
adata.obs['gsMap_homo_neighbor_count'] = 0
|
|
1116
|
+
if 'gsMap_homo_cell_sims_median' not in adata.obs.columns:
|
|
1117
|
+
adata.obs['gsMap_homo_cell_sims_median'] = np.nan
|
|
1118
|
+
if has_real_niche_embedding and 'gsMap_homo_niche_sims_median' not in adata.obs.columns:
|
|
1119
|
+
adata.obs['gsMap_homo_niche_sims_median'] = np.nan
|
|
1120
|
+
|
|
1121
|
+
# Calculate statistics for each cell (only consider valid neighbors where cell_sims > 0)
|
|
1122
|
+
valid_mask = cell_sims > 0 # (n_cells, k) boolean mask
|
|
1123
|
+
|
|
1124
|
+
# Count of valid homogeneous neighbors
|
|
1125
|
+
homo_neighbor_count = valid_mask.sum(axis=1)
|
|
1126
|
+
adata.obs.loc[adata.obs.index[cell_indices_sorted], 'gsMap_homo_neighbor_count'] = homo_neighbor_count
|
|
1127
|
+
|
|
1128
|
+
# Median of cell similarity (only for valid neighbors)
|
|
1129
|
+
cell_sims_masked = np.where(valid_mask, cell_sims, np.nan)
|
|
1130
|
+
cell_sims_median = np.nanmedian(cell_sims_masked, axis=1)
|
|
1131
|
+
adata.obs.loc[adata.obs.index[cell_indices_sorted], 'gsMap_homo_cell_sims_median'] = cell_sims_median
|
|
1132
|
+
|
|
1133
|
+
# Median of niche similarity (only for valid neighbors, only if niche embedding provided)
|
|
1134
|
+
if has_real_niche_embedding and niche_sims is not None:
|
|
1135
|
+
niche_sims_masked = np.where(valid_mask, niche_sims, np.nan)
|
|
1136
|
+
niche_sims_median = np.nanmedian(niche_sims_masked, axis=1)
|
|
1137
|
+
adata.obs.loc[adata.obs.index[cell_indices_sorted], 'gsMap_homo_niche_sims_median'] = niche_sims_median
|
|
1138
|
+
|
|
1139
|
+
def _calculate_marker_scores_by_cell_type(
|
|
1140
|
+
self,
|
|
1141
|
+
adata: ad.AnnData,
|
|
1142
|
+
cell_type: str,
|
|
1143
|
+
coords: np.ndarray | None,
|
|
1144
|
+
emb_niche: np.ndarray,
|
|
1145
|
+
emb_indv: np.ndarray,
|
|
1146
|
+
annotation_key: str,
|
|
1147
|
+
slice_ids: np.ndarray | None = None,
|
|
1148
|
+
high_quality_mask: np.ndarray = None
|
|
1149
|
+
):
|
|
1150
|
+
"""Process a single cell type with shared pools"""
|
|
1151
|
+
|
|
1152
|
+
# Find homogeneous spots
|
|
1153
|
+
neighbor_indices, cell_sims, niche_sims, cell_indices_sorted, n_cells = self._find_homogeneous_spots(
|
|
1154
|
+
adata, cell_type, annotation_key, coords, emb_niche,
|
|
1155
|
+
emb_indv, slice_ids, self.reader.shape, high_quality_mask
|
|
1156
|
+
)
|
|
1157
|
+
|
|
1158
|
+
# Reset the message queue for this cell type
|
|
1159
|
+
self.marker_score_queue.reset_for_cell_type(cell_type, n_cells)
|
|
1160
|
+
|
|
1161
|
+
# Run the marker_score_queue to compute marker score
|
|
1162
|
+
self.marker_score_queue.start(
|
|
1163
|
+
neighbor_indices=neighbor_indices,
|
|
1164
|
+
neighbor_weights=cell_sims,
|
|
1165
|
+
cell_indices_sorted=cell_indices_sorted,
|
|
1166
|
+
enable_profiling=self.config.enable_profiling
|
|
1167
|
+
)
|
|
1168
|
+
|
|
1169
|
+
def calculate_marker_scores(
|
|
1170
|
+
self,
|
|
1171
|
+
adata_path: str,
|
|
1172
|
+
rank_memmap_path: str,
|
|
1173
|
+
mean_frac_path: str,
|
|
1174
|
+
output_path: str | Path | None = None
|
|
1175
|
+
) -> str | Path:
|
|
1176
|
+
"""
|
|
1177
|
+
Main execution function for marker score calculation
|
|
1178
|
+
|
|
1179
|
+
Args:
|
|
1180
|
+
adata_path: Path to concatenated latent adata
|
|
1181
|
+
rank_memmap_path: Path to rank memory map
|
|
1182
|
+
mean_frac_path: Path to mean expression fraction parquet
|
|
1183
|
+
output_path: Optional output path for marker scores
|
|
1184
|
+
|
|
1185
|
+
Returns:
|
|
1186
|
+
Path to output marker score memory map file
|
|
1187
|
+
"""
|
|
1188
|
+
logger.info("Starting marker score calculation...")
|
|
1189
|
+
self.console = Console()
|
|
1190
|
+
|
|
1191
|
+
self.console.print(Panel.fit(
|
|
1192
|
+
"[bold cyan]Marker Score Calculation Stage[/bold cyan]",
|
|
1193
|
+
subtitle="gsMap Stage 2",
|
|
1194
|
+
border_style="cyan"
|
|
1195
|
+
))
|
|
1196
|
+
|
|
1197
|
+
# Use config path if not specified
|
|
1198
|
+
if output_path is None:
|
|
1199
|
+
output_path = Path(self.config.marker_scores_memmap_path)
|
|
1200
|
+
else:
|
|
1201
|
+
output_path = Path(output_path)
|
|
1202
|
+
|
|
1203
|
+
# Load all input data
|
|
1204
|
+
adata, rank_memmap, global_log_gmean, global_expr_frac, n_cells, n_genes, high_quality_mask = self._load_input_data(
|
|
1205
|
+
adata_path, rank_memmap_path, mean_frac_path
|
|
1206
|
+
)
|
|
1207
|
+
|
|
1208
|
+
# Initialize output memory map
|
|
1209
|
+
output_memmap = MemMapDense(
|
|
1210
|
+
output_path,
|
|
1211
|
+
shape=(n_cells, n_genes),
|
|
1212
|
+
dtype=np.float16, # Use float16 to save memory
|
|
1213
|
+
mode='w',
|
|
1214
|
+
num_write_workers=self.config.mkscore_write_workers,
|
|
1215
|
+
tmp_dir=self.config.memmap_tmp_dir
|
|
1216
|
+
)
|
|
1217
|
+
|
|
1218
|
+
# Get cell types to process
|
|
1219
|
+
cell_types = self._get_cell_types(adata)
|
|
1220
|
+
annotation_key = self.config.annotation
|
|
1221
|
+
|
|
1222
|
+
# Display summary
|
|
1223
|
+
self._display_input_summary(adata, cell_types, n_cells, n_genes)
|
|
1224
|
+
|
|
1225
|
+
# Prepare embeddings
|
|
1226
|
+
coords, emb_niche, emb_indv, slice_ids = self._prepare_embeddings(adata)
|
|
1227
|
+
|
|
1228
|
+
# Initialize processing pipeline
|
|
1229
|
+
self._initialize_pipeline(rank_memmap, output_memmap, global_log_gmean, global_expr_frac)
|
|
1230
|
+
|
|
1231
|
+
# Process each cell type
|
|
1232
|
+
for cell_type in cell_types:
|
|
1233
|
+
self._calculate_marker_scores_by_cell_type(
|
|
1234
|
+
adata,
|
|
1235
|
+
cell_type,
|
|
1236
|
+
coords,
|
|
1237
|
+
emb_niche,
|
|
1238
|
+
emb_indv,
|
|
1239
|
+
annotation_key,
|
|
1240
|
+
slice_ids,
|
|
1241
|
+
high_quality_mask
|
|
1242
|
+
)
|
|
1243
|
+
|
|
1244
|
+
# Save updated AnnData with neighbor matrices
|
|
1245
|
+
logger.info(f"Saving updated AnnData with homogeneous spot matrices to {adata_path}")
|
|
1246
|
+
adata.write_h5ad(adata_path)
|
|
1247
|
+
|
|
1248
|
+
# Close all shared pools after all cell types are processed
|
|
1249
|
+
logger.info("Closing shared processing pools...")
|
|
1250
|
+
self.reader.close()
|
|
1251
|
+
self.computer.close()
|
|
1252
|
+
self.writer.close()
|
|
1253
|
+
|
|
1254
|
+
# Close memory maps
|
|
1255
|
+
rank_memmap.close()
|
|
1256
|
+
output_memmap.close()
|
|
1257
|
+
|
|
1258
|
+
self.console.print(Panel.fit(
|
|
1259
|
+
"[bold green]✓ Marker score calculation complete![/bold green]",
|
|
1260
|
+
border_style="green"
|
|
1261
|
+
))
|
|
1262
|
+
|
|
1263
|
+
logger.info(f"Results saved to {output_path}")
|
|
1264
|
+
|
|
1265
|
+
return str(output_path)
|