gsMap3D 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/__init__.py +13 -0
- gsMap/__main__.py +4 -0
- gsMap/cauchy_combination_test.py +342 -0
- gsMap/cli.py +355 -0
- gsMap/config/__init__.py +72 -0
- gsMap/config/base.py +296 -0
- gsMap/config/cauchy_config.py +79 -0
- gsMap/config/dataclasses.py +235 -0
- gsMap/config/decorators.py +302 -0
- gsMap/config/find_latent_config.py +276 -0
- gsMap/config/format_sumstats_config.py +54 -0
- gsMap/config/latent2gene_config.py +461 -0
- gsMap/config/ldscore_config.py +261 -0
- gsMap/config/quick_mode_config.py +242 -0
- gsMap/config/report_config.py +81 -0
- gsMap/config/spatial_ldsc_config.py +334 -0
- gsMap/config/utils.py +286 -0
- gsMap/find_latent/__init__.py +3 -0
- gsMap/find_latent/find_latent_representation.py +312 -0
- gsMap/find_latent/gnn/distribution.py +498 -0
- gsMap/find_latent/gnn/encoder_decoder.py +186 -0
- gsMap/find_latent/gnn/gcn.py +85 -0
- gsMap/find_latent/gnn/gene_former.py +164 -0
- gsMap/find_latent/gnn/loss.py +18 -0
- gsMap/find_latent/gnn/st_model.py +125 -0
- gsMap/find_latent/gnn/train_step.py +177 -0
- gsMap/find_latent/st_process.py +781 -0
- gsMap/format_sumstats.py +446 -0
- gsMap/generate_ldscore.py +1018 -0
- gsMap/latent2gene/__init__.py +18 -0
- gsMap/latent2gene/connectivity.py +781 -0
- gsMap/latent2gene/entry_point.py +141 -0
- gsMap/latent2gene/marker_scores.py +1265 -0
- gsMap/latent2gene/memmap_io.py +766 -0
- gsMap/latent2gene/rank_calculator.py +590 -0
- gsMap/latent2gene/row_ordering.py +182 -0
- gsMap/latent2gene/row_ordering_jax.py +159 -0
- gsMap/ldscore/__init__.py +1 -0
- gsMap/ldscore/batch_construction.py +163 -0
- gsMap/ldscore/compute.py +126 -0
- gsMap/ldscore/constants.py +70 -0
- gsMap/ldscore/io.py +262 -0
- gsMap/ldscore/mapping.py +262 -0
- gsMap/ldscore/pipeline.py +615 -0
- gsMap/pipeline/quick_mode.py +134 -0
- gsMap/report/__init__.py +2 -0
- gsMap/report/diagnosis.py +375 -0
- gsMap/report/report.py +100 -0
- gsMap/report/report_data.py +1832 -0
- gsMap/report/static/js_lib/alpine.min.js +5 -0
- gsMap/report/static/js_lib/tailwindcss.js +83 -0
- gsMap/report/static/template.html +2242 -0
- gsMap/report/three_d_combine.py +312 -0
- gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
- gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
- gsMap/report/three_d_plot/three_d_plots.py +425 -0
- gsMap/report/visualize.py +1409 -0
- gsMap/setup.py +5 -0
- gsMap/spatial_ldsc/__init__.py +0 -0
- gsMap/spatial_ldsc/io.py +656 -0
- gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
- gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
- gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +610 -0
- gsMap/utils/jackknife.py +518 -0
- gsMap/utils/manhattan_plot.py +643 -0
- gsMap/utils/regression_read.py +177 -0
- gsMap/utils/torch_utils.py +23 -0
- gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
- gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
- gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
- gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
- gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,766 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Memory-mapped I/O utilities for efficient large-scale data handling
|
|
3
|
+
Replaces Zarr-backed storage with NumPy memory maps for better performance
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
import queue
|
|
9
|
+
import shutil
|
|
10
|
+
import threading
|
|
11
|
+
import time
|
|
12
|
+
import traceback
|
|
13
|
+
import uuid
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MemMapDense:
|
|
23
|
+
"""Dense matrix storage using NumPy memory maps with async multi-threaded writing"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
path: str | Path,
|
|
28
|
+
shape: tuple[int, int],
|
|
29
|
+
dtype=np.float16,
|
|
30
|
+
mode: str = 'w',
|
|
31
|
+
num_write_workers: int = 4,
|
|
32
|
+
flush_interval: float = 30,
|
|
33
|
+
tmp_dir: str | Path | None = None,
|
|
34
|
+
):
|
|
35
|
+
"""
|
|
36
|
+
Initialize a memory-mapped dense matrix.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
path: Path to the memory-mapped file (without extension)
|
|
40
|
+
shape: Shape of the matrix (n_rows, n_cols)
|
|
41
|
+
dtype: Data type of the matrix
|
|
42
|
+
mode: 'w' for write (create/overwrite), 'r' for read, 'r+' for read/write
|
|
43
|
+
num_write_workers: Number of worker threads for async writing
|
|
44
|
+
tmp_dir: Optional temporary directory for faster I/O on slow filesystems.
|
|
45
|
+
If provided, files will be created/copied to tmp_dir for operations
|
|
46
|
+
and synced back to the original path on close.
|
|
47
|
+
"""
|
|
48
|
+
self.original_path = Path(path)
|
|
49
|
+
self.shape = shape
|
|
50
|
+
self.dtype = dtype
|
|
51
|
+
self.mode = mode
|
|
52
|
+
self.num_write_workers = num_write_workers
|
|
53
|
+
self.flush_interval = flush_interval
|
|
54
|
+
self.tmp_dir = Path(tmp_dir) if tmp_dir else None
|
|
55
|
+
self.using_tmp = False
|
|
56
|
+
self.tmp_path = None
|
|
57
|
+
|
|
58
|
+
# Set up paths based on whether tmp_dir is provided
|
|
59
|
+
if self.tmp_dir:
|
|
60
|
+
self._setup_tmp_paths()
|
|
61
|
+
else:
|
|
62
|
+
self.path = self.original_path
|
|
63
|
+
|
|
64
|
+
# File paths
|
|
65
|
+
self.data_path = self.path.with_suffix('.dat')
|
|
66
|
+
self.meta_path = self.path.with_suffix('.meta.json')
|
|
67
|
+
|
|
68
|
+
# Initialize memory map
|
|
69
|
+
if mode == 'w':
|
|
70
|
+
self._create_memmap()
|
|
71
|
+
elif mode == 'r':
|
|
72
|
+
self._open_memmap_readonly()
|
|
73
|
+
elif mode == 'r+':
|
|
74
|
+
self._open_memmap_readwrite()
|
|
75
|
+
else:
|
|
76
|
+
raise ValueError(f"Invalid mode: {mode}. Must be 'w', 'r', or 'r+'")
|
|
77
|
+
|
|
78
|
+
# Async writing setup (only for write modes)
|
|
79
|
+
self.write_queue = queue.Queue(maxsize=100)
|
|
80
|
+
self.writer_threads = []
|
|
81
|
+
self.stop_writer = threading.Event()
|
|
82
|
+
|
|
83
|
+
if mode in ('w', 'r+'):
|
|
84
|
+
self._start_writer_threads()
|
|
85
|
+
|
|
86
|
+
def _setup_tmp_paths(self):
|
|
87
|
+
"""Set up temporary paths for memory-mapped files"""
|
|
88
|
+
# Create a unique subdirectory in tmp_dir to avoid conflicts
|
|
89
|
+
unique_id = str(uuid.uuid4())[:8]
|
|
90
|
+
self.tmp_subdir = self.tmp_dir / f"memmap_{unique_id}"
|
|
91
|
+
self.tmp_subdir.mkdir(parents=True, exist_ok=True)
|
|
92
|
+
|
|
93
|
+
# Create tmp path with same structure as original
|
|
94
|
+
self.tmp_path = self.tmp_subdir / self.original_path.name
|
|
95
|
+
self.path = self.tmp_path
|
|
96
|
+
self.using_tmp = True
|
|
97
|
+
|
|
98
|
+
logger.info(f"Using temporary directory for memmap: {self.tmp_subdir}")
|
|
99
|
+
|
|
100
|
+
# If reading, copy existing files to tmp directory
|
|
101
|
+
if self.mode in ('r', 'r+'):
|
|
102
|
+
original_data_path = self.original_path.with_suffix('.dat')
|
|
103
|
+
original_meta_path = self.original_path.with_suffix('.meta.json')
|
|
104
|
+
|
|
105
|
+
if original_data_path.exists() and original_meta_path.exists():
|
|
106
|
+
tmp_data_path = self.tmp_path.with_suffix('.dat')
|
|
107
|
+
tmp_meta_path = self.tmp_path.with_suffix('.meta.json')
|
|
108
|
+
|
|
109
|
+
logger.info("Copying memmap files to temporary directory for faster access...")
|
|
110
|
+
shutil.copy2(original_data_path, tmp_data_path)
|
|
111
|
+
shutil.copy2(original_meta_path, tmp_meta_path)
|
|
112
|
+
logger.info(f"Memmap files copied to {self.tmp_subdir}")
|
|
113
|
+
|
|
114
|
+
def _sync_tmp_to_original(self):
|
|
115
|
+
"""Sync temporary files back to original location"""
|
|
116
|
+
if not self.using_tmp:
|
|
117
|
+
return
|
|
118
|
+
|
|
119
|
+
tmp_data_path = self.tmp_path.with_suffix('.dat')
|
|
120
|
+
tmp_meta_path = self.tmp_path.with_suffix('.meta.json')
|
|
121
|
+
original_data_path = self.original_path.with_suffix('.dat')
|
|
122
|
+
original_meta_path = self.original_path.with_suffix('.meta.json')
|
|
123
|
+
|
|
124
|
+
if tmp_data_path.exists():
|
|
125
|
+
logger.info("Syncing memmap data from tmp to original location...")
|
|
126
|
+
shutil.move(str(tmp_data_path), str(original_data_path))
|
|
127
|
+
|
|
128
|
+
if tmp_meta_path.exists():
|
|
129
|
+
shutil.move(str(tmp_meta_path), str(original_meta_path))
|
|
130
|
+
|
|
131
|
+
logger.info(f"Memmap files synced to {self.original_path}")
|
|
132
|
+
|
|
133
|
+
def _cleanup_tmp(self):
|
|
134
|
+
"""Clean up temporary directory"""
|
|
135
|
+
if self.using_tmp and self.tmp_subdir and self.tmp_subdir.exists():
|
|
136
|
+
try:
|
|
137
|
+
shutil.rmtree(self.tmp_subdir)
|
|
138
|
+
logger.debug(f"Cleaned up temporary directory: {self.tmp_subdir}")
|
|
139
|
+
except Exception as e:
|
|
140
|
+
logger.warning(f"Could not clean up temporary directory {self.tmp_subdir}: {e}")
|
|
141
|
+
|
|
142
|
+
def _create_memmap(self):
|
|
143
|
+
"""Create a new memory-mapped file"""
|
|
144
|
+
# Check if already exists and is complete
|
|
145
|
+
if self.meta_path.exists():
|
|
146
|
+
try:
|
|
147
|
+
with open(self.meta_path) as f:
|
|
148
|
+
meta = json.load(f)
|
|
149
|
+
if meta.get('complete', False):
|
|
150
|
+
raise ValueError(
|
|
151
|
+
f"MemMapDense at {self.path} already exists and is marked as complete. "
|
|
152
|
+
f"Please delete it manually if you want to overwrite: rm {self.data_path} {self.meta_path}"
|
|
153
|
+
)
|
|
154
|
+
else:
|
|
155
|
+
logger.warning(f"MemMapDense at {self.path} exists but is incomplete. Recreating.")
|
|
156
|
+
except (json.JSONDecodeError, KeyError):
|
|
157
|
+
logger.warning(f"Invalid metadata at {self.meta_path}. Recreating.")
|
|
158
|
+
|
|
159
|
+
# Create new memory map
|
|
160
|
+
self.memmap = np.memmap(
|
|
161
|
+
self.data_path,
|
|
162
|
+
dtype=self.dtype,
|
|
163
|
+
mode='w+',
|
|
164
|
+
shape=self.shape
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# # Initialize to zeros
|
|
168
|
+
# self.memmap[:] = 0
|
|
169
|
+
# self.memmap.flush()
|
|
170
|
+
|
|
171
|
+
# Write metadata
|
|
172
|
+
meta = {
|
|
173
|
+
'shape': self.shape,
|
|
174
|
+
'dtype': np.dtype(self.dtype).name, # Use dtype.name for proper serialization
|
|
175
|
+
'complete': False,
|
|
176
|
+
'created_at': time.time()
|
|
177
|
+
}
|
|
178
|
+
with open(self.meta_path, 'w') as f:
|
|
179
|
+
json.dump(meta, f, indent=2)
|
|
180
|
+
|
|
181
|
+
logger.info(f"Created MemMapDense at {self.data_path} with shape {self.shape}")
|
|
182
|
+
|
|
183
|
+
def _open_memmap_readonly(self):
|
|
184
|
+
"""Open an existing memory-mapped file in read-only mode"""
|
|
185
|
+
if not self.meta_path.exists():
|
|
186
|
+
raise FileNotFoundError(f"Metadata file not found: {self.meta_path}")
|
|
187
|
+
|
|
188
|
+
# Read metadata
|
|
189
|
+
with open(self.meta_path) as f:
|
|
190
|
+
meta = json.load(f)
|
|
191
|
+
|
|
192
|
+
if not meta.get('complete', False):
|
|
193
|
+
raise ValueError(f"MemMapDense at {self.path} is incomplete")
|
|
194
|
+
|
|
195
|
+
# Validate shape and dtype
|
|
196
|
+
if tuple(meta['shape']) != self.shape:
|
|
197
|
+
raise ValueError(
|
|
198
|
+
f"Shape mismatch: expected {self.shape}, got {tuple(meta['shape'])}"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# Open memory map
|
|
202
|
+
self.memmap = np.memmap(
|
|
203
|
+
self.data_path,
|
|
204
|
+
dtype=self.dtype,
|
|
205
|
+
mode='r',
|
|
206
|
+
shape=self.shape
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
logger.info(f"Opened MemMapDense at {self.data_path} in read-only mode")
|
|
210
|
+
|
|
211
|
+
def _open_memmap_readwrite(self):
|
|
212
|
+
"""Open an existing memory-mapped file in read-write mode"""
|
|
213
|
+
if not self.meta_path.exists():
|
|
214
|
+
raise FileNotFoundError(f"Metadata file not found: {self.meta_path}")
|
|
215
|
+
|
|
216
|
+
# Read metadata
|
|
217
|
+
with open(self.meta_path) as f:
|
|
218
|
+
meta = json.load(f)
|
|
219
|
+
|
|
220
|
+
# Open memory map
|
|
221
|
+
self.memmap = np.memmap(
|
|
222
|
+
self.data_path,
|
|
223
|
+
dtype=self.dtype,
|
|
224
|
+
mode='r+',
|
|
225
|
+
shape=tuple(meta['shape'])
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
logger.info(f"Opened MemMapDense at {self.data_path} in read-write mode")
|
|
229
|
+
|
|
230
|
+
def _start_writer_threads(self):
|
|
231
|
+
"""Start multiple background writer threads sharing the same memmap object"""
|
|
232
|
+
def writer_worker(worker_id):
|
|
233
|
+
last_flush_time = time.time() # Track last flush time for worker 0
|
|
234
|
+
|
|
235
|
+
while not self.stop_writer.is_set():
|
|
236
|
+
try:
|
|
237
|
+
item = self.write_queue.get(timeout=1)
|
|
238
|
+
if item is None:
|
|
239
|
+
break
|
|
240
|
+
data, row_indices, col_slice = item
|
|
241
|
+
|
|
242
|
+
# Write data with thread safety using shared memmap
|
|
243
|
+
if isinstance(row_indices, slice):
|
|
244
|
+
self.memmap[row_indices, col_slice] = data
|
|
245
|
+
elif isinstance(row_indices, int | np.integer):
|
|
246
|
+
start_row = row_indices
|
|
247
|
+
end_row = start_row + data.shape[0]
|
|
248
|
+
self.memmap[start_row:end_row, col_slice] = data
|
|
249
|
+
else:
|
|
250
|
+
# Handle array of indices
|
|
251
|
+
self.memmap[row_indices, col_slice] = data
|
|
252
|
+
|
|
253
|
+
# Periodic flush every 1 second for worker 0
|
|
254
|
+
if worker_id == 0:
|
|
255
|
+
current_time = time.time()
|
|
256
|
+
if current_time - last_flush_time >= self.flush_interval:
|
|
257
|
+
self.memmap.flush()
|
|
258
|
+
last_flush_time = time.time()
|
|
259
|
+
logger.debug(f"Worker 0 flushed memmap at {last_flush_time:.2f}")
|
|
260
|
+
|
|
261
|
+
self.write_queue.task_done()
|
|
262
|
+
except queue.Empty:
|
|
263
|
+
continue
|
|
264
|
+
except Exception as e:
|
|
265
|
+
logger.error(f"Writer thread {worker_id} error: {e}")
|
|
266
|
+
raise
|
|
267
|
+
|
|
268
|
+
# Start multiple writer threads
|
|
269
|
+
for i in range(self.num_write_workers):
|
|
270
|
+
thread = threading.Thread(target=writer_worker, args=(i,), daemon=True)
|
|
271
|
+
thread.start()
|
|
272
|
+
self.writer_threads.append(thread)
|
|
273
|
+
logger.info(f"Started {self.num_write_workers} writer threads for MemMapDense")
|
|
274
|
+
|
|
275
|
+
def write_batch(self, data: np.ndarray, row_indices: int | slice | np.ndarray, col_slice=slice(None)):
|
|
276
|
+
"""Queue batch for async writing
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
data: Data to write
|
|
280
|
+
row_indices: Either a single row index, slice, or array of row indices
|
|
281
|
+
col_slice: Column slice (default: all columns)
|
|
282
|
+
"""
|
|
283
|
+
if self.mode not in ('w', 'r+'):
|
|
284
|
+
logger.warning("Cannot write to read-only MemMapDense")
|
|
285
|
+
return
|
|
286
|
+
|
|
287
|
+
self.write_queue.put((data, row_indices, col_slice))
|
|
288
|
+
|
|
289
|
+
def read_batch(self, row_indices: int | slice | np.ndarray, col_slice=slice(None)) -> np.ndarray:
|
|
290
|
+
"""Read batch of data
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
row_indices: Row indices to read
|
|
294
|
+
col_slice: Column slice (default: all columns)
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
NumPy array with the requested data
|
|
298
|
+
"""
|
|
299
|
+
if isinstance(row_indices, int | np.integer):
|
|
300
|
+
return self.memmap[row_indices:row_indices+1, col_slice].copy()
|
|
301
|
+
else:
|
|
302
|
+
return self.memmap[row_indices, col_slice].copy()
|
|
303
|
+
|
|
304
|
+
def __getitem__(self, key):
|
|
305
|
+
"""Direct array access for compatibility"""
|
|
306
|
+
return self.memmap[key]
|
|
307
|
+
|
|
308
|
+
def __setitem__(self, key, value):
|
|
309
|
+
"""Direct array access for compatibility"""
|
|
310
|
+
if self.mode not in ('w', 'r+'):
|
|
311
|
+
raise ValueError("Cannot write to read-only MemMapDense")
|
|
312
|
+
self.memmap[key] = value
|
|
313
|
+
|
|
314
|
+
def mark_complete(self):
|
|
315
|
+
"""Mark the memory map as complete"""
|
|
316
|
+
if self.mode in ('w', 'r+'):
|
|
317
|
+
logger.info("Marking memmap as complete")
|
|
318
|
+
# Ensure all writes are flushed
|
|
319
|
+
if self.writer_threads and not self.write_queue.empty():
|
|
320
|
+
logger.info("Waiting for remaining writes before marking complete...")
|
|
321
|
+
self.write_queue.join()
|
|
322
|
+
|
|
323
|
+
# Flush memory map to disk
|
|
324
|
+
logger.info("Flushing memmap to disk...")
|
|
325
|
+
self.memmap.flush()
|
|
326
|
+
logger.info("Memmap flush complete")
|
|
327
|
+
|
|
328
|
+
# Update metadata
|
|
329
|
+
with open(self.meta_path) as f:
|
|
330
|
+
meta = json.load(f)
|
|
331
|
+
meta['complete'] = True
|
|
332
|
+
meta['completed_at'] = time.time()
|
|
333
|
+
# Ensure dtype is properly serialized
|
|
334
|
+
if 'dtype' in meta and not isinstance(meta['dtype'], str):
|
|
335
|
+
meta['dtype'] = np.dtype(self.dtype).name
|
|
336
|
+
with open(self.meta_path, 'w') as f:
|
|
337
|
+
json.dump(meta, f, indent=2)
|
|
338
|
+
|
|
339
|
+
logger.info(f"Marked MemMapDense at {self.path} as complete")
|
|
340
|
+
|
|
341
|
+
@classmethod
|
|
342
|
+
def check_complete(cls, memmap_path: str | Path, meta_path: str | Path | None = None) -> tuple[bool, dict | None]:
|
|
343
|
+
"""
|
|
344
|
+
Check if a memory map file is complete without opening it.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
memmap_path: Path to the memory-mapped file (without extension)
|
|
348
|
+
meta_path: Optional path to metadata file. If not provided, will be derived from memmap_path
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
Tuple of (is_complete, metadata_dict). metadata_dict is None if file doesn't exist or can't be read.
|
|
352
|
+
"""
|
|
353
|
+
memmap_path = Path(memmap_path)
|
|
354
|
+
|
|
355
|
+
if meta_path is None:
|
|
356
|
+
# Derive metadata path from memmap path
|
|
357
|
+
if memmap_path.suffix == '.dat':
|
|
358
|
+
meta_path = memmap_path.with_suffix('.meta.json')
|
|
359
|
+
elif memmap_path.suffix == '.meta.json':
|
|
360
|
+
meta_path = memmap_path
|
|
361
|
+
else:
|
|
362
|
+
# Assume no extension, add .meta.json
|
|
363
|
+
meta_path = memmap_path.with_suffix('.meta.json')
|
|
364
|
+
else:
|
|
365
|
+
meta_path = Path(meta_path)
|
|
366
|
+
|
|
367
|
+
if not meta_path.exists():
|
|
368
|
+
return False, None
|
|
369
|
+
|
|
370
|
+
try:
|
|
371
|
+
with open(meta_path) as f:
|
|
372
|
+
meta = json.load(f)
|
|
373
|
+
return meta.get('complete', False), meta
|
|
374
|
+
except (OSError, json.JSONDecodeError) as e:
|
|
375
|
+
logger.warning(f"Could not read metadata from {meta_path}: {e}")
|
|
376
|
+
return False, None
|
|
377
|
+
|
|
378
|
+
def close(self):
|
|
379
|
+
"""Clean up resources"""
|
|
380
|
+
logger.info("MemMapDense.close() called")
|
|
381
|
+
if self.writer_threads:
|
|
382
|
+
logger.info("Closing MemMapDense: waiting for queued writes...")
|
|
383
|
+
self.write_queue.join()
|
|
384
|
+
logger.info("All queued writes have been processed")
|
|
385
|
+
self.stop_writer.set()
|
|
386
|
+
logger.info("Stop signal set for writer threads")
|
|
387
|
+
|
|
388
|
+
# Send stop signal to all threads
|
|
389
|
+
for _ in self.writer_threads:
|
|
390
|
+
self.write_queue.put(None)
|
|
391
|
+
logger.info("Stop sentinels queued for writer threads")
|
|
392
|
+
|
|
393
|
+
# Wait for all threads to finish
|
|
394
|
+
for thread in self.writer_threads:
|
|
395
|
+
thread.join(timeout=5.0)
|
|
396
|
+
|
|
397
|
+
# Final flush
|
|
398
|
+
if self.mode in ('w', 'r+'):
|
|
399
|
+
self.mark_complete()
|
|
400
|
+
|
|
401
|
+
# Sync tmp files back to original location if using tmp
|
|
402
|
+
if self.using_tmp:
|
|
403
|
+
self._sync_tmp_to_original()
|
|
404
|
+
|
|
405
|
+
self._cleanup_tmp()
|
|
406
|
+
|
|
407
|
+
def __enter__(self):
|
|
408
|
+
return self
|
|
409
|
+
|
|
410
|
+
@property
|
|
411
|
+
def attrs(self):
|
|
412
|
+
"""Compatibility property for accessing metadata"""
|
|
413
|
+
if hasattr(self, '_attrs'):
|
|
414
|
+
return self._attrs
|
|
415
|
+
|
|
416
|
+
if self.meta_path.exists():
|
|
417
|
+
with open(self.meta_path) as f:
|
|
418
|
+
self._attrs = json.load(f)
|
|
419
|
+
else:
|
|
420
|
+
self._attrs = {}
|
|
421
|
+
return self._attrs
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
@dataclass
|
|
425
|
+
class ComponentThroughput:
|
|
426
|
+
"""Track throughput for individual pipeline components"""
|
|
427
|
+
total_batches: int = 0
|
|
428
|
+
total_time: float = 0.0
|
|
429
|
+
last_batch_time: float = 0.0
|
|
430
|
+
|
|
431
|
+
def record_batch(self, elapsed_time: float):
|
|
432
|
+
"""Record a batch completion"""
|
|
433
|
+
self.total_batches += 1
|
|
434
|
+
self.total_time += elapsed_time
|
|
435
|
+
self.last_batch_time = elapsed_time
|
|
436
|
+
|
|
437
|
+
@property
|
|
438
|
+
def average_time(self) -> float:
|
|
439
|
+
"""Average time per batch"""
|
|
440
|
+
if self.total_batches > 0:
|
|
441
|
+
return self.total_time / self.total_batches
|
|
442
|
+
return 0.0
|
|
443
|
+
|
|
444
|
+
@property
|
|
445
|
+
def throughput(self) -> float:
|
|
446
|
+
if self.average_time > 0:
|
|
447
|
+
return 1.0 / self.average_time
|
|
448
|
+
return 0.0
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
class ParallelRankReader:
|
|
455
|
+
"""Multi-threaded reader for log-rank data from memory-mapped storage"""
|
|
456
|
+
|
|
457
|
+
def __init__(
|
|
458
|
+
self,
|
|
459
|
+
rank_memmap: MemMapDense | str,
|
|
460
|
+
num_workers: int = 4,
|
|
461
|
+
output_queue: queue.Queue = None,
|
|
462
|
+
cache_size_mb: int = 1000
|
|
463
|
+
):
|
|
464
|
+
# Store shared memmap object if provided
|
|
465
|
+
if isinstance(rank_memmap, MemMapDense):
|
|
466
|
+
self.shared_memmap = rank_memmap.memmap
|
|
467
|
+
self.memmap_path = rank_memmap.path
|
|
468
|
+
self.shape = rank_memmap.shape
|
|
469
|
+
self.dtype = rank_memmap.dtype
|
|
470
|
+
else:
|
|
471
|
+
# Fallback for string path: open a shared read-only memmap here
|
|
472
|
+
self.memmap_path = Path(rank_memmap)
|
|
473
|
+
meta_path = self.memmap_path.with_suffix('.meta.json')
|
|
474
|
+
data_path = self.memmap_path.with_suffix('.dat')
|
|
475
|
+
with open(meta_path) as f:
|
|
476
|
+
meta = json.load(f)
|
|
477
|
+
self.shape = tuple(meta['shape'])
|
|
478
|
+
self.dtype = np.dtype(meta['dtype'])
|
|
479
|
+
|
|
480
|
+
# Open the single shared memmap
|
|
481
|
+
self.shared_memmap = np.memmap(
|
|
482
|
+
data_path,
|
|
483
|
+
dtype=self.dtype,
|
|
484
|
+
mode='r',
|
|
485
|
+
shape=self.shape
|
|
486
|
+
)
|
|
487
|
+
logger.info(f"Opened shared memmap for reading at {data_path}")
|
|
488
|
+
|
|
489
|
+
self.num_workers = num_workers
|
|
490
|
+
|
|
491
|
+
# Queues for communication
|
|
492
|
+
self.read_queue = queue.Queue()
|
|
493
|
+
# Use provided output queue or create own
|
|
494
|
+
self.result_queue = output_queue if output_queue else queue.Queue(maxsize=self.num_workers * 4)
|
|
495
|
+
|
|
496
|
+
# Throughput tracking
|
|
497
|
+
self.throughput = ComponentThroughput()
|
|
498
|
+
self.throughput_lock = threading.Lock()
|
|
499
|
+
|
|
500
|
+
# Exception handling
|
|
501
|
+
self.exception_queue = queue.Queue()
|
|
502
|
+
self.has_error = threading.Event()
|
|
503
|
+
|
|
504
|
+
# Start worker threads
|
|
505
|
+
self.workers = []
|
|
506
|
+
self.stop_workers = threading.Event()
|
|
507
|
+
self._start_workers()
|
|
508
|
+
|
|
509
|
+
def _start_workers(self):
|
|
510
|
+
"""Start worker threads"""
|
|
511
|
+
for i in range(self.num_workers):
|
|
512
|
+
worker = threading.Thread(
|
|
513
|
+
target=self._worker,
|
|
514
|
+
args=(i,),
|
|
515
|
+
daemon=True
|
|
516
|
+
)
|
|
517
|
+
worker.start()
|
|
518
|
+
self.workers.append(worker)
|
|
519
|
+
|
|
520
|
+
def _worker(self, worker_id: int):
|
|
521
|
+
"""Worker thread for reading batches from shared memory map"""
|
|
522
|
+
logger.debug(f"Reader worker {worker_id} started using shared memmap")
|
|
523
|
+
|
|
524
|
+
# No need to open a new memmap; use self.shared_memmap directly.
|
|
525
|
+
# Numpy releases the GIL during memmap array access, enabling parallelism.
|
|
526
|
+
|
|
527
|
+
while not self.stop_workers.is_set():
|
|
528
|
+
try:
|
|
529
|
+
# Get batch request
|
|
530
|
+
item = self.read_queue.get()
|
|
531
|
+
if item is None:
|
|
532
|
+
break
|
|
533
|
+
|
|
534
|
+
batch_id, neighbor_indices, batch_metadata = item
|
|
535
|
+
|
|
536
|
+
# Track timing
|
|
537
|
+
start_time = time.time()
|
|
538
|
+
|
|
539
|
+
# Flatten and deduplicate indices for efficient reading
|
|
540
|
+
flat_indices = np.unique(neighbor_indices.flatten())
|
|
541
|
+
|
|
542
|
+
# Validate indices are within bounds
|
|
543
|
+
max_idx = self.shape[0] - 1
|
|
544
|
+
assert flat_indices.max() <= max_idx, \
|
|
545
|
+
f"Worker {worker_id}: Indices exceed bounds (max: {flat_indices.max()}, limit: {max_idx})"
|
|
546
|
+
|
|
547
|
+
# Read from shared memory map (thread-safe, GIL released)
|
|
548
|
+
rank_data = self.shared_memmap[flat_indices]
|
|
549
|
+
|
|
550
|
+
# Ensure we have a numpy array
|
|
551
|
+
if not isinstance(rank_data, np.ndarray):
|
|
552
|
+
rank_data = np.array(rank_data)
|
|
553
|
+
|
|
554
|
+
# Create mapping for reconstruction
|
|
555
|
+
idx_map = {idx: i for i, idx in enumerate(flat_indices)}
|
|
556
|
+
|
|
557
|
+
# Map neighbor indices to rank_data indices
|
|
558
|
+
flat_neighbors = neighbor_indices.flatten()
|
|
559
|
+
rank_indices = np.array([idx_map[neighbor_idx] for neighbor_idx in flat_neighbors])
|
|
560
|
+
|
|
561
|
+
# Track throughput
|
|
562
|
+
elapsed = time.time() - start_time
|
|
563
|
+
with self.throughput_lock:
|
|
564
|
+
self.throughput.record_batch(elapsed)
|
|
565
|
+
|
|
566
|
+
# Put result with metadata for computer
|
|
567
|
+
self.result_queue.put((batch_id, rank_data, rank_indices, neighbor_indices.shape, batch_metadata))
|
|
568
|
+
self.read_queue.task_done()
|
|
569
|
+
|
|
570
|
+
except queue.Empty:
|
|
571
|
+
continue
|
|
572
|
+
except Exception as e:
|
|
573
|
+
error_trace = traceback.format_exc()
|
|
574
|
+
logger.error(f"Reader worker {worker_id} error: {e}\nTraceback:\n{error_trace}")
|
|
575
|
+
self.exception_queue.put((worker_id, e, error_trace))
|
|
576
|
+
self.has_error.set()
|
|
577
|
+
self.stop_workers.set() # Signal all workers to stop
|
|
578
|
+
break
|
|
579
|
+
|
|
580
|
+
def submit_batch(self, batch_id: int, neighbor_indices: np.ndarray, batch_metadata: dict = None):
|
|
581
|
+
"""Submit batch for reading with metadata"""
|
|
582
|
+
self.read_queue.put((batch_id, neighbor_indices, batch_metadata or {}))
|
|
583
|
+
|
|
584
|
+
def get_result(self):
|
|
585
|
+
"""Get next completed batch"""
|
|
586
|
+
return self.result_queue.get()
|
|
587
|
+
|
|
588
|
+
def get_queue_sizes(self):
|
|
589
|
+
"""Get current queue sizes for monitoring"""
|
|
590
|
+
return self.read_queue.qsize(), self.result_queue.qsize()
|
|
591
|
+
|
|
592
|
+
def check_errors(self):
|
|
593
|
+
"""Check if any worker encountered an error"""
|
|
594
|
+
if self.has_error.is_set():
|
|
595
|
+
try:
|
|
596
|
+
worker_id, exception, error_trace = self.exception_queue.get_nowait()
|
|
597
|
+
raise RuntimeError(f"Reader worker {worker_id} failed: {exception}\nOriginal traceback:\n{error_trace}") from exception
|
|
598
|
+
except queue.Empty:
|
|
599
|
+
raise RuntimeError("Reader worker failed with unknown error")
|
|
600
|
+
|
|
601
|
+
def reset_for_cell_type(self, cell_type: str):
|
|
602
|
+
"""Reset throughput tracking for new cell type"""
|
|
603
|
+
with self.throughput_lock:
|
|
604
|
+
self.throughput = ComponentThroughput()
|
|
605
|
+
logger.debug(f"Reset reader throughput for {cell_type}")
|
|
606
|
+
|
|
607
|
+
def close(self):
|
|
608
|
+
"""Clean up resources"""
|
|
609
|
+
self.stop_workers.set()
|
|
610
|
+
for _ in range(self.num_workers):
|
|
611
|
+
self.read_queue.put(None)
|
|
612
|
+
for worker in self.workers:
|
|
613
|
+
worker.join(timeout=5)
|
|
614
|
+
|
|
615
|
+
# Do not close shared_memmap here if it was passed in (owned by caller)
|
|
616
|
+
# If we opened it ourselves (str path), we could close it, but memmap doesn't strictly require close()
|
|
617
|
+
pass
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
class ParallelMarkerScoreWriter:
|
|
621
|
+
"""Multi-threaded writer pool for marker scores using shared memmap"""
|
|
622
|
+
|
|
623
|
+
def __init__(
|
|
624
|
+
self,
|
|
625
|
+
output_memmap: MemMapDense,
|
|
626
|
+
num_workers: int = 4,
|
|
627
|
+
input_queue: queue.Queue = None
|
|
628
|
+
):
|
|
629
|
+
"""
|
|
630
|
+
Initialize writer pool
|
|
631
|
+
|
|
632
|
+
Args:
|
|
633
|
+
output_memmap: Output memory map wrapper object
|
|
634
|
+
num_workers: Number of writer threads
|
|
635
|
+
input_queue: Optional input queue (from computer)
|
|
636
|
+
"""
|
|
637
|
+
# Store shared memmap object
|
|
638
|
+
self.shared_memmap = output_memmap.memmap
|
|
639
|
+
self.memmap_path = output_memmap.path
|
|
640
|
+
self.shape = output_memmap.shape
|
|
641
|
+
self.dtype = output_memmap.dtype
|
|
642
|
+
self.num_workers = num_workers
|
|
643
|
+
|
|
644
|
+
# Queue for write requests
|
|
645
|
+
self.write_queue = input_queue if input_queue else queue.Queue(maxsize=100)
|
|
646
|
+
self.completed_count = 0
|
|
647
|
+
self.completed_lock = threading.Lock()
|
|
648
|
+
self.active_cell_type = None # Track current cell type being processed
|
|
649
|
+
|
|
650
|
+
# Throughput tracking
|
|
651
|
+
self.throughput = ComponentThroughput()
|
|
652
|
+
self.throughput_lock = threading.Lock()
|
|
653
|
+
|
|
654
|
+
# Exception handling
|
|
655
|
+
self.exception_queue = queue.Queue()
|
|
656
|
+
self.has_error = threading.Event()
|
|
657
|
+
|
|
658
|
+
# Start worker threads
|
|
659
|
+
self.workers = []
|
|
660
|
+
self.stop_workers = threading.Event()
|
|
661
|
+
self._start_workers()
|
|
662
|
+
|
|
663
|
+
def _start_workers(self):
|
|
664
|
+
"""Start writer worker threads"""
|
|
665
|
+
for i in range(self.num_workers):
|
|
666
|
+
worker = threading.Thread(
|
|
667
|
+
target=self._writer_worker,
|
|
668
|
+
args=(i,),
|
|
669
|
+
daemon=True
|
|
670
|
+
)
|
|
671
|
+
worker.start()
|
|
672
|
+
self.workers.append(worker)
|
|
673
|
+
logger.debug(f"Started {self.num_workers} writer threads with shared memmap")
|
|
674
|
+
|
|
675
|
+
def _writer_worker(self, worker_id: int):
|
|
676
|
+
"""Writer worker thread using shared memmap"""
|
|
677
|
+
logger.debug(f"Writer worker {worker_id} started")
|
|
678
|
+
|
|
679
|
+
while not self.stop_workers.is_set():
|
|
680
|
+
try:
|
|
681
|
+
# Get write request
|
|
682
|
+
item = self.write_queue.get(timeout=1)
|
|
683
|
+
if item is None:
|
|
684
|
+
break
|
|
685
|
+
|
|
686
|
+
batch_idx, marker_scores, cell_indices = item
|
|
687
|
+
|
|
688
|
+
# Track timing
|
|
689
|
+
start_time = time.time()
|
|
690
|
+
|
|
691
|
+
# Write directly to shared memory map
|
|
692
|
+
# Thread-safe because workers process disjoint batches (indices)
|
|
693
|
+
# cell_indices should be the absolute indices in the full matrix
|
|
694
|
+
self.shared_memmap[cell_indices] = marker_scores
|
|
695
|
+
|
|
696
|
+
# Track throughput
|
|
697
|
+
elapsed = time.time() - start_time
|
|
698
|
+
with self.throughput_lock:
|
|
699
|
+
self.throughput.record_batch(elapsed)
|
|
700
|
+
|
|
701
|
+
# Update completed count
|
|
702
|
+
with self.completed_lock:
|
|
703
|
+
self.completed_count += 1
|
|
704
|
+
|
|
705
|
+
self.write_queue.task_done()
|
|
706
|
+
|
|
707
|
+
except queue.Empty:
|
|
708
|
+
continue
|
|
709
|
+
except Exception as e:
|
|
710
|
+
error_trace = traceback.format_exc()
|
|
711
|
+
logger.error(f"Writer worker {worker_id} error: {e}\nTraceback:\n{error_trace}")
|
|
712
|
+
self.exception_queue.put((worker_id, e, error_trace))
|
|
713
|
+
self.has_error.set()
|
|
714
|
+
self.stop_workers.set() # Signal all workers to stop
|
|
715
|
+
break
|
|
716
|
+
|
|
717
|
+
# We do not close or delete the shared memmap here.
|
|
718
|
+
# Flushing is handled by the main thread or MemMapDense wrapper.
|
|
719
|
+
logger.debug(f"Writer worker {worker_id} stopping")
|
|
720
|
+
|
|
721
|
+
def reset_for_cell_type(self, cell_type: str):
|
|
722
|
+
"""Reset for processing a new cell type"""
|
|
723
|
+
self.active_cell_type = cell_type
|
|
724
|
+
with self.completed_lock:
|
|
725
|
+
self.completed_count = 0
|
|
726
|
+
with self.throughput_lock:
|
|
727
|
+
self.throughput = ComponentThroughput()
|
|
728
|
+
logger.debug(f"Reset writer throughput for {cell_type}")
|
|
729
|
+
|
|
730
|
+
def get_completed_count(self):
|
|
731
|
+
"""Get number of completed writes"""
|
|
732
|
+
with self.completed_lock:
|
|
733
|
+
return self.completed_count
|
|
734
|
+
|
|
735
|
+
def get_queue_size(self):
|
|
736
|
+
"""Get write queue size"""
|
|
737
|
+
return self.write_queue.qsize()
|
|
738
|
+
|
|
739
|
+
def check_errors(self):
|
|
740
|
+
"""Check if any worker encountered an error"""
|
|
741
|
+
if self.has_error.is_set():
|
|
742
|
+
try:
|
|
743
|
+
worker_id, exception, error_trace = self.exception_queue.get_nowait()
|
|
744
|
+
raise RuntimeError(f"Writer worker {worker_id} failed: {exception}\nOriginal traceback:\n{error_trace}") from exception
|
|
745
|
+
except queue.Empty:
|
|
746
|
+
raise RuntimeError("Writer worker failed with unknown error")
|
|
747
|
+
|
|
748
|
+
def close(self):
|
|
749
|
+
"""Close writer pool"""
|
|
750
|
+
logger.info("Closing writer pool...")
|
|
751
|
+
|
|
752
|
+
# Wait for queue to empty
|
|
753
|
+
if not self.write_queue.empty():
|
|
754
|
+
logger.info("Waiting for remaining writes...")
|
|
755
|
+
self.write_queue.join()
|
|
756
|
+
|
|
757
|
+
# Stop workers
|
|
758
|
+
self.stop_workers.set()
|
|
759
|
+
for _ in range(self.num_workers):
|
|
760
|
+
self.write_queue.put(None)
|
|
761
|
+
|
|
762
|
+
# Wait for workers to finish
|
|
763
|
+
for worker in self.workers:
|
|
764
|
+
worker.join(timeout=5)
|
|
765
|
+
|
|
766
|
+
logger.info("Writer pool closed")
|