gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,182 @@
1
+ """
2
+ Row ordering optimization for cache efficiency
3
+ """
4
+
5
+ import logging
6
+
7
+ import numpy as np
8
+ from numba import njit
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ @njit
14
+ def compute_jaccard_similarity(neighbors_a: np.ndarray, neighbors_b: np.ndarray) -> float:
15
+ """Compute Jaccard similarity between two neighbor sets"""
16
+ set_a = set(neighbors_a)
17
+ set_b = set(neighbors_b)
18
+ intersection = len(set_a & set_b)
19
+ union = len(set_a | set_b)
20
+ return intersection / union if union > 0 else 0.0
21
+
22
+
23
+ def optimize_row_order(
24
+ neighbor_indices: np.ndarray,
25
+ cell_indices: np.ndarray,
26
+ method: str | None = None,
27
+ neighbor_weights: np.ndarray | None = None
28
+ ) -> np.ndarray:
29
+ """
30
+ Sort rows by shared neighbors to improve cache locality
31
+
32
+ Args:
33
+ neighbor_indices: (n_cells, k) array of neighbor indices (global indices)
34
+ cell_indices: (n_cells,) array of global indices for each cell
35
+ method: None (auto), 'weighted', 'greedy', or 'none'
36
+ neighbor_weights: Optional (n_cells, k) array of weights for each neighbor
37
+
38
+ Returns:
39
+ Reordered row indices (local indices 0 to n_cells-1)
40
+
41
+ Complexity:
42
+ - weighted: O(n*k) where k is number of neighbors - very efficient!
43
+ - greedy: O(n²) - only for very small datasets
44
+ - none: O(1) - returns original order
45
+ """
46
+ n_cells = len(neighbor_indices)
47
+
48
+ # Create mapping from global to local indices
49
+ global_to_local = {global_idx: local_idx for local_idx, global_idx in enumerate(cell_indices)}
50
+
51
+ # Auto-select method if None
52
+ if method is None:
53
+ if neighbor_weights is not None and n_cells > 2000:
54
+ method = 'weighted'
55
+ else:
56
+ method = 'greedy'
57
+
58
+ if method == 'weighted' and neighbor_weights is not None:
59
+ # Efficient weighted heuristic: follow highest-weight neighbors
60
+ visited = np.zeros(n_cells, dtype=bool)
61
+ ordered = []
62
+
63
+ # Start with cell that has highest max weight to any single neighbor
64
+ max_weights = neighbor_weights.max(axis=1)
65
+ current = np.argmax(max_weights)
66
+
67
+ ordered.append(current)
68
+ visited[current] = True
69
+
70
+ # Build reverse lookup (using local indices)
71
+ reverse_neighbors = [[] for _ in range(n_cells)]
72
+ for i in range(n_cells):
73
+ for j, neighbor_global_idx in enumerate(neighbor_indices[i]):
74
+ if neighbor_global_idx in global_to_local:
75
+ neighbor_local_idx = global_to_local[neighbor_global_idx]
76
+ reverse_neighbors[neighbor_local_idx].append((i, neighbor_weights[i, j]))
77
+
78
+ # Process all cells
79
+ for _ in range(n_cells - 1):
80
+ neighbors = neighbor_indices[current]
81
+ weights = neighbor_weights[current]
82
+
83
+ # Sort neighbors by weight (highest first)
84
+ sorted_idx = np.argsort(weights)[::-1]
85
+
86
+ # Find the unvisited neighbor with highest weight
87
+ next_cell = None
88
+ best_weight = -1
89
+
90
+ for idx in sorted_idx:
91
+ neighbor_global_idx = neighbors[idx]
92
+ if neighbor_global_idx in global_to_local:
93
+ neighbor_local_idx = global_to_local[neighbor_global_idx]
94
+ if not visited[neighbor_local_idx]:
95
+ if weights[idx] > best_weight:
96
+ best_weight = weights[idx]
97
+ next_cell = neighbor_local_idx
98
+
99
+ # If no unvisited direct neighbors, find closest unvisited cell
100
+ if next_cell is None:
101
+ connection_scores = np.zeros(n_cells)
102
+
103
+ # Check connections from last few visited cells
104
+ for cell_idx in ordered[-min(10, len(ordered)):]:
105
+ # Add forward connections
106
+ for j, neighbor_global_idx in enumerate(neighbor_indices[cell_idx]):
107
+ if neighbor_global_idx in global_to_local:
108
+ neighbor_local_idx = global_to_local[neighbor_global_idx]
109
+ if not visited[neighbor_local_idx]:
110
+ connection_scores[neighbor_local_idx] += neighbor_weights[cell_idx, j]
111
+
112
+ # Add reverse connections
113
+ for reverse_idx, reverse_weight in reverse_neighbors[cell_idx]:
114
+ if not visited[reverse_idx]:
115
+ connection_scores[reverse_idx] += reverse_weight
116
+
117
+ # Pick the unvisited cell with highest connection score
118
+ if connection_scores.max() > 0:
119
+ unvisited_scores = connection_scores.copy()
120
+ unvisited_scores[visited] = -1
121
+ next_cell = np.argmax(unvisited_scores)
122
+ else:
123
+ # No connections found, pick cell with highest max weight
124
+ unvisited_max_weights = max_weights.copy()
125
+ unvisited_max_weights[visited] = -1
126
+ next_cell = np.argmax(unvisited_max_weights)
127
+
128
+ ordered.append(next_cell)
129
+ visited[next_cell] = True
130
+ current = next_cell
131
+
132
+ return np.array(ordered)
133
+
134
+ elif method == 'greedy':
135
+ # Original greedy approach - only for small datasets
136
+ if n_cells > 5000:
137
+ logger.warning(f"Greedy method is O(n²) - not recommended for {n_cells} cells")
138
+
139
+ # Convert neighbor indices to sets of local indices for Jaccard computation
140
+ neighbor_sets = []
141
+ for i in range(n_cells):
142
+ local_neighbors = set()
143
+ for neighbor_global_idx in neighbor_indices[i]:
144
+ if neighbor_global_idx in global_to_local:
145
+ local_neighbors.add(global_to_local[neighbor_global_idx])
146
+ neighbor_sets.append(local_neighbors)
147
+
148
+ remaining = set(range(n_cells))
149
+ ordered = []
150
+
151
+ current = np.random.choice(list(remaining))
152
+ ordered.append(current)
153
+ remaining.remove(current)
154
+
155
+ while remaining:
156
+ max_sim = -1
157
+ next_row = None
158
+
159
+ # Sample candidates for large datasets
160
+ candidates = list(remaining)
161
+ if len(candidates) > 100:
162
+ candidates = np.random.choice(candidates, min(100, len(candidates)), replace=False)
163
+
164
+ for candidate in candidates:
165
+ # Use precomputed sets for Jaccard similarity
166
+ intersection = len(neighbor_sets[current] & neighbor_sets[candidate])
167
+ union = len(neighbor_sets[current] | neighbor_sets[candidate])
168
+ sim = intersection / union if union > 0 else 0.0
169
+
170
+ if sim > max_sim:
171
+ max_sim = sim
172
+ next_row = candidate
173
+
174
+ ordered.append(next_row)
175
+ remaining.remove(next_row)
176
+ current = next_row
177
+
178
+ return np.array(ordered)
179
+
180
+ else:
181
+ # Simple sequential order as fallback
182
+ return np.arange(n_cells)
@@ -0,0 +1,159 @@
1
+ """JAX-accelerated row ordering optimization using jax.lax.scan"""
2
+
3
+ import logging
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import numpy as np
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ # @partial(jit, static_argnames=['k'])
13
+ def _optimize_weighted_scan(
14
+ neighbor_indices: jnp.ndarray,
15
+ neighbor_weights: jnp.ndarray,
16
+ cell_indices: jnp.ndarray,
17
+ k: int
18
+ ) -> jnp.ndarray:
19
+ """
20
+ JIT-compiled weighted row ordering using jax.lax.scan for optimal performance.
21
+
22
+ Args:
23
+ neighbor_indices: (n_cells, k) neighbor indices (global)
24
+ neighbor_weights: (n_cells, k) neighbor weights
25
+ cell_indices: (n_cells,) global cell indices
26
+ k: Number of neighbors per cell
27
+
28
+ Returns:
29
+ Reordered row indices (local 0..n-1)
30
+ """
31
+ n_cells = len(neighbor_indices)
32
+
33
+ # Pre-compute global to local mapping
34
+ # Handle -1 values in neighbor_indices (invalid neighbors)
35
+ valid_mask = neighbor_indices >= 0
36
+ max_global_idx = jnp.max(jnp.where(valid_mask, neighbor_indices, 0))
37
+
38
+ # Create inverse mapping from global to local indices
39
+ inverse_map = jnp.full(max_global_idx + 1, -1, dtype=jnp.int32)
40
+ inverse_map = inverse_map.at[cell_indices].set(jnp.arange(n_cells))
41
+
42
+ # Convert to local indices, preserving -1 for invalid neighbors
43
+ neighbor_indices_flat = neighbor_indices.ravel()
44
+ # For invalid indices (-1), keep them as -1
45
+ # For valid indices, map them through inverse_map
46
+ local_indices_flat = jnp.where(
47
+ neighbor_indices_flat >= 0,
48
+ inverse_map[jnp.where(neighbor_indices_flat >= 0, neighbor_indices_flat, 0)],
49
+ -1
50
+ )
51
+ local_neighbor_indices = local_indices_flat.reshape(n_cells, k)
52
+
53
+ # Initialize with highest weight cell
54
+ # Only consider valid neighbors when computing max weights
55
+ weights_masked = jnp.where(valid_mask, neighbor_weights, 0)
56
+ max_weights = jnp.max(weights_masked, axis=1)
57
+ start_node = jnp.argmax(max_weights)
58
+
59
+ # Initial state for scan
60
+ initial_state = {
61
+ "current": start_node,
62
+ "visited": jnp.zeros(n_cells, dtype=jnp.bool_).at[start_node].set(True),
63
+ "ordered": jnp.full(n_cells, -1, dtype=jnp.int32).at[0].set(start_node),
64
+ "max_weights": max_weights,
65
+ "local_neighbor_indices": local_neighbor_indices,
66
+ "neighbor_weights": neighbor_weights
67
+ }
68
+
69
+ def scan_step(state, t):
70
+ """Single step of the greedy ordering algorithm"""
71
+ current = state["current"]
72
+ visited = state["visited"]
73
+
74
+ # Find best unvisited neighbor
75
+ neighbors = state["local_neighbor_indices"][current]
76
+ weights = state["neighbor_weights"][current]
77
+
78
+ # Handle -1 values: create a mask for valid neighbors
79
+ is_valid = neighbors != -1
80
+ # For invalid neighbors, use False for is_unvisited to avoid indexing with -1
81
+ is_unvisited = jnp.where(is_valid, ~visited[jnp.where(is_valid, neighbors, 0)], False)
82
+ mask = is_valid & is_unvisited
83
+
84
+ neighbor_scores = jnp.where(mask, weights, -jnp.inf)
85
+ best_idx = jnp.argmax(neighbor_scores)
86
+
87
+ # If neighbor found, use it; otherwise pick highest weight unvisited
88
+ has_neighbor = neighbor_scores[best_idx] > -jnp.inf
89
+ next_neighbor = neighbors[best_idx]
90
+
91
+ # Fallback to highest weight unvisited cell
92
+ unvisited_weights = jnp.where(visited, -jnp.inf, state["max_weights"])
93
+ next_maxweight = jnp.argmax(unvisited_weights)
94
+
95
+ next_cell = jnp.where(has_neighbor, next_neighbor, next_maxweight)
96
+
97
+ # Update state
98
+ new_state = {
99
+ "current": next_cell,
100
+ "visited": visited.at[next_cell].set(True),
101
+ "ordered": state["ordered"].at[t].set(next_cell),
102
+ "max_weights": state["max_weights"],
103
+ "local_neighbor_indices": state["local_neighbor_indices"],
104
+ "neighbor_weights": state["neighbor_weights"]
105
+ }
106
+
107
+ return new_state, None
108
+
109
+ # Run scan for n_cells-1 iterations (first cell already placed)
110
+ final_state, _ = jax.lax.scan(scan_step, initial_state, jnp.arange(1, n_cells))
111
+
112
+ return final_state["ordered"]
113
+
114
+
115
+ def optimize_row_order_jax(
116
+ neighbor_indices: np.ndarray,
117
+ cell_indices: np.ndarray,
118
+ neighbor_weights: np.ndarray | None,
119
+ device: jax.Device | None = None
120
+ ) -> np.ndarray:
121
+ """
122
+ High-performance JAX-based row ordering for cache efficiency.
123
+
124
+ Args:
125
+ neighbor_indices: (n_cells, k) neighbor indices (global)
126
+ cell_indices: (n_cells,) global cell indices
127
+ neighbor_weights: (n_cells, k) neighbor weights
128
+ device: JAX device or None (uses default JAX device)
129
+
130
+ Returns:
131
+ Reordered row indices (local 0..n-1) as NumPy array
132
+ """
133
+ # Use default device if not provided (should be already configured by configure_jax_platform)
134
+ if device is None:
135
+ device = jax.devices()[0]
136
+
137
+ # Skip if on CPU - scanning is extremely slow on CPU
138
+ if device.platform == 'cpu':
139
+ n_cells = len(neighbor_indices)
140
+ logger.info(f"Skipping JAX-based row ordering optimization on CPU for {n_cells} cells (too slow).")
141
+ return np.arange(n_cells)
142
+
143
+ n_cells, k = neighbor_indices.shape
144
+ logger.debug(f"Running JAX scan-based weighted ordering for {n_cells} cells on {device.platform}")
145
+
146
+ with jax.default_device(device):
147
+ neighbor_indices_jax = jnp.asarray(neighbor_indices)
148
+ neighbor_weights_jax = jnp.asarray(neighbor_weights)
149
+ cell_indices_jax = jnp.asarray(cell_indices)
150
+
151
+ # Run optimized weighted ordering
152
+ ordered_jax = _optimize_weighted_scan(
153
+ neighbor_indices_jax,
154
+ neighbor_weights_jax,
155
+ cell_indices_jax,
156
+ k
157
+ )
158
+
159
+ return np.array(ordered_jax)
@@ -0,0 +1 @@
1
+ from gsMap.config import LDScoreConfig
@@ -0,0 +1,163 @@
1
+ """
2
+ Simplified batch construction for LD score calculation.
3
+
4
+ This module handles:
5
+ 1. Splitting HM3 SNPs into fixed-size batches
6
+ 2. Calculating reference block boundaries for each batch based on LD window
7
+ """
8
+
9
+ import logging
10
+ from dataclasses import dataclass
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @dataclass
19
+ class BatchInfo:
20
+ """
21
+ Information for a single HM3 batch.
22
+
23
+ Attributes
24
+ ----------
25
+ chromosome : str
26
+ Chromosome identifier
27
+ hm3_indices : np.ndarray
28
+ Indices of HM3 SNPs in this batch (relative to chromosome BIM)
29
+ ref_start_idx : int
30
+ Start index of reference block in chromosome (inclusive)
31
+ ref_end_idx : int
32
+ End index of reference block in chromosome (exclusive)
33
+ """
34
+
35
+ chromosome: str
36
+ hm3_indices: np.ndarray
37
+ ref_start_idx: int
38
+ ref_end_idx: int
39
+
40
+
41
+ def construct_batches(
42
+ bim_df: pd.DataFrame,
43
+ hm3_snp_names: list[str],
44
+ batch_size_hm3: int,
45
+ ld_wind: float = 1.0,
46
+ ld_unit: str = "CM",
47
+ ) -> list[BatchInfo]:
48
+ """
49
+ Construct batches of HM3 SNPs with reference block boundaries.
50
+
51
+ Optimized implementation using searchsorted for O(log N) boundary finding.
52
+
53
+ Parameters
54
+ ----------
55
+ bim_df : pd.DataFrame
56
+ BIM dataframe for this chromosome
57
+ hm3_snp_names : List[str]
58
+ List of HM3 SNP names for this chromosome
59
+ batch_size_hm3 : int
60
+ Number of HM3 SNPs per batch
61
+ ld_wind : float
62
+ LD window size (default: 1.0)
63
+ ld_unit : str
64
+ Unit for LD window: 'SNP', 'KB', 'CM' (default: 'CM')
65
+
66
+ Returns
67
+ -------
68
+ List[BatchInfo]
69
+ List of BatchInfo objects
70
+ """
71
+ chromosome = str(bim_df["CHR"].iloc[0])
72
+
73
+ # Pre-extract arrays for performance
74
+ bim_df["SNP"].values
75
+
76
+ # helper to get coordinates based on unit
77
+ if ld_unit == "SNP":
78
+ coords = np.arange(len(bim_df))
79
+ max_dist = int(ld_wind)
80
+ elif ld_unit == "KB":
81
+ coords = bim_df["BP"].values
82
+ max_dist = ld_wind * 1000
83
+ elif ld_unit == "CM":
84
+ coords = bim_df["CM"].values
85
+ max_dist = ld_wind
86
+ # Fallback if CM is all zero
87
+ if np.all(coords == 0):
88
+ logger.warning(f"All CM values are 0 for chromosome {chromosome}. Fallback to 1MB window (BP).")
89
+ coords = bim_df["BP"].values
90
+ max_dist = 1_000_000
91
+ else:
92
+ raise ValueError(f"Invalid ld_unit: {ld_unit}. Must be 'SNP', 'KB', or 'CM'.")
93
+
94
+ # Find indices of HM3 SNPs in BIM
95
+ hm3_set = set(hm3_snp_names)
96
+ # np.isin on strings is very slow, use pandas isin instead
97
+ hm3_mask = bim_df["SNP"].isin(hm3_set).values
98
+ hm3_indices_all = np.where(hm3_mask)[0]
99
+
100
+ n_hm3 = len(hm3_indices_all)
101
+ if n_hm3 == 0:
102
+ logger.warning(f"No HM3 SNPs found in chromosome {chromosome}")
103
+ return []
104
+ if n_hm3 < len(hm3_snp_names):
105
+ logger.warning(f"{len(hm3_snp_names) - n_hm3} HM3 SNPs not found in chromosome {chromosome} reference plink panel")
106
+
107
+ logger.info(f"Found {n_hm3} HM3 SNPs in chromosome {chromosome}")
108
+
109
+ # Calculate batch boundaries
110
+ # We want batches of size batch_size_hm3
111
+ # Start indices: 0, batch_size, 2*batch_size, ...
112
+ batch_starts = np.arange(0, n_hm3, batch_size_hm3)
113
+ # End indices: batch_size, 2*batch_size, ..., n_hm3
114
+ batch_ends = np.minimum(batch_starts + batch_size_hm3, n_hm3)
115
+
116
+ n_batches = len(batch_starts)
117
+
118
+ # Get indices in BIM for start and end of each batch
119
+ # hm3_indices_all contains the BIM indices of HM3 SNPs
120
+ # We need the BIM index of the first and last HM3 SNP in each batch
121
+
122
+ # First HM3 SNP in each batch
123
+ batch_start_bim_indices = hm3_indices_all[batch_starts]
124
+ # Last HM3 SNP in each batch (indices are exclusive in slice, so -1 for element access)
125
+ batch_end_bim_indices = hm3_indices_all[batch_ends - 1]
126
+
127
+ # Get coordinates for these batch boundaries
128
+ min_coords = coords[batch_start_bim_indices]
129
+ max_coords = coords[batch_end_bim_indices]
130
+
131
+ # Calculate window boundaries
132
+ window_starts = min_coords - max_dist
133
+ window_ends = max_coords + max_dist
134
+
135
+ if ld_unit == "SNP":
136
+ # For SNP unit, coordinates are indices, so we just clip
137
+ ref_starts = np.maximum(window_starts, 0).astype(int)
138
+ ref_ends = np.minimum(window_ends, len(coords)).astype(int)
139
+ else:
140
+ # Vectorized searchsorted
141
+ # Find insertion points for all window starts and ends at once
142
+ ref_starts = np.searchsorted(coords, window_starts, side='left')
143
+ ref_ends = np.searchsorted(coords, window_ends, side='right')
144
+
145
+ # Construct BatchInfo objects
146
+ batch_infos = []
147
+ for i in range(n_batches):
148
+ # Slice the HM3 indices for this batch
149
+ b_hm3_indices = hm3_indices_all[batch_starts[i]:batch_ends[i]]
150
+
151
+ batch_infos.append(BatchInfo(
152
+ chromosome=chromosome,
153
+ hm3_indices=b_hm3_indices,
154
+ ref_start_idx=int(ref_starts[i]),
155
+ ref_end_idx=int(ref_ends[i]),
156
+ ))
157
+
158
+ logger.info(f"Created {len(batch_infos)} batches for chromosome {chromosome}")
159
+ if len(batch_infos) > 0:
160
+ avg_size = np.mean(ref_ends - ref_starts)
161
+ logger.info(f" Average reference block size: {avg_size:.0f} SNPs")
162
+
163
+ return batch_infos
@@ -0,0 +1,126 @@
1
+ """
2
+ Simplified LD score computation without masking or padding.
3
+
4
+ Direct computation of unbiased L2 statistics from genotype matrices using NumPy and Scipy.
5
+ """
6
+
7
+
8
+ import numpy as np
9
+ import scipy.sparse
10
+
11
+ from .constants import LDSC_BIAS_CORRECTION_DF
12
+
13
+
14
+ def compute_unbiased_l2_batch(
15
+ X_hm3: np.ndarray,
16
+ X_ref_block: np.ndarray,
17
+ ) -> np.ndarray:
18
+ """
19
+ Compute unbiased LD scores (L2) for HM3 SNPs against reference block.
20
+
21
+ The unbiased L2 estimator is:
22
+ L2 = r^2 - (1 - r^2) / (N - 2)
23
+
24
+ where r^2 is the squared correlation and N is the number of individuals.
25
+
26
+ Parameters
27
+ ----------
28
+ X_hm3 : np.ndarray
29
+ Standardized genotypes for HM3 SNPs, shape (n_individuals, n_hm3_snps)
30
+ X_ref_block : np.ndarray
31
+ Standardized genotypes for reference block, shape (n_individuals, n_ref_snps)
32
+
33
+ Returns
34
+ -------
35
+ np.ndarray
36
+ Unbiased LD scores, shape (n_hm3_snps, n_ref_snps)
37
+ """
38
+ n_individuals = X_hm3.shape[0]
39
+
40
+ # Compute correlation matrix: r = (1/N) * X_hm3^T @ X_ref_block
41
+ # shape: (n_hm3_snps, n_ref_snps)
42
+ r = np.dot(X_hm3.T, X_ref_block) / n_individuals
43
+
44
+ # Compute r^2
45
+ r_squared = r ** 2
46
+
47
+ # Apply bias correction
48
+ bias_correction = (1.0 - r_squared) / (n_individuals - LDSC_BIAS_CORRECTION_DF)
49
+ l2_unbiased = r_squared - bias_correction
50
+
51
+ return l2_unbiased
52
+
53
+
54
+ def compute_ld_scores(
55
+ X_hm3: np.ndarray,
56
+ X_ref_block: np.ndarray,
57
+ ) -> np.ndarray:
58
+ """
59
+ Compute LD scores by summing unbiased L2 over reference SNPs.
60
+
61
+ Parameters
62
+ ----------
63
+ X_hm3 : np.ndarray
64
+ Standardized genotypes for HM3 SNPs, shape (n_individuals, n_hm3_snps)
65
+ X_ref_block : np.ndarray
66
+ Standardized genotypes for reference block, shape (n_individuals, n_ref_snps)
67
+
68
+ Returns
69
+ -------
70
+ np.ndarray
71
+ LD scores for each HM3 SNP, shape (n_hm3_snps,)
72
+ """
73
+ # Compute unbiased L2 matrix
74
+ l2_unbiased = compute_unbiased_l2_batch(X_hm3, X_ref_block)
75
+
76
+ # Sum over reference SNPs (axis=1)
77
+ ld_scores = np.sum(l2_unbiased, axis=1)
78
+
79
+ return ld_scores
80
+
81
+
82
+ def compute_batch_weights_sparse(
83
+ X_hm3: np.ndarray,
84
+ X_ref_block: np.ndarray,
85
+ block_mapping_matrix: scipy.sparse.csr_matrix | np.ndarray,
86
+ ) -> np.ndarray:
87
+ """
88
+ Compute LD score weight matrix using matrix multiplication.
89
+
90
+ Works for both sparse and dense mapping matrices.
91
+ Weights = L2_Unbiased @ Mapping_Matrix
92
+
93
+ Parameters
94
+ ----------
95
+ X_hm3 : np.ndarray
96
+ Standardized genotypes for HM3 SNPs, shape (n_individuals, n_hm3_snps)
97
+ X_ref_block : np.ndarray
98
+ Standardized genotypes for reference block, shape (n_individuals, n_ref_snps)
99
+ block_mapping_matrix : Union[scipy.sparse.csr_matrix, np.ndarray]
100
+ Mapping matrix for the reference block, shape (n_ref_snps, n_features).
101
+ Can be a sparse CSR matrix (from creating_snp_feature_map) or a dense array (from annotations).
102
+
103
+ Returns
104
+ -------
105
+ weights : np.ndarray
106
+ Weight matrix, shape (n_hm3_snps, n_features)
107
+ """
108
+ # 1. Compute unbiased L2 matrix: (n_hm3_snps, n_ref_snps)
109
+ l2_unbiased = compute_unbiased_l2_batch(X_hm3, X_ref_block)
110
+
111
+ # 2. Compute Weights: W = L2 @ M
112
+ # (n_hm3, n_ref) @ (n_ref, n_features) -> (n_hm3, n_features)
113
+ # numpy dot handles dense @ dense
114
+ # scipy.sparse handles dense @ sparse
115
+ if scipy.sparse.issparse(block_mapping_matrix):
116
+ weights = l2_unbiased @ block_mapping_matrix
117
+ else:
118
+ weights = np.dot(l2_unbiased, block_mapping_matrix)
119
+
120
+ # Ensure output is a dense numpy array
121
+ if scipy.sparse.issparse(weights):
122
+ weights = weights.toarray()
123
+ elif isinstance(weights, np.matrix):
124
+ weights = np.asarray(weights)
125
+
126
+ return weights
@@ -0,0 +1,70 @@
1
+ """
2
+ Constants and magic numbers used in LD score regression and spatial LDSC.
3
+
4
+ This module centralizes all numerical constants to improve code readability
5
+ and maintainability.
6
+ """
7
+
8
+ # === LD Score Calculation Constants ===
9
+
10
+ # Degrees of freedom for unbiased L2 estimator bias correction
11
+ # The unbiased estimator is: r^2 - (1 - r^2) / (N - LDSC_BIAS_CORRECTION_DF)
12
+ LDSC_BIAS_CORRECTION_DF = 2.0
13
+
14
+ # Default LD window sizes
15
+ DEFAULT_LD_WINDOW_CM = 1.0 # centiMorgans
16
+ DEFAULT_LD_WINDOW_KB = 1000 # kilobases
17
+ DEFAULT_LD_WINDOW_MB = 1.0 # megabases (in base pairs)
18
+
19
+ # === Genomic Control Constants ===
20
+
21
+ # Median of chi-squared distribution with 1 degree of freedom
22
+ # Used for calculating genomic control lambda (λGC):
23
+ # λGC = median(χ²) / CHI_SQUARED_1DF_MEDIAN
24
+ # where χ² = Z^2 from GWAS summary statistics
25
+ CHI_SQUARED_1DF_MEDIAN = 0.4559364
26
+
27
+ # === Statistical Thresholds ===
28
+
29
+ # Standard genome-wide significance level (GWAS)
30
+ GWAS_SIGNIFICANCE_ALPHA = 5e-8
31
+
32
+ # Standard nominal significance level
33
+ NOMINAL_SIGNIFICANCE_ALPHA = 0.05
34
+
35
+ # FDR significance threshold (commonly used for spatial analysis)
36
+ FDR_SIGNIFICANCE_ALPHA = 0.001
37
+
38
+ # === Numerical Stability Constants ===
39
+
40
+ # Minimum p-value for log transformation (prevents log(0) = -inf)
41
+ # log10(1e-300) ≈ -300, which is reasonable for visualization
42
+ MIN_P_VALUE = 1e-300
43
+
44
+ # Maximum p-value (ceiling for numerical stability)
45
+ MAX_P_VALUE = 1.0
46
+
47
+ # Minimum MAF (minor allele frequency) for SNP filtering
48
+ DEFAULT_MAF_THRESHOLD = 0.05
49
+
50
+ # === Storage and Precision Constants ===
51
+
52
+ # Default dtype for LD score storage (saves ~75% disk space vs float32)
53
+ LDSCORE_STORAGE_DTYPE = "float16"
54
+
55
+ # Default dtype for computation
56
+ LDSCORE_COMPUTE_DTYPE = "float32"
57
+
58
+ # === Regression Constants ===
59
+
60
+ # Default number of jackknife blocks for standard error estimation
61
+ DEFAULT_N_BLOCKS = 200
62
+
63
+ # Minimum number of SNPs required for reliable LD score estimation
64
+ MIN_SNPS_FOR_LDSC = 200
65
+
66
+ # === Window Quantization Constants ===
67
+
68
+ # Default number of bins for dynamic programming quantization
69
+ # (balances JIT compilation overhead vs padding waste)
70
+ DEFAULT_QUANTIZATION_BINS = 20