gsMap 1.72.3__py3-none-any.whl → 1.73.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gsMap/GNN/train.py CHANGED
@@ -17,7 +17,7 @@ def reconstruction_loss(decoded, x):
17
17
 
18
18
  def label_loss(pred_label, true_label):
19
19
  """Compute the cross-entropy loss."""
20
- return F.cross_entropy(pred_label, true_label)
20
+ return F.cross_entropy(pred_label, true_label.long())
21
21
 
22
22
 
23
23
  class ModelTrainer:
gsMap/__init__.py CHANGED
@@ -2,4 +2,4 @@
2
2
  Genetics-informed pathogenic spatial mapping
3
3
  """
4
4
 
5
- __version__ = "1.72.3"
5
+ __version__ = "1.73.1"
@@ -48,16 +48,16 @@ def acat_test(pvalues, weights=None):
48
48
  elif any(i < 0 for i in weights):
49
49
  raise Exception("All weights must be positive.")
50
50
  else:
51
- weights = [i / len(weights) for i in weights]
51
+ weights = [i / np.sum(weights) for i in weights]
52
52
 
53
53
  pvalues = np.array(pvalues)
54
54
  weights = np.array(weights)
55
55
 
56
- if not any(i < 1e-16 for i in pvalues):
56
+ if not any(i < 1e-15 for i in pvalues):
57
57
  cct_stat = sum(weights * np.tan((0.5 - pvalues) * np.pi))
58
58
  else:
59
- is_small = [i < (1e-16) for i in pvalues]
60
- is_large = [i >= (1e-16) for i in pvalues]
59
+ is_small = [i < (1e-15) for i in pvalues]
60
+ is_large = [i >= (1e-15) for i in pvalues]
61
61
  cct_stat = sum((weights[is_small] / pvalues[is_small]) / np.pi)
62
62
  cct_stat += sum(weights[is_large] * np.tan((0.5 - pvalues[is_large]) * np.pi))
63
63
 
@@ -118,7 +118,7 @@ def run_Cauchy_combination(config: CauchyCombinationConfig):
118
118
  n_removed = len(p_values) - len(p_values_filtered)
119
119
 
120
120
  # Remove outliers if the number is reasonable
121
- if 0 < n_removed < 20:
121
+ if 0 < n_removed < max(len(p_values) * 0.01, 20):
122
122
  logger.info(f"Removed {n_removed}/{len(p_values)} outliers (median + 3IQR) for {ct}.")
123
123
  p_cauchy_temp = acat_test(p_values_filtered)
124
124
  else:
gsMap/config.py CHANGED
@@ -1,7 +1,10 @@
1
1
  import argparse
2
2
  import dataclasses
3
3
  import logging
4
+ import os
4
5
  import sys
6
+ import threading
7
+ import time
5
8
  from collections import OrderedDict, namedtuple
6
9
  from collections.abc import Callable
7
10
  from dataclasses import dataclass
@@ -10,6 +13,7 @@ from pathlib import Path
10
13
  from pprint import pprint
11
14
  from typing import Literal
12
15
 
16
+ import psutil
13
17
  import pyfiglet
14
18
  import yaml
15
19
 
@@ -34,9 +38,109 @@ def get_gsMap_logger(logger_name):
34
38
  logger = get_gsMap_logger("gsMap")
35
39
 
36
40
 
41
+ def track_resource_usage(func):
42
+ """
43
+ Decorator to track resource usage during function execution.
44
+ Logs memory usage, CPU time, and wall clock time at the end of the function.
45
+ """
46
+
47
+ @wraps(func)
48
+ def wrapper(*args, **kwargs):
49
+ # Get the current process
50
+ process = psutil.Process(os.getpid())
51
+
52
+ # Initialize tracking variables
53
+ peak_memory = 0
54
+ cpu_percent_samples = []
55
+ stop_thread = False
56
+
57
+ # Function to monitor resource usage
58
+ def resource_monitor():
59
+ nonlocal peak_memory, cpu_percent_samples
60
+ while not stop_thread:
61
+ try:
62
+ # Get current memory usage in MB
63
+ current_memory = process.memory_info().rss / (1024 * 1024)
64
+ peak_memory = max(peak_memory, current_memory)
65
+
66
+ # Get CPU usage percentage
67
+ cpu_percent = process.cpu_percent(interval=None)
68
+ if cpu_percent > 0: # Skip initial zero readings
69
+ cpu_percent_samples.append(cpu_percent)
70
+
71
+ time.sleep(0.5)
72
+ except Exception: # Catching all exceptions here because... # noqa: BLE001
73
+ pass
74
+
75
+ # Start resource monitoring in a separate thread
76
+ monitor_thread = threading.Thread(target=resource_monitor)
77
+ monitor_thread.daemon = True
78
+ monitor_thread.start()
79
+
80
+ # Get start times
81
+ start_wall_time = time.time()
82
+ start_cpu_time = process.cpu_times().user + process.cpu_times().system
83
+
84
+ try:
85
+ # Run the actual function
86
+ result = func(*args, **kwargs)
87
+ return result
88
+ finally:
89
+ # Stop the monitoring thread
90
+ stop_thread = True
91
+ monitor_thread.join(timeout=1.0)
92
+
93
+ # Calculate elapsed times
94
+ end_wall_time = time.time()
95
+ end_cpu_time = process.cpu_times().user + process.cpu_times().system
96
+
97
+ wall_time = end_wall_time - start_wall_time
98
+ cpu_time = end_cpu_time - start_cpu_time
99
+
100
+ # Calculate average CPU percentage
101
+ avg_cpu_percent = (
102
+ sum(cpu_percent_samples) / len(cpu_percent_samples) if cpu_percent_samples else 0
103
+ )
104
+
105
+ # Format memory for display
106
+ if peak_memory < 1024:
107
+ memory_str = f"{peak_memory:.2f} MB"
108
+ else:
109
+ memory_str = f"{peak_memory / 1024:.2f} GB"
110
+
111
+ # Format times for display
112
+ if wall_time < 60:
113
+ wall_time_str = f"{wall_time:.2f} seconds"
114
+ elif wall_time < 3600:
115
+ wall_time_str = f"{wall_time / 60:.2f} minutes"
116
+ else:
117
+ wall_time_str = f"{wall_time / 3600:.2f} hours"
118
+
119
+ if cpu_time < 60:
120
+ cpu_time_str = f"{cpu_time:.2f} seconds"
121
+ elif cpu_time < 3600:
122
+ cpu_time_str = f"{cpu_time / 60:.2f} minutes"
123
+ else:
124
+ cpu_time_str = f"{cpu_time / 3600:.2f} hours"
125
+
126
+ # Log the resource usage
127
+ import logging
128
+
129
+ logger = logging.getLogger("gsMap")
130
+ logger.info("Resource usage summary:")
131
+ logger.info(f" • Wall clock time: {wall_time_str}")
132
+ logger.info(f" • CPU time: {cpu_time_str}")
133
+ logger.info(f" • Average CPU utilization: {avg_cpu_percent:.1f}%")
134
+ logger.info(f" • Peak memory usage: {memory_str}")
135
+
136
+ return wrapper
137
+
138
+
37
139
  # Decorator to register functions for cli parsing
38
140
  def register_cli(name: str, description: str, add_args_function: Callable) -> Callable:
39
141
  def decorator(func: Callable) -> Callable:
142
+ @track_resource_usage # Use enhanced resource tracking
143
+ @wraps(func)
40
144
  def wrapper(*args, **kwargs):
41
145
  name.replace("_", " ")
42
146
  gsMap_main_logo = pyfiglet.figlet_format(
@@ -50,8 +154,16 @@ def register_cli(name: str, description: str, add_args_function: Callable) -> Ca
50
154
  print(version_number.center(80), flush=True)
51
155
  print("=" * 80, flush=True)
52
156
  logger.info(f"Running {name}...")
157
+
158
+ # Record start time for the log message
159
+ start_time = time.strftime("%Y-%m-%d %H:%M:%S")
160
+ logger.info(f"Started at: {start_time}")
161
+
53
162
  func(*args, **kwargs)
54
- logger.info(f"Finished running {name}.")
163
+
164
+ # Record end time for the log message
165
+ end_time = time.strftime("%Y-%m-%d %H:%M:%S")
166
+ logger.info(f"Finished running {name} at: {end_time}.")
55
167
 
56
168
  cli_function_registry[name] = subcommand(
57
169
  name=name, func=wrapper, add_args_function=add_args_function, description=description
@@ -61,6 +173,13 @@ def register_cli(name: str, description: str, add_args_function: Callable) -> Ca
61
173
  return decorator
62
174
 
63
175
 
176
+ def str_or_float(value):
177
+ try:
178
+ return int(value)
179
+ except ValueError:
180
+ return value
181
+
182
+
64
183
  def add_shared_args(parser):
65
184
  parser.add_argument(
66
185
  "--workdir", type=str, required=True, help="Path to the working directory."
@@ -113,6 +232,9 @@ def add_find_latent_representations_args(parser):
113
232
  action="store_true",
114
233
  help="Enable hierarchical latent representation finding.",
115
234
  )
235
+ parser.add_argument(
236
+ "--pearson_residuals", action="store_true", help="Using the pearson residuals."
237
+ )
116
238
 
117
239
 
118
240
  def chrom_choice(value):
@@ -189,7 +311,7 @@ def add_generate_ldscore_args(parser):
189
311
  help="Root path for genotype plink bfiles (.bim, .bed, .fam).",
190
312
  )
191
313
  parser.add_argument(
192
- "--keep_snp_root", type=str, required=True, help="Root path for SNP files."
314
+ "--keep_snp_root", type=str, required=False, help="Root path for SNP files"
193
315
  )
194
316
  parser.add_argument(
195
317
  "--gtf_annotation_file", type=str, required=True, help="Path to GTF annotation file."
@@ -238,7 +360,11 @@ def add_spatial_ldsc_args(parser):
238
360
  "--sumstats_file", type=str, required=True, help="Path to GWAS summary statistics file."
239
361
  )
240
362
  parser.add_argument(
241
- "--w_file", type=str, required=True, help="Path to regression weight file."
363
+ "--w_file",
364
+ type=str,
365
+ required=False,
366
+ default=None,
367
+ help="Path to regression weight file. If not provided, will use weights generated in the generate_ldscore step.",
242
368
  )
243
369
  parser.add_argument(
244
370
  "--trait_name", type=str, required=True, help="Name of the trait being analyzed."
@@ -429,7 +555,7 @@ def add_format_sumstats_args(parser):
429
555
  parser.add_argument(
430
556
  "--n",
431
557
  default=None,
432
- type=str,
558
+ type=str_or_float,
433
559
  help="Name of sample size column (if not a name that gsMap understands)",
434
560
  )
435
561
  parser.add_argument(
@@ -559,6 +685,9 @@ def add_run_all_mode_args(parser):
559
685
  parser.add_argument(
560
686
  "--gM_slices", type=str, default=None, help="Path to the slice mean file (optional)."
561
687
  )
688
+ parser.add_argument(
689
+ "--pearson_residuals", action="store_true", help="Using the pearson residuals."
690
+ )
562
691
 
563
692
 
564
693
  def ensure_path_exists(func):
@@ -735,6 +864,7 @@ class FindLatentRepresentationsConfig(ConfigWithAutoPaths):
735
864
  var: bool = False
736
865
  convergence_threshold: float = 1e-4
737
866
  hierarchically: bool = False
867
+ pearson_residuals: bool = False
738
868
 
739
869
  def __post_init__(self):
740
870
  # self.output_hdf5_path = self.hdf5_with_latent_path
@@ -823,11 +953,11 @@ class GenerateLDScoreConfig(ConfigWithAutoPaths):
823
953
  chrom: int | str
824
954
 
825
955
  bfile_root: str
826
- keep_snp_root: str | None
827
956
 
828
957
  # annotation by gene distance
829
958
  gtf_annotation_file: str
830
959
  gene_window_size: int = 50000
960
+ keep_snp_root: str | None = None
831
961
 
832
962
  # annotation by enhancer
833
963
  enhancer_annotation_file: str = None
@@ -936,7 +1066,7 @@ class GenerateLDScoreConfig(ConfigWithAutoPaths):
936
1066
 
937
1067
  @dataclass
938
1068
  class SpatialLDSCConfig(ConfigWithAutoPaths):
939
- w_file: str
1069
+ w_file: str | None = None
940
1070
  # ldscore_save_dir: str
941
1071
  use_additional_baseline_annotation: bool = True
942
1072
  trait_name: str | None = None
@@ -986,8 +1116,19 @@ class SpatialLDSCConfig(ConfigWithAutoPaths):
986
1116
  for sumstats_file in self.sumstats_config_dict.values():
987
1117
  assert Path(sumstats_file).exists(), f"{sumstats_file} does not exist."
988
1118
 
989
- # check if additional baseline annotation is exist
990
- # self.use_additional_baseline_annotation = False
1119
+ # Handle w_file
1120
+ if self.w_file is None:
1121
+ w_ld_dir = Path(self.ldscore_save_dir) / "w_ld"
1122
+ if w_ld_dir.exists():
1123
+ self.w_file = str(w_ld_dir / "weights.")
1124
+ logger.info(f"Using weights generated in the generate_ldscore step: {self.w_file}")
1125
+ else:
1126
+ raise ValueError(
1127
+ "No w_file provided and no weights found in generate_ldscore output. "
1128
+ "Either provide --w_file or run generate_ldscore first."
1129
+ )
1130
+ else:
1131
+ logger.info(f"Using provided weights file: {self.w_file}")
991
1132
 
992
1133
  if self.use_additional_baseline_annotation:
993
1134
  self.process_additional_baseline_annotation()
@@ -998,16 +1139,6 @@ class SpatialLDSCConfig(ConfigWithAutoPaths):
998
1139
 
999
1140
  if not dir_exists:
1000
1141
  self.use_additional_baseline_annotation = False
1001
- # if self.use_additional_baseline_annotation:
1002
- # logger.warning(f"additional_baseline directory is not found in {self.ldscore_save_dir}.")
1003
- # print('''\
1004
- # if you want to use additional baseline annotation,
1005
- # please provide additional baseline annotation when calculating ld score.
1006
- # ''')
1007
- # raise FileNotFoundError(
1008
- # f'additional_baseline directory is not found.')
1009
- # return
1010
- # self.use_additional_baseline_annotation = self.use_additional_baseline_annotation or True
1011
1142
  else:
1012
1143
  logger.info(
1013
1144
  "------Additional baseline annotation is provided. It will be used with the default baseline annotation."
@@ -1037,7 +1168,7 @@ class CauchyCombinationConfig(ConfigWithAutoPaths):
1037
1168
 
1038
1169
  def __post_init__(self):
1039
1170
  if self.sample_name is not None:
1040
- if len(self.sample_name_list) > 0:
1171
+ if self.sample_name_list and len(self.sample_name_list) > 0:
1041
1172
  raise ValueError("Only one of sample_name and sample_name_list must be provided.")
1042
1173
  else:
1043
1174
  self.sample_name_list = [self.sample_name]
@@ -1106,6 +1237,10 @@ class RunAllModeConfig(ConfigWithAutoPaths):
1106
1237
  annotation: str
1107
1238
  data_layer: str = "X"
1108
1239
 
1240
+ # == Find Latent Representation PARAMETERS ==
1241
+ n_comps: int = 300
1242
+ pearson_residuals: bool = False
1243
+
1109
1244
  # == latent 2 Gene PARAMETERS ==
1110
1245
  gM_slices: str | None = None
1111
1246
  latent_representation: str = None
@@ -1124,9 +1259,7 @@ class RunAllModeConfig(ConfigWithAutoPaths):
1124
1259
 
1125
1260
  def __post_init__(self):
1126
1261
  super().__post_init__()
1127
- self.gtffile = (
1128
- f"{self.gsMap_resource_dir}/genome_annotation/gtf/gencode.v39lift37.annotation.gtf"
1129
- )
1262
+ self.gtffile = f"{self.gsMap_resource_dir}/genome_annotation/gtf/gencode.v46lift37.basic.annotation.gtf"
1130
1263
  self.bfile_root = (
1131
1264
  f"{self.gsMap_resource_dir}/LD_Reference_Panel/1000G_EUR_Phase3_plink/1000G.EUR.QC"
1132
1265
  )
@@ -1191,7 +1324,7 @@ class FormatSumstatsConfig:
1191
1324
  se: str = None
1192
1325
  p: str = None
1193
1326
  frq: str = None
1194
- n: str = None
1327
+ n: str | int = None
1195
1328
  z: str = None
1196
1329
  OR: str = None
1197
1330
  se_OR: str = None
@@ -1204,9 +1337,21 @@ class FormatSumstatsConfig:
1204
1337
  keep_chr_pos: bool = False
1205
1338
 
1206
1339
 
1340
+ @register_cli(
1341
+ name="quick_mode",
1342
+ description="Run the entire gsMap pipeline in quick mode, utilizing pre-computed weights for faster execution.",
1343
+ add_args_function=add_run_all_mode_args,
1344
+ )
1345
+ def run_all_mode_from_cli(args: argparse.Namespace):
1346
+ from gsMap.run_all_mode import run_pipeline
1347
+
1348
+ config = get_dataclass_from_parser(args, RunAllModeConfig)
1349
+ run_pipeline(config)
1350
+
1351
+
1207
1352
  @register_cli(
1208
1353
  name="run_find_latent_representations",
1209
- description="Run Find_latent_representations \nFind the latent representations of each spot by running GNN-VAE",
1354
+ description="Run Find_latent_representations \nFind the latent representations of each spot by running GNN",
1210
1355
  add_args_function=add_find_latent_representations_args,
1211
1356
  )
1212
1357
  def run_find_latent_representation_from_cli(args: argparse.Namespace):
@@ -1278,7 +1423,7 @@ def run_Report_from_cli(args: argparse.Namespace):
1278
1423
 
1279
1424
  @register_cli(
1280
1425
  name="format_sumstats",
1281
- description="Format gwas summary statistics",
1426
+ description="Format GWAS summary statistics",
1282
1427
  add_args_function=add_format_sumstats_args,
1283
1428
  )
1284
1429
  def gwas_format_from_cli(args: argparse.Namespace):
@@ -1288,18 +1433,6 @@ def gwas_format_from_cli(args: argparse.Namespace):
1288
1433
  gwas_format(config)
1289
1434
 
1290
1435
 
1291
- @register_cli(
1292
- name="quick_mode",
1293
- description="Run all the gsMap pipeline in quick mode",
1294
- add_args_function=add_run_all_mode_args,
1295
- )
1296
- def run_all_mode_from_cli(args: argparse.Namespace):
1297
- from gsMap.run_all_mode import run_pipeline
1298
-
1299
- config = get_dataclass_from_parser(args, RunAllModeConfig)
1300
- run_pipeline(config)
1301
-
1302
-
1303
1436
  @register_cli(
1304
1437
  name="create_slice_mean",
1305
1438
  description="Create slice mean from multiple h5ad files",
@@ -5,8 +5,9 @@ import anndata
5
5
  import numpy as np
6
6
  import pandas as pd
7
7
  import scanpy as sc
8
+ import scipy
8
9
  import zarr
9
- from scipy.stats import rankdata
10
+ from scipy.stats import gmean, rankdata
10
11
  from tqdm import tqdm
11
12
 
12
13
  from gsMap.config import CreateSliceMeanConfig
@@ -22,6 +23,7 @@ def get_common_genes(h5ad_files, config: CreateSliceMeanConfig):
22
23
  common_genes = None
23
24
  for file in tqdm(h5ad_files, desc="Finding common genes"):
24
25
  adata = sc.read_h5ad(file)
26
+ sc.pp.filter_genes(adata, min_cells=1)
25
27
  adata.var_names_make_unique()
26
28
  if common_genes is None:
27
29
  common_genes = adata.var_names
@@ -62,22 +64,27 @@ def calculate_one_slice_mean(
62
64
 
63
65
  adata = adata[:, common_genes].copy()
64
66
  n_cells = adata.shape[0]
65
- log_ranks = np.zeros((n_cells, adata.n_vars), dtype=np.float32)
66
- # Compute log of ranks to avoid overflow when computing geometric mean
67
- for i in tqdm(range(n_cells), desc=f"Computing log ranks for {sample_name}"):
68
- data = adata.X[i, :].toarray().flatten()
69
- ranks = rankdata(data, method="average")
70
- log_ranks[i, :] = np.log(ranks) # Adding small value to avoid log(0)
71
67
 
72
- # Calculate geometric mean via log trick: exp(mean(log(values)))
73
- gmean = (np.exp(np.mean(log_ranks, axis=0))).reshape(-1, 1)
68
+ if not scipy.sparse.issparse(adata.X):
69
+ adata_X = scipy.sparse.csr_matrix(adata.X)
70
+ elif isinstance(adata.X, scipy.sparse.csr_matrix):
71
+ adata_X = adata.X # Avoid copying if already CSR
72
+ else:
73
+ adata_X = adata.X.tocsr()
74
+
75
+ ranks = np.zeros((n_cells, adata.n_vars), dtype=np.float16)
76
+ for i in tqdm(range(n_cells), desc="Computing ranks per cell"):
77
+ data = adata_X[i, :].toarray().flatten()
78
+ ranks[i, :] = rankdata(data, method="average")
79
+
80
+ gM = gmean(ranks, axis=0).reshape(-1, 1)
74
81
 
75
82
  # Calculate the expression fractio
76
83
  adata_X_bool = adata.X.astype(bool)
77
84
  frac = (np.asarray(adata_X_bool.sum(axis=0)).flatten()).reshape(-1, 1)
78
85
 
79
86
  # Save to zarr group
80
- gmean_frac = np.concatenate([gmean, frac], axis=1)
87
+ gmean_frac = np.concatenate([gM, frac], axis=1)
81
88
  s1_zarr = gmean_zarr_group.array(sample_name, data=gmean_frac, chunks=None, dtype="f4")
82
89
  s1_zarr.attrs["spot_number"] = adata.shape[0]
83
90
 
@@ -85,34 +92,42 @@ def calculate_one_slice_mean(
85
92
  def merge_zarr_means(zarr_group_path, output_file, common_genes):
86
93
  """
