gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,382 @@
1
+ """
2
+ JAX-optimized implementation of spatial LDSC.
3
+ """
4
+
5
+ import logging
6
+ import time
7
+ from functools import partial
8
+ from pathlib import Path
9
+
10
+ import anndata as ad
11
+ import jax
12
+ import jax.numpy as jnp
13
+ from jax import jit, vmap
14
+
15
+ from gsMap.config import SpatialLDSCConfig
16
+
17
+ from .io import (
18
+ FeatherAnnData,
19
+ generate_expected_output_filename,
20
+ load_common_resources,
21
+ load_marker_scores_memmap_format,
22
+ log_existing_result_statistics,
23
+ )
24
+ from .ldscore_quick_mode import SpatialLDSCProcessor
25
+
26
+ logger = logging.getLogger("gsMap.spatial_ldsc_jax")
27
+
28
+ # Configure JAX for optimal performance and memory efficiency
29
+ jax.config.update('jax_enable_x64', False) # Use float32 for speed and memory efficiency
30
+
31
+ # Platform selection - comment/uncomment as needed
32
+ # jax.config.update('jax_platform_name', 'cpu') # Force CPU usage
33
+ # jax.config.update('jax_platform_name', 'gpu') # Force GPU usage
34
+
35
+ # Memory configuration for environments with limited resources
36
+ # os.environ.setdefault('XLA_PYTHON_CLIENT_PREALLOCATE', 'false')
37
+ # os.environ.setdefault('XLA_PYTHON_CLIENT_MEM_FRACTION', '0.5')
38
+
39
+ # ============================================================================
40
+ # Core computational functions
41
+ # ============================================================================
42
+
43
+
44
+
45
+ @jax.profiler.annotate_function
46
+ @partial(jit, static_argnums=(0, 1))
47
+ def process_chunk_jit(n_blocks: int,
48
+ batch_size: int,
49
+ spatial_ld: jnp.ndarray,
50
+ baseline_ld_sum: jnp.ndarray,
51
+ chisq: jnp.ndarray,
52
+ N: jnp.ndarray,
53
+ baseline_ann: jnp.ndarray,
54
+ w_ld: jnp.ndarray,
55
+ Nbar: float) -> tuple[jnp.ndarray, jnp.ndarray]:
56
+ """
57
+ Process an entire chunk of spots with JIT compilation and batch processing.
58
+ Processes spots in batches to reduce memory usage.
59
+ """
60
+ def process_single_spot(spot_ld):
61
+ """Process a single spot."""
62
+ # Compute initial weights
63
+ with jax.profiler.StepTraceAnnotation("weight_computation"):
64
+ x_tot = spot_ld + baseline_ld_sum
65
+
66
+ # Aggregate for weight calculation
67
+ hsq = 10000.0 * (jnp.mean(chisq) - 1.0) / jnp.mean(x_tot * N)
68
+ hsq = jnp.clip(hsq, 0.0, 1.0)
69
+
70
+ # Compute weights efficiently
71
+ ld_clip = jnp.maximum(x_tot, 1.0)
72
+ w_ld_clip = jnp.maximum(w_ld, 1.0)
73
+ c = hsq * N / 10000.0
74
+ weights = jnp.sqrt(1.0 / (2 * jnp.square(1.0 + c * ld_clip) * w_ld_clip))
75
+
76
+ # Scale weights
77
+ weights = weights.reshape(-1, 1)
78
+ weights_scaled = weights / jnp.sum(weights)
79
+
80
+ # Apply weights and combine features
81
+ with jax.profiler.StepTraceAnnotation("feature_preparation"):
82
+ x_focal = jnp.concatenate([
83
+ (spot_ld.reshape(-1, 1) * weights_scaled),
84
+ (baseline_ann * weights_scaled)
85
+ ], axis=1)
86
+ y_weighted = chisq.reshape(-1, 1) * weights_scaled
87
+
88
+ # Reshape for block computation
89
+ n_snps_used = x_focal.shape[0]
90
+ block_size = n_snps_used // n_blocks
91
+
92
+ x_blocks = x_focal.reshape(n_blocks, block_size, -1)
93
+ y_blocks = y_weighted.reshape(n_blocks, block_size, -1)
94
+
95
+ # Compute block values
96
+ with jax.profiler.StepTraceAnnotation("block_computation"):
97
+ xty_blocks = jnp.einsum('nbp,nb->np', x_blocks, y_blocks.squeeze())
98
+ xtx_blocks = jnp.einsum('nbp,nbq->npq', x_blocks, x_blocks)
99
+
100
+ # Jackknife regression
101
+ with jax.profiler.StepTraceAnnotation("jackknife_regression"):
102
+ xty_total = jnp.sum(xty_blocks, axis=0)
103
+ xtx_total = jnp.sum(xtx_blocks, axis=0)
104
+ est = jnp.linalg.solve(xtx_total, xty_total)
105
+
106
+ # Delete-one estimates using vectorized solve
107
+ xty_del = xty_total - xty_blocks
108
+ xtx_del = xtx_total - xtx_blocks
109
+ delete_ests = jnp.linalg.solve(xtx_del, xty_del[..., None]).squeeze(-1)
110
+
111
+ # Pseudovalues and standard error
112
+ pseudovalues = n_blocks * est - (n_blocks - 1) * delete_ests
113
+ jknife_est = jnp.mean(pseudovalues, axis=0)
114
+ jknife_cov = jnp.cov(pseudovalues.T, ddof=1) / n_blocks
115
+ jknife_se = jnp.sqrt(jnp.diag(jknife_cov))
116
+
117
+ # Return spatial coefficient (first element)
118
+ return jknife_est[0] / Nbar, jknife_se[0] / Nbar
119
+
120
+ # Process in batches to reduce memory usage
121
+ n_spots = spatial_ld.shape[1]
122
+
123
+ if batch_size == 0 or batch_size >= n_spots:
124
+ # Process all spots at once (batch_size=0 means no batching)
125
+ with jax.profiler.StepTraceAnnotation("vmap_all_spots"):
126
+ betas, ses = vmap(process_single_spot, in_axes=1, out_axes=0)(spatial_ld)
127
+ else:
128
+ # Process in smaller batches
129
+ betas_list = []
130
+ ses_list = []
131
+
132
+ with jax.profiler.StepTraceAnnotation("batch_processing"):
133
+ for start_idx in range(0, n_spots, batch_size):
134
+ end_idx = min(start_idx + batch_size, n_spots)
135
+ batch_ld = spatial_ld[:, start_idx:end_idx]
136
+
137
+ with jax.profiler.StepTraceAnnotation(f"vmap_batch_{start_idx}_{end_idx}"):
138
+ batch_betas, batch_ses = vmap(process_single_spot, in_axes=1, out_axes=0)(batch_ld)
139
+ betas_list.append(batch_betas)
140
+ ses_list.append(batch_ses)
141
+
142
+ with jax.profiler.StepTraceAnnotation("concatenate_results"):
143
+ betas = jnp.concatenate(betas_list)
144
+ ses = jnp.concatenate(ses_list)
145
+
146
+ return betas, ses
147
+
148
+
149
+ @partial(jit, static_argnums=(0,))
150
+ def process_chunk_batched_jit(n_blocks: int,
151
+ spatial_ld: jnp.ndarray,
152
+ baseline_ld_sum: jnp.ndarray,
153
+ chisq: jnp.ndarray,
154
+ N: jnp.ndarray,
155
+ baseline_ann: jnp.ndarray,
156
+ w_ld: jnp.ndarray,
157
+ Nbar: float) -> tuple[jnp.ndarray, jnp.ndarray]:
158
+ """
159
+ Process an entire chunk of spots with JIT compilation and BATCHED matrix operations.
160
+
161
+ OPTIMIZATION: Uses batched matrix operations instead of vmap to improve GPU utilization.
162
+ All spots are processed simultaneously using efficient matrix operations.
163
+
164
+ Args:
165
+ n_blocks: Number of jackknife blocks
166
+ spatial_ld: (n_snps, n_spots) array of spatial LD scores
167
+ baseline_ld_sum: (n_snps,) baseline LD scores summed
168
+ chisq: (n_snps,) chi-squared statistics
169
+ N: (n_snps,) sample sizes
170
+ baseline_ann: (n_snps, n_baseline_features) baseline annotations
171
+ w_ld: (n_snps,) regression weights
172
+ Nbar: Average sample size
173
+
174
+ Returns:
175
+ betas: (n_spots,) regression coefficients
176
+ ses: (n_spots,) standard errors
177
+ """
178
+ n_snps, n_spots = spatial_ld.shape
179
+ baseline_ann.shape[1]
180
+
181
+ # Compute x_tot for all spots: (n_snps, n_spots)
182
+ x_tot = spatial_ld + baseline_ld_sum.reshape(-1, 1)
183
+
184
+ # Compute hsq for each spot: (n_spots,)
185
+ # hsq = 10000 * (mean(chisq) - 1) / mean(x_tot * N)
186
+ N_expanded = N.reshape(-1, 1) # (n_snps, 1)
187
+ x_tot_N = x_tot * N_expanded # (n_snps, n_spots)
188
+ mean_chisq = jnp.mean(chisq)
189
+ mean_x_tot_N = jnp.mean(x_tot_N, axis=0) # (n_spots,)
190
+ hsq = 10000.0 * (mean_chisq - 1.0) / mean_x_tot_N # (n_spots,)
191
+ hsq = jnp.clip(hsq, 0.0, 1.0)
192
+
193
+ # Compute weights for all spots: (n_snps, n_spots)
194
+ ld_clip = jnp.maximum(x_tot, 1.0)
195
+ w_ld_clip = jnp.maximum(w_ld.reshape(-1, 1), 1.0)
196
+ c = (hsq.reshape(1, -1) * N_expanded) / 10000.0 # (n_snps, n_spots)
197
+ weights = jnp.sqrt(1.0 / (2 * jnp.square(1.0 + c * ld_clip) * w_ld_clip))
198
+
199
+ # Normalize weights per spot
200
+ weights_sum = jnp.sum(weights, axis=0, keepdims=True) # (1, n_spots)
201
+ weights_scaled = weights / weights_sum # (n_snps, n_spots)
202
+
203
+ # Prepare features for all spots
204
+ # x_focal shape: (n_snps, n_spots, 1 + n_baseline_features)
205
+ spatial_weighted = (spatial_ld * weights_scaled)[..., None] # (n_snps, n_spots, 1)
206
+ baseline_weighted = baseline_ann[:, None, :] * weights_scaled[..., None] # (n_snps, n_spots, n_baseline)
207
+ x_focal = jnp.concatenate([spatial_weighted, baseline_weighted], axis=2)
208
+
209
+ # y_weighted: (n_snps, n_spots, 1)
210
+ y_weighted = (chisq.reshape(-1, 1) * weights_scaled)[..., None]
211
+
212
+ # Reshape for block computation
213
+ block_size = n_snps // n_blocks
214
+ n_snps_used = block_size * n_blocks
215
+
216
+ # Truncate to block-aligned size
217
+ x_focal = x_focal[:n_snps_used]
218
+ y_weighted = y_weighted[:n_snps_used]
219
+
220
+ # Reshape: (n_blocks, block_size, n_spots, n_features)
221
+ x_blocks = x_focal.reshape(n_blocks, block_size, n_spots, -1)
222
+ y_blocks = y_weighted.reshape(n_blocks, block_size, n_spots, 1)
223
+
224
+ # Compute block XtY and XtX for all spots simultaneously
225
+ # xty_blocks: (n_blocks, n_spots, n_features)
226
+ xty_blocks = jnp.einsum('nbsf,nbs->nsf', x_blocks, y_blocks.squeeze(-1))
227
+
228
+ # xtx_blocks: (n_blocks, n_spots, n_features, n_features)
229
+ xtx_blocks = jnp.einsum('nbsf,nbsg->nsfg', x_blocks, x_blocks)
230
+
231
+ # Total across blocks
232
+ xty_total = jnp.sum(xty_blocks, axis=0) # (n_spots, n_features)
233
+ xtx_total = jnp.sum(xtx_blocks, axis=0) # (n_spots, n_features, n_features)
234
+
235
+ # Solve for all spots: (n_spots, n_features)
236
+ est = jnp.linalg.solve(xtx_total, xty_total[..., None]).squeeze(-1)
237
+
238
+ # Delete-one estimates: (n_blocks, n_spots, n_features)
239
+ xty_del = xty_total - xty_blocks # (n_blocks, n_spots, n_features)
240
+ xtx_del = xtx_total - xtx_blocks # (n_blocks, n_spots, n_features, n_features)
241
+ delete_ests = jnp.linalg.solve(xtx_del, xty_del[..., None]).squeeze(-1)
242
+
243
+ # Pseudovalues: (n_blocks, n_spots, n_features)
244
+ pseudovalues = n_blocks * est - (n_blocks - 1) * delete_ests
245
+
246
+ # Jackknife estimates per spot
247
+ jknife_est = jnp.mean(pseudovalues, axis=0) # (n_spots, n_features)
248
+
249
+ # Jackknife covariance for each spot
250
+ # Center pseudovalues
251
+ pseudo_centered = pseudovalues - jknife_est # broadcast (n_blocks, n_spots, n_features)
252
+
253
+ # Covariance: (n_spots, n_features, n_features)
254
+ jknife_cov = jnp.einsum('nsf,nsg->sfg', pseudo_centered, pseudo_centered) / (n_blocks * (n_blocks - 1))
255
+
256
+ # Extract diagonal for SE: (n_spots, n_features)
257
+ jknife_se = jnp.sqrt(jnp.diagonal(jknife_cov, axis1=1, axis2=2))
258
+
259
+ # Return spatial coefficient (first feature) for all spots
260
+ return jknife_est[:, 0] / Nbar, jknife_se[:, 0] / Nbar
261
+
262
+
263
+
264
+ def wrapper_of_process_chunk_jit(*args, **kwargs):
265
+ """Wrapper to call the JIT-compiled process_chunk_jit function."""
266
+ # return process_chunk_jit(*args, **kwargs)
267
+ return process_chunk_batched_jit(*args, **kwargs)
268
+
269
+
270
+ # ============================================================================
271
+ # Main entry point
272
+ # ============================================================================
273
+
274
+ def run_spatial_ldsc_jax(config: SpatialLDSCConfig):
275
+ """
276
+ Run spatial LDSC for all traits in config.sumstats_config_dict.
277
+ """
278
+ if config.marker_score_format not in ["memmap", "h5ad", "feather"]:
279
+ raise NotImplementedError(f"Marker score format '{config.marker_score_format}' is not supported. Only 'memmap', 'h5ad', and 'feather' are supported.")
280
+
281
+ traits_to_process = list(config.sumstats_config_dict.items())
282
+ if not traits_to_process:
283
+ raise ValueError("No traits to process. config.sumstats_config_dict is empty.")
284
+
285
+ # Create output directory
286
+ output_dir = config.ldsc_save_dir
287
+ output_dir.mkdir(parents=True, exist_ok=True)
288
+
289
+ # Determine number of loader threads based on platform
290
+ n_loader_threads = 10 if jax.default_backend() == 'gpu' else 2
291
+
292
+ # Load marker scores once (format-agnostic)
293
+ logger.info(f"Loading marker scores (format: {config.marker_score_format})...")
294
+ marker_score_adata = None
295
+
296
+ try:
297
+ if config.marker_score_format == "memmap":
298
+ marker_score_adata = load_marker_scores_memmap_format(config)
299
+
300
+ elif config.marker_score_format == "feather":
301
+
302
+ feather_path = Path(config.marker_score_feather_path)
303
+ logger.info(f"Loading marker scores from Feather: {feather_path}")
304
+ # Use the specialized FeatherAnnData wrapper
305
+ marker_score_adata = FeatherAnnData(feather_path, index_col='HUMAN_GENE_SYM', transpose=True)
306
+
307
+ elif config.marker_score_format == "h5ad":
308
+ if not config.marker_score_h5ad_path:
309
+ raise ValueError("marker_score_h5ad_path must be provided when marker_score_format is 'h5ad'")
310
+
311
+ h5ad_path = Path(config.marker_score_h5ad_path)
312
+ if not h5ad_path.exists():
313
+ raise FileNotFoundError(f"Marker score H5AD file not found: {h5ad_path}")
314
+
315
+ logger.info(f"Loading marker scores from H5AD: {h5ad_path}")
316
+ marker_score_adata = ad.read_h5ad(h5ad_path, backed='r')
317
+
318
+ # Load common resources once (baseline, weights, snp_gene_weights)
319
+ baseline_ld, w_ld, snp_gene_weight_adata = load_common_resources(config)
320
+
321
+ # Initialize processor with common resources
322
+ logger.debug("Initializing processor...")
323
+ processor = SpatialLDSCProcessor(
324
+ config=config,
325
+ output_dir=output_dir,
326
+ marker_score_adata=marker_score_adata,
327
+ snp_gene_weight_adata=snp_gene_weight_adata,
328
+ baseline_ld=baseline_ld,
329
+ w_ld=w_ld,
330
+ n_loader_threads=n_loader_threads
331
+ )
332
+
333
+ try:
334
+ for idx, (trait_name, sumstats_file) in enumerate(traits_to_process):
335
+ logger.info("=" * 70)
336
+ logger.info("Running Spatial LDSC (JAX Implementation)")
337
+ logger.info(f"Project: {config.project_name}, Trait: {trait_name} ({idx+1}/{len(traits_to_process)})")
338
+ if config.sample_filter:
339
+ logger.info(f"Sample filter: {config.sample_filter}")
340
+ if config.cell_indices_range:
341
+ logger.info(f"Cell indices range: {config.cell_indices_range}")
342
+ logger.info("=" * 70)
343
+
344
+ # Check if output already exists
345
+ expected_filename = generate_expected_output_filename(config, trait_name)
346
+ if expected_filename is not None:
347
+ expected_output_path = output_dir / expected_filename
348
+ if expected_output_path.exists():
349
+ logger.info(f"Output file already exists: {expected_output_path}")
350
+ logger.info(f"Skipping trait {trait_name} ({idx+1}/{len(traits_to_process)})")
351
+
352
+ # Log statistics from existing result
353
+ log_existing_result_statistics(expected_output_path, trait_name)
354
+ continue
355
+
356
+ # Setup processor for current trait
357
+ processor.setup_trait(trait_name, sumstats_file)
358
+
359
+ # Process all chunks for current trait
360
+ start_time = time.time()
361
+ processor.process_all_chunks(wrapper_of_process_chunk_jit)
362
+
363
+ elapsed_time = time.time() - start_time
364
+ h, rem = divmod(elapsed_time, 3600)
365
+ m, s = divmod(rem, 60)
366
+ logger.info(f"Trait {trait_name} completed in {int(h)}h {int(m)}m {s:.2f}s")
367
+
368
+ finally:
369
+ # Cleanup once: close memmap/adata if needed
370
+ if marker_score_adata is not None:
371
+ logger.info("Closing marker score resources...")
372
+ # If it's our MemMap wrapper, close it explicitly
373
+ if config.marker_score_format == "memmap" and 'memmap_manager' in marker_score_adata.uns:
374
+ marker_score_adata.uns['memmap_manager'].close()
375
+ # If it's backed AnnData, close the file
376
+ if config.marker_score_format == "h5ad" and marker_score_adata.isbacked:
377
+ marker_score_adata.file.close()
378
+
379
+ except Exception as e:
380
+ logger.error(f"An error occurred during execution: {e}")
381
+ raise
382
+