gsMap3D 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/__init__.py +13 -0
- gsMap/__main__.py +4 -0
- gsMap/cauchy_combination_test.py +342 -0
- gsMap/cli.py +355 -0
- gsMap/config/__init__.py +72 -0
- gsMap/config/base.py +296 -0
- gsMap/config/cauchy_config.py +79 -0
- gsMap/config/dataclasses.py +235 -0
- gsMap/config/decorators.py +302 -0
- gsMap/config/find_latent_config.py +276 -0
- gsMap/config/format_sumstats_config.py +54 -0
- gsMap/config/latent2gene_config.py +461 -0
- gsMap/config/ldscore_config.py +261 -0
- gsMap/config/quick_mode_config.py +242 -0
- gsMap/config/report_config.py +81 -0
- gsMap/config/spatial_ldsc_config.py +334 -0
- gsMap/config/utils.py +286 -0
- gsMap/find_latent/__init__.py +3 -0
- gsMap/find_latent/find_latent_representation.py +312 -0
- gsMap/find_latent/gnn/distribution.py +498 -0
- gsMap/find_latent/gnn/encoder_decoder.py +186 -0
- gsMap/find_latent/gnn/gcn.py +85 -0
- gsMap/find_latent/gnn/gene_former.py +164 -0
- gsMap/find_latent/gnn/loss.py +18 -0
- gsMap/find_latent/gnn/st_model.py +125 -0
- gsMap/find_latent/gnn/train_step.py +177 -0
- gsMap/find_latent/st_process.py +781 -0
- gsMap/format_sumstats.py +446 -0
- gsMap/generate_ldscore.py +1018 -0
- gsMap/latent2gene/__init__.py +18 -0
- gsMap/latent2gene/connectivity.py +781 -0
- gsMap/latent2gene/entry_point.py +141 -0
- gsMap/latent2gene/marker_scores.py +1265 -0
- gsMap/latent2gene/memmap_io.py +766 -0
- gsMap/latent2gene/rank_calculator.py +590 -0
- gsMap/latent2gene/row_ordering.py +182 -0
- gsMap/latent2gene/row_ordering_jax.py +159 -0
- gsMap/ldscore/__init__.py +1 -0
- gsMap/ldscore/batch_construction.py +163 -0
- gsMap/ldscore/compute.py +126 -0
- gsMap/ldscore/constants.py +70 -0
- gsMap/ldscore/io.py +262 -0
- gsMap/ldscore/mapping.py +262 -0
- gsMap/ldscore/pipeline.py +615 -0
- gsMap/pipeline/quick_mode.py +134 -0
- gsMap/report/__init__.py +2 -0
- gsMap/report/diagnosis.py +375 -0
- gsMap/report/report.py +100 -0
- gsMap/report/report_data.py +1832 -0
- gsMap/report/static/js_lib/alpine.min.js +5 -0
- gsMap/report/static/js_lib/tailwindcss.js +83 -0
- gsMap/report/static/template.html +2242 -0
- gsMap/report/three_d_combine.py +312 -0
- gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
- gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
- gsMap/report/three_d_plot/three_d_plots.py +425 -0
- gsMap/report/visualize.py +1409 -0
- gsMap/setup.py +5 -0
- gsMap/spatial_ldsc/__init__.py +0 -0
- gsMap/spatial_ldsc/io.py +656 -0
- gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
- gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
- gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +610 -0
- gsMap/utils/jackknife.py +518 -0
- gsMap/utils/manhattan_plot.py +643 -0
- gsMap/utils/regression_read.py +177 -0
- gsMap/utils/torch_utils.py +23 -0
- gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
- gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
- gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
- gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
- gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Row ordering optimization for cache efficiency
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from numba import njit
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@njit
|
|
14
|
+
def compute_jaccard_similarity(neighbors_a: np.ndarray, neighbors_b: np.ndarray) -> float:
|
|
15
|
+
"""Compute Jaccard similarity between two neighbor sets"""
|
|
16
|
+
set_a = set(neighbors_a)
|
|
17
|
+
set_b = set(neighbors_b)
|
|
18
|
+
intersection = len(set_a & set_b)
|
|
19
|
+
union = len(set_a | set_b)
|
|
20
|
+
return intersection / union if union > 0 else 0.0
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def optimize_row_order(
|
|
24
|
+
neighbor_indices: np.ndarray,
|
|
25
|
+
cell_indices: np.ndarray,
|
|
26
|
+
method: str | None = None,
|
|
27
|
+
neighbor_weights: np.ndarray | None = None
|
|
28
|
+
) -> np.ndarray:
|
|
29
|
+
"""
|
|
30
|
+
Sort rows by shared neighbors to improve cache locality
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
neighbor_indices: (n_cells, k) array of neighbor indices (global indices)
|
|
34
|
+
cell_indices: (n_cells,) array of global indices for each cell
|
|
35
|
+
method: None (auto), 'weighted', 'greedy', or 'none'
|
|
36
|
+
neighbor_weights: Optional (n_cells, k) array of weights for each neighbor
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
Reordered row indices (local indices 0 to n_cells-1)
|
|
40
|
+
|
|
41
|
+
Complexity:
|
|
42
|
+
- weighted: O(n*k) where k is number of neighbors - very efficient!
|
|
43
|
+
- greedy: O(n²) - only for very small datasets
|
|
44
|
+
- none: O(1) - returns original order
|
|
45
|
+
"""
|
|
46
|
+
n_cells = len(neighbor_indices)
|
|
47
|
+
|
|
48
|
+
# Create mapping from global to local indices
|
|
49
|
+
global_to_local = {global_idx: local_idx for local_idx, global_idx in enumerate(cell_indices)}
|
|
50
|
+
|
|
51
|
+
# Auto-select method if None
|
|
52
|
+
if method is None:
|
|
53
|
+
if neighbor_weights is not None and n_cells > 2000:
|
|
54
|
+
method = 'weighted'
|
|
55
|
+
else:
|
|
56
|
+
method = 'greedy'
|
|
57
|
+
|
|
58
|
+
if method == 'weighted' and neighbor_weights is not None:
|
|
59
|
+
# Efficient weighted heuristic: follow highest-weight neighbors
|
|
60
|
+
visited = np.zeros(n_cells, dtype=bool)
|
|
61
|
+
ordered = []
|
|
62
|
+
|
|
63
|
+
# Start with cell that has highest max weight to any single neighbor
|
|
64
|
+
max_weights = neighbor_weights.max(axis=1)
|
|
65
|
+
current = np.argmax(max_weights)
|
|
66
|
+
|
|
67
|
+
ordered.append(current)
|
|
68
|
+
visited[current] = True
|
|
69
|
+
|
|
70
|
+
# Build reverse lookup (using local indices)
|
|
71
|
+
reverse_neighbors = [[] for _ in range(n_cells)]
|
|
72
|
+
for i in range(n_cells):
|
|
73
|
+
for j, neighbor_global_idx in enumerate(neighbor_indices[i]):
|
|
74
|
+
if neighbor_global_idx in global_to_local:
|
|
75
|
+
neighbor_local_idx = global_to_local[neighbor_global_idx]
|
|
76
|
+
reverse_neighbors[neighbor_local_idx].append((i, neighbor_weights[i, j]))
|
|
77
|
+
|
|
78
|
+
# Process all cells
|
|
79
|
+
for _ in range(n_cells - 1):
|
|
80
|
+
neighbors = neighbor_indices[current]
|
|
81
|
+
weights = neighbor_weights[current]
|
|
82
|
+
|
|
83
|
+
# Sort neighbors by weight (highest first)
|
|
84
|
+
sorted_idx = np.argsort(weights)[::-1]
|
|
85
|
+
|
|
86
|
+
# Find the unvisited neighbor with highest weight
|
|
87
|
+
next_cell = None
|
|
88
|
+
best_weight = -1
|
|
89
|
+
|
|
90
|
+
for idx in sorted_idx:
|
|
91
|
+
neighbor_global_idx = neighbors[idx]
|
|
92
|
+
if neighbor_global_idx in global_to_local:
|
|
93
|
+
neighbor_local_idx = global_to_local[neighbor_global_idx]
|
|
94
|
+
if not visited[neighbor_local_idx]:
|
|
95
|
+
if weights[idx] > best_weight:
|
|
96
|
+
best_weight = weights[idx]
|
|
97
|
+
next_cell = neighbor_local_idx
|
|
98
|
+
|
|
99
|
+
# If no unvisited direct neighbors, find closest unvisited cell
|
|
100
|
+
if next_cell is None:
|
|
101
|
+
connection_scores = np.zeros(n_cells)
|
|
102
|
+
|
|
103
|
+
# Check connections from last few visited cells
|
|
104
|
+
for cell_idx in ordered[-min(10, len(ordered)):]:
|
|
105
|
+
# Add forward connections
|
|
106
|
+
for j, neighbor_global_idx in enumerate(neighbor_indices[cell_idx]):
|
|
107
|
+
if neighbor_global_idx in global_to_local:
|
|
108
|
+
neighbor_local_idx = global_to_local[neighbor_global_idx]
|
|
109
|
+
if not visited[neighbor_local_idx]:
|
|
110
|
+
connection_scores[neighbor_local_idx] += neighbor_weights[cell_idx, j]
|
|
111
|
+
|
|
112
|
+
# Add reverse connections
|
|
113
|
+
for reverse_idx, reverse_weight in reverse_neighbors[cell_idx]:
|
|
114
|
+
if not visited[reverse_idx]:
|
|
115
|
+
connection_scores[reverse_idx] += reverse_weight
|
|
116
|
+
|
|
117
|
+
# Pick the unvisited cell with highest connection score
|
|
118
|
+
if connection_scores.max() > 0:
|
|
119
|
+
unvisited_scores = connection_scores.copy()
|
|
120
|
+
unvisited_scores[visited] = -1
|
|
121
|
+
next_cell = np.argmax(unvisited_scores)
|
|
122
|
+
else:
|
|
123
|
+
# No connections found, pick cell with highest max weight
|
|
124
|
+
unvisited_max_weights = max_weights.copy()
|
|
125
|
+
unvisited_max_weights[visited] = -1
|
|
126
|
+
next_cell = np.argmax(unvisited_max_weights)
|
|
127
|
+
|
|
128
|
+
ordered.append(next_cell)
|
|
129
|
+
visited[next_cell] = True
|
|
130
|
+
current = next_cell
|
|
131
|
+
|
|
132
|
+
return np.array(ordered)
|
|
133
|
+
|
|
134
|
+
elif method == 'greedy':
|
|
135
|
+
# Original greedy approach - only for small datasets
|
|
136
|
+
if n_cells > 5000:
|
|
137
|
+
logger.warning(f"Greedy method is O(n²) - not recommended for {n_cells} cells")
|
|
138
|
+
|
|
139
|
+
# Convert neighbor indices to sets of local indices for Jaccard computation
|
|
140
|
+
neighbor_sets = []
|
|
141
|
+
for i in range(n_cells):
|
|
142
|
+
local_neighbors = set()
|
|
143
|
+
for neighbor_global_idx in neighbor_indices[i]:
|
|
144
|
+
if neighbor_global_idx in global_to_local:
|
|
145
|
+
local_neighbors.add(global_to_local[neighbor_global_idx])
|
|
146
|
+
neighbor_sets.append(local_neighbors)
|
|
147
|
+
|
|
148
|
+
remaining = set(range(n_cells))
|
|
149
|
+
ordered = []
|
|
150
|
+
|
|
151
|
+
current = np.random.choice(list(remaining))
|
|
152
|
+
ordered.append(current)
|
|
153
|
+
remaining.remove(current)
|
|
154
|
+
|
|
155
|
+
while remaining:
|
|
156
|
+
max_sim = -1
|
|
157
|
+
next_row = None
|
|
158
|
+
|
|
159
|
+
# Sample candidates for large datasets
|
|
160
|
+
candidates = list(remaining)
|
|
161
|
+
if len(candidates) > 100:
|
|
162
|
+
candidates = np.random.choice(candidates, min(100, len(candidates)), replace=False)
|
|
163
|
+
|
|
164
|
+
for candidate in candidates:
|
|
165
|
+
# Use precomputed sets for Jaccard similarity
|
|
166
|
+
intersection = len(neighbor_sets[current] & neighbor_sets[candidate])
|
|
167
|
+
union = len(neighbor_sets[current] | neighbor_sets[candidate])
|
|
168
|
+
sim = intersection / union if union > 0 else 0.0
|
|
169
|
+
|
|
170
|
+
if sim > max_sim:
|
|
171
|
+
max_sim = sim
|
|
172
|
+
next_row = candidate
|
|
173
|
+
|
|
174
|
+
ordered.append(next_row)
|
|
175
|
+
remaining.remove(next_row)
|
|
176
|
+
current = next_row
|
|
177
|
+
|
|
178
|
+
return np.array(ordered)
|
|
179
|
+
|
|
180
|
+
else:
|
|
181
|
+
# Simple sequential order as fallback
|
|
182
|
+
return np.arange(n_cells)
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
"""JAX-accelerated row ordering optimization using jax.lax.scan"""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# @partial(jit, static_argnames=['k'])
|
|
13
|
+
def _optimize_weighted_scan(
|
|
14
|
+
neighbor_indices: jnp.ndarray,
|
|
15
|
+
neighbor_weights: jnp.ndarray,
|
|
16
|
+
cell_indices: jnp.ndarray,
|
|
17
|
+
k: int
|
|
18
|
+
) -> jnp.ndarray:
|
|
19
|
+
"""
|
|
20
|
+
JIT-compiled weighted row ordering using jax.lax.scan for optimal performance.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
neighbor_indices: (n_cells, k) neighbor indices (global)
|
|
24
|
+
neighbor_weights: (n_cells, k) neighbor weights
|
|
25
|
+
cell_indices: (n_cells,) global cell indices
|
|
26
|
+
k: Number of neighbors per cell
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Reordered row indices (local 0..n-1)
|
|
30
|
+
"""
|
|
31
|
+
n_cells = len(neighbor_indices)
|
|
32
|
+
|
|
33
|
+
# Pre-compute global to local mapping
|
|
34
|
+
# Handle -1 values in neighbor_indices (invalid neighbors)
|
|
35
|
+
valid_mask = neighbor_indices >= 0
|
|
36
|
+
max_global_idx = jnp.max(jnp.where(valid_mask, neighbor_indices, 0))
|
|
37
|
+
|
|
38
|
+
# Create inverse mapping from global to local indices
|
|
39
|
+
inverse_map = jnp.full(max_global_idx + 1, -1, dtype=jnp.int32)
|
|
40
|
+
inverse_map = inverse_map.at[cell_indices].set(jnp.arange(n_cells))
|
|
41
|
+
|
|
42
|
+
# Convert to local indices, preserving -1 for invalid neighbors
|
|
43
|
+
neighbor_indices_flat = neighbor_indices.ravel()
|
|
44
|
+
# For invalid indices (-1), keep them as -1
|
|
45
|
+
# For valid indices, map them through inverse_map
|
|
46
|
+
local_indices_flat = jnp.where(
|
|
47
|
+
neighbor_indices_flat >= 0,
|
|
48
|
+
inverse_map[jnp.where(neighbor_indices_flat >= 0, neighbor_indices_flat, 0)],
|
|
49
|
+
-1
|
|
50
|
+
)
|
|
51
|
+
local_neighbor_indices = local_indices_flat.reshape(n_cells, k)
|
|
52
|
+
|
|
53
|
+
# Initialize with highest weight cell
|
|
54
|
+
# Only consider valid neighbors when computing max weights
|
|
55
|
+
weights_masked = jnp.where(valid_mask, neighbor_weights, 0)
|
|
56
|
+
max_weights = jnp.max(weights_masked, axis=1)
|
|
57
|
+
start_node = jnp.argmax(max_weights)
|
|
58
|
+
|
|
59
|
+
# Initial state for scan
|
|
60
|
+
initial_state = {
|
|
61
|
+
"current": start_node,
|
|
62
|
+
"visited": jnp.zeros(n_cells, dtype=jnp.bool_).at[start_node].set(True),
|
|
63
|
+
"ordered": jnp.full(n_cells, -1, dtype=jnp.int32).at[0].set(start_node),
|
|
64
|
+
"max_weights": max_weights,
|
|
65
|
+
"local_neighbor_indices": local_neighbor_indices,
|
|
66
|
+
"neighbor_weights": neighbor_weights
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
def scan_step(state, t):
|
|
70
|
+
"""Single step of the greedy ordering algorithm"""
|
|
71
|
+
current = state["current"]
|
|
72
|
+
visited = state["visited"]
|
|
73
|
+
|
|
74
|
+
# Find best unvisited neighbor
|
|
75
|
+
neighbors = state["local_neighbor_indices"][current]
|
|
76
|
+
weights = state["neighbor_weights"][current]
|
|
77
|
+
|
|
78
|
+
# Handle -1 values: create a mask for valid neighbors
|
|
79
|
+
is_valid = neighbors != -1
|
|
80
|
+
# For invalid neighbors, use False for is_unvisited to avoid indexing with -1
|
|
81
|
+
is_unvisited = jnp.where(is_valid, ~visited[jnp.where(is_valid, neighbors, 0)], False)
|
|
82
|
+
mask = is_valid & is_unvisited
|
|
83
|
+
|
|
84
|
+
neighbor_scores = jnp.where(mask, weights, -jnp.inf)
|
|
85
|
+
best_idx = jnp.argmax(neighbor_scores)
|
|
86
|
+
|
|
87
|
+
# If neighbor found, use it; otherwise pick highest weight unvisited
|
|
88
|
+
has_neighbor = neighbor_scores[best_idx] > -jnp.inf
|
|
89
|
+
next_neighbor = neighbors[best_idx]
|
|
90
|
+
|
|
91
|
+
# Fallback to highest weight unvisited cell
|
|
92
|
+
unvisited_weights = jnp.where(visited, -jnp.inf, state["max_weights"])
|
|
93
|
+
next_maxweight = jnp.argmax(unvisited_weights)
|
|
94
|
+
|
|
95
|
+
next_cell = jnp.where(has_neighbor, next_neighbor, next_maxweight)
|
|
96
|
+
|
|
97
|
+
# Update state
|
|
98
|
+
new_state = {
|
|
99
|
+
"current": next_cell,
|
|
100
|
+
"visited": visited.at[next_cell].set(True),
|
|
101
|
+
"ordered": state["ordered"].at[t].set(next_cell),
|
|
102
|
+
"max_weights": state["max_weights"],
|
|
103
|
+
"local_neighbor_indices": state["local_neighbor_indices"],
|
|
104
|
+
"neighbor_weights": state["neighbor_weights"]
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
return new_state, None
|
|
108
|
+
|
|
109
|
+
# Run scan for n_cells-1 iterations (first cell already placed)
|
|
110
|
+
final_state, _ = jax.lax.scan(scan_step, initial_state, jnp.arange(1, n_cells))
|
|
111
|
+
|
|
112
|
+
return final_state["ordered"]
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def optimize_row_order_jax(
|
|
116
|
+
neighbor_indices: np.ndarray,
|
|
117
|
+
cell_indices: np.ndarray,
|
|
118
|
+
neighbor_weights: np.ndarray | None,
|
|
119
|
+
device: jax.Device | None = None
|
|
120
|
+
) -> np.ndarray:
|
|
121
|
+
"""
|
|
122
|
+
High-performance JAX-based row ordering for cache efficiency.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
neighbor_indices: (n_cells, k) neighbor indices (global)
|
|
126
|
+
cell_indices: (n_cells,) global cell indices
|
|
127
|
+
neighbor_weights: (n_cells, k) neighbor weights
|
|
128
|
+
device: JAX device or None (uses default JAX device)
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
Reordered row indices (local 0..n-1) as NumPy array
|
|
132
|
+
"""
|
|
133
|
+
# Use default device if not provided (should be already configured by configure_jax_platform)
|
|
134
|
+
if device is None:
|
|
135
|
+
device = jax.devices()[0]
|
|
136
|
+
|
|
137
|
+
# Skip if on CPU - scanning is extremely slow on CPU
|
|
138
|
+
if device.platform == 'cpu':
|
|
139
|
+
n_cells = len(neighbor_indices)
|
|
140
|
+
logger.info(f"Skipping JAX-based row ordering optimization on CPU for {n_cells} cells (too slow).")
|
|
141
|
+
return np.arange(n_cells)
|
|
142
|
+
|
|
143
|
+
n_cells, k = neighbor_indices.shape
|
|
144
|
+
logger.debug(f"Running JAX scan-based weighted ordering for {n_cells} cells on {device.platform}")
|
|
145
|
+
|
|
146
|
+
with jax.default_device(device):
|
|
147
|
+
neighbor_indices_jax = jnp.asarray(neighbor_indices)
|
|
148
|
+
neighbor_weights_jax = jnp.asarray(neighbor_weights)
|
|
149
|
+
cell_indices_jax = jnp.asarray(cell_indices)
|
|
150
|
+
|
|
151
|
+
# Run optimized weighted ordering
|
|
152
|
+
ordered_jax = _optimize_weighted_scan(
|
|
153
|
+
neighbor_indices_jax,
|
|
154
|
+
neighbor_weights_jax,
|
|
155
|
+
cell_indices_jax,
|
|
156
|
+
k
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
return np.array(ordered_jax)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from gsMap.config import LDScoreConfig
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Simplified batch construction for LD score calculation.
|
|
3
|
+
|
|
4
|
+
This module handles:
|
|
5
|
+
1. Splitting HM3 SNPs into fixed-size batches
|
|
6
|
+
2. Calculating reference block boundaries for each batch based on LD window
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import pandas as pd
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class BatchInfo:
|
|
20
|
+
"""
|
|
21
|
+
Information for a single HM3 batch.
|
|
22
|
+
|
|
23
|
+
Attributes
|
|
24
|
+
----------
|
|
25
|
+
chromosome : str
|
|
26
|
+
Chromosome identifier
|
|
27
|
+
hm3_indices : np.ndarray
|
|
28
|
+
Indices of HM3 SNPs in this batch (relative to chromosome BIM)
|
|
29
|
+
ref_start_idx : int
|
|
30
|
+
Start index of reference block in chromosome (inclusive)
|
|
31
|
+
ref_end_idx : int
|
|
32
|
+
End index of reference block in chromosome (exclusive)
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
chromosome: str
|
|
36
|
+
hm3_indices: np.ndarray
|
|
37
|
+
ref_start_idx: int
|
|
38
|
+
ref_end_idx: int
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def construct_batches(
|
|
42
|
+
bim_df: pd.DataFrame,
|
|
43
|
+
hm3_snp_names: list[str],
|
|
44
|
+
batch_size_hm3: int,
|
|
45
|
+
ld_wind: float = 1.0,
|
|
46
|
+
ld_unit: str = "CM",
|
|
47
|
+
) -> list[BatchInfo]:
|
|
48
|
+
"""
|
|
49
|
+
Construct batches of HM3 SNPs with reference block boundaries.
|
|
50
|
+
|
|
51
|
+
Optimized implementation using searchsorted for O(log N) boundary finding.
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
bim_df : pd.DataFrame
|
|
56
|
+
BIM dataframe for this chromosome
|
|
57
|
+
hm3_snp_names : List[str]
|
|
58
|
+
List of HM3 SNP names for this chromosome
|
|
59
|
+
batch_size_hm3 : int
|
|
60
|
+
Number of HM3 SNPs per batch
|
|
61
|
+
ld_wind : float
|
|
62
|
+
LD window size (default: 1.0)
|
|
63
|
+
ld_unit : str
|
|
64
|
+
Unit for LD window: 'SNP', 'KB', 'CM' (default: 'CM')
|
|
65
|
+
|
|
66
|
+
Returns
|
|
67
|
+
-------
|
|
68
|
+
List[BatchInfo]
|
|
69
|
+
List of BatchInfo objects
|
|
70
|
+
"""
|
|
71
|
+
chromosome = str(bim_df["CHR"].iloc[0])
|
|
72
|
+
|
|
73
|
+
# Pre-extract arrays for performance
|
|
74
|
+
bim_df["SNP"].values
|
|
75
|
+
|
|
76
|
+
# helper to get coordinates based on unit
|
|
77
|
+
if ld_unit == "SNP":
|
|
78
|
+
coords = np.arange(len(bim_df))
|
|
79
|
+
max_dist = int(ld_wind)
|
|
80
|
+
elif ld_unit == "KB":
|
|
81
|
+
coords = bim_df["BP"].values
|
|
82
|
+
max_dist = ld_wind * 1000
|
|
83
|
+
elif ld_unit == "CM":
|
|
84
|
+
coords = bim_df["CM"].values
|
|
85
|
+
max_dist = ld_wind
|
|
86
|
+
# Fallback if CM is all zero
|
|
87
|
+
if np.all(coords == 0):
|
|
88
|
+
logger.warning(f"All CM values are 0 for chromosome {chromosome}. Fallback to 1MB window (BP).")
|
|
89
|
+
coords = bim_df["BP"].values
|
|
90
|
+
max_dist = 1_000_000
|
|
91
|
+
else:
|
|
92
|
+
raise ValueError(f"Invalid ld_unit: {ld_unit}. Must be 'SNP', 'KB', or 'CM'.")
|
|
93
|
+
|
|
94
|
+
# Find indices of HM3 SNPs in BIM
|
|
95
|
+
hm3_set = set(hm3_snp_names)
|
|
96
|
+
# np.isin on strings is very slow, use pandas isin instead
|
|
97
|
+
hm3_mask = bim_df["SNP"].isin(hm3_set).values
|
|
98
|
+
hm3_indices_all = np.where(hm3_mask)[0]
|
|
99
|
+
|
|
100
|
+
n_hm3 = len(hm3_indices_all)
|
|
101
|
+
if n_hm3 == 0:
|
|
102
|
+
logger.warning(f"No HM3 SNPs found in chromosome {chromosome}")
|
|
103
|
+
return []
|
|
104
|
+
if n_hm3 < len(hm3_snp_names):
|
|
105
|
+
logger.warning(f"{len(hm3_snp_names) - n_hm3} HM3 SNPs not found in chromosome {chromosome} reference plink panel")
|
|
106
|
+
|
|
107
|
+
logger.info(f"Found {n_hm3} HM3 SNPs in chromosome {chromosome}")
|
|
108
|
+
|
|
109
|
+
# Calculate batch boundaries
|
|
110
|
+
# We want batches of size batch_size_hm3
|
|
111
|
+
# Start indices: 0, batch_size, 2*batch_size, ...
|
|
112
|
+
batch_starts = np.arange(0, n_hm3, batch_size_hm3)
|
|
113
|
+
# End indices: batch_size, 2*batch_size, ..., n_hm3
|
|
114
|
+
batch_ends = np.minimum(batch_starts + batch_size_hm3, n_hm3)
|
|
115
|
+
|
|
116
|
+
n_batches = len(batch_starts)
|
|
117
|
+
|
|
118
|
+
# Get indices in BIM for start and end of each batch
|
|
119
|
+
# hm3_indices_all contains the BIM indices of HM3 SNPs
|
|
120
|
+
# We need the BIM index of the first and last HM3 SNP in each batch
|
|
121
|
+
|
|
122
|
+
# First HM3 SNP in each batch
|
|
123
|
+
batch_start_bim_indices = hm3_indices_all[batch_starts]
|
|
124
|
+
# Last HM3 SNP in each batch (indices are exclusive in slice, so -1 for element access)
|
|
125
|
+
batch_end_bim_indices = hm3_indices_all[batch_ends - 1]
|
|
126
|
+
|
|
127
|
+
# Get coordinates for these batch boundaries
|
|
128
|
+
min_coords = coords[batch_start_bim_indices]
|
|
129
|
+
max_coords = coords[batch_end_bim_indices]
|
|
130
|
+
|
|
131
|
+
# Calculate window boundaries
|
|
132
|
+
window_starts = min_coords - max_dist
|
|
133
|
+
window_ends = max_coords + max_dist
|
|
134
|
+
|
|
135
|
+
if ld_unit == "SNP":
|
|
136
|
+
# For SNP unit, coordinates are indices, so we just clip
|
|
137
|
+
ref_starts = np.maximum(window_starts, 0).astype(int)
|
|
138
|
+
ref_ends = np.minimum(window_ends, len(coords)).astype(int)
|
|
139
|
+
else:
|
|
140
|
+
# Vectorized searchsorted
|
|
141
|
+
# Find insertion points for all window starts and ends at once
|
|
142
|
+
ref_starts = np.searchsorted(coords, window_starts, side='left')
|
|
143
|
+
ref_ends = np.searchsorted(coords, window_ends, side='right')
|
|
144
|
+
|
|
145
|
+
# Construct BatchInfo objects
|
|
146
|
+
batch_infos = []
|
|
147
|
+
for i in range(n_batches):
|
|
148
|
+
# Slice the HM3 indices for this batch
|
|
149
|
+
b_hm3_indices = hm3_indices_all[batch_starts[i]:batch_ends[i]]
|
|
150
|
+
|
|
151
|
+
batch_infos.append(BatchInfo(
|
|
152
|
+
chromosome=chromosome,
|
|
153
|
+
hm3_indices=b_hm3_indices,
|
|
154
|
+
ref_start_idx=int(ref_starts[i]),
|
|
155
|
+
ref_end_idx=int(ref_ends[i]),
|
|
156
|
+
))
|
|
157
|
+
|
|
158
|
+
logger.info(f"Created {len(batch_infos)} batches for chromosome {chromosome}")
|
|
159
|
+
if len(batch_infos) > 0:
|
|
160
|
+
avg_size = np.mean(ref_ends - ref_starts)
|
|
161
|
+
logger.info(f" Average reference block size: {avg_size:.0f} SNPs")
|
|
162
|
+
|
|
163
|
+
return batch_infos
|
gsMap/ldscore/compute.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Simplified LD score computation without masking or padding.
|
|
3
|
+
|
|
4
|
+
Direct computation of unbiased L2 statistics from genotype matrices using NumPy and Scipy.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import scipy.sparse
|
|
10
|
+
|
|
11
|
+
from .constants import LDSC_BIAS_CORRECTION_DF
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def compute_unbiased_l2_batch(
|
|
15
|
+
X_hm3: np.ndarray,
|
|
16
|
+
X_ref_block: np.ndarray,
|
|
17
|
+
) -> np.ndarray:
|
|
18
|
+
"""
|
|
19
|
+
Compute unbiased LD scores (L2) for HM3 SNPs against reference block.
|
|
20
|
+
|
|
21
|
+
The unbiased L2 estimator is:
|
|
22
|
+
L2 = r^2 - (1 - r^2) / (N - 2)
|
|
23
|
+
|
|
24
|
+
where r^2 is the squared correlation and N is the number of individuals.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
X_hm3 : np.ndarray
|
|
29
|
+
Standardized genotypes for HM3 SNPs, shape (n_individuals, n_hm3_snps)
|
|
30
|
+
X_ref_block : np.ndarray
|
|
31
|
+
Standardized genotypes for reference block, shape (n_individuals, n_ref_snps)
|
|
32
|
+
|
|
33
|
+
Returns
|
|
34
|
+
-------
|
|
35
|
+
np.ndarray
|
|
36
|
+
Unbiased LD scores, shape (n_hm3_snps, n_ref_snps)
|
|
37
|
+
"""
|
|
38
|
+
n_individuals = X_hm3.shape[0]
|
|
39
|
+
|
|
40
|
+
# Compute correlation matrix: r = (1/N) * X_hm3^T @ X_ref_block
|
|
41
|
+
# shape: (n_hm3_snps, n_ref_snps)
|
|
42
|
+
r = np.dot(X_hm3.T, X_ref_block) / n_individuals
|
|
43
|
+
|
|
44
|
+
# Compute r^2
|
|
45
|
+
r_squared = r ** 2
|
|
46
|
+
|
|
47
|
+
# Apply bias correction
|
|
48
|
+
bias_correction = (1.0 - r_squared) / (n_individuals - LDSC_BIAS_CORRECTION_DF)
|
|
49
|
+
l2_unbiased = r_squared - bias_correction
|
|
50
|
+
|
|
51
|
+
return l2_unbiased
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def compute_ld_scores(
|
|
55
|
+
X_hm3: np.ndarray,
|
|
56
|
+
X_ref_block: np.ndarray,
|
|
57
|
+
) -> np.ndarray:
|
|
58
|
+
"""
|
|
59
|
+
Compute LD scores by summing unbiased L2 over reference SNPs.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
X_hm3 : np.ndarray
|
|
64
|
+
Standardized genotypes for HM3 SNPs, shape (n_individuals, n_hm3_snps)
|
|
65
|
+
X_ref_block : np.ndarray
|
|
66
|
+
Standardized genotypes for reference block, shape (n_individuals, n_ref_snps)
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
np.ndarray
|
|
71
|
+
LD scores for each HM3 SNP, shape (n_hm3_snps,)
|
|
72
|
+
"""
|
|
73
|
+
# Compute unbiased L2 matrix
|
|
74
|
+
l2_unbiased = compute_unbiased_l2_batch(X_hm3, X_ref_block)
|
|
75
|
+
|
|
76
|
+
# Sum over reference SNPs (axis=1)
|
|
77
|
+
ld_scores = np.sum(l2_unbiased, axis=1)
|
|
78
|
+
|
|
79
|
+
return ld_scores
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def compute_batch_weights_sparse(
|
|
83
|
+
X_hm3: np.ndarray,
|
|
84
|
+
X_ref_block: np.ndarray,
|
|
85
|
+
block_mapping_matrix: scipy.sparse.csr_matrix | np.ndarray,
|
|
86
|
+
) -> np.ndarray:
|
|
87
|
+
"""
|
|
88
|
+
Compute LD score weight matrix using matrix multiplication.
|
|
89
|
+
|
|
90
|
+
Works for both sparse and dense mapping matrices.
|
|
91
|
+
Weights = L2_Unbiased @ Mapping_Matrix
|
|
92
|
+
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
X_hm3 : np.ndarray
|
|
96
|
+
Standardized genotypes for HM3 SNPs, shape (n_individuals, n_hm3_snps)
|
|
97
|
+
X_ref_block : np.ndarray
|
|
98
|
+
Standardized genotypes for reference block, shape (n_individuals, n_ref_snps)
|
|
99
|
+
block_mapping_matrix : Union[scipy.sparse.csr_matrix, np.ndarray]
|
|
100
|
+
Mapping matrix for the reference block, shape (n_ref_snps, n_features).
|
|
101
|
+
Can be a sparse CSR matrix (from creating_snp_feature_map) or a dense array (from annotations).
|
|
102
|
+
|
|
103
|
+
Returns
|
|
104
|
+
-------
|
|
105
|
+
weights : np.ndarray
|
|
106
|
+
Weight matrix, shape (n_hm3_snps, n_features)
|
|
107
|
+
"""
|
|
108
|
+
# 1. Compute unbiased L2 matrix: (n_hm3_snps, n_ref_snps)
|
|
109
|
+
l2_unbiased = compute_unbiased_l2_batch(X_hm3, X_ref_block)
|
|
110
|
+
|
|
111
|
+
# 2. Compute Weights: W = L2 @ M
|
|
112
|
+
# (n_hm3, n_ref) @ (n_ref, n_features) -> (n_hm3, n_features)
|
|
113
|
+
# numpy dot handles dense @ dense
|
|
114
|
+
# scipy.sparse handles dense @ sparse
|
|
115
|
+
if scipy.sparse.issparse(block_mapping_matrix):
|
|
116
|
+
weights = l2_unbiased @ block_mapping_matrix
|
|
117
|
+
else:
|
|
118
|
+
weights = np.dot(l2_unbiased, block_mapping_matrix)
|
|
119
|
+
|
|
120
|
+
# Ensure output is a dense numpy array
|
|
121
|
+
if scipy.sparse.issparse(weights):
|
|
122
|
+
weights = weights.toarray()
|
|
123
|
+
elif isinstance(weights, np.matrix):
|
|
124
|
+
weights = np.asarray(weights)
|
|
125
|
+
|
|
126
|
+
return weights
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Constants and magic numbers used in LD score regression and spatial LDSC.
|
|
3
|
+
|
|
4
|
+
This module centralizes all numerical constants to improve code readability
|
|
5
|
+
and maintainability.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
# === LD Score Calculation Constants ===
|
|
9
|
+
|
|
10
|
+
# Degrees of freedom for unbiased L2 estimator bias correction
|
|
11
|
+
# The unbiased estimator is: r^2 - (1 - r^2) / (N - LDSC_BIAS_CORRECTION_DF)
|
|
12
|
+
LDSC_BIAS_CORRECTION_DF = 2.0
|
|
13
|
+
|
|
14
|
+
# Default LD window sizes
|
|
15
|
+
DEFAULT_LD_WINDOW_CM = 1.0 # centiMorgans
|
|
16
|
+
DEFAULT_LD_WINDOW_KB = 1000 # kilobases
|
|
17
|
+
DEFAULT_LD_WINDOW_MB = 1.0 # megabases (in base pairs)
|
|
18
|
+
|
|
19
|
+
# === Genomic Control Constants ===
|
|
20
|
+
|
|
21
|
+
# Median of chi-squared distribution with 1 degree of freedom
|
|
22
|
+
# Used for calculating genomic control lambda (λGC):
|
|
23
|
+
# λGC = median(χ²) / CHI_SQUARED_1DF_MEDIAN
|
|
24
|
+
# where χ² = Z^2 from GWAS summary statistics
|
|
25
|
+
CHI_SQUARED_1DF_MEDIAN = 0.4559364
|
|
26
|
+
|
|
27
|
+
# === Statistical Thresholds ===
|
|
28
|
+
|
|
29
|
+
# Standard genome-wide significance level (GWAS)
|
|
30
|
+
GWAS_SIGNIFICANCE_ALPHA = 5e-8
|
|
31
|
+
|
|
32
|
+
# Standard nominal significance level
|
|
33
|
+
NOMINAL_SIGNIFICANCE_ALPHA = 0.05
|
|
34
|
+
|
|
35
|
+
# FDR significance threshold (commonly used for spatial analysis)
|
|
36
|
+
FDR_SIGNIFICANCE_ALPHA = 0.001
|
|
37
|
+
|
|
38
|
+
# === Numerical Stability Constants ===
|
|
39
|
+
|
|
40
|
+
# Minimum p-value for log transformation (prevents log(0) = -inf)
|
|
41
|
+
# log10(1e-300) ≈ -300, which is reasonable for visualization
|
|
42
|
+
MIN_P_VALUE = 1e-300
|
|
43
|
+
|
|
44
|
+
# Maximum p-value (ceiling for numerical stability)
|
|
45
|
+
MAX_P_VALUE = 1.0
|
|
46
|
+
|
|
47
|
+
# Minimum MAF (minor allele frequency) for SNP filtering
|
|
48
|
+
DEFAULT_MAF_THRESHOLD = 0.05
|
|
49
|
+
|
|
50
|
+
# === Storage and Precision Constants ===
|
|
51
|
+
|
|
52
|
+
# Default dtype for LD score storage (saves ~75% disk space vs float32)
|
|
53
|
+
LDSCORE_STORAGE_DTYPE = "float16"
|
|
54
|
+
|
|
55
|
+
# Default dtype for computation
|
|
56
|
+
LDSCORE_COMPUTE_DTYPE = "float32"
|
|
57
|
+
|
|
58
|
+
# === Regression Constants ===
|
|
59
|
+
|
|
60
|
+
# Default number of jackknife blocks for standard error estimation
|
|
61
|
+
DEFAULT_N_BLOCKS = 200
|
|
62
|
+
|
|
63
|
+
# Minimum number of SNPs required for reliable LD score estimation
|
|
64
|
+
MIN_SNPS_FOR_LDSC = 200
|
|
65
|
+
|
|
66
|
+
# === Window Quantization Constants ===
|
|
67
|
+
|
|
68
|
+
# Default number of bins for dynamic programming quantization
|
|
69
|
+
# (balances JIT compilation overhead vs padding waste)
|
|
70
|
+
DEFAULT_QUANTIZATION_BINS = 20
|