gsMap 1.72.3__py3-none-any.whl → 1.73.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/GNN/train.py +1 -1
- gsMap/__init__.py +1 -1
- gsMap/cauchy_combination_test.py +5 -5
- gsMap/config.py +170 -37
- gsMap/create_slice_mean.py +33 -18
- gsMap/diagnosis.py +4 -14
- gsMap/find_latent_representation.py +19 -3
- gsMap/format_sumstats.py +6 -0
- gsMap/generate_ldscore.py +1071 -476
- gsMap/latent_to_gene.py +57 -19
- gsMap/run_all_mode.py +2 -0
- gsMap/utils/generate_r2_matrix.py +15 -294
- gsMap/utils/regression_read.py +0 -76
- gsmap-1.73.1.dist-info/METADATA +177 -0
- gsmap-1.73.1.dist-info/RECORD +31 -0
- {gsmap-1.72.3.dist-info → gsmap-1.73.1.dist-info}/WHEEL +1 -1
- {gsmap-1.72.3.dist-info → gsmap-1.73.1.dist-info/licenses}/LICENSE +6 -6
- gsmap-1.72.3.dist-info/METADATA +0 -120
- gsmap-1.72.3.dist-info/RECORD +0 -31
- {gsmap-1.72.3.dist-info → gsmap-1.73.1.dist-info}/entry_points.txt +0 -0
gsMap/GNN/train.py
CHANGED
gsMap/__init__.py
CHANGED
gsMap/cauchy_combination_test.py
CHANGED
@@ -48,16 +48,16 @@ def acat_test(pvalues, weights=None):
|
|
48
48
|
elif any(i < 0 for i in weights):
|
49
49
|
raise Exception("All weights must be positive.")
|
50
50
|
else:
|
51
|
-
weights = [i /
|
51
|
+
weights = [i / np.sum(weights) for i in weights]
|
52
52
|
|
53
53
|
pvalues = np.array(pvalues)
|
54
54
|
weights = np.array(weights)
|
55
55
|
|
56
|
-
if not any(i < 1e-
|
56
|
+
if not any(i < 1e-15 for i in pvalues):
|
57
57
|
cct_stat = sum(weights * np.tan((0.5 - pvalues) * np.pi))
|
58
58
|
else:
|
59
|
-
is_small = [i < (1e-
|
60
|
-
is_large = [i >= (1e-
|
59
|
+
is_small = [i < (1e-15) for i in pvalues]
|
60
|
+
is_large = [i >= (1e-15) for i in pvalues]
|
61
61
|
cct_stat = sum((weights[is_small] / pvalues[is_small]) / np.pi)
|
62
62
|
cct_stat += sum(weights[is_large] * np.tan((0.5 - pvalues[is_large]) * np.pi))
|
63
63
|
|
@@ -118,7 +118,7 @@ def run_Cauchy_combination(config: CauchyCombinationConfig):
|
|
118
118
|
n_removed = len(p_values) - len(p_values_filtered)
|
119
119
|
|
120
120
|
# Remove outliers if the number is reasonable
|
121
|
-
if 0 < n_removed < 20:
|
121
|
+
if 0 < n_removed < max(len(p_values) * 0.01, 20):
|
122
122
|
logger.info(f"Removed {n_removed}/{len(p_values)} outliers (median + 3IQR) for {ct}.")
|
123
123
|
p_cauchy_temp = acat_test(p_values_filtered)
|
124
124
|
else:
|
gsMap/config.py
CHANGED
@@ -1,7 +1,10 @@
|
|
1
1
|
import argparse
|
2
2
|
import dataclasses
|
3
3
|
import logging
|
4
|
+
import os
|
4
5
|
import sys
|
6
|
+
import threading
|
7
|
+
import time
|
5
8
|
from collections import OrderedDict, namedtuple
|
6
9
|
from collections.abc import Callable
|
7
10
|
from dataclasses import dataclass
|
@@ -10,6 +13,7 @@ from pathlib import Path
|
|
10
13
|
from pprint import pprint
|
11
14
|
from typing import Literal
|
12
15
|
|
16
|
+
import psutil
|
13
17
|
import pyfiglet
|
14
18
|
import yaml
|
15
19
|
|
@@ -34,9 +38,109 @@ def get_gsMap_logger(logger_name):
|
|
34
38
|
logger = get_gsMap_logger("gsMap")
|
35
39
|
|
36
40
|
|
41
|
+
def track_resource_usage(func):
|
42
|
+
"""
|
43
|
+
Decorator to track resource usage during function execution.
|
44
|
+
Logs memory usage, CPU time, and wall clock time at the end of the function.
|
45
|
+
"""
|
46
|
+
|
47
|
+
@wraps(func)
|
48
|
+
def wrapper(*args, **kwargs):
|
49
|
+
# Get the current process
|
50
|
+
process = psutil.Process(os.getpid())
|
51
|
+
|
52
|
+
# Initialize tracking variables
|
53
|
+
peak_memory = 0
|
54
|
+
cpu_percent_samples = []
|
55
|
+
stop_thread = False
|
56
|
+
|
57
|
+
# Function to monitor resource usage
|
58
|
+
def resource_monitor():
|
59
|
+
nonlocal peak_memory, cpu_percent_samples
|
60
|
+
while not stop_thread:
|
61
|
+
try:
|
62
|
+
# Get current memory usage in MB
|
63
|
+
current_memory = process.memory_info().rss / (1024 * 1024)
|
64
|
+
peak_memory = max(peak_memory, current_memory)
|
65
|
+
|
66
|
+
# Get CPU usage percentage
|
67
|
+
cpu_percent = process.cpu_percent(interval=None)
|
68
|
+
if cpu_percent > 0: # Skip initial zero readings
|
69
|
+
cpu_percent_samples.append(cpu_percent)
|
70
|
+
|
71
|
+
time.sleep(0.5)
|
72
|
+
except Exception: # Catching all exceptions here because... # noqa: BLE001
|
73
|
+
pass
|
74
|
+
|
75
|
+
# Start resource monitoring in a separate thread
|
76
|
+
monitor_thread = threading.Thread(target=resource_monitor)
|
77
|
+
monitor_thread.daemon = True
|
78
|
+
monitor_thread.start()
|
79
|
+
|
80
|
+
# Get start times
|
81
|
+
start_wall_time = time.time()
|
82
|
+
start_cpu_time = process.cpu_times().user + process.cpu_times().system
|
83
|
+
|
84
|
+
try:
|
85
|
+
# Run the actual function
|
86
|
+
result = func(*args, **kwargs)
|
87
|
+
return result
|
88
|
+
finally:
|
89
|
+
# Stop the monitoring thread
|
90
|
+
stop_thread = True
|
91
|
+
monitor_thread.join(timeout=1.0)
|
92
|
+
|
93
|
+
# Calculate elapsed times
|
94
|
+
end_wall_time = time.time()
|
95
|
+
end_cpu_time = process.cpu_times().user + process.cpu_times().system
|
96
|
+
|
97
|
+
wall_time = end_wall_time - start_wall_time
|
98
|
+
cpu_time = end_cpu_time - start_cpu_time
|
99
|
+
|
100
|
+
# Calculate average CPU percentage
|
101
|
+
avg_cpu_percent = (
|
102
|
+
sum(cpu_percent_samples) / len(cpu_percent_samples) if cpu_percent_samples else 0
|
103
|
+
)
|
104
|
+
|
105
|
+
# Format memory for display
|
106
|
+
if peak_memory < 1024:
|
107
|
+
memory_str = f"{peak_memory:.2f} MB"
|
108
|
+
else:
|
109
|
+
memory_str = f"{peak_memory / 1024:.2f} GB"
|
110
|
+
|
111
|
+
# Format times for display
|
112
|
+
if wall_time < 60:
|
113
|
+
wall_time_str = f"{wall_time:.2f} seconds"
|
114
|
+
elif wall_time < 3600:
|
115
|
+
wall_time_str = f"{wall_time / 60:.2f} minutes"
|
116
|
+
else:
|
117
|
+
wall_time_str = f"{wall_time / 3600:.2f} hours"
|
118
|
+
|
119
|
+
if cpu_time < 60:
|
120
|
+
cpu_time_str = f"{cpu_time:.2f} seconds"
|
121
|
+
elif cpu_time < 3600:
|
122
|
+
cpu_time_str = f"{cpu_time / 60:.2f} minutes"
|
123
|
+
else:
|
124
|
+
cpu_time_str = f"{cpu_time / 3600:.2f} hours"
|
125
|
+
|
126
|
+
# Log the resource usage
|
127
|
+
import logging
|
128
|
+
|
129
|
+
logger = logging.getLogger("gsMap")
|
130
|
+
logger.info("Resource usage summary:")
|
131
|
+
logger.info(f" • Wall clock time: {wall_time_str}")
|
132
|
+
logger.info(f" • CPU time: {cpu_time_str}")
|
133
|
+
logger.info(f" • Average CPU utilization: {avg_cpu_percent:.1f}%")
|
134
|
+
logger.info(f" • Peak memory usage: {memory_str}")
|
135
|
+
|
136
|
+
return wrapper
|
137
|
+
|
138
|
+
|
37
139
|
# Decorator to register functions for cli parsing
|
38
140
|
def register_cli(name: str, description: str, add_args_function: Callable) -> Callable:
|
39
141
|
def decorator(func: Callable) -> Callable:
|
142
|
+
@track_resource_usage # Use enhanced resource tracking
|
143
|
+
@wraps(func)
|
40
144
|
def wrapper(*args, **kwargs):
|
41
145
|
name.replace("_", " ")
|
42
146
|
gsMap_main_logo = pyfiglet.figlet_format(
|
@@ -50,8 +154,16 @@ def register_cli(name: str, description: str, add_args_function: Callable) -> Ca
|
|
50
154
|
print(version_number.center(80), flush=True)
|
51
155
|
print("=" * 80, flush=True)
|
52
156
|
logger.info(f"Running {name}...")
|
157
|
+
|
158
|
+
# Record start time for the log message
|
159
|
+
start_time = time.strftime("%Y-%m-%d %H:%M:%S")
|
160
|
+
logger.info(f"Started at: {start_time}")
|
161
|
+
|
53
162
|
func(*args, **kwargs)
|
54
|
-
|
163
|
+
|
164
|
+
# Record end time for the log message
|
165
|
+
end_time = time.strftime("%Y-%m-%d %H:%M:%S")
|
166
|
+
logger.info(f"Finished running {name} at: {end_time}.")
|
55
167
|
|
56
168
|
cli_function_registry[name] = subcommand(
|
57
169
|
name=name, func=wrapper, add_args_function=add_args_function, description=description
|
@@ -61,6 +173,13 @@ def register_cli(name: str, description: str, add_args_function: Callable) -> Ca
|
|
61
173
|
return decorator
|
62
174
|
|
63
175
|
|
176
|
+
def str_or_float(value):
|
177
|
+
try:
|
178
|
+
return int(value)
|
179
|
+
except ValueError:
|
180
|
+
return value
|
181
|
+
|
182
|
+
|
64
183
|
def add_shared_args(parser):
|
65
184
|
parser.add_argument(
|
66
185
|
"--workdir", type=str, required=True, help="Path to the working directory."
|
@@ -113,6 +232,9 @@ def add_find_latent_representations_args(parser):
|
|
113
232
|
action="store_true",
|
114
233
|
help="Enable hierarchical latent representation finding.",
|
115
234
|
)
|
235
|
+
parser.add_argument(
|
236
|
+
"--pearson_residuals", action="store_true", help="Using the pearson residuals."
|
237
|
+
)
|
116
238
|
|
117
239
|
|
118
240
|
def chrom_choice(value):
|
@@ -189,7 +311,7 @@ def add_generate_ldscore_args(parser):
|
|
189
311
|
help="Root path for genotype plink bfiles (.bim, .bed, .fam).",
|
190
312
|
)
|
191
313
|
parser.add_argument(
|
192
|
-
"--keep_snp_root", type=str, required=
|
314
|
+
"--keep_snp_root", type=str, required=False, help="Root path for SNP files"
|
193
315
|
)
|
194
316
|
parser.add_argument(
|
195
317
|
"--gtf_annotation_file", type=str, required=True, help="Path to GTF annotation file."
|
@@ -238,7 +360,11 @@ def add_spatial_ldsc_args(parser):
|
|
238
360
|
"--sumstats_file", type=str, required=True, help="Path to GWAS summary statistics file."
|
239
361
|
)
|
240
362
|
parser.add_argument(
|
241
|
-
"--w_file",
|
363
|
+
"--w_file",
|
364
|
+
type=str,
|
365
|
+
required=False,
|
366
|
+
default=None,
|
367
|
+
help="Path to regression weight file. If not provided, will use weights generated in the generate_ldscore step.",
|
242
368
|
)
|
243
369
|
parser.add_argument(
|
244
370
|
"--trait_name", type=str, required=True, help="Name of the trait being analyzed."
|
@@ -429,7 +555,7 @@ def add_format_sumstats_args(parser):
|
|
429
555
|
parser.add_argument(
|
430
556
|
"--n",
|
431
557
|
default=None,
|
432
|
-
type=
|
558
|
+
type=str_or_float,
|
433
559
|
help="Name of sample size column (if not a name that gsMap understands)",
|
434
560
|
)
|
435
561
|
parser.add_argument(
|
@@ -559,6 +685,9 @@ def add_run_all_mode_args(parser):
|
|
559
685
|
parser.add_argument(
|
560
686
|
"--gM_slices", type=str, default=None, help="Path to the slice mean file (optional)."
|
561
687
|
)
|
688
|
+
parser.add_argument(
|
689
|
+
"--pearson_residuals", action="store_true", help="Using the pearson residuals."
|
690
|
+
)
|
562
691
|
|
563
692
|
|
564
693
|
def ensure_path_exists(func):
|
@@ -735,6 +864,7 @@ class FindLatentRepresentationsConfig(ConfigWithAutoPaths):
|
|
735
864
|
var: bool = False
|
736
865
|
convergence_threshold: float = 1e-4
|
737
866
|
hierarchically: bool = False
|
867
|
+
pearson_residuals: bool = False
|
738
868
|
|
739
869
|
def __post_init__(self):
|
740
870
|
# self.output_hdf5_path = self.hdf5_with_latent_path
|
@@ -823,11 +953,11 @@ class GenerateLDScoreConfig(ConfigWithAutoPaths):
|
|
823
953
|
chrom: int | str
|
824
954
|
|
825
955
|
bfile_root: str
|
826
|
-
keep_snp_root: str | None
|
827
956
|
|
828
957
|
# annotation by gene distance
|
829
958
|
gtf_annotation_file: str
|
830
959
|
gene_window_size: int = 50000
|
960
|
+
keep_snp_root: str | None = None
|
831
961
|
|
832
962
|
# annotation by enhancer
|
833
963
|
enhancer_annotation_file: str = None
|
@@ -936,7 +1066,7 @@ class GenerateLDScoreConfig(ConfigWithAutoPaths):
|
|
936
1066
|
|
937
1067
|
@dataclass
|
938
1068
|
class SpatialLDSCConfig(ConfigWithAutoPaths):
|
939
|
-
w_file: str
|
1069
|
+
w_file: str | None = None
|
940
1070
|
# ldscore_save_dir: str
|
941
1071
|
use_additional_baseline_annotation: bool = True
|
942
1072
|
trait_name: str | None = None
|
@@ -986,8 +1116,19 @@ class SpatialLDSCConfig(ConfigWithAutoPaths):
|
|
986
1116
|
for sumstats_file in self.sumstats_config_dict.values():
|
987
1117
|
assert Path(sumstats_file).exists(), f"{sumstats_file} does not exist."
|
988
1118
|
|
989
|
-
#
|
990
|
-
|
1119
|
+
# Handle w_file
|
1120
|
+
if self.w_file is None:
|
1121
|
+
w_ld_dir = Path(self.ldscore_save_dir) / "w_ld"
|
1122
|
+
if w_ld_dir.exists():
|
1123
|
+
self.w_file = str(w_ld_dir / "weights.")
|
1124
|
+
logger.info(f"Using weights generated in the generate_ldscore step: {self.w_file}")
|
1125
|
+
else:
|
1126
|
+
raise ValueError(
|
1127
|
+
"No w_file provided and no weights found in generate_ldscore output. "
|
1128
|
+
"Either provide --w_file or run generate_ldscore first."
|
1129
|
+
)
|
1130
|
+
else:
|
1131
|
+
logger.info(f"Using provided weights file: {self.w_file}")
|
991
1132
|
|
992
1133
|
if self.use_additional_baseline_annotation:
|
993
1134
|
self.process_additional_baseline_annotation()
|
@@ -998,16 +1139,6 @@ class SpatialLDSCConfig(ConfigWithAutoPaths):
|
|
998
1139
|
|
999
1140
|
if not dir_exists:
|
1000
1141
|
self.use_additional_baseline_annotation = False
|
1001
|
-
# if self.use_additional_baseline_annotation:
|
1002
|
-
# logger.warning(f"additional_baseline directory is not found in {self.ldscore_save_dir}.")
|
1003
|
-
# print('''\
|
1004
|
-
# if you want to use additional baseline annotation,
|
1005
|
-
# please provide additional baseline annotation when calculating ld score.
|
1006
|
-
# ''')
|
1007
|
-
# raise FileNotFoundError(
|
1008
|
-
# f'additional_baseline directory is not found.')
|
1009
|
-
# return
|
1010
|
-
# self.use_additional_baseline_annotation = self.use_additional_baseline_annotation or True
|
1011
1142
|
else:
|
1012
1143
|
logger.info(
|
1013
1144
|
"------Additional baseline annotation is provided. It will be used with the default baseline annotation."
|
@@ -1037,7 +1168,7 @@ class CauchyCombinationConfig(ConfigWithAutoPaths):
|
|
1037
1168
|
|
1038
1169
|
def __post_init__(self):
|
1039
1170
|
if self.sample_name is not None:
|
1040
|
-
if len(self.sample_name_list) > 0:
|
1171
|
+
if self.sample_name_list and len(self.sample_name_list) > 0:
|
1041
1172
|
raise ValueError("Only one of sample_name and sample_name_list must be provided.")
|
1042
1173
|
else:
|
1043
1174
|
self.sample_name_list = [self.sample_name]
|
@@ -1106,6 +1237,10 @@ class RunAllModeConfig(ConfigWithAutoPaths):
|
|
1106
1237
|
annotation: str
|
1107
1238
|
data_layer: str = "X"
|
1108
1239
|
|
1240
|
+
# == Find Latent Representation PARAMETERS ==
|
1241
|
+
n_comps: int = 300
|
1242
|
+
pearson_residuals: bool = False
|
1243
|
+
|
1109
1244
|
# == latent 2 Gene PARAMETERS ==
|
1110
1245
|
gM_slices: str | None = None
|
1111
1246
|
latent_representation: str = None
|
@@ -1124,9 +1259,7 @@ class RunAllModeConfig(ConfigWithAutoPaths):
|
|
1124
1259
|
|
1125
1260
|
def __post_init__(self):
|
1126
1261
|
super().__post_init__()
|
1127
|
-
self.gtffile =
|
1128
|
-
f"{self.gsMap_resource_dir}/genome_annotation/gtf/gencode.v39lift37.annotation.gtf"
|
1129
|
-
)
|
1262
|
+
self.gtffile = f"{self.gsMap_resource_dir}/genome_annotation/gtf/gencode.v46lift37.basic.annotation.gtf"
|
1130
1263
|
self.bfile_root = (
|
1131
1264
|
f"{self.gsMap_resource_dir}/LD_Reference_Panel/1000G_EUR_Phase3_plink/1000G.EUR.QC"
|
1132
1265
|
)
|
@@ -1191,7 +1324,7 @@ class FormatSumstatsConfig:
|
|
1191
1324
|
se: str = None
|
1192
1325
|
p: str = None
|
1193
1326
|
frq: str = None
|
1194
|
-
n: str = None
|
1327
|
+
n: str | int = None
|
1195
1328
|
z: str = None
|
1196
1329
|
OR: str = None
|
1197
1330
|
se_OR: str = None
|
@@ -1204,9 +1337,21 @@ class FormatSumstatsConfig:
|
|
1204
1337
|
keep_chr_pos: bool = False
|
1205
1338
|
|
1206
1339
|
|
1340
|
+
@register_cli(
|
1341
|
+
name="quick_mode",
|
1342
|
+
description="Run the entire gsMap pipeline in quick mode, utilizing pre-computed weights for faster execution.",
|
1343
|
+
add_args_function=add_run_all_mode_args,
|
1344
|
+
)
|
1345
|
+
def run_all_mode_from_cli(args: argparse.Namespace):
|
1346
|
+
from gsMap.run_all_mode import run_pipeline
|
1347
|
+
|
1348
|
+
config = get_dataclass_from_parser(args, RunAllModeConfig)
|
1349
|
+
run_pipeline(config)
|
1350
|
+
|
1351
|
+
|
1207
1352
|
@register_cli(
|
1208
1353
|
name="run_find_latent_representations",
|
1209
|
-
description="Run Find_latent_representations \nFind the latent representations of each spot by running GNN
|
1354
|
+
description="Run Find_latent_representations \nFind the latent representations of each spot by running GNN",
|
1210
1355
|
add_args_function=add_find_latent_representations_args,
|
1211
1356
|
)
|
1212
1357
|
def run_find_latent_representation_from_cli(args: argparse.Namespace):
|
@@ -1278,7 +1423,7 @@ def run_Report_from_cli(args: argparse.Namespace):
|
|
1278
1423
|
|
1279
1424
|
@register_cli(
|
1280
1425
|
name="format_sumstats",
|
1281
|
-
description="Format
|
1426
|
+
description="Format GWAS summary statistics",
|
1282
1427
|
add_args_function=add_format_sumstats_args,
|
1283
1428
|
)
|
1284
1429
|
def gwas_format_from_cli(args: argparse.Namespace):
|
@@ -1288,18 +1433,6 @@ def gwas_format_from_cli(args: argparse.Namespace):
|
|
1288
1433
|
gwas_format(config)
|
1289
1434
|
|
1290
1435
|
|
1291
|
-
@register_cli(
|
1292
|
-
name="quick_mode",
|
1293
|
-
description="Run all the gsMap pipeline in quick mode",
|
1294
|
-
add_args_function=add_run_all_mode_args,
|
1295
|
-
)
|
1296
|
-
def run_all_mode_from_cli(args: argparse.Namespace):
|
1297
|
-
from gsMap.run_all_mode import run_pipeline
|
1298
|
-
|
1299
|
-
config = get_dataclass_from_parser(args, RunAllModeConfig)
|
1300
|
-
run_pipeline(config)
|
1301
|
-
|
1302
|
-
|
1303
1436
|
@register_cli(
|
1304
1437
|
name="create_slice_mean",
|
1305
1438
|
description="Create slice mean from multiple h5ad files",
|
gsMap/create_slice_mean.py
CHANGED
@@ -5,8 +5,9 @@ import anndata
|
|
5
5
|
import numpy as np
|
6
6
|
import pandas as pd
|
7
7
|
import scanpy as sc
|
8
|
+
import scipy
|
8
9
|
import zarr
|
9
|
-
from scipy.stats import rankdata
|
10
|
+
from scipy.stats import gmean, rankdata
|
10
11
|
from tqdm import tqdm
|
11
12
|
|
12
13
|
from gsMap.config import CreateSliceMeanConfig
|
@@ -22,6 +23,7 @@ def get_common_genes(h5ad_files, config: CreateSliceMeanConfig):
|
|
22
23
|
common_genes = None
|
23
24
|
for file in tqdm(h5ad_files, desc="Finding common genes"):
|
24
25
|
adata = sc.read_h5ad(file)
|
26
|
+
sc.pp.filter_genes(adata, min_cells=1)
|
25
27
|
adata.var_names_make_unique()
|
26
28
|
if common_genes is None:
|
27
29
|
common_genes = adata.var_names
|
@@ -62,22 +64,27 @@ def calculate_one_slice_mean(
|
|
62
64
|
|
63
65
|
adata = adata[:, common_genes].copy()
|
64
66
|
n_cells = adata.shape[0]
|
65
|
-
log_ranks = np.zeros((n_cells, adata.n_vars), dtype=np.float32)
|
66
|
-
# Compute log of ranks to avoid overflow when computing geometric mean
|
67
|
-
for i in tqdm(range(n_cells), desc=f"Computing log ranks for {sample_name}"):
|
68
|
-
data = adata.X[i, :].toarray().flatten()
|
69
|
-
ranks = rankdata(data, method="average")
|
70
|
-
log_ranks[i, :] = np.log(ranks) # Adding small value to avoid log(0)
|
71
67
|
|
72
|
-
|
73
|
-
|
68
|
+
if not scipy.sparse.issparse(adata.X):
|
69
|
+
adata_X = scipy.sparse.csr_matrix(adata.X)
|
70
|
+
elif isinstance(adata.X, scipy.sparse.csr_matrix):
|
71
|
+
adata_X = adata.X # Avoid copying if already CSR
|
72
|
+
else:
|
73
|
+
adata_X = adata.X.tocsr()
|
74
|
+
|
75
|
+
ranks = np.zeros((n_cells, adata.n_vars), dtype=np.float16)
|
76
|
+
for i in tqdm(range(n_cells), desc="Computing ranks per cell"):
|
77
|
+
data = adata_X[i, :].toarray().flatten()
|
78
|
+
ranks[i, :] = rankdata(data, method="average")
|
79
|
+
|
80
|
+
gM = gmean(ranks, axis=0).reshape(-1, 1)
|
74
81
|
|
75
82
|
# Calculate the expression fractio
|
76
83
|
adata_X_bool = adata.X.astype(bool)
|
77
84
|
frac = (np.asarray(adata_X_bool.sum(axis=0)).flatten()).reshape(-1, 1)
|
78
85
|
|
79
86
|
# Save to zarr group
|
80
|
-
gmean_frac = np.concatenate([
|
87
|
+
gmean_frac = np.concatenate([gM, frac], axis=1)
|
81
88
|
s1_zarr = gmean_zarr_group.array(sample_name, data=gmean_frac, chunks=None, dtype="f4")
|
82
89
|
s1_zarr.attrs["spot_number"] = adata.shape[0]
|
83
90
|
|
@@ -85,34 +92,42 @@ def calculate_one_slice_mean(
|
|
85
92
|
def merge_zarr_means(zarr_group_path, output_file, common_genes):
|
86
93
|
"""
|
87
94
|
Merge all Zarr arrays into a weighted geometric mean and save to a Parquet file.
|
88
|
-
Instead of calculating the mean, it sums the logs and applies the exponential.
|
89
95
|
"""
|
90
96
|
gmean_zarr_group = zarr.open(zarr_group_path, mode="a")
|
91
|
-
|
97
|
+
|
98
|
+
sample_gmeans = []
|
99
|
+
sample_weights = []
|
92
100
|
frac_sum = None
|
93
101
|
total_spot_number = 0
|
102
|
+
|
103
|
+
# Collect all geometric means and their weights (spot numbers)
|
94
104
|
for key in tqdm(gmean_zarr_group.array_keys(), desc="Merging Zarr arrays"):
|
95
105
|
s1 = gmean_zarr_group[key]
|
96
106
|
s1_array_gmean = s1[:][:, 0]
|
97
107
|
s1_array_frac = s1[:][:, 1]
|
98
108
|
n = s1.attrs["spot_number"]
|
99
109
|
|
100
|
-
|
101
|
-
|
110
|
+
sample_gmeans.append(s1_array_gmean)
|
111
|
+
sample_weights.append(n)
|
112
|
+
|
113
|
+
if frac_sum is None:
|
102
114
|
frac_sum = s1_array_frac
|
103
115
|
else:
|
104
|
-
log_sum += np.log(s1_array_gmean) * n
|
105
116
|
frac_sum += s1_array_frac
|
106
117
|
|
107
118
|
total_spot_number += n
|
108
119
|
|
109
|
-
#
|
110
|
-
|
120
|
+
# Convert to arrays
|
121
|
+
sample_gmeans = np.array(sample_gmeans)
|
122
|
+
sample_weights = np.array(sample_weights)
|
123
|
+
|
124
|
+
final_gmean = gmean(sample_gmeans, axis=0, weights=sample_weights[:, np.newaxis])
|
125
|
+
|
111
126
|
final_frac = frac_sum / total_spot_number
|
112
127
|
|
113
128
|
# Save the final mean to a Parquet file
|
114
129
|
gene_names = common_genes
|
115
|
-
final_df = pd.DataFrame({"gene": gene_names, "G_Mean":
|
130
|
+
final_df = pd.DataFrame({"gene": gene_names, "G_Mean": final_gmean, "frac": final_frac})
|
116
131
|
final_df.set_index("gene", inplace=True)
|
117
132
|
final_df.to_parquet(output_file)
|
118
133
|
return final_df
|
gsMap/diagnosis.py
CHANGED
@@ -49,7 +49,10 @@ def compute_gene_diagnostic_info(config: DiagnosisConfig):
|
|
49
49
|
|
50
50
|
# Align marker scores with trait LDSC results
|
51
51
|
mk_score = mk_score.loc[trait_ldsc_result.index]
|
52
|
-
|
52
|
+
|
53
|
+
# Filter out genes with no variation
|
54
|
+
non_zero_std_cols = mk_score.columns[mk_score.std() > 0]
|
55
|
+
mk_score = mk_score.loc[:, non_zero_std_cols]
|
53
56
|
|
54
57
|
logger.info("Calculating correlation between gene marker scores and trait logp-values...")
|
55
58
|
corr = mk_score.corrwith(trait_ldsc_result["logp"])
|
@@ -88,19 +91,6 @@ def compute_gene_diagnostic_info(config: DiagnosisConfig):
|
|
88
91
|
gene_diagnostic_info.to_csv(gene_diagnostic_info_save_path, index=False)
|
89
92
|
logger.info(f"Gene diagnostic information saved to {gene_diagnostic_info_save_path}.")
|
90
93
|
|
91
|
-
# TODO: A new script is needed to save the gene diagnostic info to adata.var and trait_ldsc_result to adata.obs when running multiple traits
|
92
|
-
# # Save to adata.var with the trait_name prefix
|
93
|
-
# logger.info('Saving gene diagnostic info to adata.var...')
|
94
|
-
# gene_diagnostic_info.set_index('Gene', inplace=True) # Use 'Gene' as the index to align with adata.var
|
95
|
-
# adata.var[f'{config.trait_name}_Annotation'] = gene_diagnostic_info['Annotation']
|
96
|
-
# adata.var[f'{config.trait_name}_Median_GSS'] = gene_diagnostic_info['Median_GSS']
|
97
|
-
# adata.var[f'{config.trait_name}_PCC'] = gene_diagnostic_info['PCC']
|
98
|
-
#
|
99
|
-
# # Save trait_ldsc_result to adata.obs
|
100
|
-
# logger.info(f'Saving trait LDSC results to adata.obs as gsMap_{config.trait_name}_p_value...')
|
101
|
-
# adata.obs[f'gsMap_{config.trait_name}_p_value'] = trait_ldsc_result['p']
|
102
|
-
# adata.write(config.hdf5_with_latent_path, )
|
103
|
-
|
104
94
|
return gene_diagnostic_info.reset_index()
|
105
95
|
|
106
96
|
|
@@ -38,7 +38,7 @@ def preprocess_data(adata, params):
|
|
38
38
|
|
39
39
|
if params.data_layer in adata.layers.keys():
|
40
40
|
logger.info(f"Using data layer: {params.data_layer}...")
|
41
|
-
adata.X = adata.layers[params.data_layer]
|
41
|
+
adata.X = adata.layers[params.data_layer].copy()
|
42
42
|
elif params.data_layer == "X":
|
43
43
|
logger.info(f"Using data layer: {params.data_layer}...")
|
44
44
|
if adata.X.dtype == "float32" or adata.X.dtype == "float64":
|
@@ -50,6 +50,15 @@ def preprocess_data(adata, params):
|
|
50
50
|
# HVGs based on count
|
51
51
|
logger.info("Dealing with count data...")
|
52
52
|
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=params.feat_cell)
|
53
|
+
|
54
|
+
# Get the pearson residuals
|
55
|
+
if params.pearson_residuals:
|
56
|
+
sc.experimental.pp.normalize_pearson_residuals(adata, inplace=False)
|
57
|
+
pearson_residuals = sc.experimental.pp.normalize_pearson_residuals(
|
58
|
+
adata, inplace=False, clip=10
|
59
|
+
)
|
60
|
+
adata.layers["pearson_residuals"] = pearson_residuals["X"]
|
61
|
+
|
53
62
|
# Normalize the data
|
54
63
|
sc.pp.normalize_total(adata, target_sum=1e4)
|
55
64
|
sc.pp.log1p(adata)
|
@@ -64,8 +73,13 @@ class LatentRepresentationFinder:
|
|
64
73
|
def __init__(self, adata, args: FindLatentRepresentationsConfig):
|
65
74
|
self.params = args
|
66
75
|
|
67
|
-
|
68
|
-
|
76
|
+
if "pearson_residuals" in adata.layers:
|
77
|
+
self.expression_array = (
|
78
|
+
adata[:, adata.var.highly_variable].layers["pearson_residuals"].copy()
|
79
|
+
)
|
80
|
+
else:
|
81
|
+
self.expression_array = adata[:, adata.var.highly_variable].X.copy()
|
82
|
+
self.expression_array = sc.pp.scale(self.expression_array, max_value=10)
|
69
83
|
|
70
84
|
# Construct the neighboring graph
|
71
85
|
self.graph_dict = construct_adjacency_matrix(adata, self.params)
|
@@ -103,6 +117,8 @@ def run_find_latent_representation(args: FindLatentRepresentationsConfig):
|
|
103
117
|
# Load the ST data
|
104
118
|
logger.info(f"Loading ST data of {args.sample_name}...")
|
105
119
|
adata = sc.read_h5ad(args.input_hdf5_path)
|
120
|
+
sc.pp.filter_genes(adata, min_cells=1)
|
121
|
+
|
106
122
|
logger.info(f"The ST data contains {adata.shape[0]} cells, {adata.shape[1]} genes.")
|
107
123
|
|
108
124
|
# Load the cell type annotation
|
gsMap/format_sumstats.py
CHANGED
@@ -409,6 +409,12 @@ def gwas_format(config: FormatSumstatsConfig):
|
|
409
409
|
compression=compression_type,
|
410
410
|
na_values=[".", "NA"],
|
411
411
|
)
|
412
|
+
|
413
|
+
if isinstance(config.n, int | float):
|
414
|
+
logger.info(f"Set the sample size of gwas data as {config.n}.")
|
415
|
+
gwas["N"] = config.n
|
416
|
+
config.n = "N"
|
417
|
+
|
412
418
|
logger.info(f"Read {len(gwas)} SNPs from {config.sumstats}.")
|
413
419
|
|
414
420
|
# Check name and format
|