gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
gsMap/__init__.py ADDED
@@ -0,0 +1,13 @@
1
+ """Genetically informed spatial mapping of cells for complex traits."""
2
+
3
+ import logging
4
+ from importlib.metadata import version
5
+
6
+ # Package name and version
7
+ package_name = "gsMap3D"
8
+ __version__ = version(package_name)
9
+
10
+ # Set up logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger("gsMap")
13
+ logger.propagate = False
gsMap/__main__.py ADDED
@@ -0,0 +1,4 @@
1
+ from .cli import app
2
+
3
+ if __name__ == "__main__":
4
+ app()
@@ -0,0 +1,342 @@
1
+ import logging
2
+ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import scanpy as sc
7
+ import scipy as sp
8
+ from rich.progress import Progress
9
+
10
+ from gsMap.config.cauchy_config import CauchyCombinationConfig
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ def load_ldsc(path):
15
+ logger.debug(f'Loading {path}')
16
+ df = pd.read_csv(path)
17
+ df['log10_p'] = -np.log10(df['p'])
18
+ # Clean up spot index
19
+ df['spot'] = df['spot'].astype(str)
20
+ df.set_index('spot', inplace=True)
21
+ # drop nan
22
+ df = df.dropna()
23
+ return df
24
+
25
+ def load_ldsc_with_key(key_path_tuple):
26
+ """Helper function to load LDSC and return with its key"""
27
+ key, path = key_path_tuple
28
+ df = load_ldsc(path)
29
+ # Select log10_p and rename with the key
30
+ df_subset = df[['log10_p']].rename(columns={'log10_p': key})
31
+ return key, df_subset
32
+
33
+ def join_ldsc_results(paths_dict, columns_to_keep=None, max_workers=None):
34
+ """
35
+ Load and join LDSC results from multiple paths using ProcessPoolExecutor.
36
+ Each log10_p column is renamed with the dictionary key.
37
+ """
38
+ if columns_to_keep is None:
39
+ columns_to_keep = ['log10_p']
40
+ dfs_dict = {}
41
+
42
+ # Use ProcessPoolExecutor for parallel loading
43
+ with ProcessPoolExecutor(max_workers=max_workers) as executor:
44
+ # Submit all tasks
45
+ future_to_key = {
46
+ executor.submit(load_ldsc_with_key, (key, path)): key
47
+ for key, path in paths_dict.items()
48
+ }
49
+
50
+ # Collect results as they complete
51
+ for future in as_completed(future_to_key):
52
+ key, df_subset = future.result()
53
+ dfs_dict[key] = df_subset
54
+
55
+ # Maintain original order
56
+ dfs_to_join = [dfs_dict[key] for key in paths_dict.keys()]
57
+
58
+ # OPTIMIZED JOIN: Use pd.concat which is much faster than sequential joins
59
+ df_merged = pd.concat(dfs_to_join, axis=1, join='inner', sort=False)
60
+
61
+ return df_merged
62
+
63
+ def _acat_test(pvalues: np.ndarray, weights=None):
64
+ """
65
+ Aggregated Cauchy Association Test (ACAT)
66
+ """
67
+ if np.any(np.isnan(pvalues)):
68
+ raise ValueError("Cannot have NAs in the p-values.")
69
+ if np.any((pvalues > 1) | (pvalues < 0)):
70
+ raise ValueError("P-values must be between 0 and 1.")
71
+
72
+ # Handle exact 0 or 1
73
+ if np.any(pvalues == 0):
74
+ return 0.0
75
+ if np.any(pvalues == 1):
76
+ if np.all(pvalues == 1):
77
+ return 1.0
78
+ # If mixed 1s and <1s, 1s contribute nothing to the sum of tans likely,
79
+ # but let's follow the standard logic: tan((0.5-1)*pi) -> tan(-0.5pi) -> -inf.
80
+ # This implementation handles small p-values carefully but large ones (near 1)
81
+ # result in negative large stats which is fine (large CDF).
82
+ pass
83
+
84
+ if weights is None:
85
+ weights = np.full(len(pvalues), 1 / len(pvalues))
86
+ else:
87
+ if len(weights) != len(pvalues):
88
+ raise Exception("Length of weights and p-values differs.")
89
+ if any(weights < 0):
90
+ raise Exception("All weights must be positive.")
91
+ weights = np.array(weights) / np.sum(weights)
92
+
93
+ is_small = pvalues < 1e-16
94
+ is_large = ~is_small
95
+
96
+ cct_stat = 0.0
97
+ if np.any(is_small):
98
+ cct_stat += np.sum((weights[is_small] / pvalues[is_small]) / np.pi)
99
+
100
+ if np.any(is_large):
101
+ cct_stat += np.sum(weights[is_large] * np.tan((0.5 - pvalues[is_large]) * np.pi))
102
+
103
+ if cct_stat > 1e15:
104
+ pval = (1 / cct_stat) / np.pi
105
+ else:
106
+ pval = 1 - sp.stats.cauchy.cdf(cct_stat)
107
+
108
+ return pval
109
+
110
+ def remove_outliers_IQR(data, threshold_factor=3.0):
111
+ """
112
+ Remove outliers using IQR method on log10 p-values.
113
+ p_values_filtered = p_values[p_values_log < median_log + 3 * iqr_log]
114
+ """
115
+ if len(data) == 0:
116
+ return data, np.array([], dtype=bool)
117
+
118
+ median = np.median(data)
119
+ q75, q25 = np.percentile(data, [75, 25])
120
+ iqr = q75 - q25
121
+
122
+ cutoff = median + threshold_factor * iqr
123
+
124
+ # Filter: keep values less than cutoff
125
+ mask = data < cutoff
126
+ data_passed = data[mask]
127
+
128
+ return data_passed, mask
129
+
130
+ def process_trait(trait, anno_data, all_data, annotation, annotation_col):
131
+ """
132
+ Process a single trait for a given annotation: calculate Cauchy combination p-value.
133
+ """
134
+ # Calculate significance threshold (Bonferroni correction)
135
+ sig_threshold = 0.05 / len(all_data)
136
+
137
+ # Get p-values for this annotation and trait
138
+ if trait not in anno_data.columns:
139
+ logger.warning(f"Trait {trait} not found in data columns.")
140
+ return None
141
+
142
+ log10p = anno_data[trait].values
143
+ log10p, mask = remove_outliers_IQR(log10p)
144
+ p_values = 10 ** (-log10p) # convert from log10(p) to p
145
+
146
+ # Calculate Cauchy combination and mean/median metrics
147
+ if len(p_values) == 0:
148
+ p_cauchy_val = 1.0
149
+ p_median_val = 1.0
150
+ top_95_quantile = 0.0
151
+ else:
152
+ p_cauchy_val = _acat_test(p_values)
153
+ p_median_val = np.median(p_values)
154
+
155
+ # Calculate 95% quantile of -log10pvalue (referred to as top_95_quantile)
156
+ if len(log10p) > 0:
157
+ top_95_quantile = np.quantile(log10p, 0.95)
158
+ else:
159
+ top_95_quantile = 0.0
160
+
161
+ # Calculate significance statistics
162
+ sig_spots_in_anno = np.sum(p_values < sig_threshold)
163
+ total_spots_in_anno = len(p_values)
164
+
165
+ return {
166
+ 'trait': trait,
167
+ 'annotation_name': annotation_col,
168
+ 'annotation': annotation,
169
+ 'p_cauchy': p_cauchy_val,
170
+ 'p_median': p_median_val,
171
+ 'top_95_quantile': top_95_quantile,
172
+ 'sig_spots': sig_spots_in_anno,
173
+ 'total_spots': total_spots_in_anno,
174
+ }
175
+
176
+ def run_cauchy_on_dataframe(df, annotation_col, trait_cols=None, extra_group_col=None):
177
+ """
178
+ Run Cauchy combination test on a dataframe.
179
+
180
+ Args:
181
+ df: DataFrame containing log10(p) values and annotation column.
182
+ annotation_col: Name of column containing annotations.
183
+ trait_cols: List of trait columns. If None, uses all columns except annotation_col.
184
+ extra_group_col: Optional extra column to group by (e.g., 'sample_name').
185
+
186
+ Returns:
187
+ DataFrame with results.
188
+ """
189
+ if trait_cols is None:
190
+ trait_cols = [c for c in df.columns if c not in [annotation_col, extra_group_col] and c != 'spot']
191
+
192
+ all_results = []
193
+
194
+ # Define the groups
195
+ group_cols = [annotation_col]
196
+ if extra_group_col:
197
+ group_cols.append(extra_group_col)
198
+
199
+ # Pre-calculate groups to avoid repeated filtering
200
+ grouped = df.groupby(group_cols, observed=True)
201
+
202
+ for trait in trait_cols:
203
+ def process_one_group(group_key, trait=trait):
204
+ df_group = grouped.get_group(group_key)
205
+
206
+ # Usually enrichment is within the same context.
207
+ # If extra_group_col is sample_name, background should be other annotations in the SAME sample.
208
+ if extra_group_col:
209
+ # Group key is (anno, extra)
210
+ anno, extra = group_key
211
+ df_background = df[df[extra_group_col] == extra]
212
+ # We need all_data for process_trait to calculate background
213
+ res = process_trait(trait, df_group, df_background, anno, annotation_col)
214
+ if res:
215
+ res[extra_group_col] = extra
216
+ else:
217
+ # Group key is just (anno,)
218
+ anno = group_key[0] if isinstance(group_key, tuple) else group_key
219
+ res = process_trait(trait, df_group, df, anno, annotation_col)
220
+
221
+ return res
222
+
223
+ group_keys = list(grouped.groups.keys())
224
+
225
+ # Use rich progress bar
226
+ with Progress(transient=True) as progress:
227
+ task = progress.add_task(f"[green]Processing {trait}...", total=len(group_keys))
228
+
229
+ # Parallelize over groups
230
+ with ThreadPoolExecutor(max_workers=None) as executor:
231
+ # Submit all tasks
232
+ future_to_group = {executor.submit(process_one_group, g): g for g in group_keys}
233
+
234
+ for future in as_completed(future_to_group):
235
+ res = future.result()
236
+ if res is not None:
237
+ all_results.append(res)
238
+ progress.advance(task)
239
+
240
+ if not all_results:
241
+ return pd.DataFrame()
242
+
243
+ combined_results = pd.DataFrame(all_results)
244
+ sort_cols = ['trait', 'p_median']
245
+ if extra_group_col:
246
+ sort_cols = [extra_group_col] + sort_cols
247
+
248
+ combined_results.sort_values(by=sort_cols, ascending=True, inplace=True)
249
+
250
+ return combined_results
251
+
252
+
253
+
254
+ def run_Cauchy_combination(config: CauchyCombinationConfig):
255
+ # 1. Discover traits and load combined LDSC results
256
+ traits_dict = config.ldsc_traits_result_path_dict
257
+
258
+ logger.info(f"Joining LDSC results for {len(traits_dict)} traits...")
259
+ df_combined = join_ldsc_results(traits_dict)
260
+
261
+ # 2. Add Annotation & Metadata
262
+ logger.info("------Loading annotations...")
263
+ latent_adata_path = config.concatenated_latent_adata_path
264
+ if not latent_adata_path.exists():
265
+ raise FileNotFoundError(f"Latent adata with annotations not found at: {latent_adata_path}")
266
+
267
+ logger.info(f"Loading metadata from {latent_adata_path}")
268
+ adata = sc.read_h5ad(latent_adata_path)
269
+
270
+ # Check for all requested annotation columns
271
+ for anno in config.annotation_list:
272
+ if anno not in adata.obs.columns:
273
+ raise ValueError(f"Annotation column '{anno}' not found in adata.obs.")
274
+
275
+ # Check for sample_name column
276
+ sample_col = 'sample_name'
277
+ if sample_col not in adata.obs.columns:
278
+ logger.warning("'sample_name' column not found in adata.obs. Sample-level Cauchy will be skipped.")
279
+ sample_col = None
280
+
281
+ # Filter to common spots
282
+ common_cells = np.intersect1d(df_combined.index, adata.obs_names)
283
+ logger.info(f"Found {len(common_cells)} common spots between LDSC results and metadata.")
284
+ if len(common_cells) == 0:
285
+ logger.warning("No common spots found. Skipping...")
286
+ return
287
+
288
+ df_combined = df_combined.loc[common_cells].copy()
289
+
290
+ # Add metadata columns
291
+ metadata_cols = config.annotation_list.copy()
292
+ if sample_col:
293
+ metadata_cols.append(sample_col)
294
+
295
+ for col in metadata_cols:
296
+ series = adata.obs.loc[common_cells, col]
297
+ if isinstance(series.dtype, pd.CategoricalDtype):
298
+ # Efficiently handle NaNs in categorical data by adding a 'NaN' category
299
+ if series.isna().any():
300
+ if 'NaN' not in series.cat.categories:
301
+ series = series.cat.add_categories('NaN')
302
+ df_combined[col] = series.fillna('NaN')
303
+ else:
304
+ df_combined[col] = series
305
+ else:
306
+ # For non-categorical, fillna and ensure string type for consistency
307
+ df_combined[col] = series.fillna('NaN').astype(str).replace(['nan', 'None'], 'NaN')
308
+
309
+ # 3. Save combined data to parquet
310
+ logger.info(f"Saving combined data to {config.ldsc_combined_parquet_path}")
311
+ df_to_save = df_combined.reset_index().rename(columns={'index': 'spot'})
312
+ df_to_save.to_parquet(config.ldsc_combined_parquet_path)
313
+
314
+ # 4. Process each annotation and each trait
315
+ trait_cols = list(traits_dict.keys())
316
+ for annotation_col in config.annotation_list:
317
+ logger.info(f"=== Processing Cauchy combination for annotation: {annotation_col} ===")
318
+
319
+ # Run Cauchy Combination (Annotation Level)
320
+ result_df = run_cauchy_on_dataframe(df_combined,
321
+ annotation_col=annotation_col,
322
+ trait_cols=trait_cols)
323
+
324
+ # Save results per trait
325
+ for trait_name in trait_cols:
326
+ trait_result = result_df[result_df['trait'] == trait_name]
327
+ output_file = config.get_cauchy_result_file(trait_name, annotation=annotation_col, all_samples=True)
328
+ trait_result.to_csv(output_file, index=False)
329
+
330
+ # Run Cauchy Combination (Sample-Annotation level)
331
+ if sample_col:
332
+ sample_result_df = run_cauchy_on_dataframe(df_combined,
333
+ annotation_col=annotation_col,
334
+ trait_cols=trait_cols,
335
+ extra_group_col=sample_col)
336
+
337
+ for trait_name in trait_cols:
338
+ trait_sample_result = sample_result_df[sample_result_df['trait'] == trait_name]
339
+ sample_output_file = config.get_cauchy_result_file(trait_name, annotation=annotation_col, all_samples=False)
340
+ trait_sample_result.to_csv(sample_output_file, index=False)
341
+
342
+ logger.info("Cauchy combination processing completed.")