87
94
  Merge all Zarr arrays into a weighted geometric mean and save to a Parquet file.
88
- Instead of calculating the mean, it sums the logs and applies the exponential.
89
95
  """
90
96
  gmean_zarr_group = zarr.open(zarr_group_path, mode="a")
91
- log_sum = None
97
+
98
+ sample_gmeans = []
99
+ sample_weights = []
92
100
  frac_sum = None
93
101
  total_spot_number = 0
102
+
103
+ # Collect all geometric means and their weights (spot numbers)
94
104
  for key in tqdm(gmean_zarr_group.array_keys(), desc="Merging Zarr arrays"):
95
105
  s1 = gmean_zarr_group[key]
96
106
  s1_array_gmean = s1[:][:, 0]
97
107
  s1_array_frac = s1[:][:, 1]
98
108
  n = s1.attrs["spot_number"]
99
109
 
100
- if log_sum is None:
101
- log_sum = np.log(s1_array_gmean) * n
110
+ sample_gmeans.append(s1_array_gmean)
111
+ sample_weights.append(n)
112
+
113
+ if frac_sum is None:
102
114
  frac_sum = s1_array_frac
103
115
  else:
104
- log_sum += np.log(s1_array_gmean) * n
105
116
  frac_sum += s1_array_frac
106
117
 
107
118
  total_spot_number += n
108
119
 
109
- # Apply the geometric mean via exponentiation of the averaged logs
110
- final_mean = np.exp(log_sum / total_spot_number)
120
+ # Convert to arrays
121
+ sample_gmeans = np.array(sample_gmeans)
122
+ sample_weights = np.array(sample_weights)
123
+
124
+ final_gmean = gmean(sample_gmeans, axis=0, weights=sample_weights[:, np.newaxis])
125
+
111
126
  final_frac = frac_sum / total_spot_number
112
127
 
113
128
  # Save the final mean to a Parquet file
114
129
  gene_names = common_genes
115
- final_df = pd.DataFrame({"gene": gene_names, "G_Mean": final_mean, "frac": final_frac})
130
+ final_df = pd.DataFrame({"gene": gene_names, "G_Mean": final_gmean, "frac": final_frac})
116
131
  final_df.set_index("gene", inplace=True)
117
132
  final_df.to_parquet(output_file)
118
133
  return final_df
gsMap/diagnosis.py CHANGED
@@ -49,7 +49,10 @@ def compute_gene_diagnostic_info(config: DiagnosisConfig):
49
49
 
50
50
  # Align marker scores with trait LDSC results
51
51
  mk_score = mk_score.loc[trait_ldsc_result.index]
52
- mk_score = mk_score.loc[:, mk_score.sum(axis=0) != 0]
52
+
53
+ # Filter out genes with no variation
54
+ non_zero_std_cols = mk_score.columns[mk_score.std() > 0]
55
+ mk_score = mk_score.loc[:, non_zero_std_cols]
53
56
 
54
57
  logger.info("Calculating correlation between gene marker scores and trait logp-values...")
55
58
  corr = mk_score.corrwith(trait_ldsc_result["logp"])
@@ -88,19 +91,6 @@ def compute_gene_diagnostic_info(config: DiagnosisConfig):
88
91
  gene_diagnostic_info.to_csv(gene_diagnostic_info_save_path, index=False)
89
92
  logger.info(f"Gene diagnostic information saved to {gene_diagnostic_info_save_path}.")
90
93
 
91
- # TODO: A new script is needed to save the gene diagnostic info to adata.var and trait_ldsc_result to adata.obs when running multiple traits
92
- # # Save to adata.var with the trait_name prefix
93
- # logger.info('Saving gene diagnostic info to adata.var...')
94
- # gene_diagnostic_info.set_index('Gene', inplace=True) # Use 'Gene' as the index to align with adata.var
95
- # adata.var[f'{config.trait_name}_Annotation'] = gene_diagnostic_info['Annotation']
96
- # adata.var[f'{config.trait_name}_Median_GSS'] = gene_diagnostic_info['Median_GSS']
97
- # adata.var[f'{config.trait_name}_PCC'] = gene_diagnostic_info['PCC']
98
- #
99
- # # Save trait_ldsc_result to adata.obs
100
- # logger.info(f'Saving trait LDSC results to adata.obs as gsMap_{config.trait_name}_p_value...')
101
- # adata.obs[f'gsMap_{config.trait_name}_p_value'] = trait_ldsc_result['p']
102
- # adata.write(config.hdf5_with_latent_path, )
103
-
104
94
  return gene_diagnostic_info.reset_index()
105
95
 
106
96
 
@@ -38,7 +38,7 @@ def preprocess_data(adata, params):
38
38
 
39
39
  if params.data_layer in adata.layers.keys():
40
40
  logger.info(f"Using data layer: {params.data_layer}...")
41
- adata.X = adata.layers[params.data_layer]
41
+ adata.X = adata.layers[params.data_layer].copy()
42
42
  elif params.data_layer == "X":
43
43
  logger.info(f"Using data layer: {params.data_layer}...")
44
44
  if adata.X.dtype == "float32" or adata.X.dtype == "float64":
@@ -50,6 +50,15 @@ def preprocess_data(adata, params):
50
50
  # HVGs based on count
51
51
  logger.info("Dealing with count data...")
52
52
  sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=params.feat_cell)
53
+
54
+ # Get the pearson residuals
55
+ if params.pearson_residuals:
56
+ sc.experimental.pp.normalize_pearson_residuals(adata, inplace=False)
57
+ pearson_residuals = sc.experimental.pp.normalize_pearson_residuals(
58
+ adata, inplace=False, clip=10
59
+ )
60
+ adata.layers["pearson_residuals"] = pearson_residuals["X"]
61
+
53
62
  # Normalize the data
54
63
  sc.pp.normalize_total(adata, target_sum=1e4)
55
64
  sc.pp.log1p(adata)
@@ -64,8 +73,13 @@ class LatentRepresentationFinder:
64
73
  def __init__(self, adata, args: FindLatentRepresentationsConfig):
65
74
  self.params = args
66
75
 
67
- self.expression_array = adata[:, adata.var.highly_variable].X.copy()
68
- self.expression_array = sc.pp.scale(self.expression_array, max_value=10)
76
+ if "pearson_residuals" in adata.layers:
77
+ self.expression_array = (
78
+ adata[:, adata.var.highly_variable].layers["pearson_residuals"].copy()
79
+ )
80
+ else:
81
+ self.expression_array = adata[:, adata.var.highly_variable].X.copy()
82
+ self.expression_array = sc.pp.scale(self.expression_array, max_value=10)
69
83
 
70
84
  # Construct the neighboring graph
71
85
  self.graph_dict = construct_adjacency_matrix(adata, self.params)
@@ -103,6 +117,8 @@ def run_find_latent_representation(args: FindLatentRepresentationsConfig):
103
117
  # Load the ST data
104
118
  logger.info(f"Loading ST data of {args.sample_name}...")
105
119
  adata = sc.read_h5ad(args.input_hdf5_path)
120
+ sc.pp.filter_genes(adata, min_cells=1)
121
+
106
122
  logger.info(f"The ST data contains {adata.shape[0]} cells, {adata.shape[1]} genes.")
107
123
 
108
124
  # Load the cell type annotation
gsMap/format_sumstats.py CHANGED
@@ -409,6 +409,12 @@ def gwas_format(config: FormatSumstatsConfig):
409
409
  compression=compression_type,
410
410
  na_values=[".", "NA"],
411
411
  )
412
+
413
+ if isinstance(config.n, int | float):
414
+ logger.info(f"Set the sample size of gwas data as {config.n}.")
415
+ gwas["N"] = config.n
416
+ config.n = "N"
417
+
412
418
  logger.info(f"Read {len(gwas)} SNPs from {config.sumstats}.")
413
419
 
414
420
  # Check name and format