gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
gsMap/ldscore/io.py ADDED
@@ -0,0 +1,262 @@
1
+ """
2
+ I/O utilities for loading genomic data in the LD score framework.
3
+
4
+ This module provides optimized readers for PLINK binary files and omics feature matrices,
5
+ leveraging pandas-plink for efficient I/O and standardizing data with NumPy.
6
+ """
7
+
8
+ import logging
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import xarray as xr
14
+ from pandas_plink import read_plink1_bin
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class PlinkBEDReader:
20
+ """
21
+ Optimized reader for PLINK binary files using pandas-plink and Xarray.
22
+
23
+ Loads genotypes lazily via Dask/Xarray, applies matrix-based MAF calculation
24
+ and filtering, and converts to standardized NumPy arrays.
25
+
26
+ Attributes
27
+ ----------
28
+ bfile : str
29
+ Base filename prefix for PLINK files
30
+ G : xr.DataArray
31
+ The underlying xarray DataArray (samples x variants) containing genotypes (0, 1, 2, nan)
32
+ bim : pd.DataFrame
33
+ BIM file data (SNP information, filtered)
34
+ fam : pd.DataFrame
35
+ FAM file data (individual information)
36
+ m : int
37
+ Number of SNPs (after filtering)
38
+ n : int
39
+ Number of individuals
40
+ genotypes : np.ndarray
41
+ Pre-loaded and standardized genotype matrix (n_individuals, m_snps)
42
+ maf : np.ndarray
43
+ Minor allele frequency for each SNP
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ bfile_prefix: str,
49
+ maf_min: float | None = None,
50
+ keep_snps: list[str] | None = None,
51
+ preload: bool = True
52
+ ):
53
+ """
54
+ Initialize PlinkBEDReader with optional filtering.
55
+
56
+ Parameters
57
+ ----------
58
+ bfile_prefix : str
59
+ PLINK file prefix (without .bed/.bim/.fam extension)
60
+ maf_min : float, optional
61
+ Minimum MAF threshold for SNP filtering (default: None, no filtering)
62
+ keep_snps : list[str], optional
63
+ List of SNP IDs to keep (default: None, keep all)
64
+ preload : bool, optional
65
+ Whether to pre-load and standardize genotypes into memory (default: True)
66
+ """
67
+ self.bfile = bfile_prefix
68
+
69
+ # Construct paths
70
+ bed_path = f"{bfile_prefix}.bed"
71
+ bim_path = f"{bfile_prefix}.bim"
72
+ fam_path = f"{bfile_prefix}.fam"
73
+
74
+ # Validate existence
75
+ if not (Path(bed_path).exists() and Path(bim_path).exists() and Path(fam_path).exists()):
76
+ raise FileNotFoundError(f"One or more PLINK files missing for prefix: {bfile_prefix}")
77
+
78
+ logger.info(f"Loading PLINK files from: {bfile_prefix}")
79
+
80
+ # Load using pandas-plink
81
+ # This returns an xarray DataArray with dask backing (lazy loading)
82
+ # Shape is (sample, variant)
83
+ self.G = read_plink1_bin(bed_path, bim_path, fam_path, verbose=False)
84
+
85
+ # Initial dimensions
86
+ self.n_original = self.G.sizes['sample']
87
+ self.m_original = self.G.sizes['variant']
88
+ self.snp_ids_original = pd.Index(self.G.snp.values)
89
+
90
+ logger.info(f"Loaded metadata: {self.m_original} SNPs × {self.n_original} individuals")
91
+
92
+ # Calculate MAF using matrix operations (lazy execution via dask)
93
+ logger.info("Calculating MAF ...")
94
+ self.maf = self._calculate_maf()
95
+
96
+ # Apply filters to the xarray DataArray
97
+ self._apply_filters(maf_min=maf_min, keep_snps=keep_snps)
98
+
99
+ # Apply Basic QC (Filter Monomorphic and All-Missing)
100
+ # This must happen before _sync_metadata so BIM reflects valid SNPs only
101
+ self._apply_basic_qc()
102
+
103
+ # Update dimensions after filtering
104
+ self.n = self.G.sizes['sample']
105
+ self.m = self.G.sizes['variant']
106
+
107
+ # Extract metadata DataFrames from xarray coordinates for compatibility
108
+ self._sync_metadata()
109
+
110
+ # Pre-load genotypes if requested
111
+ self.genotypes = None
112
+ if preload:
113
+ logger.info(f"Pre-loading and standardizing {self.m} SNPs...")
114
+ self.genotypes = self._load_and_standardize_all()
115
+ logger.info(f"✓ Genotypes ready: {self.genotypes.shape}")
116
+
117
+ logger.info(f"PlinkBEDReader initialized: {self.m} SNPs × {self.n} individuals")
118
+
119
+ def _calculate_maf(self) -> xr.DataArray:
120
+ """
121
+ Calculate Minor Allele Frequency using matrix operations on the DataArray.
122
+ """
123
+ # Calculate mean across samples (axis 0), ignoring NaNs
124
+ # The result 'freq_a1' represents the frequency of the A1 allele (coded as 2)
125
+ freq_a1 = self.G.mean(dim="sample", skipna=True) / 2.0
126
+
127
+ # Calculate MAF: min(f, 1-f)
128
+ maf = xr.ufuncs.minimum(freq_a1, 1.0 - freq_a1)
129
+
130
+ return maf.compute()
131
+
132
+ def _apply_filters(
133
+ self,
134
+ maf_min: float | None = None,
135
+ keep_snps: list[str] | None = None
136
+ ) -> None:
137
+ """
138
+ Apply SNP filters directly to the xarray DataArray.
139
+ """
140
+ # 1. Create a boolean mask for variants
141
+ variant_ids = self.G.variant.values
142
+ mask = np.ones(len(variant_ids), dtype=bool)
143
+
144
+ # 2. Apply MAF filter
145
+ if maf_min is not None and maf_min > 0:
146
+ maf_mask = (self.maf >= maf_min).values
147
+ mask &= maf_mask
148
+
149
+ n_removed_maf = np.sum(~maf_mask)
150
+ if n_removed_maf > 0:
151
+ logger.info(f"Filtered {n_removed_maf} SNPs with MAF < {maf_min}")
152
+
153
+ # 3. Apply Keep List filter
154
+ if keep_snps is not None:
155
+ current_snps = self.G.snp.values
156
+ keep_set = set(keep_snps)
157
+ snp_mask = np.isin(current_snps, list(keep_set))
158
+
159
+ mask &= snp_mask
160
+ n_removed_snp = np.sum(~snp_mask)
161
+ if n_removed_snp > 0:
162
+ logger.info(f"Filtered {n_removed_snp} SNPs not in keep list")
163
+
164
+ # 4. Filter the main DataArray
165
+ n_before = self.G.sizes['variant']
166
+ self.G = self.G.isel(variant=mask)
167
+ self.maf = self.maf[mask]
168
+
169
+ n_after = self.G.sizes['variant']
170
+ if n_before != n_after:
171
+ logger.info(f"Total SNPs filtered: {n_before - n_after}/{n_before}")
172
+
173
+ def _apply_basic_qc(self) -> None:
174
+ """
175
+ Filter out monomorphic variants (std=0) and variants with all missing values.
176
+ """
177
+ logger.info("Applying basic QC (removing monomorphic and all-missing variants)...")
178
+
179
+ # Calculate stats lazily via dask
180
+ # 1. Standard Deviation (skipna=True handles missing)
181
+ stds = self.G.std(dim="sample", skipna=True)
182
+
183
+ # 2. Count of non-missing values
184
+ counts = self.G.count(dim="sample")
185
+
186
+ # Compute to get numpy arrays for boolean masking
187
+ stds_val = stds.values
188
+ counts_val = counts.values
189
+
190
+ # Create masks
191
+ # Monomorphic: std == 0 (or very close to 0)
192
+ # All missing: count == 0
193
+ mask_polymorphic = (stds_val > 0)
194
+ mask_not_empty = (counts_val > 0)
195
+
196
+ mask = mask_polymorphic & mask_not_empty
197
+
198
+ n_removed = np.sum(~mask)
199
+ if n_removed > 0:
200
+ logger.info(f"QC: Filtered {n_removed} variants (monomorphic or all-missing)")
201
+
202
+ # Apply filter
203
+ self.G = self.G.isel(variant=mask)
204
+ self.maf = self.maf[mask]
205
+ else:
206
+ logger.info("QC: No monomorphic or all-missing variants found.")
207
+
208
+ def _sync_metadata(self):
209
+
210
+ self.bim = pd.DataFrame({
211
+ 'CHR': self.G.chrom.values,
212
+ 'SNP': self.G.snp.values,
213
+ 'CM': self.G.cm.values,
214
+ 'BP': self.G.pos.values,
215
+ 'A1': self.G.a1.values,
216
+ 'A2': self.G.a0.values,
217
+ 'i': np.arange(self.m)
218
+ })
219
+ self.bim['MAF'] = self.maf.values
220
+
221
+ # pandas-plink stores FAM info in coordinates
222
+ self.fam = pd.DataFrame({
223
+ 'fid': self.G.fid.values,
224
+ 'iid': self.G.iid.values,
225
+ 'father': self.G.father.values,
226
+ 'mother': self.G.mother.values,
227
+ 'gender': self.G.gender.values,
228
+ 'trait': self.G.trait.values
229
+ })
230
+
231
+ def _load_and_standardize_all(self) -> np.ndarray:
232
+ """
233
+ Load genotypes from Dask into memory, convert to NumPy, and standardize.
234
+
235
+ We assume basic QC (monomorphic/all-missing removal) has already run.
236
+
237
+ Returns
238
+ -------
239
+ np.ndarray
240
+ Standardized genotype matrix of shape (n_individuals, m_snps)
241
+ """
242
+ logger.info("Reading filtered genotype matrix into memory...")
243
+ X = self.G.values.astype(np.float32)
244
+
245
+ # Compute stats (ignoring NaNs)
246
+ means = np.nanmean(X, axis=0)
247
+ stds = np.nanstd(X, axis=0)
248
+
249
+ # Impute missing values with column means
250
+ nan_mask = np.isnan(X)
251
+
252
+ # Broadcasting means: (m,) -> (1, m) to match (n, m) for X which is (n, m)
253
+ # Note: In X, rows=individuals, cols=snps. means shape is (m_snps,).
254
+ # We need to broadcast properly.
255
+ # X is (n, m), means is (m,). numpy broadcasts (m,) to (n, m) automatically on last dim.
256
+ X = np.where(nan_mask, means, X)
257
+
258
+ # Standardize: (X - mean) / std
259
+ # Standard broadcasting rules apply: (n, m) - (m,) -> (n, m)
260
+ X_std = (X - means) / stds
261
+
262
+ return X_std
@@ -0,0 +1,262 @@
1
+ import logging
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import pyranges as pr
6
+ import scipy.sparse
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def create_snp_feature_map(
12
+ bim_df: pd.DataFrame,
13
+ mapping_type: str,
14
+ mapping_data: pd.DataFrame | dict[str, str],
15
+ feature_window_size: int = 0,
16
+ strategy: str = "score",
17
+ ) -> tuple[scipy.sparse.csr_matrix, list[str], pd.DataFrame | None]:
18
+ """
19
+ Create a sparse mapping matrix assigning each SNP in the BIM file to feature indices.
20
+
21
+ Returns:
22
+ 1. Sparse matrix where rows correspond to SNPs (in bim_df order) and columns correspond to features.
23
+ 2. List of feature names corresponding to indices 0 to F-1.
24
+ 3. (Optional) DataFrame containing the curated SNP-feature mappings (for BED type).
25
+
26
+ Feature Indexing Scheme
27
+ -----------------------
28
+ - Mapped features get indices: 0, 1, 2, ..., F-1
29
+ - Unmapped bin gets index: F
30
+ - Total columns: F + 1
31
+ - Values: Score from BED file (if available, else 1.0)
32
+
33
+ Parameters
34
+ ----------
35
+ bim_df : pd.DataFrame
36
+ PLINK BIM file data (must have 'CHR', 'SNP', 'BP').
37
+ mapping_type : str
38
+ 'bed' or 'dict'.
39
+ mapping_data : Union[pd.DataFrame, Dict[str, str]]
40
+ Mapping source data. For 'bed' type, this should be a DataFrame read from
41
+ a standard BED file using pr.read_bed() with columns:
42
+ - Chromosome (chr)
43
+ - Start (0-based start position)
44
+ - End (1-based end position)
45
+ - Feature (name/identifier from 4th column)
46
+ - Score (optional, from 5th column)
47
+ - Strand (optional, from 6th column)
48
+
49
+ Note: BED files should be in standard BED6 format WITHOUT a header line.
50
+ feature_window_size : int
51
+ Window extension (for 'bed').
52
+ strategy : str
53
+ 'score', 'tss', 'center', or 'allow_repeat'.
54
+ - 'allow_repeat': A SNP can map to multiple features (values are summed or kept).
55
+ - 'score': Keep mapping with highest score per SNP.
56
+ - 'tss': Keep mapping closest to TSS per SNP.
57
+ - 'center': Keep mapping closest to the center of the feature interval per SNP.
58
+
59
+ Returns
60
+ -------
61
+ mapping_matrix : scipy.sparse.csr_matrix
62
+ Sparse matrix of shape (M, F+1).
63
+ feature_names : List[str]
64
+ List of feature names corresponding to indices 0 to F-1.
65
+ mapping_df : Optional[pd.DataFrame]
66
+ DataFrame with columns [SNP, Feature, ...] showing final mappings. None if using 'dict'.
67
+ """
68
+ m_ref = len(bim_df)
69
+ unique_feature_names = []
70
+ curated_mapping_df = None
71
+
72
+ # Prepare basic SNP info for joining
73
+ # We assign an integer index to every SNP in BIM to construct the sparse matrix later
74
+ bim_df = bim_df.copy()
75
+ bim_df['snp_row_idx'] = np.arange(m_ref)
76
+
77
+ # Intermediate storage for sparse construction
78
+ row_indices = []
79
+ col_indices = []
80
+ data_values = []
81
+
82
+ # === STRATEGY A: Dictionary Mapping ===
83
+ if mapping_type == 'dict':
84
+ if not isinstance(mapping_data, dict):
85
+ raise ValueError("mapping_data must be a dictionary when mapping_type='dict'")
86
+
87
+ # 1. Identify all unique features
88
+ unique_feature_names = sorted(set(mapping_data.values()))
89
+ feature_to_idx = {f: i for i, f in enumerate(unique_feature_names)}
90
+ n_features = len(unique_feature_names)
91
+
92
+ # 2. Map Dictionary Values to Indices
93
+ # Filter mapping data to only include SNPs present in BIM to save memory/time
94
+ bim_snps = set(bim_df['SNP'])
95
+ valid_mapping = {k: v for k, v in mapping_data.items() if k in bim_snps}
96
+
97
+ # 3. Create Sparse Entries
98
+ # Create a Series for mapping: SNP -> Feature Index
99
+ snp_to_feat_idx = pd.Series(valid_mapping).map(feature_to_idx)
100
+
101
+ # Map BIM SNPs to feature indices
102
+ mapped_feat_indices = bim_df['SNP'].map(snp_to_feat_idx)
103
+
104
+ # Drop NaNs (Unmapped)
105
+ valid_mask = mapped_feat_indices.notna()
106
+
107
+ if valid_mask.any():
108
+ row_indices = bim_df.loc[valid_mask, 'snp_row_idx'].values
109
+ col_indices = mapped_feat_indices[valid_mask].values.astype(int)
110
+ data_values = np.ones(len(row_indices), dtype=np.float32)
111
+
112
+ # === STRATEGY B: BED/Genomic Interval Mapping ===
113
+ elif mapping_type == 'bed':
114
+ if not isinstance(mapping_data, pd.DataFrame):
115
+ raise ValueError("mapping_data must be a DataFrame when mapping_type='bed'")
116
+
117
+ df_features = mapping_data.copy()
118
+ required_cols = ['Feature', 'Chromosome', 'Start', 'End']
119
+ if not all(c in df_features.columns for c in required_cols):
120
+ raise ValueError(f"BED DataFrame missing required columns: {required_cols}")
121
+
122
+ # 1. Identify all unique features
123
+ unique_feature_names = sorted(df_features['Feature'].unique())
124
+ feature_to_idx = {f: i for i, f in enumerate(unique_feature_names)}
125
+ n_features = len(unique_feature_names)
126
+
127
+ # 2. Prepare BIM for PyRanges
128
+ bim_pr_df = bim_df[['CHR', 'BP', 'SNP', 'snp_row_idx']].rename(columns={
129
+ 'CHR': 'Chromosome',
130
+ 'BP': 'Start'
131
+ })
132
+ bim_pr_df['End'] = bim_pr_df['Start'] + 1
133
+
134
+ # Clean chromosome names (remove 'chr' prefix if present)
135
+ # PLINK BIM files typically use numeric chromosome identifiers
136
+ bim_pr_df['Chromosome'] = bim_pr_df['Chromosome'].astype(str).str.replace('chr', '', case=False)
137
+ df_features['Chromosome'] = df_features['Chromosome'].astype(str).str.replace('chr', '', case=False)
138
+
139
+ pr_bim = pr.PyRanges(bim_pr_df)
140
+
141
+ # 3. Pre-calculate TSS
142
+ if strategy == 'tss':
143
+ if 'Strand' not in df_features.columns:
144
+ raise ValueError("Strategy 'tss' requires 'Strand' column in mapping data.")
145
+ df_features['RefTSS'] = np.where(
146
+ df_features['Strand'] == '+',
147
+ df_features['Start'],
148
+ df_features['End']
149
+ )
150
+
151
+ # 4. Apply Window Size
152
+ df_features['Start'] = np.maximum(0, df_features['Start'] - feature_window_size)
153
+ df_features['End'] = df_features['End'] + feature_window_size
154
+
155
+ pr_features = pr.PyRanges(df_features)
156
+
157
+ # 5. Join
158
+ # Columns from pr_features (Right) that overlap with pr_bim (Left) get suffix '_b'
159
+ # Overlapping cols usually: Start, End.
160
+ # pr_bim (SNP) Start is 'Start'. pr_features (Window) Start is 'Start_b'.
161
+ joined = pr_bim.join(pr_features, apply_strand_suffix=False).df
162
+
163
+ logger.info(f"Initial SNP-feature overlaps: {len(joined)} pairs")
164
+ if not joined.empty:
165
+ logger.info(f" Unique SNPs in overlaps: {joined['SNP'].nunique()} | Unique Features in overlaps: {joined['Feature'].nunique()}")
166
+
167
+ if not joined.empty:
168
+ # 6. Resolve Conflicts / Filter
169
+ if strategy == 'score':
170
+ if 'Score' not in joined.columns:
171
+ raise ValueError("Strategy 'score' requires 'Score' column in mapping data.")
172
+ joined = joined.sort_values(by=['SNP', 'Score'], ascending=[True, False])
173
+ joined = joined.drop_duplicates(subset=['SNP'], keep='first')
174
+
175
+ elif strategy == 'tss':
176
+ joined['distance_to_tss'] = np.abs(joined['Start'] - joined['RefTSS'])
177
+ joined = joined.sort_values(by=['SNP', 'distance_to_tss'], ascending=[True, True])
178
+ joined = joined.drop_duplicates(subset=['SNP'], keep='first')
179
+
180
+ elif strategy == 'center':
181
+ # Calculate center of the feature interval (the window around feature)
182
+ # 'Start_b' and 'End_b' are the feature window coordinates from PyRanges join
183
+ joined['feature_center'] = (joined['Start_b'] + joined['End_b']) / 2.0
184
+ # Calculate distance from SNP position ('Start') to center
185
+ joined['distance_to_center'] = np.abs(joined['Start'] - joined['feature_center'])
186
+ # Sort and pick closest
187
+ joined = joined.sort_values(by=['SNP', 'distance_to_center'], ascending=[True, True])
188
+ joined = joined.drop_duplicates(subset=['SNP'], keep='first')
189
+
190
+ elif strategy == 'allow_repeat':
191
+ # No de-duplication. One SNP can map to multiple features.
192
+ pass
193
+
194
+ # 7. Prepare Sparse Data
195
+ # Row indices come from BIM 'snp_row_idx' which is preserved in join
196
+ row_indices = joined['snp_row_idx'].values
197
+ col_indices = joined['Feature'].map(feature_to_idx).values
198
+
199
+ logger.info(f"Final SNP-feature pairs after strategy '{strategy}': {len(row_indices)}")
200
+
201
+ # Data values: Use Score if available and requested, otherwise 1.0
202
+ if 'Score' in joined.columns and strategy in ['score', 'allow_repeat']:
203
+ # Use provided score
204
+ data_values = joined['Score'].values.astype(np.float32)
205
+ else:
206
+ # Default to 1.0 for geometric strategies (tss, center) or missing score
207
+ data_values = np.ones(len(row_indices), dtype=np.float32)
208
+
209
+ # Save the result for output
210
+ # We filter columns to make it cleaner
211
+ output_cols = ['SNP', 'Chromosome', 'Start', 'Feature']
212
+ if 'Score' in joined.columns: output_cols.append('Score')
213
+ if 'distance_to_tss' in joined.columns: output_cols.append('distance_to_tss')
214
+ if 'distance_to_center' in joined.columns: output_cols.append('distance_to_center')
215
+
216
+ # Ensure columns exist before selecting
217
+ output_cols = [c for c in output_cols if c in joined.columns]
218
+ curated_mapping_df = joined[output_cols].copy()
219
+ curated_mapping_df.rename(columns={'Start': 'SNP_BP'}, inplace=True)
220
+
221
+ else:
222
+ raise ValueError(f"Unknown mapping_type: {mapping_type}")
223
+
224
+ # === Final Matrix Construction ===
225
+
226
+ # Determine Unmapped SNPs
227
+ # We create a matrix of shape (M, F+1). The last column (index F) is for unmapped.
228
+ # Any row_idx NOT present in row_indices gets a 1.0 in the last column.
229
+
230
+ mapped_rows_set = set(row_indices)
231
+ all_rows = set(range(m_ref))
232
+ unmapped_rows = list(all_rows - mapped_rows_set)
233
+
234
+ if unmapped_rows:
235
+ unmapped_rows = np.array(unmapped_rows, dtype=int)
236
+ unmapped_cols = np.full(len(unmapped_rows), n_features, dtype=int) # Last column index
237
+ unmapped_data = np.ones(len(unmapped_rows), dtype=np.float32)
238
+
239
+ # Append unmapped data
240
+ if len(row_indices) > 0:
241
+ row_indices = np.concatenate([row_indices, unmapped_rows])
242
+ col_indices = np.concatenate([col_indices, unmapped_cols])
243
+ data_values = np.concatenate([data_values, unmapped_data])
244
+ else:
245
+ row_indices = unmapped_rows
246
+ col_indices = unmapped_cols
247
+ data_values = unmapped_data
248
+
249
+ else:
250
+ # If all SNPs mapped, we technically don't need to append anything,
251
+ # but we still need the matrix shape to include the extra column.
252
+ pass
253
+
254
+ # Construct CSR Matrix
255
+ # Shape: (m_ref, n_features + 1)
256
+ mapping_matrix = scipy.sparse.csr_matrix(
257
+ (data_values, (row_indices, col_indices)),
258
+ shape=(m_ref, n_features + 1),
259
+ dtype=np.float32
260
+ )
261
+
262
+ return mapping_matrix, unique_feature_names, curated_mapping_df