gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1265 @@
1
+ """
2
+ Marker score calculation using homogeneous neighbors
3
+ Implements weighted geometric mean calculation in log space with JAX acceleration
4
+ """
5
+
6
+ import gc
7
+ import json
8
+ import logging
9
+ import queue
10
+ import threading
11
+ import time
12
+ import traceback
13
+ from dataclasses import dataclass
14
+ from functools import partial
15
+ from pathlib import Path
16
+
17
+ import anndata as ad
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+ import pandas as pd
21
+ import scanpy as sc
22
+ from jax import jit
23
+
24
+ # Progress bar imports
25
+ from rich.console import Console, Group
26
+ from rich.panel import Panel
27
+ from rich.progress import (
28
+ BarColumn,
29
+ MofNCompleteColumn,
30
+ Progress,
31
+ SpinnerColumn,
32
+ TaskProgressColumn,
33
+ TextColumn,
34
+ TimeElapsedColumn,
35
+ TimeRemainingColumn,
36
+ )
37
+ from rich.table import Table
38
+
39
+ from gsMap.config import DatasetType, MarkerScoreCrossSliceStrategy
40
+
41
+ from .connectivity import ConnectivityMatrixBuilder
42
+ from .memmap_io import MemMapDense
43
+ from .row_ordering_jax import optimize_row_order_jax
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+ from .memmap_io import ComponentThroughput, ParallelMarkerScoreWriter, ParallelRankReader
48
+
49
+
50
+ class ParallelMarkerScoreComputer:
51
+ """Multi-threaded computer pool for marker score calculation (reusable across cell types)"""
52
+
53
+ def __init__(
54
+ self,
55
+ global_log_gmean: np.ndarray,
56
+ global_expr_frac: np.ndarray,
57
+ homogeneous_neighbors: int,
58
+ num_workers: int = 4,
59
+ input_queue: queue.Queue = None,
60
+ output_queue: queue.Queue = None,
61
+ cross_slice_strategy: str = None,
62
+ n_slices: int = 1,
63
+ num_homogeneous_per_slice: int = None,
64
+ no_expression_fraction: bool = False
65
+ ):
66
+ """
67
+ Initialize computer pool
68
+
69
+ Args:
70
+ global_log_gmean: Global log geometric mean
71
+ global_expr_frac: Global expression fraction
72
+ homogeneous_neighbors: Number of homogeneous neighbors
73
+ num_workers: Number of compute workers
74
+ input_queue: Optional input queue (from reader)
75
+ output_queue: Optional output queue (to writer)
76
+ cross_slice_strategy: Strategy for 3D ('per_slice_pool' or 'hierarchical_pool')
77
+ n_slices: Number of slices for 3D data
78
+ num_homogeneous_per_slice: Neighbors per slice for 3D strategies
79
+ no_expression_fraction: Skip expression fraction filtering if True
80
+ """
81
+ self.num_workers = num_workers
82
+ self.homogeneous_neighbors = homogeneous_neighbors
83
+ self.cross_slice_strategy = cross_slice_strategy
84
+ self.n_slices = n_slices
85
+ self.num_homogeneous_per_slice = num_homogeneous_per_slice or homogeneous_neighbors
86
+ self.no_expression_fraction = no_expression_fraction
87
+
88
+ # Store global statistics as JAX arrays
89
+ self.global_log_gmean = jnp.array(global_log_gmean)
90
+ self.global_expr_frac = jnp.array(global_expr_frac)
91
+
92
+ # Queues for communication
93
+ self.compute_queue = input_queue if input_queue else queue.Queue(maxsize=num_workers * 2)
94
+ self.result_queue = output_queue if output_queue else queue.Queue(maxsize=num_workers * 2)
95
+
96
+ # Throughput tracking
97
+ self.throughput = ComponentThroughput()
98
+ self.throughput_lock = threading.Lock()
99
+
100
+ # Current processing context
101
+ self.neighbor_weights = None
102
+ self.cell_indices_sorted = None
103
+ self.batch_size = None
104
+ self.n_cells = None
105
+
106
+ # Exception handling
107
+ self.exception_queue = queue.Queue()
108
+ self.has_error = threading.Event()
109
+
110
+ # Start worker threads
111
+ self.workers = []
112
+ self.stop_workers = threading.Event()
113
+ self.active_cell_type = None # Track current cell type being processed
114
+ self._start_workers()
115
+
116
+ def _start_workers(self):
117
+ """Start compute worker threads"""
118
+ for i in range(self.num_workers):
119
+ worker = threading.Thread(
120
+ target=self._compute_worker,
121
+ args=(i,),
122
+ daemon=True
123
+ )
124
+ worker.start()
125
+ self.workers.append(worker)
126
+ logger.info(f"Started {self.num_workers} compute workers")
127
+
128
+ def _compute_worker(self, worker_id: int):
129
+ """Compute worker thread"""
130
+ logger.info(f"Compute worker {worker_id} started")
131
+
132
+ while not self.stop_workers.is_set():
133
+ try:
134
+ # Get data from reader
135
+ item = self.compute_queue.get(timeout=1)
136
+ if item is None:
137
+ break
138
+
139
+ # Unpack data from reader
140
+ batch_idx, rank_data, rank_indices, original_shape, batch_metadata = item
141
+
142
+ # Track timing
143
+ start_time = time.time()
144
+
145
+ # Extract batch parameters from metadata
146
+ # These should always be provided, but check for safety
147
+ if self.batch_size is None or self.neighbor_weights is None:
148
+ logger.error(f"Compute worker {worker_id}: batch context not set")
149
+ raise RuntimeError("Batch context must be set before processing")
150
+
151
+ batch_start = batch_metadata['batch_start']
152
+ batch_end = batch_metadata['batch_end']
153
+ actual_batch_size = batch_end - batch_start
154
+
155
+ # Verify shape
156
+ assert original_shape == (actual_batch_size, self.homogeneous_neighbors * self.n_slices), \
157
+ f"Unexpected rank data shape: {original_shape}, expected ({actual_batch_size}, {self.homogeneous_neighbors * self.n_slices})"
158
+
159
+ # Get batch-specific data
160
+ batch_weights = self.neighbor_weights[batch_start:batch_end]
161
+ batch_cell_indices = self.cell_indices_sorted[batch_start:batch_end]
162
+
163
+ # Convert to JAX for efficient computation
164
+ rank_data_jax = jnp.array(rank_data)
165
+ rank_indices_jax = jnp.array(rank_indices)
166
+
167
+ # Use JAX fancy indexing
168
+ batch_ranks = rank_data_jax[rank_indices_jax]
169
+
170
+ # Compute marker scores using appropriate strategy
171
+ if self.cross_slice_strategy == 'hierarchical_pool':
172
+ # Use hierarchical pooling (per-slice marker score average) for 3D data
173
+ marker_scores = compute_marker_scores_3d_hierarchical_pool_jax(
174
+ batch_ranks,
175
+ batch_weights,
176
+ actual_batch_size,
177
+ self.n_slices,
178
+ self.num_homogeneous_per_slice,
179
+ self.global_log_gmean,
180
+ self.global_expr_frac,
181
+ self.no_expression_fraction
182
+ )
183
+ else:
184
+ # Use standard computation (includes mean pooling via weights)
185
+ marker_scores = compute_marker_scores_jax(
186
+ batch_ranks,
187
+ batch_weights,
188
+ actual_batch_size,
189
+ self.homogeneous_neighbors * self.n_slices,
190
+ self.global_log_gmean,
191
+ self.global_expr_frac,
192
+ self.no_expression_fraction
193
+ )
194
+
195
+ # Convert back to numpy as float16 for memory efficiency
196
+ marker_scores_np = np.array(marker_scores, dtype=np.float16)
197
+
198
+ # Track throughput
199
+ elapsed = time.time() - start_time
200
+ with self.throughput_lock:
201
+ self.throughput.record_batch(elapsed)
202
+
203
+ # Put result directly to writer queue
204
+ self.result_queue.put((batch_idx, marker_scores_np, batch_cell_indices))
205
+ self.compute_queue.task_done()
206
+
207
+ except queue.Empty:
208
+ continue
209
+ except Exception as e:
210
+ error_trace = traceback.format_exc()
211
+ logger.error(f"Compute worker {worker_id} error: {e}\nTraceback:\n{error_trace}")
212
+ self.exception_queue.put((worker_id, e, error_trace))
213
+ self.has_error.set()
214
+ self.stop_workers.set() # Signal all workers to stop
215
+ break
216
+
217
+ logger.info(f"Compute worker {worker_id} stopped")
218
+
219
+ def set_batch_context(self, neighbor_weights: np.ndarray, cell_indices_sorted: np.ndarray,
220
+ batch_size: int, n_cells: int):
221
+ """Set context for processing batches of current cell type"""
222
+ self.neighbor_weights = jnp.asarray(neighbor_weights) # transfer to jax array only once
223
+ self.cell_indices_sorted = cell_indices_sorted
224
+ self.batch_size = batch_size
225
+ self.n_cells = n_cells
226
+
227
+ def reset_for_cell_type(self, cell_type: str):
228
+ """Reset for processing a new cell type"""
229
+ self.active_cell_type = cell_type
230
+ self.neighbor_weights = None
231
+ self.cell_indices_sorted = None
232
+ self.batch_size = None
233
+ self.n_cells = None
234
+ with self.throughput_lock:
235
+ self.throughput = ComponentThroughput()
236
+ logger.debug(f"Reset computer throughput for {cell_type}")
237
+
238
+ def get_queue_sizes(self):
239
+ """Get current queue sizes for progress tracking"""
240
+ return self.compute_queue.qsize(), self.result_queue.qsize()
241
+
242
+ def check_errors(self):
243
+ """Check if any worker encountered an error"""
244
+ if self.has_error.is_set():
245
+ try:
246
+ worker_id, exception, error_trace = self.exception_queue.get_nowait()
247
+ raise RuntimeError(f"Compute worker {worker_id} failed: {exception}\nOriginal traceback:\n{error_trace}") from exception
248
+ except queue.Empty:
249
+ raise RuntimeError("Compute worker failed with unknown error")
250
+
251
+ def close(self):
252
+ """Close compute pool"""
253
+ self.stop_workers.set()
254
+ for _ in range(self.num_workers):
255
+ self.compute_queue.put(None)
256
+ for worker in self.workers:
257
+ worker.join(timeout=5)
258
+ logger.info("Compute pool closed")
259
+
260
+
261
+
262
+
263
+
264
+ @dataclass
265
+ class PipelineStats:
266
+ """Statistics for pipeline monitoring"""
267
+ total_batches: int
268
+ completed_reads: int = 0
269
+ completed_computes: int = 0
270
+ completed_writes: int = 0
271
+ pending_compute: int = 0
272
+ pending_write: int = 0
273
+ pending_read: int = 0
274
+ start_time: float = 0
275
+
276
+ # Component throughput tracking
277
+ reader_throughput: ComponentThroughput = None
278
+ computer_throughput: ComponentThroughput = None
279
+ writer_throughput: ComponentThroughput = None
280
+
281
+ def __post_init__(self):
282
+ self.start_time = time.time()
283
+ self.reader_throughput = ComponentThroughput()
284
+ self.computer_throughput = ComponentThroughput()
285
+ self.writer_throughput = ComponentThroughput()
286
+
287
+ @property
288
+ def elapsed_time(self):
289
+ return time.time() - self.start_time
290
+
291
+ @property
292
+ def throughput(self):
293
+ if self.elapsed_time > 0:
294
+ return self.completed_writes / self.elapsed_time
295
+ return 0
296
+
297
+
298
+ class MarkerScoreMessageQueue:
299
+ """Streamlined pipeline for marker score calculation (reusable across cell types)"""
300
+
301
+ def __init__(
302
+ self,
303
+ reader: ParallelRankReader,
304
+ computer: ParallelMarkerScoreComputer,
305
+ writer: ParallelMarkerScoreWriter,
306
+ batch_size: int,
307
+ ):
308
+ """
309
+ Initialize pipeline with shared pools
310
+
311
+ Args:
312
+ reader: Rank reader pool
313
+ computer: Marker score computer pool
314
+ writer: Marker score writer pool
315
+ batch_size: Size of each batch
316
+ """
317
+ self.reader = reader
318
+ self.computer = computer
319
+ self.writer = writer
320
+ self.batch_size = batch_size
321
+
322
+ # Cell type specific parameters (set via reset_for_cell_type)
323
+ self.n_batches = None
324
+ self.n_cells = None
325
+ self.active_cell_type = None
326
+ self.stats = None
327
+
328
+ def reset_for_cell_type(self, cell_type: str, n_cells: int):
329
+ """Reset the queue for processing a new cell type
330
+
331
+ Args:
332
+ cell_type: Name of the cell type
333
+ n_cells: Total number of cells for this cell type
334
+ """
335
+ # Reset all components for new cell type
336
+ self.reader.reset_for_cell_type(cell_type)
337
+ self.computer.reset_for_cell_type(cell_type)
338
+ self.writer.reset_for_cell_type(cell_type)
339
+
340
+ self.active_cell_type = cell_type
341
+ self.n_cells = n_cells
342
+ self.n_batches = (n_cells + self.batch_size - 1) // self.batch_size
343
+ self.stats = PipelineStats(total_batches=self.n_batches)
344
+ logger.debug(f"Reset MarkerScoreMessageQueue for {cell_type}: {self.n_batches} batches, {n_cells} cells")
345
+
346
+ def _submit_batches(
347
+ self,
348
+ neighbor_indices: np.ndarray,
349
+ neighbor_weights: np.ndarray,
350
+ cell_indices_sorted: np.ndarray
351
+ ):
352
+ """Submit all batches to the reader"""
353
+ # Set context for computer
354
+ self.computer.set_batch_context(
355
+ neighbor_weights=neighbor_weights,
356
+ cell_indices_sorted=cell_indices_sorted,
357
+ batch_size=self.batch_size,
358
+ n_cells=self.n_cells
359
+ )
360
+
361
+ # Submit all batches with metadata
362
+ for batch_idx in range(self.n_batches):
363
+ batch_start = batch_idx * self.batch_size
364
+ batch_end = min(batch_start + self.batch_size, self.n_cells)
365
+ batch_neighbors = neighbor_indices[batch_start:batch_end]
366
+
367
+ # Include metadata for computer
368
+ batch_metadata = {
369
+ 'batch_start': batch_start,
370
+ 'batch_end': batch_end,
371
+ 'batch_idx': batch_idx
372
+ }
373
+
374
+ self.reader.submit_batch(batch_idx, batch_neighbors, batch_metadata)
375
+
376
+ def _update_stats(self):
377
+ """Update pipeline statistics"""
378
+ # Update completed counts
379
+ self.stats.completed_writes = self.writer.get_completed_count()
380
+
381
+ # Update queue sizes
382
+ read_pending, read_ready = self.reader.get_queue_sizes()
383
+ self.stats.pending_read = read_pending
384
+
385
+ compute_pending, compute_ready = self.computer.get_queue_sizes()
386
+ self.stats.pending_compute = compute_pending
387
+
388
+ self.stats.pending_write = self.writer.get_queue_size()
389
+
390
+ # Update throughput stats
391
+ self.stats.reader_throughput = self.reader.throughput
392
+ self.stats.computer_throughput = self.computer.throughput
393
+ self.stats.writer_throughput = self.writer.throughput
394
+
395
+ # Estimate completed reads and computes based on queue states
396
+ self.stats.completed_reads = self.n_batches - read_pending
397
+ self.stats.completed_computes = self.stats.completed_writes + self.stats.pending_write
398
+
399
+ def start(
400
+ self,
401
+ neighbor_indices: np.ndarray,
402
+ neighbor_weights: np.ndarray,
403
+ cell_indices_sorted: np.ndarray,
404
+ enable_profiling: bool = False
405
+ ):
406
+ """Run pipeline with rich progress display
407
+
408
+ Args:
409
+ neighbor_indices: Neighbor indices for each cell
410
+ neighbor_weights: Weights for each neighbor
411
+ cell_indices_sorted: Sorted cell indices
412
+ enable_profiling: Whether to enable profiling
413
+ """
414
+
415
+ # Ensure the queue has been reset for this cell type
416
+ if self.active_cell_type is None or self.stats is None:
417
+ raise RuntimeError("MarkerScoreMessageQueue must be reset before starting. Call reset_for_cell_type first.")
418
+
419
+ # Optional profiling
420
+ tracer = None
421
+ if enable_profiling:
422
+ import viztracer
423
+ tracer = viztracer.VizTracer(
424
+ output_file=f"marker_score_{self.active_cell_type}_{self.n_batches}.json",
425
+ max_stack_depth=10,
426
+ )
427
+ tracer.start()
428
+
429
+ console = Console()
430
+
431
+ # Define queue color mapping based on queue size
432
+ def get_queue_color(size: int) -> str:
433
+ """Get color based on queue size"""
434
+ if size == 0:
435
+ return "dim"
436
+ elif size < 5:
437
+ return "green"
438
+ elif size < 10:
439
+ return "yellow"
440
+ elif size < 20:
441
+ return "bright_yellow"
442
+ else:
443
+ return "red"
444
+
445
+ try:
446
+ # Create progress bars
447
+ with Progress(
448
+ SpinnerColumn(),
449
+ TextColumn("[bold blue]{task.description}"),
450
+ BarColumn(),
451
+ MofNCompleteColumn(),
452
+ TaskProgressColumn(),
453
+ TimeElapsedColumn(),
454
+ TimeRemainingColumn(),
455
+ console=console,
456
+ refresh_per_second=10
457
+ ) as progress:
458
+
459
+ # Add single task for pipeline
460
+ pipeline_task = progress.add_task(
461
+ f"[bold]{self.active_cell_type}[/bold]", total=self.n_batches
462
+ )
463
+
464
+ # Submit all batches to start the pipeline
465
+ self._submit_batches(neighbor_indices, neighbor_weights, cell_indices_sorted)
466
+
467
+ # Monitor progress
468
+ while self.stats.completed_writes < self.n_batches:
469
+ # Check for errors in any component
470
+ self.reader.check_errors()
471
+ self.computer.check_errors()
472
+ self.writer.check_errors()
473
+
474
+ self._update_stats()
475
+
476
+ # Update progress based on completed writes (final stage)
477
+ progress.update(pipeline_task, completed=self.stats.completed_writes)
478
+
479
+ # Color code queue sizes based on fullness
480
+ compute_color = get_queue_color(self.stats.pending_compute)
481
+ write_color = get_queue_color(self.stats.pending_write)
482
+
483
+ # Update description with queue information
484
+ progress.update(
485
+ pipeline_task,
486
+ description=(
487
+ f"[bold]{self.active_cell_type}[/bold] | "
488
+ f"Queues: [{compute_color}]R→C:{self.stats.pending_compute}[/{compute_color}] "
489
+ f"[{write_color}]C→W:{self.stats.pending_write}[/{write_color}]"
490
+ )
491
+ )
492
+
493
+ time.sleep(0.1)
494
+
495
+ # Final update
496
+ progress.update(pipeline_task, completed=self.n_batches)
497
+
498
+ except Exception as e:
499
+ logger.error(f"Pipeline failed for {self.active_cell_type}: {e}")
500
+ # Stop all workers to prevent hanging
501
+ self.stop()
502
+ raise
503
+
504
+ finally:
505
+ # Stop profiling if enabled
506
+ if tracer is not None:
507
+ tracer.stop()
508
+ tracer.save()
509
+ logger.info(f"Profiling saved to marker_score_{self.active_cell_type}_{self.n_batches}.json")
510
+
511
+ # No threads to wait for - processing happens in component worker threads
512
+
513
+ # Print summary with component throughputs
514
+ # Calculate effective pipeline throughput (limited by bottleneck)
515
+ min(
516
+ self.stats.reader_throughput.throughput * self.reader.num_workers if self.stats.reader_throughput.throughput > 0 else float('inf'),
517
+ self.stats.computer_throughput.throughput * self.computer.num_workers if self.stats.computer_throughput.throughput > 0 else float('inf'),
518
+ self.stats.writer_throughput.throughput * self.writer.num_workers if self.stats.writer_throughput.throughput > 0 else float('inf')
519
+ )
520
+
521
+ # Calculate cells per second for each component and pipeline
522
+ pipeline_cells_per_sec = self.stats.throughput * self.batch_size if self.stats.throughput > 0 else 0
523
+ reader_cells_per_sec = self.stats.reader_throughput.throughput * self.batch_size * self.reader.num_workers if self.stats.reader_throughput.throughput > 0 else 0
524
+ computer_cells_per_sec = self.stats.computer_throughput.throughput * self.batch_size * self.computer.num_workers if self.stats.computer_throughput.throughput > 0 else 0
525
+ writer_cells_per_sec = self.stats.writer_throughput.throughput * self.batch_size * self.writer.num_workers if self.stats.writer_throughput.throughput > 0 else 0
526
+
527
+ # Create summary table
528
+ summary_table = Table(title=f"[bold green]✓ Completed {self.active_cell_type}[/bold green]", box=None)
529
+ summary_table.add_column("Property", style="dim")
530
+ summary_table.add_column("Value", style="green", justify="right")
531
+
532
+ summary_table.add_row("Total Batches", str(self.n_batches))
533
+ summary_table.add_row("Time Elapsed", f"{self.stats.elapsed_time:.2f}s")
534
+ summary_table.add_row("Pipeline Throughput", f"{self.stats.throughput:.2f} batches/s ({pipeline_cells_per_sec:.0f} cells/s)")
535
+
536
+ perf_table = Table(title="[bold]Component Performance (per worker)[/bold]", show_header=True, header_style="bold blue")
537
+ perf_table.add_column("Component")
538
+ perf_table.add_column("Throughput", justify="right")
539
+ perf_table.add_column("Workers", justify="right")
540
+ perf_table.add_column("Total Throughput", justify="right", style="green")
541
+
542
+ perf_table.add_row("Reader", f"{self.stats.reader_throughput.throughput:.2f} b/s", str(self.reader.num_workers), f"{reader_cells_per_sec:.0f} c/s")
543
+ perf_table.add_row("Computer", f"{self.stats.computer_throughput.throughput:.2f} b/s", str(self.computer.num_workers), f"{computer_cells_per_sec:.0f} c/s")
544
+ perf_table.add_row("Writer", f"{self.stats.writer_throughput.throughput:.2f} b/s", str(self.writer.num_workers), f"{writer_cells_per_sec:.0f} c/s")
545
+
546
+ console.print(Panel(
547
+ Group(summary_table, perf_table),
548
+ title="Pipeline Summary",
549
+ border_style="green"
550
+ ))
551
+
552
+ # Final garbage collection
553
+ import gc
554
+ gc.collect()
555
+
556
+ logger.info(f"✓ Completed processing {self.active_cell_type}")
557
+
558
+ def stop(self):
559
+ """Stop all workers in the pipeline components"""
560
+ logger.info("Stopping MarkerScoreMessageQueue workers...")
561
+ self.reader.stop_workers.set()
562
+ self.computer.stop_workers.set()
563
+ self.writer.stop_workers.set()
564
+ logger.info("MarkerScoreMessageQueue workers stopped")
565
+
566
+
567
+
568
+ @partial(jit, static_argnums=(2, 3, 6))
569
+ def compute_marker_scores_jax(
570
+ log_ranks: jnp.ndarray, # (B*N) × G matrix
571
+ weights: jnp.ndarray, # B × N weight matrix
572
+ batch_size: int,
573
+ num_neighbors: int,
574
+ global_log_gmean: jnp.ndarray, # G-dimensional vector
575
+ global_expr_frac: jnp.ndarray, # G-dimensional vector
576
+ no_expression_fraction: bool = False # Skip expression fraction filtering if True
577
+ ) -> jnp.ndarray:
578
+ """
579
+ JAX-accelerated marker score computation
580
+
581
+ Returns:
582
+ B × G marker scores
583
+ """
584
+ n_genes = log_ranks.shape[1]
585
+
586
+ # Reshape to batch format
587
+ log_ranks_3d = log_ranks.reshape(batch_size, num_neighbors, n_genes)
588
+
589
+ # Compute weighted geometric mean in log space
590
+ weights = weights / weights.sum(axis=1, keepdims=True) # Normalize weights
591
+ weighted_log_mean = jnp.einsum('bn,bng->bg', weights, log_ranks_3d)
592
+
593
+ # Calculate marker score
594
+ marker_score = jnp.exp(weighted_log_mean - global_log_gmean)
595
+ marker_score = jnp.where(marker_score < 1.0, 0.0, marker_score)
596
+
597
+ # Apply expression fraction filter (only if not disabled)
598
+ if not no_expression_fraction:
599
+ # Compute expression fraction (mean of is_expressed across neighbors)
600
+ # Treat min log rank as non-expressed
601
+ is_expressed = (log_ranks_3d != log_ranks_3d.min(axis=-1, keepdims=True))
602
+
603
+ # Create mask for valid neighbors (where weights > 0)
604
+ valid_mask = weights > 0 # Shape: (batch_size, num_neighbors)
605
+
606
+ # Apply mask and compute mean only for valid neighbors
607
+ is_expressed_masked = jnp.where(valid_mask[:, :, None], is_expressed, 0)
608
+ valid_counts = valid_mask.sum(axis=1, keepdims=True) # Count of valid neighbors per cell
609
+
610
+ # Compute mean only over valid neighbors (avoid division by zero)
611
+ expr_frac = jnp.where(
612
+ valid_counts > 0,
613
+ is_expressed_masked.astype(jnp.float16).sum(axis=1) / valid_counts,
614
+ 0.0
615
+ )
616
+
617
+ frac_mask = expr_frac > global_expr_frac
618
+ marker_score = jnp.where(frac_mask, marker_score, 0.0)
619
+
620
+ marker_score = jnp.exp(marker_score ** 1.5) - 1.0
621
+
622
+ # Return as float16 for memory efficiency
623
+ return marker_score.astype(jnp.float16)
624
+
625
+
626
+ @partial(jit, static_argnums=(2, 3, 4, 7))
627
+ def compute_marker_scores_3d_hierarchical_pool_jax(
628
+ log_ranks: jnp.ndarray, # (B*N) × G matrix where N = n_slices * num_homogeneous_per_slice
629
+ weights: jnp.ndarray, # B × N weight matrix
630
+ batch_size: int,
631
+ n_slices: int,
632
+ num_homogeneous_per_slice: int,
633
+ global_log_gmean: jnp.ndarray, # G-dimensional vector
634
+ global_expr_frac: jnp.ndarray, # G-dimensional vector
635
+ no_expression_fraction: bool = False # Skip expression fraction filtering if True
636
+ ) -> jnp.ndarray:
637
+ """
638
+ JAX-accelerated marker score computation with hierarchical pooling for 3D spatial data.
639
+ Computes marker scores independently for each slice and takes the average.
640
+
641
+ Args:
642
+ log_ranks: Flattened log ranks (batch_size * total_neighbors, n_genes)
643
+ weights: Flattened weights (batch_size, total_neighbors)
644
+ batch_size: Number of cells in batch
645
+ n_slices: Number of slices (1 + 2 * n_adjacent_slices)
646
+ num_homogeneous_per_slice: Number of homogeneous neighbors per slice
647
+ global_log_gmean: Global log geometric mean
648
+ global_expr_frac: Global expression fraction
649
+
650
+ Returns:
651
+ (batch_size, n_genes) marker scores using average cross slices
652
+ """
653
+ n_genes = log_ranks.shape[1]
654
+ n_slices * num_homogeneous_per_slice
655
+
656
+ # Reshape to separate slices: (batch_size, n_slices, num_homogeneous_per_slice, n_genes)
657
+ log_ranks_4d = log_ranks.reshape(batch_size, n_slices, num_homogeneous_per_slice, n_genes)
658
+
659
+ # Reshape weights: (batch_size, n_slices, num_homogeneous_per_slice)
660
+ weights_3d = weights.reshape(batch_size, n_slices, num_homogeneous_per_slice)
661
+
662
+ # Normalize weights within each slice (sum to 1 along num_homogeneous_per_slice axis)
663
+ weights_sum = weights_3d.sum(axis=2, keepdims=True) # Shape: (batch_size, n_slices, 1)
664
+ weights_normalized = weights_3d / jnp.where(weights_sum > 0, weights_sum, 1.0)
665
+
666
+ # Compute weighted geometric mean in log space for each slice
667
+ # Result: (batch_size, n_slices, n_genes)
668
+ weighted_log_mean = jnp.einsum('bsn,bsng->bsg', weights_normalized, log_ranks_4d)
669
+
670
+ # Calculate marker score for each slice
671
+ marker_score_per_slice = jnp.exp(weighted_log_mean - global_log_gmean[None, None, :])
672
+ marker_score_per_slice = jnp.where(marker_score_per_slice < 1.0, 0.0, marker_score_per_slice)
673
+
674
+ # Apply expression fraction filter for each slice (only if not disabled)
675
+ if not no_expression_fraction:
676
+ # Compute expression fraction for each slice
677
+ # Treat min log rank as non-expressed
678
+ min_log_rank = log_ranks_4d.min(axis=-1, keepdims=True)
679
+ is_expressed = (log_ranks_4d != min_log_rank)
680
+
681
+ # Create mask for valid neighbors within each slice (where weights > 0)
682
+ valid_mask = weights_3d > 0 # Shape: (batch_size, n_slices, num_homogeneous_per_slice)
683
+
684
+ # Apply mask and compute mean only for valid neighbors within each slice
685
+ is_expressed_masked = jnp.where(valid_mask[:, :, :, None], is_expressed, 0)
686
+ valid_counts = valid_mask.sum(axis=2, keepdims=True) # Count of valid neighbors per slice
687
+
688
+ # Compute mean only over valid neighbors (avoid division by zero)
689
+ # Result: (batch_size, n_slices, n_genes)
690
+ # valid_counts has shape (batch_size, n_slices, 1), need to broadcast properly
691
+ expr_frac = jnp.where(
692
+ valid_counts > 0,
693
+ is_expressed_masked.astype(jnp.float16).sum(axis=2) / valid_counts,
694
+ 0.0
695
+ )
696
+
697
+ frac_mask = expr_frac > global_expr_frac[None, None, :]
698
+ marker_score_per_slice = jnp.where(frac_mask, marker_score_per_slice, 0.0)
699
+
700
+ marker_score_per_slice = jnp.exp(marker_score_per_slice ** 1.5) - 1.0
701
+
702
+ # Calculate median across slices instead of max pooling to reduce noise
703
+ # Result: (batch_size, n_genes)
704
+
705
+ # Get the invalid slice mask before calculating median.
706
+ # These slices are invalid could be:
707
+ # 1. no neighbors from this slice for the outermost slices
708
+ # 2. the adjacent slices that do not have enough high quality cells which has been filtered out
709
+ invalid_slice_mask = (weights_sum == 0).squeeze(-1) # Shape: (batch_size, n_slices)
710
+
711
+ # Set invalid slice scores to NaN, then calculate the median
712
+ # NaN values will be ignored by jnp.nanmedian
713
+ marker_score_per_slice = jnp.where(
714
+ invalid_slice_mask[:, :, None], # Broadcast to (batch_size, n_slices, n_genes)
715
+ jnp.nan,
716
+ marker_score_per_slice
717
+ )
718
+
719
+ # Calculate median across slices (axis=1), ignoring NaN values
720
+ marker_score = jnp.nanmean(marker_score_per_slice, axis=1)
721
+ # marker_score = jnp.nanmedian(marker_score_per_slice, axis=1)
722
+
723
+ # Handle cases where all slices are invalid (all NaN) - set to 0
724
+ marker_score = jnp.where(jnp.isnan(marker_score), 0.0, marker_score)
725
+
726
+ # Return as float16 for memory efficiency
727
+ return marker_score.astype(jnp.float16)
728
+
729
+
730
+ class MarkerScoreCalculator:
731
+ """Main class for calculating marker scores"""
732
+
733
+ def __init__(self, config):
734
+ """
735
+ Initialize with configuration
736
+
737
+ Args:
738
+ config: LatentToGeneConfig object
739
+ """
740
+ self.config = config
741
+ self.connectivity_builder = ConnectivityMatrixBuilder(config)
742
+
743
+ def load_global_stats(self, mean_frac_path: str) -> tuple[np.ndarray, np.ndarray]:
744
+ """Load pre-calculated global geometric mean and expression fraction from parquet"""
745
+
746
+ logger.info("Loading global statistics from parquet...")
747
+ parquet_path = Path(mean_frac_path)
748
+
749
+ if not parquet_path.exists():
750
+ raise FileNotFoundError(f"Global stats file not found: {parquet_path}")
751
+
752
+ # Load the dataframe
753
+ mean_frac_df = pd.read_parquet(parquet_path)
754
+
755
+ # Extract global log geometric mean and expression fraction
756
+ global_log_gmean = mean_frac_df['G_Mean'].values.astype(np.float32)
757
+ global_expr_frac = mean_frac_df['frac'].values.astype(np.float32)
758
+
759
+ logger.info(f"Loaded global stats for {len(global_log_gmean)} genes")
760
+
761
+ return global_log_gmean, global_expr_frac
762
+
763
+ def _display_input_summary(self, adata, cell_types, n_cells, n_genes):
764
+ """Display summary of input data and cell types"""
765
+ table = Table(title="[bold cyan]Marker Score Input Summary[/bold cyan]", show_header=True, header_style="bold magenta")
766
+ table.add_column("Property", style="dim")
767
+ table.add_column("Value", style="green")
768
+
769
+ table.add_row("Total Cells", str(n_cells))
770
+ table.add_row("Total Genes", str(n_genes))
771
+ table.add_row("Cell Types", str(len(cell_types)))
772
+ table.add_row("Dataset Type", str(self.config.dataset_type.value))
773
+ table.add_row("Annotation Key", str(self.config.annotation))
774
+
775
+ self.console.print(table)
776
+
777
+ # Display cell type breakdown
778
+ ct_table = Table(title="[bold cyan]Cell Type Breakdown[/bold cyan]", show_header=True, header_style="bold blue")
779
+ ct_table.add_column("Cell Type", style="dim")
780
+ ct_table.add_column("Count", justify="right")
781
+
782
+ ct_counts = adata.obs[self.config.annotation].value_counts()
783
+ for ct in cell_types:
784
+ count = ct_counts.get(ct, 0)
785
+ style = "green" if count >= self.config.min_cells_per_type else "red"
786
+ ct_table.add_row(ct, f"[{style}]{count}[/{style}]")
787
+
788
+ self.console.print(ct_table)
789
+
790
+ def _load_input_data(
791
+ self,
792
+ adata_path: str,
793
+ rank_memmap_path: str,
794
+ mean_frac_path: str
795
+ ) -> tuple[ad.AnnData, MemMapDense, np.ndarray, np.ndarray, int, int, np.ndarray]:
796
+ """Load input data: AnnData, rank memory map, global statistics, and high quality mask
797
+
798
+ Returns:
799
+ Tuple of (adata, rank_memmap, global_log_gmean, global_expr_frac, n_cells, n_genes, high_quality_mask)
800
+ """
801
+ # Load concatenated AnnData
802
+ logger.info(f"Loading concatenated AnnData from {adata_path}")
803
+ if not Path(adata_path).exists():
804
+ raise FileNotFoundError(f"Concatenated AnnData not found: {adata_path}")
805
+ adata = sc.read_h5ad(adata_path)
806
+
807
+ # Load pre-calculated global statistics
808
+ global_log_gmean, global_expr_frac = self.load_global_stats(mean_frac_path)
809
+
810
+ # Open rank memory map and get dimensions
811
+ rank_memmap_path = Path(rank_memmap_path)
812
+ meta_path = rank_memmap_path.with_suffix('.meta.json')
813
+ with open(meta_path) as f:
814
+ meta = json.load(f)
815
+
816
+ rank_memmap = MemMapDense(
817
+ path=rank_memmap_path,
818
+ shape=tuple(meta['shape']),
819
+ dtype=np.dtype(meta['dtype']),
820
+ mode='r',
821
+ tmp_dir=self.config.memmap_tmp_dir
822
+ )
823
+
824
+ logger.info(f"Opened rank memory map from {rank_memmap_path}")
825
+ n_cells = adata.n_obs
826
+ n_cells_rank = rank_memmap.shape[0]
827
+ n_genes = rank_memmap.shape[1]
828
+
829
+ logger.info(f"AnnData dimensions: {n_cells} cells × {adata.n_vars} genes")
830
+ logger.info(f"Rank MemMap dimensions: {n_cells_rank} cells × {n_genes} genes")
831
+
832
+ # Cells should match exactly since filtering is done before rank memmap creation
833
+ assert n_cells == n_cells_rank, \
834
+ f"Cell count mismatch: AnnData has {n_cells} cells, Rank MemMap has {n_cells_rank} cells. " \
835
+ f"This indicates the filtering was not applied consistently during rank calculation."
836
+
837
+ # Load high quality mask based on configuration
838
+ if self.config.high_quality_neighbor_filter:
839
+ if 'High_quality' not in adata.obs.columns:
840
+ raise ValueError("High_quality column not found in AnnData obs. Please ensure QC was applied during find_latent_representation step.")
841
+ high_quality_mask = adata.obs['High_quality'].values.astype(bool)
842
+ logger.info(f"Loaded high quality mask: {high_quality_mask.sum()}/{len(high_quality_mask)} cells marked as high quality")
843
+ else:
844
+ # Create all-True mask when high quality filtering is disabled
845
+ high_quality_mask = np.ones(n_cells, dtype=bool)
846
+ logger.info("High quality filtering disabled - using all cells")
847
+
848
+ return adata, rank_memmap, global_log_gmean, global_expr_frac, n_cells, n_genes, high_quality_mask
849
+
850
+ def _prepare_embeddings(
851
+ self,
852
+ adata: ad.AnnData
853
+ ) -> tuple[np.ndarray | None, np.ndarray, np.ndarray, np.ndarray | None]:
854
+ """Prepare and normalize embeddings based on dataset type
855
+
856
+ Returns:
857
+ Tuple of (coords, emb_niche, emb_indv, slice_ids)
858
+ """
859
+ logger.info("Loading shared data structures...")
860
+
861
+ coords = None
862
+ emb_niche = None
863
+ slice_ids = None
864
+
865
+ if self.config.dataset_type in ['spatial2D', 'spatial3D']:
866
+ # Load spatial coordinates for spatial datasets
867
+ coords = adata.obsm[self.config.spatial_key]
868
+
869
+ # Load slice IDs if provided (for both spatial2D and spatial3D)
870
+ assert 'slice_id' in adata.obs.columns
871
+ slice_ids = adata.obs['slice_id'].values.astype(np.int32)
872
+
873
+ # Try to load niche embeddings if they exist
874
+ if self.config.latent_representation_niche in adata.obsm:
875
+ emb_niche = adata.obsm[self.config.latent_representation_niche]
876
+
877
+ # Load cell embeddings for all dataset types
878
+ emb_indv = adata.obsm[self.config.latent_representation_cell].astype(np.float16)
879
+
880
+ # --- ELEGANT FIX START ---
881
+ # If emb_niche is missing (scRNA-seq or spatial without niche),
882
+ # create a dummy (N, 1) array of ones.
883
+ # Normalizing a vector of 1s results in 1.0, so cosine similarity will be 1.0 everywhere.
884
+ if emb_niche is None:
885
+ logger.info("No niche embeddings found. Using dummy embeddings (all ones).")
886
+ emb_niche = np.ones((emb_indv.shape[0], 1), dtype=np.float32)
887
+ # --- ELEGANT FIX END ---
888
+
889
+ # Normalize embeddings
890
+ logger.info("Normalizing embeddings...")
891
+
892
+ # L2 normalize niche embeddings (always exists now)
893
+ emb_niche_norm = np.linalg.norm(emb_niche, axis=1, keepdims=True)
894
+ emb_niche = emb_niche / (emb_niche_norm + 1e-8)
895
+
896
+ # L2 normalize individual embeddings
897
+ emb_indv_norm = np.linalg.norm(emb_indv, axis=1, keepdims=True)
898
+ emb_indv = emb_indv / (emb_indv_norm + 1e-8)
899
+
900
+ return coords, emb_niche, emb_indv, slice_ids
901
+
902
+ def _get_cell_types(self, adata: ad.AnnData) -> np.ndarray:
903
+ """Get cell types from annotation key
904
+
905
+ Returns:
906
+ Array of unique cell types
907
+ """
908
+ annotation_key = self.config.annotation
909
+
910
+ if annotation_key is not None:
911
+ assert annotation_key in adata.obs.columns, f"Annotation key '{annotation_key}' not found in adata.obs"
912
+ # Get unique cell types, excluding NaN values
913
+ cell_types = adata.obs[annotation_key].dropna().unique()
914
+
915
+ # Check if there are any NaN values and handle them
916
+ nan_count = adata.obs[annotation_key].isna().sum()
917
+ if nan_count > 0:
918
+ logger.warning(f"Found {nan_count} cells with NaN annotation in '{annotation_key}', these will be skipped")
919
+ else:
920
+ logger.warning(f"Annotation {annotation_key} not found, processing all cells as one type")
921
+ cell_types = ["all"]
922
+ adata.obs[annotation_key] = "all"
923
+
924
+ logger.info(f"Processing {len(cell_types)} cell types")
925
+ return cell_types
926
+
927
+ def _initialize_pipeline(
928
+ self,
929
+ rank_memmap: MemMapDense,
930
+ output_memmap: MemMapDense,
931
+ global_log_gmean: np.ndarray,
932
+ global_expr_frac: np.ndarray
933
+ ):
934
+ """Initialize the processing pipeline with shared pools and queues"""
935
+ logger.info("Initializing shared processing pools with direct queue connections...")
936
+
937
+ # Create shared queues to connect components with configured sizes
938
+ reader_to_computer_queue = queue.Queue(maxsize=self.config.mkscore_compute_workers * self.config.compute_input_queue_size)
939
+ computer_to_writer_queue = queue.Queue(maxsize=self.config.writer_queue_size)
940
+
941
+ self.reader = ParallelRankReader(
942
+ rank_memmap,
943
+ num_workers=self.config.rank_read_workers,
944
+ output_queue=reader_to_computer_queue # Direct connection to computer
945
+ )
946
+
947
+ # Determine 3D strategy parameters
948
+ cross_slice_strategy = None
949
+ n_slices = 1
950
+ num_homogeneous_per_slice = self.config.homogeneous_neighbors
951
+
952
+ if (self.config.dataset_type == DatasetType.SPATIAL_3D and
953
+ self.config.cross_slice_marker_score_strategy in [
954
+ MarkerScoreCrossSliceStrategy.PER_SLICE_POOL,
955
+ MarkerScoreCrossSliceStrategy.HIERARCHICAL_POOL
956
+ ]):
957
+
958
+ cross_slice_strategy = self.config.cross_slice_marker_score_strategy
959
+ n_slices = 1 + 2 * self.config.n_adjacent_slices
960
+ # For pooling strategies, num_homogeneous is per slice
961
+ num_homogeneous_per_slice = self.config.homogeneous_neighbors
962
+
963
+ self.computer = ParallelMarkerScoreComputer(
964
+ global_log_gmean,
965
+ global_expr_frac,
966
+ self.config.homogeneous_neighbors,
967
+ num_workers=self.config.mkscore_compute_workers,
968
+ input_queue=reader_to_computer_queue, # Input from reader
969
+ output_queue=computer_to_writer_queue, # Output to writer
970
+ cross_slice_strategy=cross_slice_strategy,
971
+ n_slices=n_slices,
972
+ num_homogeneous_per_slice=num_homogeneous_per_slice,
973
+ no_expression_fraction=self.config.no_expression_fraction
974
+ )
975
+
976
+ self.writer = ParallelMarkerScoreWriter(
977
+ output_memmap,
978
+ num_workers=self.config.mkscore_write_workers,
979
+ input_queue=computer_to_writer_queue # Input from computer
980
+ )
981
+
982
+ logger.info(f"Processing pools initialized: {self.config.rank_read_workers} readers, "
983
+ f"{self.config.mkscore_compute_workers} computers, "
984
+ f"{self.config.mkscore_write_workers} writers")
985
+
986
+ self.marker_score_queue = MarkerScoreMessageQueue(
987
+ reader=self.reader,
988
+ computer=self.computer,
989
+ writer=self.writer,
990
+ batch_size=self.config.mkscore_batch_size
991
+ )
992
+
993
+ def _find_homogeneous_spots(
994
+ self,
995
+ adata: ad.AnnData,
996
+ cell_type: str,
997
+ annotation_key: str,
998
+ coords: np.ndarray | None,
999
+ emb_niche: np.ndarray,
1000
+ emb_indv: np.ndarray,
1001
+ slice_ids: np.ndarray | None,
1002
+ rank_shape: tuple[int, int],
1003
+ high_quality_mask: np.ndarray
1004
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray | None, np.ndarray, int] | None:
1005
+ """
1006
+ Prepare batch data for a cell type
1007
+
1008
+ Returns:
1009
+ Tuple of (neighbor_indices, cell_sims, niche_sims, cell_indices_sorted, n_cells)
1010
+ or None if cell type should be skipped
1011
+ """
1012
+ # Get cells of this type, excluding NaN values
1013
+ cell_mask = (adata.obs[annotation_key] == cell_type) & (adata.obs[annotation_key].notna())
1014
+ cell_indices = np.where(cell_mask)[0]
1015
+ n_cells = len(cell_indices)
1016
+
1017
+ # Check minimum cells
1018
+ min_cells = self.config.min_cells_per_type
1019
+ if n_cells < min_cells:
1020
+ logger.warning(f"Skipping {cell_type}: only {n_cells} cells (min: {min_cells})")
1021
+
1022
+ logger.info(f"Processing {cell_type}: {n_cells} cells")
1023
+
1024
+ # Build connectivity matrix
1025
+ logger.info("Building connectivity matrix...")
1026
+ neighbor_indices, cell_sims, niche_sims = self.connectivity_builder.build_connectivity_matrix(
1027
+ coords=coords,
1028
+ emb_niche=emb_niche,
1029
+ emb_indv=emb_indv,
1030
+ cell_mask=cell_mask,
1031
+ high_quality_mask=high_quality_mask,
1032
+ slice_ids=slice_ids,
1033
+ k_central=self.config.spatial_neighbors,
1034
+ k_adjacent=self.config.adjacent_slice_spatial_neighbors,
1035
+ n_adjacent_slices=self.config.n_adjacent_slices
1036
+ )
1037
+ gc.collect()
1038
+
1039
+ # Validate neighbor indices
1040
+ max_valid_idx = rank_shape[0] - 1
1041
+ assert neighbor_indices.max() <= max_valid_idx, \
1042
+ f"Neighbor indices exceed bounds (max: {neighbor_indices.max()}, limit: {max_valid_idx})"
1043
+
1044
+ # Optimize row order using JAX implementation
1045
+ logger.info("Optimizing row order for cache efficiency...")
1046
+ row_order = optimize_row_order_jax(
1047
+ neighbor_indices=neighbor_indices[:,:self.config.homogeneous_neighbors],
1048
+ cell_indices=cell_indices,
1049
+ neighbor_weights=cell_sims[:,:self.config.homogeneous_neighbors],
1050
+ )
1051
+
1052
+ neighbor_indices = neighbor_indices[row_order]
1053
+ cell_sims = cell_sims[row_order]
1054
+ if niche_sims is not None:
1055
+ niche_sims = niche_sims[row_order]
1056
+ cell_indices_sorted = cell_indices[row_order]
1057
+
1058
+ # Save homogeneous neighbor data to adata
1059
+ has_real_niche_embedding = self.config.latent_representation_niche is not None
1060
+ self._save_homogeneous_data_to_adata(
1061
+ adata=adata,
1062
+ neighbor_indices=neighbor_indices,
1063
+ cell_sims=cell_sims,
1064
+ niche_sims=niche_sims,
1065
+ cell_indices_sorted=cell_indices_sorted,
1066
+ has_real_niche_embedding=has_real_niche_embedding
1067
+ )
1068
+
1069
+ # warning for cells not find homo neighbors
1070
+ homo_neighbor_count = np.count_nonzero(cell_sims > 0, axis=1)
1071
+ zero_homo_neighbor_mask = (homo_neighbor_count <= 5)
1072
+ if np.any(zero_homo_neighbor_mask):
1073
+ logger.warning(f"Cell type {cell_type}: {zero_homo_neighbor_mask.sum()} cells can't find enough homogeneous neighbors")
1074
+
1075
+ return neighbor_indices, cell_sims, niche_sims, cell_indices_sorted, n_cells
1076
+
1077
+ def _save_homogeneous_data_to_adata(
1078
+ self,
1079
+ adata: ad.AnnData,
1080
+ neighbor_indices: np.ndarray,
1081
+ cell_sims: np.ndarray,
1082
+ niche_sims: np.ndarray | None,
1083
+ cell_indices_sorted: np.ndarray,
1084
+ has_real_niche_embedding: bool
1085
+ ):
1086
+ """
1087
+ Save homogeneous neighbor data to adata obsm and obs.
1088
+
1089
+ Args:
1090
+ adata: AnnData object to save data to
1091
+ neighbor_indices: (n_cells, k) array of neighbor indices
1092
+ cell_sims: (n_cells, k) array of cell similarity scores
1093
+ niche_sims: (n_cells, k) array of niche similarity scores or None
1094
+ cell_indices_sorted: (n_cells,) array of cell indices in sorted order
1095
+ has_real_niche_embedding: Whether real niche embedding was provided
1096
+ """
1097
+ # Initialize obsm matrices if they don't exist
1098
+ if 'gsMap_homo_indices' not in adata.obsm.keys() or (adata.obsm['gsMap_homo_indices'].shape[1] != neighbor_indices.shape[1]):
1099
+ adata.obsm['gsMap_homo_indices'] = np.zeros((adata.n_obs, neighbor_indices.shape[1]), dtype=neighbor_indices.dtype)
1100
+ if 'gsMap_homo_cell_sims' not in adata.obsm.keys() or (adata.obsm['gsMap_homo_cell_sims'].shape[1] != neighbor_indices.shape[1]):
1101
+ adata.obsm['gsMap_homo_cell_sims'] = np.zeros((adata.n_obs, cell_sims.shape[1]), dtype=cell_sims.dtype)
1102
+
1103
+ # Store the neighbor indices and cell_sims for this cell type
1104
+ adata.obsm['gsMap_homo_indices'][cell_indices_sorted] = neighbor_indices
1105
+ adata.obsm['gsMap_homo_cell_sims'][cell_indices_sorted] = cell_sims
1106
+
1107
+ # Only store niche_sims if niche embedding was provided (not dummy)
1108
+ if has_real_niche_embedding and niche_sims is not None:
1109
+ if 'gsMap_homo_niche_sims' not in adata.obsm.keys() or (adata.obsm['gsMap_homo_niche_sims'].shape[1] != neighbor_indices.shape[1]):
1110
+ adata.obsm['gsMap_homo_niche_sims'] = np.zeros((adata.n_obs, niche_sims.shape[1]), dtype=niche_sims.dtype)
1111
+ adata.obsm['gsMap_homo_niche_sims'][cell_indices_sorted] = niche_sims
1112
+
1113
+ # Initialize obs columns if they don't exist
1114
+ if 'gsMap_homo_neighbor_count' not in adata.obs.columns:
1115
+ adata.obs['gsMap_homo_neighbor_count'] = 0
1116
+ if 'gsMap_homo_cell_sims_median' not in adata.obs.columns:
1117
+ adata.obs['gsMap_homo_cell_sims_median'] = np.nan
1118
+ if has_real_niche_embedding and 'gsMap_homo_niche_sims_median' not in adata.obs.columns:
1119
+ adata.obs['gsMap_homo_niche_sims_median'] = np.nan
1120
+
1121
+ # Calculate statistics for each cell (only consider valid neighbors where cell_sims > 0)
1122
+ valid_mask = cell_sims > 0 # (n_cells, k) boolean mask
1123
+
1124
+ # Count of valid homogeneous neighbors
1125
+ homo_neighbor_count = valid_mask.sum(axis=1)
1126
+ adata.obs.loc[adata.obs.index[cell_indices_sorted], 'gsMap_homo_neighbor_count'] = homo_neighbor_count
1127
+
1128
+ # Median of cell similarity (only for valid neighbors)
1129
+ cell_sims_masked = np.where(valid_mask, cell_sims, np.nan)
1130
+ cell_sims_median = np.nanmedian(cell_sims_masked, axis=1)
1131
+ adata.obs.loc[adata.obs.index[cell_indices_sorted], 'gsMap_homo_cell_sims_median'] = cell_sims_median
1132
+
1133
+ # Median of niche similarity (only for valid neighbors, only if niche embedding provided)
1134
+ if has_real_niche_embedding and niche_sims is not None:
1135
+ niche_sims_masked = np.where(valid_mask, niche_sims, np.nan)
1136
+ niche_sims_median = np.nanmedian(niche_sims_masked, axis=1)
1137
+ adata.obs.loc[adata.obs.index[cell_indices_sorted], 'gsMap_homo_niche_sims_median'] = niche_sims_median
1138
+
1139
+ def _calculate_marker_scores_by_cell_type(
1140
+ self,
1141
+ adata: ad.AnnData,
1142
+ cell_type: str,
1143
+ coords: np.ndarray | None,
1144
+ emb_niche: np.ndarray,
1145
+ emb_indv: np.ndarray,
1146
+ annotation_key: str,
1147
+ slice_ids: np.ndarray | None = None,
1148
+ high_quality_mask: np.ndarray = None
1149
+ ):
1150
+ """Process a single cell type with shared pools"""
1151
+
1152
+ # Find homogeneous spots
1153
+ neighbor_indices, cell_sims, niche_sims, cell_indices_sorted, n_cells = self._find_homogeneous_spots(
1154
+ adata, cell_type, annotation_key, coords, emb_niche,
1155
+ emb_indv, slice_ids, self.reader.shape, high_quality_mask
1156
+ )
1157
+
1158
+ # Reset the message queue for this cell type
1159
+ self.marker_score_queue.reset_for_cell_type(cell_type, n_cells)
1160
+
1161
+ # Run the marker_score_queue to compute marker score
1162
+ self.marker_score_queue.start(
1163
+ neighbor_indices=neighbor_indices,
1164
+ neighbor_weights=cell_sims,
1165
+ cell_indices_sorted=cell_indices_sorted,
1166
+ enable_profiling=self.config.enable_profiling
1167
+ )
1168
+
1169
+ def calculate_marker_scores(
1170
+ self,
1171
+ adata_path: str,
1172
+ rank_memmap_path: str,
1173
+ mean_frac_path: str,
1174
+ output_path: str | Path | None = None
1175
+ ) -> str | Path:
1176
+ """
1177
+ Main execution function for marker score calculation
1178
+
1179
+ Args:
1180
+ adata_path: Path to concatenated latent adata
1181
+ rank_memmap_path: Path to rank memory map
1182
+ mean_frac_path: Path to mean expression fraction parquet
1183
+ output_path: Optional output path for marker scores
1184
+
1185
+ Returns:
1186
+ Path to output marker score memory map file
1187
+ """
1188
+ logger.info("Starting marker score calculation...")
1189
+ self.console = Console()
1190
+
1191
+ self.console.print(Panel.fit(
1192
+ "[bold cyan]Marker Score Calculation Stage[/bold cyan]",
1193
+ subtitle="gsMap Stage 2",
1194
+ border_style="cyan"
1195
+ ))
1196
+
1197
+ # Use config path if not specified
1198
+ if output_path is None:
1199
+ output_path = Path(self.config.marker_scores_memmap_path)
1200
+ else:
1201
+ output_path = Path(output_path)
1202
+
1203
+ # Load all input data
1204
+ adata, rank_memmap, global_log_gmean, global_expr_frac, n_cells, n_genes, high_quality_mask = self._load_input_data(
1205
+ adata_path, rank_memmap_path, mean_frac_path
1206
+ )
1207
+
1208
+ # Initialize output memory map
1209
+ output_memmap = MemMapDense(
1210
+ output_path,
1211
+ shape=(n_cells, n_genes),
1212
+ dtype=np.float16, # Use float16 to save memory
1213
+ mode='w',
1214
+ num_write_workers=self.config.mkscore_write_workers,
1215
+ tmp_dir=self.config.memmap_tmp_dir
1216
+ )
1217
+
1218
+ # Get cell types to process
1219
+ cell_types = self._get_cell_types(adata)
1220
+ annotation_key = self.config.annotation
1221
+
1222
+ # Display summary
1223
+ self._display_input_summary(adata, cell_types, n_cells, n_genes)
1224
+
1225
+ # Prepare embeddings
1226
+ coords, emb_niche, emb_indv, slice_ids = self._prepare_embeddings(adata)
1227
+
1228
+ # Initialize processing pipeline
1229
+ self._initialize_pipeline(rank_memmap, output_memmap, global_log_gmean, global_expr_frac)
1230
+
1231
+ # Process each cell type
1232
+ for cell_type in cell_types:
1233
+ self._calculate_marker_scores_by_cell_type(
1234
+ adata,
1235
+ cell_type,
1236
+ coords,
1237
+ emb_niche,
1238
+ emb_indv,
1239
+ annotation_key,
1240
+ slice_ids,
1241
+ high_quality_mask
1242
+ )
1243
+
1244
+ # Save updated AnnData with neighbor matrices
1245
+ logger.info(f"Saving updated AnnData with homogeneous spot matrices to {adata_path}")
1246
+ adata.write_h5ad(adata_path)
1247
+
1248
+ # Close all shared pools after all cell types are processed
1249
+ logger.info("Closing shared processing pools...")
1250
+ self.reader.close()
1251
+ self.computer.close()
1252
+ self.writer.close()
1253
+
1254
+ # Close memory maps
1255
+ rank_memmap.close()
1256
+ output_memmap.close()
1257
+
1258
+ self.console.print(Panel.fit(
1259
+ "[bold green]✓ Marker score calculation complete![/bold green]",
1260
+ border_style="green"
1261
+ ))
1262
+
1263
+ logger.info(f"Results saved to {output_path}")
1264
+
1265
+ return str(output_path)