gsMap3D 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/__init__.py +13 -0
- gsMap/__main__.py +4 -0
- gsMap/cauchy_combination_test.py +342 -0
- gsMap/cli.py +355 -0
- gsMap/config/__init__.py +72 -0
- gsMap/config/base.py +296 -0
- gsMap/config/cauchy_config.py +79 -0
- gsMap/config/dataclasses.py +235 -0
- gsMap/config/decorators.py +302 -0
- gsMap/config/find_latent_config.py +276 -0
- gsMap/config/format_sumstats_config.py +54 -0
- gsMap/config/latent2gene_config.py +461 -0
- gsMap/config/ldscore_config.py +261 -0
- gsMap/config/quick_mode_config.py +242 -0
- gsMap/config/report_config.py +81 -0
- gsMap/config/spatial_ldsc_config.py +334 -0
- gsMap/config/utils.py +286 -0
- gsMap/find_latent/__init__.py +3 -0
- gsMap/find_latent/find_latent_representation.py +312 -0
- gsMap/find_latent/gnn/distribution.py +498 -0
- gsMap/find_latent/gnn/encoder_decoder.py +186 -0
- gsMap/find_latent/gnn/gcn.py +85 -0
- gsMap/find_latent/gnn/gene_former.py +164 -0
- gsMap/find_latent/gnn/loss.py +18 -0
- gsMap/find_latent/gnn/st_model.py +125 -0
- gsMap/find_latent/gnn/train_step.py +177 -0
- gsMap/find_latent/st_process.py +781 -0
- gsMap/format_sumstats.py +446 -0
- gsMap/generate_ldscore.py +1018 -0
- gsMap/latent2gene/__init__.py +18 -0
- gsMap/latent2gene/connectivity.py +781 -0
- gsMap/latent2gene/entry_point.py +141 -0
- gsMap/latent2gene/marker_scores.py +1265 -0
- gsMap/latent2gene/memmap_io.py +766 -0
- gsMap/latent2gene/rank_calculator.py +590 -0
- gsMap/latent2gene/row_ordering.py +182 -0
- gsMap/latent2gene/row_ordering_jax.py +159 -0
- gsMap/ldscore/__init__.py +1 -0
- gsMap/ldscore/batch_construction.py +163 -0
- gsMap/ldscore/compute.py +126 -0
- gsMap/ldscore/constants.py +70 -0
- gsMap/ldscore/io.py +262 -0
- gsMap/ldscore/mapping.py +262 -0
- gsMap/ldscore/pipeline.py +615 -0
- gsMap/pipeline/quick_mode.py +134 -0
- gsMap/report/__init__.py +2 -0
- gsMap/report/diagnosis.py +375 -0
- gsMap/report/report.py +100 -0
- gsMap/report/report_data.py +1832 -0
- gsMap/report/static/js_lib/alpine.min.js +5 -0
- gsMap/report/static/js_lib/tailwindcss.js +83 -0
- gsMap/report/static/template.html +2242 -0
- gsMap/report/three_d_combine.py +312 -0
- gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
- gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
- gsMap/report/three_d_plot/three_d_plots.py +425 -0
- gsMap/report/visualize.py +1409 -0
- gsMap/setup.py +5 -0
- gsMap/spatial_ldsc/__init__.py +0 -0
- gsMap/spatial_ldsc/io.py +656 -0
- gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
- gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
- gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +610 -0
- gsMap/utils/jackknife.py +518 -0
- gsMap/utils/manhattan_plot.py +643 -0
- gsMap/utils/regression_read.py +177 -0
- gsMap/utils/torch_utils.py +23 -0
- gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
- gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
- gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
- gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
- gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
gsMap/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Genetically informed spatial mapping of cells for complex traits."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from importlib.metadata import version
|
|
5
|
+
|
|
6
|
+
# Package name and version
|
|
7
|
+
package_name = "gsMap3D"
|
|
8
|
+
__version__ = version(package_name)
|
|
9
|
+
|
|
10
|
+
# Set up logging
|
|
11
|
+
logging.basicConfig(level=logging.INFO)
|
|
12
|
+
logger = logging.getLogger("gsMap")
|
|
13
|
+
logger.propagate = False
|
gsMap/__main__.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import scanpy as sc
|
|
7
|
+
import scipy as sp
|
|
8
|
+
from rich.progress import Progress
|
|
9
|
+
|
|
10
|
+
from gsMap.config.cauchy_config import CauchyCombinationConfig
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
def load_ldsc(path):
|
|
15
|
+
logger.debug(f'Loading {path}')
|
|
16
|
+
df = pd.read_csv(path)
|
|
17
|
+
df['log10_p'] = -np.log10(df['p'])
|
|
18
|
+
# Clean up spot index
|
|
19
|
+
df['spot'] = df['spot'].astype(str)
|
|
20
|
+
df.set_index('spot', inplace=True)
|
|
21
|
+
# drop nan
|
|
22
|
+
df = df.dropna()
|
|
23
|
+
return df
|
|
24
|
+
|
|
25
|
+
def load_ldsc_with_key(key_path_tuple):
|
|
26
|
+
"""Helper function to load LDSC and return with its key"""
|
|
27
|
+
key, path = key_path_tuple
|
|
28
|
+
df = load_ldsc(path)
|
|
29
|
+
# Select log10_p and rename with the key
|
|
30
|
+
df_subset = df[['log10_p']].rename(columns={'log10_p': key})
|
|
31
|
+
return key, df_subset
|
|
32
|
+
|
|
33
|
+
def join_ldsc_results(paths_dict, columns_to_keep=None, max_workers=None):
|
|
34
|
+
"""
|
|
35
|
+
Load and join LDSC results from multiple paths using ProcessPoolExecutor.
|
|
36
|
+
Each log10_p column is renamed with the dictionary key.
|
|
37
|
+
"""
|
|
38
|
+
if columns_to_keep is None:
|
|
39
|
+
columns_to_keep = ['log10_p']
|
|
40
|
+
dfs_dict = {}
|
|
41
|
+
|
|
42
|
+
# Use ProcessPoolExecutor for parallel loading
|
|
43
|
+
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
|
44
|
+
# Submit all tasks
|
|
45
|
+
future_to_key = {
|
|
46
|
+
executor.submit(load_ldsc_with_key, (key, path)): key
|
|
47
|
+
for key, path in paths_dict.items()
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
# Collect results as they complete
|
|
51
|
+
for future in as_completed(future_to_key):
|
|
52
|
+
key, df_subset = future.result()
|
|
53
|
+
dfs_dict[key] = df_subset
|
|
54
|
+
|
|
55
|
+
# Maintain original order
|
|
56
|
+
dfs_to_join = [dfs_dict[key] for key in paths_dict.keys()]
|
|
57
|
+
|
|
58
|
+
# OPTIMIZED JOIN: Use pd.concat which is much faster than sequential joins
|
|
59
|
+
df_merged = pd.concat(dfs_to_join, axis=1, join='inner', sort=False)
|
|
60
|
+
|
|
61
|
+
return df_merged
|
|
62
|
+
|
|
63
|
+
def _acat_test(pvalues: np.ndarray, weights=None):
|
|
64
|
+
"""
|
|
65
|
+
Aggregated Cauchy Association Test (ACAT)
|
|
66
|
+
"""
|
|
67
|
+
if np.any(np.isnan(pvalues)):
|
|
68
|
+
raise ValueError("Cannot have NAs in the p-values.")
|
|
69
|
+
if np.any((pvalues > 1) | (pvalues < 0)):
|
|
70
|
+
raise ValueError("P-values must be between 0 and 1.")
|
|
71
|
+
|
|
72
|
+
# Handle exact 0 or 1
|
|
73
|
+
if np.any(pvalues == 0):
|
|
74
|
+
return 0.0
|
|
75
|
+
if np.any(pvalues == 1):
|
|
76
|
+
if np.all(pvalues == 1):
|
|
77
|
+
return 1.0
|
|
78
|
+
# If mixed 1s and <1s, 1s contribute nothing to the sum of tans likely,
|
|
79
|
+
# but let's follow the standard logic: tan((0.5-1)*pi) -> tan(-0.5pi) -> -inf.
|
|
80
|
+
# This implementation handles small p-values carefully but large ones (near 1)
|
|
81
|
+
# result in negative large stats which is fine (large CDF).
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
if weights is None:
|
|
85
|
+
weights = np.full(len(pvalues), 1 / len(pvalues))
|
|
86
|
+
else:
|
|
87
|
+
if len(weights) != len(pvalues):
|
|
88
|
+
raise Exception("Length of weights and p-values differs.")
|
|
89
|
+
if any(weights < 0):
|
|
90
|
+
raise Exception("All weights must be positive.")
|
|
91
|
+
weights = np.array(weights) / np.sum(weights)
|
|
92
|
+
|
|
93
|
+
is_small = pvalues < 1e-16
|
|
94
|
+
is_large = ~is_small
|
|
95
|
+
|
|
96
|
+
cct_stat = 0.0
|
|
97
|
+
if np.any(is_small):
|
|
98
|
+
cct_stat += np.sum((weights[is_small] / pvalues[is_small]) / np.pi)
|
|
99
|
+
|
|
100
|
+
if np.any(is_large):
|
|
101
|
+
cct_stat += np.sum(weights[is_large] * np.tan((0.5 - pvalues[is_large]) * np.pi))
|
|
102
|
+
|
|
103
|
+
if cct_stat > 1e15:
|
|
104
|
+
pval = (1 / cct_stat) / np.pi
|
|
105
|
+
else:
|
|
106
|
+
pval = 1 - sp.stats.cauchy.cdf(cct_stat)
|
|
107
|
+
|
|
108
|
+
return pval
|
|
109
|
+
|
|
110
|
+
def remove_outliers_IQR(data, threshold_factor=3.0):
|
|
111
|
+
"""
|
|
112
|
+
Remove outliers using IQR method on log10 p-values.
|
|
113
|
+
p_values_filtered = p_values[p_values_log < median_log + 3 * iqr_log]
|
|
114
|
+
"""
|
|
115
|
+
if len(data) == 0:
|
|
116
|
+
return data, np.array([], dtype=bool)
|
|
117
|
+
|
|
118
|
+
median = np.median(data)
|
|
119
|
+
q75, q25 = np.percentile(data, [75, 25])
|
|
120
|
+
iqr = q75 - q25
|
|
121
|
+
|
|
122
|
+
cutoff = median + threshold_factor * iqr
|
|
123
|
+
|
|
124
|
+
# Filter: keep values less than cutoff
|
|
125
|
+
mask = data < cutoff
|
|
126
|
+
data_passed = data[mask]
|
|
127
|
+
|
|
128
|
+
return data_passed, mask
|
|
129
|
+
|
|
130
|
+
def process_trait(trait, anno_data, all_data, annotation, annotation_col):
|
|
131
|
+
"""
|
|
132
|
+
Process a single trait for a given annotation: calculate Cauchy combination p-value.
|
|
133
|
+
"""
|
|
134
|
+
# Calculate significance threshold (Bonferroni correction)
|
|
135
|
+
sig_threshold = 0.05 / len(all_data)
|
|
136
|
+
|
|
137
|
+
# Get p-values for this annotation and trait
|
|
138
|
+
if trait not in anno_data.columns:
|
|
139
|
+
logger.warning(f"Trait {trait} not found in data columns.")
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
log10p = anno_data[trait].values
|
|
143
|
+
log10p, mask = remove_outliers_IQR(log10p)
|
|
144
|
+
p_values = 10 ** (-log10p) # convert from log10(p) to p
|
|
145
|
+
|
|
146
|
+
# Calculate Cauchy combination and mean/median metrics
|
|
147
|
+
if len(p_values) == 0:
|
|
148
|
+
p_cauchy_val = 1.0
|
|
149
|
+
p_median_val = 1.0
|
|
150
|
+
top_95_quantile = 0.0
|
|
151
|
+
else:
|
|
152
|
+
p_cauchy_val = _acat_test(p_values)
|
|
153
|
+
p_median_val = np.median(p_values)
|
|
154
|
+
|
|
155
|
+
# Calculate 95% quantile of -log10pvalue (referred to as top_95_quantile)
|
|
156
|
+
if len(log10p) > 0:
|
|
157
|
+
top_95_quantile = np.quantile(log10p, 0.95)
|
|
158
|
+
else:
|
|
159
|
+
top_95_quantile = 0.0
|
|
160
|
+
|
|
161
|
+
# Calculate significance statistics
|
|
162
|
+
sig_spots_in_anno = np.sum(p_values < sig_threshold)
|
|
163
|
+
total_spots_in_anno = len(p_values)
|
|
164
|
+
|
|
165
|
+
return {
|
|
166
|
+
'trait': trait,
|
|
167
|
+
'annotation_name': annotation_col,
|
|
168
|
+
'annotation': annotation,
|
|
169
|
+
'p_cauchy': p_cauchy_val,
|
|
170
|
+
'p_median': p_median_val,
|
|
171
|
+
'top_95_quantile': top_95_quantile,
|
|
172
|
+
'sig_spots': sig_spots_in_anno,
|
|
173
|
+
'total_spots': total_spots_in_anno,
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
def run_cauchy_on_dataframe(df, annotation_col, trait_cols=None, extra_group_col=None):
|
|
177
|
+
"""
|
|
178
|
+
Run Cauchy combination test on a dataframe.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
df: DataFrame containing log10(p) values and annotation column.
|
|
182
|
+
annotation_col: Name of column containing annotations.
|
|
183
|
+
trait_cols: List of trait columns. If None, uses all columns except annotation_col.
|
|
184
|
+
extra_group_col: Optional extra column to group by (e.g., 'sample_name').
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
DataFrame with results.
|
|
188
|
+
"""
|
|
189
|
+
if trait_cols is None:
|
|
190
|
+
trait_cols = [c for c in df.columns if c not in [annotation_col, extra_group_col] and c != 'spot']
|
|
191
|
+
|
|
192
|
+
all_results = []
|
|
193
|
+
|
|
194
|
+
# Define the groups
|
|
195
|
+
group_cols = [annotation_col]
|
|
196
|
+
if extra_group_col:
|
|
197
|
+
group_cols.append(extra_group_col)
|
|
198
|
+
|
|
199
|
+
# Pre-calculate groups to avoid repeated filtering
|
|
200
|
+
grouped = df.groupby(group_cols, observed=True)
|
|
201
|
+
|
|
202
|
+
for trait in trait_cols:
|
|
203
|
+
def process_one_group(group_key, trait=trait):
|
|
204
|
+
df_group = grouped.get_group(group_key)
|
|
205
|
+
|
|
206
|
+
# Usually enrichment is within the same context.
|
|
207
|
+
# If extra_group_col is sample_name, background should be other annotations in the SAME sample.
|
|
208
|
+
if extra_group_col:
|
|
209
|
+
# Group key is (anno, extra)
|
|
210
|
+
anno, extra = group_key
|
|
211
|
+
df_background = df[df[extra_group_col] == extra]
|
|
212
|
+
# We need all_data for process_trait to calculate background
|
|
213
|
+
res = process_trait(trait, df_group, df_background, anno, annotation_col)
|
|
214
|
+
if res:
|
|
215
|
+
res[extra_group_col] = extra
|
|
216
|
+
else:
|
|
217
|
+
# Group key is just (anno,)
|
|
218
|
+
anno = group_key[0] if isinstance(group_key, tuple) else group_key
|
|
219
|
+
res = process_trait(trait, df_group, df, anno, annotation_col)
|
|
220
|
+
|
|
221
|
+
return res
|
|
222
|
+
|
|
223
|
+
group_keys = list(grouped.groups.keys())
|
|
224
|
+
|
|
225
|
+
# Use rich progress bar
|
|
226
|
+
with Progress(transient=True) as progress:
|
|
227
|
+
task = progress.add_task(f"[green]Processing {trait}...", total=len(group_keys))
|
|
228
|
+
|
|
229
|
+
# Parallelize over groups
|
|
230
|
+
with ThreadPoolExecutor(max_workers=None) as executor:
|
|
231
|
+
# Submit all tasks
|
|
232
|
+
future_to_group = {executor.submit(process_one_group, g): g for g in group_keys}
|
|
233
|
+
|
|
234
|
+
for future in as_completed(future_to_group):
|
|
235
|
+
res = future.result()
|
|
236
|
+
if res is not None:
|
|
237
|
+
all_results.append(res)
|
|
238
|
+
progress.advance(task)
|
|
239
|
+
|
|
240
|
+
if not all_results:
|
|
241
|
+
return pd.DataFrame()
|
|
242
|
+
|
|
243
|
+
combined_results = pd.DataFrame(all_results)
|
|
244
|
+
sort_cols = ['trait', 'p_median']
|
|
245
|
+
if extra_group_col:
|
|
246
|
+
sort_cols = [extra_group_col] + sort_cols
|
|
247
|
+
|
|
248
|
+
combined_results.sort_values(by=sort_cols, ascending=True, inplace=True)
|
|
249
|
+
|
|
250
|
+
return combined_results
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def run_Cauchy_combination(config: CauchyCombinationConfig):
|
|
255
|
+
# 1. Discover traits and load combined LDSC results
|
|
256
|
+
traits_dict = config.ldsc_traits_result_path_dict
|
|
257
|
+
|
|
258
|
+
logger.info(f"Joining LDSC results for {len(traits_dict)} traits...")
|
|
259
|
+
df_combined = join_ldsc_results(traits_dict)
|
|
260
|
+
|
|
261
|
+
# 2. Add Annotation & Metadata
|
|
262
|
+
logger.info("------Loading annotations...")
|
|
263
|
+
latent_adata_path = config.concatenated_latent_adata_path
|
|
264
|
+
if not latent_adata_path.exists():
|
|
265
|
+
raise FileNotFoundError(f"Latent adata with annotations not found at: {latent_adata_path}")
|
|
266
|
+
|
|
267
|
+
logger.info(f"Loading metadata from {latent_adata_path}")
|
|
268
|
+
adata = sc.read_h5ad(latent_adata_path)
|
|
269
|
+
|
|
270
|
+
# Check for all requested annotation columns
|
|
271
|
+
for anno in config.annotation_list:
|
|
272
|
+
if anno not in adata.obs.columns:
|
|
273
|
+
raise ValueError(f"Annotation column '{anno}' not found in adata.obs.")
|
|
274
|
+
|
|
275
|
+
# Check for sample_name column
|
|
276
|
+
sample_col = 'sample_name'
|
|
277
|
+
if sample_col not in adata.obs.columns:
|
|
278
|
+
logger.warning("'sample_name' column not found in adata.obs. Sample-level Cauchy will be skipped.")
|
|
279
|
+
sample_col = None
|
|
280
|
+
|
|
281
|
+
# Filter to common spots
|
|
282
|
+
common_cells = np.intersect1d(df_combined.index, adata.obs_names)
|
|
283
|
+
logger.info(f"Found {len(common_cells)} common spots between LDSC results and metadata.")
|
|
284
|
+
if len(common_cells) == 0:
|
|
285
|
+
logger.warning("No common spots found. Skipping...")
|
|
286
|
+
return
|
|
287
|
+
|
|
288
|
+
df_combined = df_combined.loc[common_cells].copy()
|
|
289
|
+
|
|
290
|
+
# Add metadata columns
|
|
291
|
+
metadata_cols = config.annotation_list.copy()
|
|
292
|
+
if sample_col:
|
|
293
|
+
metadata_cols.append(sample_col)
|
|
294
|
+
|
|
295
|
+
for col in metadata_cols:
|
|
296
|
+
series = adata.obs.loc[common_cells, col]
|
|
297
|
+
if isinstance(series.dtype, pd.CategoricalDtype):
|
|
298
|
+
# Efficiently handle NaNs in categorical data by adding a 'NaN' category
|
|
299
|
+
if series.isna().any():
|
|
300
|
+
if 'NaN' not in series.cat.categories:
|
|
301
|
+
series = series.cat.add_categories('NaN')
|
|
302
|
+
df_combined[col] = series.fillna('NaN')
|
|
303
|
+
else:
|
|
304
|
+
df_combined[col] = series
|
|
305
|
+
else:
|
|
306
|
+
# For non-categorical, fillna and ensure string type for consistency
|
|
307
|
+
df_combined[col] = series.fillna('NaN').astype(str).replace(['nan', 'None'], 'NaN')
|
|
308
|
+
|
|
309
|
+
# 3. Save combined data to parquet
|
|
310
|
+
logger.info(f"Saving combined data to {config.ldsc_combined_parquet_path}")
|
|
311
|
+
df_to_save = df_combined.reset_index().rename(columns={'index': 'spot'})
|
|
312
|
+
df_to_save.to_parquet(config.ldsc_combined_parquet_path)
|
|
313
|
+
|
|
314
|
+
# 4. Process each annotation and each trait
|
|
315
|
+
trait_cols = list(traits_dict.keys())
|
|
316
|
+
for annotation_col in config.annotation_list:
|
|
317
|
+
logger.info(f"=== Processing Cauchy combination for annotation: {annotation_col} ===")
|
|
318
|
+
|
|
319
|
+
# Run Cauchy Combination (Annotation Level)
|
|
320
|
+
result_df = run_cauchy_on_dataframe(df_combined,
|
|
321
|
+
annotation_col=annotation_col,
|
|
322
|
+
trait_cols=trait_cols)
|
|
323
|
+
|
|
324
|
+
# Save results per trait
|
|
325
|
+
for trait_name in trait_cols:
|
|
326
|
+
trait_result = result_df[result_df['trait'] == trait_name]
|
|
327
|
+
output_file = config.get_cauchy_result_file(trait_name, annotation=annotation_col, all_samples=True)
|
|
328
|
+
trait_result.to_csv(output_file, index=False)
|
|
329
|
+
|
|
330
|
+
# Run Cauchy Combination (Sample-Annotation level)
|
|
331
|
+
if sample_col:
|
|
332
|
+
sample_result_df = run_cauchy_on_dataframe(df_combined,
|
|
333
|
+
annotation_col=annotation_col,
|
|
334
|
+
trait_cols=trait_cols,
|
|
335
|
+
extra_group_col=sample_col)
|
|
336
|
+
|
|
337
|
+
for trait_name in trait_cols:
|
|
338
|
+
trait_sample_result = sample_result_df[sample_result_df['trait'] == trait_name]
|
|
339
|
+
sample_output_file = config.get_cauchy_result_file(trait_name, annotation=annotation_col, all_samples=False)
|
|
340
|
+
trait_sample_result.to_csv(sample_output_file, index=False)
|
|
341
|
+
|
|
342
|
+
logger.info("Cauchy combination processing completed.")
|