gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,766 @@
1
+ """
2
+ Memory-mapped I/O utilities for efficient large-scale data handling
3
+ Replaces Zarr-backed storage with NumPy memory maps for better performance
4
+ """
5
+
6
+ import json
7
+ import logging
8
+ import queue
9
+ import shutil
10
+ import threading
11
+ import time
12
+ import traceback
13
+ import uuid
14
+ from dataclasses import dataclass
15
+ from pathlib import Path
16
+
17
+ import numpy as np
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class MemMapDense:
23
+ """Dense matrix storage using NumPy memory maps with async multi-threaded writing"""
24
+
25
+ def __init__(
26
+ self,
27
+ path: str | Path,
28
+ shape: tuple[int, int],
29
+ dtype=np.float16,
30
+ mode: str = 'w',
31
+ num_write_workers: int = 4,
32
+ flush_interval: float = 30,
33
+ tmp_dir: str | Path | None = None,
34
+ ):
35
+ """
36
+ Initialize a memory-mapped dense matrix.
37
+
38
+ Args:
39
+ path: Path to the memory-mapped file (without extension)
40
+ shape: Shape of the matrix (n_rows, n_cols)
41
+ dtype: Data type of the matrix
42
+ mode: 'w' for write (create/overwrite), 'r' for read, 'r+' for read/write
43
+ num_write_workers: Number of worker threads for async writing
44
+ tmp_dir: Optional temporary directory for faster I/O on slow filesystems.
45
+ If provided, files will be created/copied to tmp_dir for operations
46
+ and synced back to the original path on close.
47
+ """
48
+ self.original_path = Path(path)
49
+ self.shape = shape
50
+ self.dtype = dtype
51
+ self.mode = mode
52
+ self.num_write_workers = num_write_workers
53
+ self.flush_interval = flush_interval
54
+ self.tmp_dir = Path(tmp_dir) if tmp_dir else None
55
+ self.using_tmp = False
56
+ self.tmp_path = None
57
+
58
+ # Set up paths based on whether tmp_dir is provided
59
+ if self.tmp_dir:
60
+ self._setup_tmp_paths()
61
+ else:
62
+ self.path = self.original_path
63
+
64
+ # File paths
65
+ self.data_path = self.path.with_suffix('.dat')
66
+ self.meta_path = self.path.with_suffix('.meta.json')
67
+
68
+ # Initialize memory map
69
+ if mode == 'w':
70
+ self._create_memmap()
71
+ elif mode == 'r':
72
+ self._open_memmap_readonly()
73
+ elif mode == 'r+':
74
+ self._open_memmap_readwrite()
75
+ else:
76
+ raise ValueError(f"Invalid mode: {mode}. Must be 'w', 'r', or 'r+'")
77
+
78
+ # Async writing setup (only for write modes)
79
+ self.write_queue = queue.Queue(maxsize=100)
80
+ self.writer_threads = []
81
+ self.stop_writer = threading.Event()
82
+
83
+ if mode in ('w', 'r+'):
84
+ self._start_writer_threads()
85
+
86
+ def _setup_tmp_paths(self):
87
+ """Set up temporary paths for memory-mapped files"""
88
+ # Create a unique subdirectory in tmp_dir to avoid conflicts
89
+ unique_id = str(uuid.uuid4())[:8]
90
+ self.tmp_subdir = self.tmp_dir / f"memmap_{unique_id}"
91
+ self.tmp_subdir.mkdir(parents=True, exist_ok=True)
92
+
93
+ # Create tmp path with same structure as original
94
+ self.tmp_path = self.tmp_subdir / self.original_path.name
95
+ self.path = self.tmp_path
96
+ self.using_tmp = True
97
+
98
+ logger.info(f"Using temporary directory for memmap: {self.tmp_subdir}")
99
+
100
+ # If reading, copy existing files to tmp directory
101
+ if self.mode in ('r', 'r+'):
102
+ original_data_path = self.original_path.with_suffix('.dat')
103
+ original_meta_path = self.original_path.with_suffix('.meta.json')
104
+
105
+ if original_data_path.exists() and original_meta_path.exists():
106
+ tmp_data_path = self.tmp_path.with_suffix('.dat')
107
+ tmp_meta_path = self.tmp_path.with_suffix('.meta.json')
108
+
109
+ logger.info("Copying memmap files to temporary directory for faster access...")
110
+ shutil.copy2(original_data_path, tmp_data_path)
111
+ shutil.copy2(original_meta_path, tmp_meta_path)
112
+ logger.info(f"Memmap files copied to {self.tmp_subdir}")
113
+
114
+ def _sync_tmp_to_original(self):
115
+ """Sync temporary files back to original location"""
116
+ if not self.using_tmp:
117
+ return
118
+
119
+ tmp_data_path = self.tmp_path.with_suffix('.dat')
120
+ tmp_meta_path = self.tmp_path.with_suffix('.meta.json')
121
+ original_data_path = self.original_path.with_suffix('.dat')
122
+ original_meta_path = self.original_path.with_suffix('.meta.json')
123
+
124
+ if tmp_data_path.exists():
125
+ logger.info("Syncing memmap data from tmp to original location...")
126
+ shutil.move(str(tmp_data_path), str(original_data_path))
127
+
128
+ if tmp_meta_path.exists():
129
+ shutil.move(str(tmp_meta_path), str(original_meta_path))
130
+
131
+ logger.info(f"Memmap files synced to {self.original_path}")
132
+
133
+ def _cleanup_tmp(self):
134
+ """Clean up temporary directory"""
135
+ if self.using_tmp and self.tmp_subdir and self.tmp_subdir.exists():
136
+ try:
137
+ shutil.rmtree(self.tmp_subdir)
138
+ logger.debug(f"Cleaned up temporary directory: {self.tmp_subdir}")
139
+ except Exception as e:
140
+ logger.warning(f"Could not clean up temporary directory {self.tmp_subdir}: {e}")
141
+
142
+ def _create_memmap(self):
143
+ """Create a new memory-mapped file"""
144
+ # Check if already exists and is complete
145
+ if self.meta_path.exists():
146
+ try:
147
+ with open(self.meta_path) as f:
148
+ meta = json.load(f)
149
+ if meta.get('complete', False):
150
+ raise ValueError(
151
+ f"MemMapDense at {self.path} already exists and is marked as complete. "
152
+ f"Please delete it manually if you want to overwrite: rm {self.data_path} {self.meta_path}"
153
+ )
154
+ else:
155
+ logger.warning(f"MemMapDense at {self.path} exists but is incomplete. Recreating.")
156
+ except (json.JSONDecodeError, KeyError):
157
+ logger.warning(f"Invalid metadata at {self.meta_path}. Recreating.")
158
+
159
+ # Create new memory map
160
+ self.memmap = np.memmap(
161
+ self.data_path,
162
+ dtype=self.dtype,
163
+ mode='w+',
164
+ shape=self.shape
165
+ )
166
+
167
+ # # Initialize to zeros
168
+ # self.memmap[:] = 0
169
+ # self.memmap.flush()
170
+
171
+ # Write metadata
172
+ meta = {
173
+ 'shape': self.shape,
174
+ 'dtype': np.dtype(self.dtype).name, # Use dtype.name for proper serialization
175
+ 'complete': False,
176
+ 'created_at': time.time()
177
+ }
178
+ with open(self.meta_path, 'w') as f:
179
+ json.dump(meta, f, indent=2)
180
+
181
+ logger.info(f"Created MemMapDense at {self.data_path} with shape {self.shape}")
182
+
183
+ def _open_memmap_readonly(self):
184
+ """Open an existing memory-mapped file in read-only mode"""
185
+ if not self.meta_path.exists():
186
+ raise FileNotFoundError(f"Metadata file not found: {self.meta_path}")
187
+
188
+ # Read metadata
189
+ with open(self.meta_path) as f:
190
+ meta = json.load(f)
191
+
192
+ if not meta.get('complete', False):
193
+ raise ValueError(f"MemMapDense at {self.path} is incomplete")
194
+
195
+ # Validate shape and dtype
196
+ if tuple(meta['shape']) != self.shape:
197
+ raise ValueError(
198
+ f"Shape mismatch: expected {self.shape}, got {tuple(meta['shape'])}"
199
+ )
200
+
201
+ # Open memory map
202
+ self.memmap = np.memmap(
203
+ self.data_path,
204
+ dtype=self.dtype,
205
+ mode='r',
206
+ shape=self.shape
207
+ )
208
+
209
+ logger.info(f"Opened MemMapDense at {self.data_path} in read-only mode")
210
+
211
+ def _open_memmap_readwrite(self):
212
+ """Open an existing memory-mapped file in read-write mode"""
213
+ if not self.meta_path.exists():
214
+ raise FileNotFoundError(f"Metadata file not found: {self.meta_path}")
215
+
216
+ # Read metadata
217
+ with open(self.meta_path) as f:
218
+ meta = json.load(f)
219
+
220
+ # Open memory map
221
+ self.memmap = np.memmap(
222
+ self.data_path,
223
+ dtype=self.dtype,
224
+ mode='r+',
225
+ shape=tuple(meta['shape'])
226
+ )
227
+
228
+ logger.info(f"Opened MemMapDense at {self.data_path} in read-write mode")
229
+
230
+ def _start_writer_threads(self):
231
+ """Start multiple background writer threads sharing the same memmap object"""
232
+ def writer_worker(worker_id):
233
+ last_flush_time = time.time() # Track last flush time for worker 0
234
+
235
+ while not self.stop_writer.is_set():
236
+ try:
237
+ item = self.write_queue.get(timeout=1)
238
+ if item is None:
239
+ break
240
+ data, row_indices, col_slice = item
241
+
242
+ # Write data with thread safety using shared memmap
243
+ if isinstance(row_indices, slice):
244
+ self.memmap[row_indices, col_slice] = data
245
+ elif isinstance(row_indices, int | np.integer):
246
+ start_row = row_indices
247
+ end_row = start_row + data.shape[0]
248
+ self.memmap[start_row:end_row, col_slice] = data
249
+ else:
250
+ # Handle array of indices
251
+ self.memmap[row_indices, col_slice] = data
252
+
253
+ # Periodic flush every 1 second for worker 0
254
+ if worker_id == 0:
255
+ current_time = time.time()
256
+ if current_time - last_flush_time >= self.flush_interval:
257
+ self.memmap.flush()
258
+ last_flush_time = time.time()
259
+ logger.debug(f"Worker 0 flushed memmap at {last_flush_time:.2f}")
260
+
261
+ self.write_queue.task_done()
262
+ except queue.Empty:
263
+ continue
264
+ except Exception as e:
265
+ logger.error(f"Writer thread {worker_id} error: {e}")
266
+ raise
267
+
268
+ # Start multiple writer threads
269
+ for i in range(self.num_write_workers):
270
+ thread = threading.Thread(target=writer_worker, args=(i,), daemon=True)
271
+ thread.start()
272
+ self.writer_threads.append(thread)
273
+ logger.info(f"Started {self.num_write_workers} writer threads for MemMapDense")
274
+
275
+ def write_batch(self, data: np.ndarray, row_indices: int | slice | np.ndarray, col_slice=slice(None)):
276
+ """Queue batch for async writing
277
+
278
+ Args:
279
+ data: Data to write
280
+ row_indices: Either a single row index, slice, or array of row indices
281
+ col_slice: Column slice (default: all columns)
282
+ """
283
+ if self.mode not in ('w', 'r+'):
284
+ logger.warning("Cannot write to read-only MemMapDense")
285
+ return
286
+
287
+ self.write_queue.put((data, row_indices, col_slice))
288
+
289
+ def read_batch(self, row_indices: int | slice | np.ndarray, col_slice=slice(None)) -> np.ndarray:
290
+ """Read batch of data
291
+
292
+ Args:
293
+ row_indices: Row indices to read
294
+ col_slice: Column slice (default: all columns)
295
+
296
+ Returns:
297
+ NumPy array with the requested data
298
+ """
299
+ if isinstance(row_indices, int | np.integer):
300
+ return self.memmap[row_indices:row_indices+1, col_slice].copy()
301
+ else:
302
+ return self.memmap[row_indices, col_slice].copy()
303
+
304
+ def __getitem__(self, key):
305
+ """Direct array access for compatibility"""
306
+ return self.memmap[key]
307
+
308
+ def __setitem__(self, key, value):
309
+ """Direct array access for compatibility"""
310
+ if self.mode not in ('w', 'r+'):
311
+ raise ValueError("Cannot write to read-only MemMapDense")
312
+ self.memmap[key] = value
313
+
314
+ def mark_complete(self):
315
+ """Mark the memory map as complete"""
316
+ if self.mode in ('w', 'r+'):
317
+ logger.info("Marking memmap as complete")
318
+ # Ensure all writes are flushed
319
+ if self.writer_threads and not self.write_queue.empty():
320
+ logger.info("Waiting for remaining writes before marking complete...")
321
+ self.write_queue.join()
322
+
323
+ # Flush memory map to disk
324
+ logger.info("Flushing memmap to disk...")
325
+ self.memmap.flush()
326
+ logger.info("Memmap flush complete")
327
+
328
+ # Update metadata
329
+ with open(self.meta_path) as f:
330
+ meta = json.load(f)
331
+ meta['complete'] = True
332
+ meta['completed_at'] = time.time()
333
+ # Ensure dtype is properly serialized
334
+ if 'dtype' in meta and not isinstance(meta['dtype'], str):
335
+ meta['dtype'] = np.dtype(self.dtype).name
336
+ with open(self.meta_path, 'w') as f:
337
+ json.dump(meta, f, indent=2)
338
+
339
+ logger.info(f"Marked MemMapDense at {self.path} as complete")
340
+
341
+ @classmethod
342
+ def check_complete(cls, memmap_path: str | Path, meta_path: str | Path | None = None) -> tuple[bool, dict | None]:
343
+ """
344
+ Check if a memory map file is complete without opening it.
345
+
346
+ Args:
347
+ memmap_path: Path to the memory-mapped file (without extension)
348
+ meta_path: Optional path to metadata file. If not provided, will be derived from memmap_path
349
+
350
+ Returns:
351
+ Tuple of (is_complete, metadata_dict). metadata_dict is None if file doesn't exist or can't be read.
352
+ """
353
+ memmap_path = Path(memmap_path)
354
+
355
+ if meta_path is None:
356
+ # Derive metadata path from memmap path
357
+ if memmap_path.suffix == '.dat':
358
+ meta_path = memmap_path.with_suffix('.meta.json')
359
+ elif memmap_path.suffix == '.meta.json':
360
+ meta_path = memmap_path
361
+ else:
362
+ # Assume no extension, add .meta.json
363
+ meta_path = memmap_path.with_suffix('.meta.json')
364
+ else:
365
+ meta_path = Path(meta_path)
366
+
367
+ if not meta_path.exists():
368
+ return False, None
369
+
370
+ try:
371
+ with open(meta_path) as f:
372
+ meta = json.load(f)
373
+ return meta.get('complete', False), meta
374
+ except (OSError, json.JSONDecodeError) as e:
375
+ logger.warning(f"Could not read metadata from {meta_path}: {e}")
376
+ return False, None
377
+
378
+ def close(self):
379
+ """Clean up resources"""
380
+ logger.info("MemMapDense.close() called")
381
+ if self.writer_threads:
382
+ logger.info("Closing MemMapDense: waiting for queued writes...")
383
+ self.write_queue.join()
384
+ logger.info("All queued writes have been processed")
385
+ self.stop_writer.set()
386
+ logger.info("Stop signal set for writer threads")
387
+
388
+ # Send stop signal to all threads
389
+ for _ in self.writer_threads:
390
+ self.write_queue.put(None)
391
+ logger.info("Stop sentinels queued for writer threads")
392
+
393
+ # Wait for all threads to finish
394
+ for thread in self.writer_threads:
395
+ thread.join(timeout=5.0)
396
+
397
+ # Final flush
398
+ if self.mode in ('w', 'r+'):
399
+ self.mark_complete()
400
+
401
+ # Sync tmp files back to original location if using tmp
402
+ if self.using_tmp:
403
+ self._sync_tmp_to_original()
404
+
405
+ self._cleanup_tmp()
406
+
407
+ def __enter__(self):
408
+ return self
409
+
410
+ @property
411
+ def attrs(self):
412
+ """Compatibility property for accessing metadata"""
413
+ if hasattr(self, '_attrs'):
414
+ return self._attrs
415
+
416
+ if self.meta_path.exists():
417
+ with open(self.meta_path) as f:
418
+ self._attrs = json.load(f)
419
+ else:
420
+ self._attrs = {}
421
+ return self._attrs
422
+
423
+
424
+ @dataclass
425
+ class ComponentThroughput:
426
+ """Track throughput for individual pipeline components"""
427
+ total_batches: int = 0
428
+ total_time: float = 0.0
429
+ last_batch_time: float = 0.0
430
+
431
+ def record_batch(self, elapsed_time: float):
432
+ """Record a batch completion"""
433
+ self.total_batches += 1
434
+ self.total_time += elapsed_time
435
+ self.last_batch_time = elapsed_time
436
+
437
+ @property
438
+ def average_time(self) -> float:
439
+ """Average time per batch"""
440
+ if self.total_batches > 0:
441
+ return self.total_time / self.total_batches
442
+ return 0.0
443
+
444
+ @property
445
+ def throughput(self) -> float:
446
+ if self.average_time > 0:
447
+ return 1.0 / self.average_time
448
+ return 0.0
449
+
450
+
451
+
452
+
453
+
454
+ class ParallelRankReader:
455
+ """Multi-threaded reader for log-rank data from memory-mapped storage"""
456
+
457
+ def __init__(
458
+ self,
459
+ rank_memmap: MemMapDense | str,
460
+ num_workers: int = 4,
461
+ output_queue: queue.Queue = None,
462
+ cache_size_mb: int = 1000
463
+ ):
464
+ # Store shared memmap object if provided
465
+ if isinstance(rank_memmap, MemMapDense):
466
+ self.shared_memmap = rank_memmap.memmap
467
+ self.memmap_path = rank_memmap.path
468
+ self.shape = rank_memmap.shape
469
+ self.dtype = rank_memmap.dtype
470
+ else:
471
+ # Fallback for string path: open a shared read-only memmap here
472
+ self.memmap_path = Path(rank_memmap)
473
+ meta_path = self.memmap_path.with_suffix('.meta.json')
474
+ data_path = self.memmap_path.with_suffix('.dat')
475
+ with open(meta_path) as f:
476
+ meta = json.load(f)
477
+ self.shape = tuple(meta['shape'])
478
+ self.dtype = np.dtype(meta['dtype'])
479
+
480
+ # Open the single shared memmap
481
+ self.shared_memmap = np.memmap(
482
+ data_path,
483
+ dtype=self.dtype,
484
+ mode='r',
485
+ shape=self.shape
486
+ )
487
+ logger.info(f"Opened shared memmap for reading at {data_path}")
488
+
489
+ self.num_workers = num_workers
490
+
491
+ # Queues for communication
492
+ self.read_queue = queue.Queue()
493
+ # Use provided output queue or create own
494
+ self.result_queue = output_queue if output_queue else queue.Queue(maxsize=self.num_workers * 4)
495
+
496
+ # Throughput tracking
497
+ self.throughput = ComponentThroughput()
498
+ self.throughput_lock = threading.Lock()
499
+
500
+ # Exception handling
501
+ self.exception_queue = queue.Queue()
502
+ self.has_error = threading.Event()
503
+
504
+ # Start worker threads
505
+ self.workers = []
506
+ self.stop_workers = threading.Event()
507
+ self._start_workers()
508
+
509
+ def _start_workers(self):
510
+ """Start worker threads"""
511
+ for i in range(self.num_workers):
512
+ worker = threading.Thread(
513
+ target=self._worker,
514
+ args=(i,),
515
+ daemon=True
516
+ )
517
+ worker.start()
518
+ self.workers.append(worker)
519
+
520
+ def _worker(self, worker_id: int):
521
+ """Worker thread for reading batches from shared memory map"""
522
+ logger.debug(f"Reader worker {worker_id} started using shared memmap")
523
+
524
+ # No need to open a new memmap; use self.shared_memmap directly.
525
+ # Numpy releases the GIL during memmap array access, enabling parallelism.
526
+
527
+ while not self.stop_workers.is_set():
528
+ try:
529
+ # Get batch request
530
+ item = self.read_queue.get()
531
+ if item is None:
532
+ break
533
+
534
+ batch_id, neighbor_indices, batch_metadata = item
535
+
536
+ # Track timing
537
+ start_time = time.time()
538
+
539
+ # Flatten and deduplicate indices for efficient reading
540
+ flat_indices = np.unique(neighbor_indices.flatten())
541
+
542
+ # Validate indices are within bounds
543
+ max_idx = self.shape[0] - 1
544
+ assert flat_indices.max() <= max_idx, \
545
+ f"Worker {worker_id}: Indices exceed bounds (max: {flat_indices.max()}, limit: {max_idx})"
546
+
547
+ # Read from shared memory map (thread-safe, GIL released)
548
+ rank_data = self.shared_memmap[flat_indices]
549
+
550
+ # Ensure we have a numpy array
551
+ if not isinstance(rank_data, np.ndarray):
552
+ rank_data = np.array(rank_data)
553
+
554
+ # Create mapping for reconstruction
555
+ idx_map = {idx: i for i, idx in enumerate(flat_indices)}
556
+
557
+ # Map neighbor indices to rank_data indices
558
+ flat_neighbors = neighbor_indices.flatten()
559
+ rank_indices = np.array([idx_map[neighbor_idx] for neighbor_idx in flat_neighbors])
560
+
561
+ # Track throughput
562
+ elapsed = time.time() - start_time
563
+ with self.throughput_lock:
564
+ self.throughput.record_batch(elapsed)
565
+
566
+ # Put result with metadata for computer
567
+ self.result_queue.put((batch_id, rank_data, rank_indices, neighbor_indices.shape, batch_metadata))
568
+ self.read_queue.task_done()
569
+
570
+ except queue.Empty:
571
+ continue
572
+ except Exception as e:
573
+ error_trace = traceback.format_exc()
574
+ logger.error(f"Reader worker {worker_id} error: {e}\nTraceback:\n{error_trace}")
575
+ self.exception_queue.put((worker_id, e, error_trace))
576
+ self.has_error.set()
577
+ self.stop_workers.set() # Signal all workers to stop
578
+ break
579
+
580
+ def submit_batch(self, batch_id: int, neighbor_indices: np.ndarray, batch_metadata: dict = None):
581
+ """Submit batch for reading with metadata"""
582
+ self.read_queue.put((batch_id, neighbor_indices, batch_metadata or {}))
583
+
584
+ def get_result(self):
585
+ """Get next completed batch"""
586
+ return self.result_queue.get()
587
+
588
+ def get_queue_sizes(self):
589
+ """Get current queue sizes for monitoring"""
590
+ return self.read_queue.qsize(), self.result_queue.qsize()
591
+
592
+ def check_errors(self):
593
+ """Check if any worker encountered an error"""
594
+ if self.has_error.is_set():
595
+ try:
596
+ worker_id, exception, error_trace = self.exception_queue.get_nowait()
597
+ raise RuntimeError(f"Reader worker {worker_id} failed: {exception}\nOriginal traceback:\n{error_trace}") from exception
598
+ except queue.Empty:
599
+ raise RuntimeError("Reader worker failed with unknown error")
600
+
601
+ def reset_for_cell_type(self, cell_type: str):
602
+ """Reset throughput tracking for new cell type"""
603
+ with self.throughput_lock:
604
+ self.throughput = ComponentThroughput()
605
+ logger.debug(f"Reset reader throughput for {cell_type}")
606
+
607
+ def close(self):
608
+ """Clean up resources"""
609
+ self.stop_workers.set()
610
+ for _ in range(self.num_workers):
611
+ self.read_queue.put(None)
612
+ for worker in self.workers:
613
+ worker.join(timeout=5)
614
+
615
+ # Do not close shared_memmap here if it was passed in (owned by caller)
616
+ # If we opened it ourselves (str path), we could close it, but memmap doesn't strictly require close()
617
+ pass
618
+
619
+
620
+ class ParallelMarkerScoreWriter:
621
+ """Multi-threaded writer pool for marker scores using shared memmap"""
622
+
623
+ def __init__(
624
+ self,
625
+ output_memmap: MemMapDense,
626
+ num_workers: int = 4,
627
+ input_queue: queue.Queue = None
628
+ ):
629
+ """
630
+ Initialize writer pool
631
+
632
+ Args:
633
+ output_memmap: Output memory map wrapper object
634
+ num_workers: Number of writer threads
635
+ input_queue: Optional input queue (from computer)
636
+ """
637
+ # Store shared memmap object
638
+ self.shared_memmap = output_memmap.memmap
639
+ self.memmap_path = output_memmap.path
640
+ self.shape = output_memmap.shape
641
+ self.dtype = output_memmap.dtype
642
+ self.num_workers = num_workers
643
+
644
+ # Queue for write requests
645
+ self.write_queue = input_queue if input_queue else queue.Queue(maxsize=100)
646
+ self.completed_count = 0
647
+ self.completed_lock = threading.Lock()
648
+ self.active_cell_type = None # Track current cell type being processed
649
+
650
+ # Throughput tracking
651
+ self.throughput = ComponentThroughput()
652
+ self.throughput_lock = threading.Lock()
653
+
654
+ # Exception handling
655
+ self.exception_queue = queue.Queue()
656
+ self.has_error = threading.Event()
657
+
658
+ # Start worker threads
659
+ self.workers = []
660
+ self.stop_workers = threading.Event()
661
+ self._start_workers()
662
+
663
+ def _start_workers(self):
664
+ """Start writer worker threads"""
665
+ for i in range(self.num_workers):
666
+ worker = threading.Thread(
667
+ target=self._writer_worker,
668
+ args=(i,),
669
+ daemon=True
670
+ )
671
+ worker.start()
672
+ self.workers.append(worker)
673
+ logger.debug(f"Started {self.num_workers} writer threads with shared memmap")
674
+
675
+ def _writer_worker(self, worker_id: int):
676
+ """Writer worker thread using shared memmap"""
677
+ logger.debug(f"Writer worker {worker_id} started")
678
+
679
+ while not self.stop_workers.is_set():
680
+ try:
681
+ # Get write request
682
+ item = self.write_queue.get(timeout=1)
683
+ if item is None:
684
+ break
685
+
686
+ batch_idx, marker_scores, cell_indices = item
687
+
688
+ # Track timing
689
+ start_time = time.time()
690
+
691
+ # Write directly to shared memory map
692
+ # Thread-safe because workers process disjoint batches (indices)
693
+ # cell_indices should be the absolute indices in the full matrix
694
+ self.shared_memmap[cell_indices] = marker_scores
695
+
696
+ # Track throughput
697
+ elapsed = time.time() - start_time
698
+ with self.throughput_lock:
699
+ self.throughput.record_batch(elapsed)
700
+
701
+ # Update completed count
702
+ with self.completed_lock:
703
+ self.completed_count += 1
704
+
705
+ self.write_queue.task_done()
706
+
707
+ except queue.Empty:
708
+ continue
709
+ except Exception as e:
710
+ error_trace = traceback.format_exc()
711
+ logger.error(f"Writer worker {worker_id} error: {e}\nTraceback:\n{error_trace}")
712
+ self.exception_queue.put((worker_id, e, error_trace))
713
+ self.has_error.set()
714
+ self.stop_workers.set() # Signal all workers to stop
715
+ break
716
+
717
+ # We do not close or delete the shared memmap here.
718
+ # Flushing is handled by the main thread or MemMapDense wrapper.
719
+ logger.debug(f"Writer worker {worker_id} stopping")
720
+
721
+ def reset_for_cell_type(self, cell_type: str):
722
+ """Reset for processing a new cell type"""
723
+ self.active_cell_type = cell_type
724
+ with self.completed_lock:
725
+ self.completed_count = 0
726
+ with self.throughput_lock:
727
+ self.throughput = ComponentThroughput()
728
+ logger.debug(f"Reset writer throughput for {cell_type}")
729
+
730
+ def get_completed_count(self):
731
+ """Get number of completed writes"""
732
+ with self.completed_lock:
733
+ return self.completed_count
734
+
735
+ def get_queue_size(self):
736
+ """Get write queue size"""
737
+ return self.write_queue.qsize()
738
+
739
+ def check_errors(self):
740
+ """Check if any worker encountered an error"""
741
+ if self.has_error.is_set():
742
+ try:
743
+ worker_id, exception, error_trace = self.exception_queue.get_nowait()
744
+ raise RuntimeError(f"Writer worker {worker_id} failed: {exception}\nOriginal traceback:\n{error_trace}") from exception
745
+ except queue.Empty:
746
+ raise RuntimeError("Writer worker failed with unknown error")
747
+
748
+ def close(self):
749
+ """Close writer pool"""
750
+ logger.info("Closing writer pool...")
751
+
752
+ # Wait for queue to empty
753
+ if not self.write_queue.empty():
754
+ logger.info("Waiting for remaining writes...")
755
+ self.write_queue.join()
756
+
757
+ # Stop workers
758
+ self.stop_workers.set()
759
+ for _ in range(self.num_workers):
760
+ self.write_queue.put(None)
761
+
762
+ # Wait for workers to finish
763
+ for worker in self.workers:
764
+ worker.join(timeout=5)
765
+
766
+ logger.info("Writer pool closed")