gsMap3D 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/__init__.py +13 -0
- gsMap/__main__.py +4 -0
- gsMap/cauchy_combination_test.py +342 -0
- gsMap/cli.py +355 -0
- gsMap/config/__init__.py +72 -0
- gsMap/config/base.py +296 -0
- gsMap/config/cauchy_config.py +79 -0
- gsMap/config/dataclasses.py +235 -0
- gsMap/config/decorators.py +302 -0
- gsMap/config/find_latent_config.py +276 -0
- gsMap/config/format_sumstats_config.py +54 -0
- gsMap/config/latent2gene_config.py +461 -0
- gsMap/config/ldscore_config.py +261 -0
- gsMap/config/quick_mode_config.py +242 -0
- gsMap/config/report_config.py +81 -0
- gsMap/config/spatial_ldsc_config.py +334 -0
- gsMap/config/utils.py +286 -0
- gsMap/find_latent/__init__.py +3 -0
- gsMap/find_latent/find_latent_representation.py +312 -0
- gsMap/find_latent/gnn/distribution.py +498 -0
- gsMap/find_latent/gnn/encoder_decoder.py +186 -0
- gsMap/find_latent/gnn/gcn.py +85 -0
- gsMap/find_latent/gnn/gene_former.py +164 -0
- gsMap/find_latent/gnn/loss.py +18 -0
- gsMap/find_latent/gnn/st_model.py +125 -0
- gsMap/find_latent/gnn/train_step.py +177 -0
- gsMap/find_latent/st_process.py +781 -0
- gsMap/format_sumstats.py +446 -0
- gsMap/generate_ldscore.py +1018 -0
- gsMap/latent2gene/__init__.py +18 -0
- gsMap/latent2gene/connectivity.py +781 -0
- gsMap/latent2gene/entry_point.py +141 -0
- gsMap/latent2gene/marker_scores.py +1265 -0
- gsMap/latent2gene/memmap_io.py +766 -0
- gsMap/latent2gene/rank_calculator.py +590 -0
- gsMap/latent2gene/row_ordering.py +182 -0
- gsMap/latent2gene/row_ordering_jax.py +159 -0
- gsMap/ldscore/__init__.py +1 -0
- gsMap/ldscore/batch_construction.py +163 -0
- gsMap/ldscore/compute.py +126 -0
- gsMap/ldscore/constants.py +70 -0
- gsMap/ldscore/io.py +262 -0
- gsMap/ldscore/mapping.py +262 -0
- gsMap/ldscore/pipeline.py +615 -0
- gsMap/pipeline/quick_mode.py +134 -0
- gsMap/report/__init__.py +2 -0
- gsMap/report/diagnosis.py +375 -0
- gsMap/report/report.py +100 -0
- gsMap/report/report_data.py +1832 -0
- gsMap/report/static/js_lib/alpine.min.js +5 -0
- gsMap/report/static/js_lib/tailwindcss.js +83 -0
- gsMap/report/static/template.html +2242 -0
- gsMap/report/three_d_combine.py +312 -0
- gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
- gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
- gsMap/report/three_d_plot/three_d_plots.py +425 -0
- gsMap/report/visualize.py +1409 -0
- gsMap/setup.py +5 -0
- gsMap/spatial_ldsc/__init__.py +0 -0
- gsMap/spatial_ldsc/io.py +656 -0
- gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
- gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
- gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +610 -0
- gsMap/utils/jackknife.py +518 -0
- gsMap/utils/manhattan_plot.py +643 -0
- gsMap/utils/regression_read.py +177 -0
- gsMap/utils/torch_utils.py +23 -0
- gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
- gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
- gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
- gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
- gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,382 @@
|
|
|
1
|
+
"""
|
|
2
|
+
JAX-optimized implementation of spatial LDSC.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import time
|
|
7
|
+
from functools import partial
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
import anndata as ad
|
|
11
|
+
import jax
|
|
12
|
+
import jax.numpy as jnp
|
|
13
|
+
from jax import jit, vmap
|
|
14
|
+
|
|
15
|
+
from gsMap.config import SpatialLDSCConfig
|
|
16
|
+
|
|
17
|
+
from .io import (
|
|
18
|
+
FeatherAnnData,
|
|
19
|
+
generate_expected_output_filename,
|
|
20
|
+
load_common_resources,
|
|
21
|
+
load_marker_scores_memmap_format,
|
|
22
|
+
log_existing_result_statistics,
|
|
23
|
+
)
|
|
24
|
+
from .ldscore_quick_mode import SpatialLDSCProcessor
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger("gsMap.spatial_ldsc_jax")
|
|
27
|
+
|
|
28
|
+
# Configure JAX for optimal performance and memory efficiency
|
|
29
|
+
jax.config.update('jax_enable_x64', False) # Use float32 for speed and memory efficiency
|
|
30
|
+
|
|
31
|
+
# Platform selection - comment/uncomment as needed
|
|
32
|
+
# jax.config.update('jax_platform_name', 'cpu') # Force CPU usage
|
|
33
|
+
# jax.config.update('jax_platform_name', 'gpu') # Force GPU usage
|
|
34
|
+
|
|
35
|
+
# Memory configuration for environments with limited resources
|
|
36
|
+
# os.environ.setdefault('XLA_PYTHON_CLIENT_PREALLOCATE', 'false')
|
|
37
|
+
# os.environ.setdefault('XLA_PYTHON_CLIENT_MEM_FRACTION', '0.5')
|
|
38
|
+
|
|
39
|
+
# ============================================================================
|
|
40
|
+
# Core computational functions
|
|
41
|
+
# ============================================================================
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@jax.profiler.annotate_function
|
|
46
|
+
@partial(jit, static_argnums=(0, 1))
|
|
47
|
+
def process_chunk_jit(n_blocks: int,
|
|
48
|
+
batch_size: int,
|
|
49
|
+
spatial_ld: jnp.ndarray,
|
|
50
|
+
baseline_ld_sum: jnp.ndarray,
|
|
51
|
+
chisq: jnp.ndarray,
|
|
52
|
+
N: jnp.ndarray,
|
|
53
|
+
baseline_ann: jnp.ndarray,
|
|
54
|
+
w_ld: jnp.ndarray,
|
|
55
|
+
Nbar: float) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
56
|
+
"""
|
|
57
|
+
Process an entire chunk of spots with JIT compilation and batch processing.
|
|
58
|
+
Processes spots in batches to reduce memory usage.
|
|
59
|
+
"""
|
|
60
|
+
def process_single_spot(spot_ld):
|
|
61
|
+
"""Process a single spot."""
|
|
62
|
+
# Compute initial weights
|
|
63
|
+
with jax.profiler.StepTraceAnnotation("weight_computation"):
|
|
64
|
+
x_tot = spot_ld + baseline_ld_sum
|
|
65
|
+
|
|
66
|
+
# Aggregate for weight calculation
|
|
67
|
+
hsq = 10000.0 * (jnp.mean(chisq) - 1.0) / jnp.mean(x_tot * N)
|
|
68
|
+
hsq = jnp.clip(hsq, 0.0, 1.0)
|
|
69
|
+
|
|
70
|
+
# Compute weights efficiently
|
|
71
|
+
ld_clip = jnp.maximum(x_tot, 1.0)
|
|
72
|
+
w_ld_clip = jnp.maximum(w_ld, 1.0)
|
|
73
|
+
c = hsq * N / 10000.0
|
|
74
|
+
weights = jnp.sqrt(1.0 / (2 * jnp.square(1.0 + c * ld_clip) * w_ld_clip))
|
|
75
|
+
|
|
76
|
+
# Scale weights
|
|
77
|
+
weights = weights.reshape(-1, 1)
|
|
78
|
+
weights_scaled = weights / jnp.sum(weights)
|
|
79
|
+
|
|
80
|
+
# Apply weights and combine features
|
|
81
|
+
with jax.profiler.StepTraceAnnotation("feature_preparation"):
|
|
82
|
+
x_focal = jnp.concatenate([
|
|
83
|
+
(spot_ld.reshape(-1, 1) * weights_scaled),
|
|
84
|
+
(baseline_ann * weights_scaled)
|
|
85
|
+
], axis=1)
|
|
86
|
+
y_weighted = chisq.reshape(-1, 1) * weights_scaled
|
|
87
|
+
|
|
88
|
+
# Reshape for block computation
|
|
89
|
+
n_snps_used = x_focal.shape[0]
|
|
90
|
+
block_size = n_snps_used // n_blocks
|
|
91
|
+
|
|
92
|
+
x_blocks = x_focal.reshape(n_blocks, block_size, -1)
|
|
93
|
+
y_blocks = y_weighted.reshape(n_blocks, block_size, -1)
|
|
94
|
+
|
|
95
|
+
# Compute block values
|
|
96
|
+
with jax.profiler.StepTraceAnnotation("block_computation"):
|
|
97
|
+
xty_blocks = jnp.einsum('nbp,nb->np', x_blocks, y_blocks.squeeze())
|
|
98
|
+
xtx_blocks = jnp.einsum('nbp,nbq->npq', x_blocks, x_blocks)
|
|
99
|
+
|
|
100
|
+
# Jackknife regression
|
|
101
|
+
with jax.profiler.StepTraceAnnotation("jackknife_regression"):
|
|
102
|
+
xty_total = jnp.sum(xty_blocks, axis=0)
|
|
103
|
+
xtx_total = jnp.sum(xtx_blocks, axis=0)
|
|
104
|
+
est = jnp.linalg.solve(xtx_total, xty_total)
|
|
105
|
+
|
|
106
|
+
# Delete-one estimates using vectorized solve
|
|
107
|
+
xty_del = xty_total - xty_blocks
|
|
108
|
+
xtx_del = xtx_total - xtx_blocks
|
|
109
|
+
delete_ests = jnp.linalg.solve(xtx_del, xty_del[..., None]).squeeze(-1)
|
|
110
|
+
|
|
111
|
+
# Pseudovalues and standard error
|
|
112
|
+
pseudovalues = n_blocks * est - (n_blocks - 1) * delete_ests
|
|
113
|
+
jknife_est = jnp.mean(pseudovalues, axis=0)
|
|
114
|
+
jknife_cov = jnp.cov(pseudovalues.T, ddof=1) / n_blocks
|
|
115
|
+
jknife_se = jnp.sqrt(jnp.diag(jknife_cov))
|
|
116
|
+
|
|
117
|
+
# Return spatial coefficient (first element)
|
|
118
|
+
return jknife_est[0] / Nbar, jknife_se[0] / Nbar
|
|
119
|
+
|
|
120
|
+
# Process in batches to reduce memory usage
|
|
121
|
+
n_spots = spatial_ld.shape[1]
|
|
122
|
+
|
|
123
|
+
if batch_size == 0 or batch_size >= n_spots:
|
|
124
|
+
# Process all spots at once (batch_size=0 means no batching)
|
|
125
|
+
with jax.profiler.StepTraceAnnotation("vmap_all_spots"):
|
|
126
|
+
betas, ses = vmap(process_single_spot, in_axes=1, out_axes=0)(spatial_ld)
|
|
127
|
+
else:
|
|
128
|
+
# Process in smaller batches
|
|
129
|
+
betas_list = []
|
|
130
|
+
ses_list = []
|
|
131
|
+
|
|
132
|
+
with jax.profiler.StepTraceAnnotation("batch_processing"):
|
|
133
|
+
for start_idx in range(0, n_spots, batch_size):
|
|
134
|
+
end_idx = min(start_idx + batch_size, n_spots)
|
|
135
|
+
batch_ld = spatial_ld[:, start_idx:end_idx]
|
|
136
|
+
|
|
137
|
+
with jax.profiler.StepTraceAnnotation(f"vmap_batch_{start_idx}_{end_idx}"):
|
|
138
|
+
batch_betas, batch_ses = vmap(process_single_spot, in_axes=1, out_axes=0)(batch_ld)
|
|
139
|
+
betas_list.append(batch_betas)
|
|
140
|
+
ses_list.append(batch_ses)
|
|
141
|
+
|
|
142
|
+
with jax.profiler.StepTraceAnnotation("concatenate_results"):
|
|
143
|
+
betas = jnp.concatenate(betas_list)
|
|
144
|
+
ses = jnp.concatenate(ses_list)
|
|
145
|
+
|
|
146
|
+
return betas, ses
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
@partial(jit, static_argnums=(0,))
|
|
150
|
+
def process_chunk_batched_jit(n_blocks: int,
|
|
151
|
+
spatial_ld: jnp.ndarray,
|
|
152
|
+
baseline_ld_sum: jnp.ndarray,
|
|
153
|
+
chisq: jnp.ndarray,
|
|
154
|
+
N: jnp.ndarray,
|
|
155
|
+
baseline_ann: jnp.ndarray,
|
|
156
|
+
w_ld: jnp.ndarray,
|
|
157
|
+
Nbar: float) -> tuple[jnp.ndarray, jnp.ndarray]:
|
|
158
|
+
"""
|
|
159
|
+
Process an entire chunk of spots with JIT compilation and BATCHED matrix operations.
|
|
160
|
+
|
|
161
|
+
OPTIMIZATION: Uses batched matrix operations instead of vmap to improve GPU utilization.
|
|
162
|
+
All spots are processed simultaneously using efficient matrix operations.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
n_blocks: Number of jackknife blocks
|
|
166
|
+
spatial_ld: (n_snps, n_spots) array of spatial LD scores
|
|
167
|
+
baseline_ld_sum: (n_snps,) baseline LD scores summed
|
|
168
|
+
chisq: (n_snps,) chi-squared statistics
|
|
169
|
+
N: (n_snps,) sample sizes
|
|
170
|
+
baseline_ann: (n_snps, n_baseline_features) baseline annotations
|
|
171
|
+
w_ld: (n_snps,) regression weights
|
|
172
|
+
Nbar: Average sample size
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
betas: (n_spots,) regression coefficients
|
|
176
|
+
ses: (n_spots,) standard errors
|
|
177
|
+
"""
|
|
178
|
+
n_snps, n_spots = spatial_ld.shape
|
|
179
|
+
baseline_ann.shape[1]
|
|
180
|
+
|
|
181
|
+
# Compute x_tot for all spots: (n_snps, n_spots)
|
|
182
|
+
x_tot = spatial_ld + baseline_ld_sum.reshape(-1, 1)
|
|
183
|
+
|
|
184
|
+
# Compute hsq for each spot: (n_spots,)
|
|
185
|
+
# hsq = 10000 * (mean(chisq) - 1) / mean(x_tot * N)
|
|
186
|
+
N_expanded = N.reshape(-1, 1) # (n_snps, 1)
|
|
187
|
+
x_tot_N = x_tot * N_expanded # (n_snps, n_spots)
|
|
188
|
+
mean_chisq = jnp.mean(chisq)
|
|
189
|
+
mean_x_tot_N = jnp.mean(x_tot_N, axis=0) # (n_spots,)
|
|
190
|
+
hsq = 10000.0 * (mean_chisq - 1.0) / mean_x_tot_N # (n_spots,)
|
|
191
|
+
hsq = jnp.clip(hsq, 0.0, 1.0)
|
|
192
|
+
|
|
193
|
+
# Compute weights for all spots: (n_snps, n_spots)
|
|
194
|
+
ld_clip = jnp.maximum(x_tot, 1.0)
|
|
195
|
+
w_ld_clip = jnp.maximum(w_ld.reshape(-1, 1), 1.0)
|
|
196
|
+
c = (hsq.reshape(1, -1) * N_expanded) / 10000.0 # (n_snps, n_spots)
|
|
197
|
+
weights = jnp.sqrt(1.0 / (2 * jnp.square(1.0 + c * ld_clip) * w_ld_clip))
|
|
198
|
+
|
|
199
|
+
# Normalize weights per spot
|
|
200
|
+
weights_sum = jnp.sum(weights, axis=0, keepdims=True) # (1, n_spots)
|
|
201
|
+
weights_scaled = weights / weights_sum # (n_snps, n_spots)
|
|
202
|
+
|
|
203
|
+
# Prepare features for all spots
|
|
204
|
+
# x_focal shape: (n_snps, n_spots, 1 + n_baseline_features)
|
|
205
|
+
spatial_weighted = (spatial_ld * weights_scaled)[..., None] # (n_snps, n_spots, 1)
|
|
206
|
+
baseline_weighted = baseline_ann[:, None, :] * weights_scaled[..., None] # (n_snps, n_spots, n_baseline)
|
|
207
|
+
x_focal = jnp.concatenate([spatial_weighted, baseline_weighted], axis=2)
|
|
208
|
+
|
|
209
|
+
# y_weighted: (n_snps, n_spots, 1)
|
|
210
|
+
y_weighted = (chisq.reshape(-1, 1) * weights_scaled)[..., None]
|
|
211
|
+
|
|
212
|
+
# Reshape for block computation
|
|
213
|
+
block_size = n_snps // n_blocks
|
|
214
|
+
n_snps_used = block_size * n_blocks
|
|
215
|
+
|
|
216
|
+
# Truncate to block-aligned size
|
|
217
|
+
x_focal = x_focal[:n_snps_used]
|
|
218
|
+
y_weighted = y_weighted[:n_snps_used]
|
|
219
|
+
|
|
220
|
+
# Reshape: (n_blocks, block_size, n_spots, n_features)
|
|
221
|
+
x_blocks = x_focal.reshape(n_blocks, block_size, n_spots, -1)
|
|
222
|
+
y_blocks = y_weighted.reshape(n_blocks, block_size, n_spots, 1)
|
|
223
|
+
|
|
224
|
+
# Compute block XtY and XtX for all spots simultaneously
|
|
225
|
+
# xty_blocks: (n_blocks, n_spots, n_features)
|
|
226
|
+
xty_blocks = jnp.einsum('nbsf,nbs->nsf', x_blocks, y_blocks.squeeze(-1))
|
|
227
|
+
|
|
228
|
+
# xtx_blocks: (n_blocks, n_spots, n_features, n_features)
|
|
229
|
+
xtx_blocks = jnp.einsum('nbsf,nbsg->nsfg', x_blocks, x_blocks)
|
|
230
|
+
|
|
231
|
+
# Total across blocks
|
|
232
|
+
xty_total = jnp.sum(xty_blocks, axis=0) # (n_spots, n_features)
|
|
233
|
+
xtx_total = jnp.sum(xtx_blocks, axis=0) # (n_spots, n_features, n_features)
|
|
234
|
+
|
|
235
|
+
# Solve for all spots: (n_spots, n_features)
|
|
236
|
+
est = jnp.linalg.solve(xtx_total, xty_total[..., None]).squeeze(-1)
|
|
237
|
+
|
|
238
|
+
# Delete-one estimates: (n_blocks, n_spots, n_features)
|
|
239
|
+
xty_del = xty_total - xty_blocks # (n_blocks, n_spots, n_features)
|
|
240
|
+
xtx_del = xtx_total - xtx_blocks # (n_blocks, n_spots, n_features, n_features)
|
|
241
|
+
delete_ests = jnp.linalg.solve(xtx_del, xty_del[..., None]).squeeze(-1)
|
|
242
|
+
|
|
243
|
+
# Pseudovalues: (n_blocks, n_spots, n_features)
|
|
244
|
+
pseudovalues = n_blocks * est - (n_blocks - 1) * delete_ests
|
|
245
|
+
|
|
246
|
+
# Jackknife estimates per spot
|
|
247
|
+
jknife_est = jnp.mean(pseudovalues, axis=0) # (n_spots, n_features)
|
|
248
|
+
|
|
249
|
+
# Jackknife covariance for each spot
|
|
250
|
+
# Center pseudovalues
|
|
251
|
+
pseudo_centered = pseudovalues - jknife_est # broadcast (n_blocks, n_spots, n_features)
|
|
252
|
+
|
|
253
|
+
# Covariance: (n_spots, n_features, n_features)
|
|
254
|
+
jknife_cov = jnp.einsum('nsf,nsg->sfg', pseudo_centered, pseudo_centered) / (n_blocks * (n_blocks - 1))
|
|
255
|
+
|
|
256
|
+
# Extract diagonal for SE: (n_spots, n_features)
|
|
257
|
+
jknife_se = jnp.sqrt(jnp.diagonal(jknife_cov, axis1=1, axis2=2))
|
|
258
|
+
|
|
259
|
+
# Return spatial coefficient (first feature) for all spots
|
|
260
|
+
return jknife_est[:, 0] / Nbar, jknife_se[:, 0] / Nbar
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def wrapper_of_process_chunk_jit(*args, **kwargs):
|
|
265
|
+
"""Wrapper to call the JIT-compiled process_chunk_jit function."""
|
|
266
|
+
# return process_chunk_jit(*args, **kwargs)
|
|
267
|
+
return process_chunk_batched_jit(*args, **kwargs)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
# ============================================================================
|
|
271
|
+
# Main entry point
|
|
272
|
+
# ============================================================================
|
|
273
|
+
|
|
274
|
+
def run_spatial_ldsc_jax(config: SpatialLDSCConfig):
|
|
275
|
+
"""
|
|
276
|
+
Run spatial LDSC for all traits in config.sumstats_config_dict.
|
|
277
|
+
"""
|
|
278
|
+
if config.marker_score_format not in ["memmap", "h5ad", "feather"]:
|
|
279
|
+
raise NotImplementedError(f"Marker score format '{config.marker_score_format}' is not supported. Only 'memmap', 'h5ad', and 'feather' are supported.")
|
|
280
|
+
|
|
281
|
+
traits_to_process = list(config.sumstats_config_dict.items())
|
|
282
|
+
if not traits_to_process:
|
|
283
|
+
raise ValueError("No traits to process. config.sumstats_config_dict is empty.")
|
|
284
|
+
|
|
285
|
+
# Create output directory
|
|
286
|
+
output_dir = config.ldsc_save_dir
|
|
287
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
288
|
+
|
|
289
|
+
# Determine number of loader threads based on platform
|
|
290
|
+
n_loader_threads = 10 if jax.default_backend() == 'gpu' else 2
|
|
291
|
+
|
|
292
|
+
# Load marker scores once (format-agnostic)
|
|
293
|
+
logger.info(f"Loading marker scores (format: {config.marker_score_format})...")
|
|
294
|
+
marker_score_adata = None
|
|
295
|
+
|
|
296
|
+
try:
|
|
297
|
+
if config.marker_score_format == "memmap":
|
|
298
|
+
marker_score_adata = load_marker_scores_memmap_format(config)
|
|
299
|
+
|
|
300
|
+
elif config.marker_score_format == "feather":
|
|
301
|
+
|
|
302
|
+
feather_path = Path(config.marker_score_feather_path)
|
|
303
|
+
logger.info(f"Loading marker scores from Feather: {feather_path}")
|
|
304
|
+
# Use the specialized FeatherAnnData wrapper
|
|
305
|
+
marker_score_adata = FeatherAnnData(feather_path, index_col='HUMAN_GENE_SYM', transpose=True)
|
|
306
|
+
|
|
307
|
+
elif config.marker_score_format == "h5ad":
|
|
308
|
+
if not config.marker_score_h5ad_path:
|
|
309
|
+
raise ValueError("marker_score_h5ad_path must be provided when marker_score_format is 'h5ad'")
|
|
310
|
+
|
|
311
|
+
h5ad_path = Path(config.marker_score_h5ad_path)
|
|
312
|
+
if not h5ad_path.exists():
|
|
313
|
+
raise FileNotFoundError(f"Marker score H5AD file not found: {h5ad_path}")
|
|
314
|
+
|
|
315
|
+
logger.info(f"Loading marker scores from H5AD: {h5ad_path}")
|
|
316
|
+
marker_score_adata = ad.read_h5ad(h5ad_path, backed='r')
|
|
317
|
+
|
|
318
|
+
# Load common resources once (baseline, weights, snp_gene_weights)
|
|
319
|
+
baseline_ld, w_ld, snp_gene_weight_adata = load_common_resources(config)
|
|
320
|
+
|
|
321
|
+
# Initialize processor with common resources
|
|
322
|
+
logger.debug("Initializing processor...")
|
|
323
|
+
processor = SpatialLDSCProcessor(
|
|
324
|
+
config=config,
|
|
325
|
+
output_dir=output_dir,
|
|
326
|
+
marker_score_adata=marker_score_adata,
|
|
327
|
+
snp_gene_weight_adata=snp_gene_weight_adata,
|
|
328
|
+
baseline_ld=baseline_ld,
|
|
329
|
+
w_ld=w_ld,
|
|
330
|
+
n_loader_threads=n_loader_threads
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
try:
|
|
334
|
+
for idx, (trait_name, sumstats_file) in enumerate(traits_to_process):
|
|
335
|
+
logger.info("=" * 70)
|
|
336
|
+
logger.info("Running Spatial LDSC (JAX Implementation)")
|
|
337
|
+
logger.info(f"Project: {config.project_name}, Trait: {trait_name} ({idx+1}/{len(traits_to_process)})")
|
|
338
|
+
if config.sample_filter:
|
|
339
|
+
logger.info(f"Sample filter: {config.sample_filter}")
|
|
340
|
+
if config.cell_indices_range:
|
|
341
|
+
logger.info(f"Cell indices range: {config.cell_indices_range}")
|
|
342
|
+
logger.info("=" * 70)
|
|
343
|
+
|
|
344
|
+
# Check if output already exists
|
|
345
|
+
expected_filename = generate_expected_output_filename(config, trait_name)
|
|
346
|
+
if expected_filename is not None:
|
|
347
|
+
expected_output_path = output_dir / expected_filename
|
|
348
|
+
if expected_output_path.exists():
|
|
349
|
+
logger.info(f"Output file already exists: {expected_output_path}")
|
|
350
|
+
logger.info(f"Skipping trait {trait_name} ({idx+1}/{len(traits_to_process)})")
|
|
351
|
+
|
|
352
|
+
# Log statistics from existing result
|
|
353
|
+
log_existing_result_statistics(expected_output_path, trait_name)
|
|
354
|
+
continue
|
|
355
|
+
|
|
356
|
+
# Setup processor for current trait
|
|
357
|
+
processor.setup_trait(trait_name, sumstats_file)
|
|
358
|
+
|
|
359
|
+
# Process all chunks for current trait
|
|
360
|
+
start_time = time.time()
|
|
361
|
+
processor.process_all_chunks(wrapper_of_process_chunk_jit)
|
|
362
|
+
|
|
363
|
+
elapsed_time = time.time() - start_time
|
|
364
|
+
h, rem = divmod(elapsed_time, 3600)
|
|
365
|
+
m, s = divmod(rem, 60)
|
|
366
|
+
logger.info(f"Trait {trait_name} completed in {int(h)}h {int(m)}m {s:.2f}s")
|
|
367
|
+
|
|
368
|
+
finally:
|
|
369
|
+
# Cleanup once: close memmap/adata if needed
|
|
370
|
+
if marker_score_adata is not None:
|
|
371
|
+
logger.info("Closing marker score resources...")
|
|
372
|
+
# If it's our MemMap wrapper, close it explicitly
|
|
373
|
+
if config.marker_score_format == "memmap" and 'memmap_manager' in marker_score_adata.uns:
|
|
374
|
+
marker_score_adata.uns['memmap_manager'].close()
|
|
375
|
+
# If it's backed AnnData, close the file
|
|
376
|
+
if config.marker_score_format == "h5ad" and marker_score_adata.isbacked:
|
|
377
|
+
marker_score_adata.file.close()
|
|
378
|
+
|
|
379
|
+
except Exception as e:
|
|
380
|
+
logger.error(f"An error occurred during execution: {e}")
|
|
381
|
+
raise
|
|
382
|
+
|