gsMap 1.72.3__py3-none-any.whl → 1.73.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gsMap/__init__.py CHANGED
@@ -2,4 +2,4 @@
2
2
  Genetics-informed pathogenic spatial mapping
3
3
  """
4
4
 
5
- __version__ = "1.72.3"
5
+ __version__ = "1.73.0"
@@ -48,16 +48,16 @@ def acat_test(pvalues, weights=None):
48
48
  elif any(i < 0 for i in weights):
49
49
  raise Exception("All weights must be positive.")
50
50
  else:
51
- weights = [i / len(weights) for i in weights]
51
+ weights = [i / np.sum(weights) for i in weights]
52
52
 
53
53
  pvalues = np.array(pvalues)
54
54
  weights = np.array(weights)
55
55
 
56
- if not any(i < 1e-16 for i in pvalues):
56
+ if not any(i < 1e-15 for i in pvalues):
57
57
  cct_stat = sum(weights * np.tan((0.5 - pvalues) * np.pi))
58
58
  else:
59
- is_small = [i < (1e-16) for i in pvalues]
60
- is_large = [i >= (1e-16) for i in pvalues]
59
+ is_small = [i < (1e-15) for i in pvalues]
60
+ is_large = [i >= (1e-15) for i in pvalues]
61
61
  cct_stat = sum((weights[is_small] / pvalues[is_small]) / np.pi)
62
62
  cct_stat += sum(weights[is_large] * np.tan((0.5 - pvalues[is_large]) * np.pi))
63
63
 
@@ -118,7 +118,7 @@ def run_Cauchy_combination(config: CauchyCombinationConfig):
118
118
  n_removed = len(p_values) - len(p_values_filtered)
119
119
 
120
120
  # Remove outliers if the number is reasonable
121
- if 0 < n_removed < 20:
121
+ if 0 < n_removed < max(len(p_values) * 0.01, 20):
122
122
  logger.info(f"Removed {n_removed}/{len(p_values)} outliers (median + 3IQR) for {ct}.")
123
123
  p_cauchy_temp = acat_test(p_values_filtered)
124
124
  else:
gsMap/config.py CHANGED
@@ -1,7 +1,10 @@
1
1
  import argparse
2
2
  import dataclasses
3
3
  import logging
4
+ import os
4
5
  import sys
6
+ import threading
7
+ import time
5
8
  from collections import OrderedDict, namedtuple
6
9
  from collections.abc import Callable
7
10
  from dataclasses import dataclass
@@ -10,6 +13,7 @@ from pathlib import Path
10
13
  from pprint import pprint
11
14
  from typing import Literal
12
15
 
16
+ import psutil
13
17
  import pyfiglet
14
18
  import yaml
15
19
 
@@ -34,9 +38,109 @@ def get_gsMap_logger(logger_name):
34
38
  logger = get_gsMap_logger("gsMap")
35
39
 
36
40
 
41
+ def track_resource_usage(func):
42
+ """
43
+ Decorator to track resource usage during function execution.
44
+ Logs memory usage, CPU time, and wall clock time at the end of the function.
45
+ """
46
+
47
+ @wraps(func)
48
+ def wrapper(*args, **kwargs):
49
+ # Get the current process
50
+ process = psutil.Process(os.getpid())
51
+
52
+ # Initialize tracking variables
53
+ peak_memory = 0
54
+ cpu_percent_samples = []
55
+ stop_thread = False
56
+
57
+ # Function to monitor resource usage
58
+ def resource_monitor():
59
+ nonlocal peak_memory, cpu_percent_samples
60
+ while not stop_thread:
61
+ try:
62
+ # Get current memory usage in MB
63
+ current_memory = process.memory_info().rss / (1024 * 1024)
64
+ peak_memory = max(peak_memory, current_memory)
65
+
66
+ # Get CPU usage percentage
67
+ cpu_percent = process.cpu_percent(interval=None)
68
+ if cpu_percent > 0: # Skip initial zero readings
69
+ cpu_percent_samples.append(cpu_percent)
70
+
71
+ time.sleep(0.5)
72
+ except Exception: # Catching all exceptions here because... # noqa: BLE001
73
+ pass
74
+
75
+ # Start resource monitoring in a separate thread
76
+ monitor_thread = threading.Thread(target=resource_monitor)
77
+ monitor_thread.daemon = True
78
+ monitor_thread.start()
79
+
80
+ # Get start times
81
+ start_wall_time = time.time()
82
+ start_cpu_time = process.cpu_times().user + process.cpu_times().system
83
+
84
+ try:
85
+ # Run the actual function
86
+ result = func(*args, **kwargs)
87
+ return result
88
+ finally:
89
+ # Stop the monitoring thread
90
+ stop_thread = True
91
+ monitor_thread.join(timeout=1.0)
92
+
93
+ # Calculate elapsed times
94
+ end_wall_time = time.time()
95
+ end_cpu_time = process.cpu_times().user + process.cpu_times().system
96
+
97
+ wall_time = end_wall_time - start_wall_time
98
+ cpu_time = end_cpu_time - start_cpu_time
99
+
100
+ # Calculate average CPU percentage
101
+ avg_cpu_percent = (
102
+ sum(cpu_percent_samples) / len(cpu_percent_samples) if cpu_percent_samples else 0
103
+ )
104
+
105
+ # Format memory for display
106
+ if peak_memory < 1024:
107
+ memory_str = f"{peak_memory:.2f} MB"
108
+ else:
109
+ memory_str = f"{peak_memory / 1024:.2f} GB"
110
+
111
+ # Format times for display
112
+ if wall_time < 60:
113
+ wall_time_str = f"{wall_time:.2f} seconds"
114
+ elif wall_time < 3600:
115
+ wall_time_str = f"{wall_time / 60:.2f} minutes"
116
+ else:
117
+ wall_time_str = f"{wall_time / 3600:.2f} hours"
118
+
119
+ if cpu_time < 60:
120
+ cpu_time_str = f"{cpu_time:.2f} seconds"
121
+ elif cpu_time < 3600:
122
+ cpu_time_str = f"{cpu_time / 60:.2f} minutes"
123
+ else:
124
+ cpu_time_str = f"{cpu_time / 3600:.2f} hours"
125
+
126
+ # Log the resource usage
127
+ import logging
128
+
129
+ logger = logging.getLogger("gsMap")
130
+ logger.info("Resource usage summary:")
131
+ logger.info(f" • Wall clock time: {wall_time_str}")
132
+ logger.info(f" • CPU time: {cpu_time_str}")
133
+ logger.info(f" • Average CPU utilization: {avg_cpu_percent:.1f}%")
134
+ logger.info(f" • Peak memory usage: {memory_str}")
135
+
136
+ return wrapper
137
+
138
+
37
139
  # Decorator to register functions for cli parsing
38
140
  def register_cli(name: str, description: str, add_args_function: Callable) -> Callable:
39
141
  def decorator(func: Callable) -> Callable:
142
+ @track_resource_usage # Use enhanced resource tracking
143
+ @wraps(func)
40
144
  def wrapper(*args, **kwargs):
41
145
  name.replace("_", " ")
42
146
  gsMap_main_logo = pyfiglet.figlet_format(
@@ -50,8 +154,16 @@ def register_cli(name: str, description: str, add_args_function: Callable) -> Ca
50
154
  print(version_number.center(80), flush=True)
51
155
  print("=" * 80, flush=True)
52
156
  logger.info(f"Running {name}...")
157
+
158
+ # Record start time for the log message
159
+ start_time = time.strftime("%Y-%m-%d %H:%M:%S")
160
+ logger.info(f"Started at: {start_time}")
161
+
53
162
  func(*args, **kwargs)
54
- logger.info(f"Finished running {name}.")
163
+
164
+ # Record end time for the log message
165
+ end_time = time.strftime("%Y-%m-%d %H:%M:%S")
166
+ logger.info(f"Finished running {name} at: {end_time}.")
55
167
 
56
168
  cli_function_registry[name] = subcommand(
57
169
  name=name, func=wrapper, add_args_function=add_args_function, description=description
@@ -61,6 +173,13 @@ def register_cli(name: str, description: str, add_args_function: Callable) -> Ca
61
173
  return decorator
62
174
 
63
175
 
176
+ def str_or_float(value):
177
+ try:
178
+ return int(value)
179
+ except ValueError:
180
+ return value
181
+
182
+
64
183
  def add_shared_args(parser):
65
184
  parser.add_argument(
66
185
  "--workdir", type=str, required=True, help="Path to the working directory."
@@ -429,7 +548,7 @@ def add_format_sumstats_args(parser):
429
548
  parser.add_argument(
430
549
  "--n",
431
550
  default=None,
432
- type=str,
551
+ type=str_or_float,
433
552
  help="Name of sample size column (if not a name that gsMap understands)",
434
553
  )
435
554
  parser.add_argument(
@@ -1037,7 +1156,7 @@ class CauchyCombinationConfig(ConfigWithAutoPaths):
1037
1156
 
1038
1157
  def __post_init__(self):
1039
1158
  if self.sample_name is not None:
1040
- if len(self.sample_name_list) > 0:
1159
+ if self.sample_name_list and len(self.sample_name_list) > 0:
1041
1160
  raise ValueError("Only one of sample_name and sample_name_list must be provided.")
1042
1161
  else:
1043
1162
  self.sample_name_list = [self.sample_name]
@@ -1106,6 +1225,9 @@ class RunAllModeConfig(ConfigWithAutoPaths):
1106
1225
  annotation: str
1107
1226
  data_layer: str = "X"
1108
1227
 
1228
+ # == Find Latent Representation PARAMETERS ==
1229
+ n_comps: int = 300
1230
+
1109
1231
  # == latent 2 Gene PARAMETERS ==
1110
1232
  gM_slices: str | None = None
1111
1233
  latent_representation: str = None
@@ -1124,9 +1246,7 @@ class RunAllModeConfig(ConfigWithAutoPaths):
1124
1246
 
1125
1247
  def __post_init__(self):
1126
1248
  super().__post_init__()
1127
- self.gtffile = (
1128
- f"{self.gsMap_resource_dir}/genome_annotation/gtf/gencode.v39lift37.annotation.gtf"
1129
- )
1249
+ self.gtffile = f"{self.gsMap_resource_dir}/genome_annotation/gtf/gencode.v46lift37.basic.annotation.gtf"
1130
1250
  self.bfile_root = (
1131
1251
  f"{self.gsMap_resource_dir}/LD_Reference_Panel/1000G_EUR_Phase3_plink/1000G.EUR.QC"
1132
1252
  )
@@ -1191,7 +1311,7 @@ class FormatSumstatsConfig:
1191
1311
  se: str = None
1192
1312
  p: str = None
1193
1313
  frq: str = None
1194
- n: str = None
1314
+ n: str | int = None
1195
1315
  z: str = None
1196
1316
  OR: str = None
1197
1317
  se_OR: str = None
@@ -1204,9 +1324,21 @@ class FormatSumstatsConfig:
1204
1324
  keep_chr_pos: bool = False
1205
1325
 
1206
1326
 
1327
+ @register_cli(
1328
+ name="quick_mode",
1329
+ description="Run the entire gsMap pipeline in quick mode, utilizing pre-computed weights for faster execution.",
1330
+ add_args_function=add_run_all_mode_args,
1331
+ )
1332
+ def run_all_mode_from_cli(args: argparse.Namespace):
1333
+ from gsMap.run_all_mode import run_pipeline
1334
+
1335
+ config = get_dataclass_from_parser(args, RunAllModeConfig)
1336
+ run_pipeline(config)
1337
+
1338
+
1207
1339
  @register_cli(
1208
1340
  name="run_find_latent_representations",
1209
- description="Run Find_latent_representations \nFind the latent representations of each spot by running GNN-VAE",
1341
+ description="Run Find_latent_representations \nFind the latent representations of each spot by running GNN",
1210
1342
  add_args_function=add_find_latent_representations_args,
1211
1343
  )
1212
1344
  def run_find_latent_representation_from_cli(args: argparse.Namespace):
@@ -1278,7 +1410,7 @@ def run_Report_from_cli(args: argparse.Namespace):
1278
1410
 
1279
1411
  @register_cli(
1280
1412
  name="format_sumstats",
1281
- description="Format gwas summary statistics",
1413
+ description="Format GWAS summary statistics",
1282
1414
  add_args_function=add_format_sumstats_args,
1283
1415
  )
1284
1416
  def gwas_format_from_cli(args: argparse.Namespace):
@@ -1288,18 +1420,6 @@ def gwas_format_from_cli(args: argparse.Namespace):
1288
1420
  gwas_format(config)
1289
1421
 
1290
1422
 
1291
- @register_cli(
1292
- name="quick_mode",
1293
- description="Run all the gsMap pipeline in quick mode",
1294
- add_args_function=add_run_all_mode_args,
1295
- )
1296
- def run_all_mode_from_cli(args: argparse.Namespace):
1297
- from gsMap.run_all_mode import run_pipeline
1298
-
1299
- config = get_dataclass_from_parser(args, RunAllModeConfig)
1300
- run_pipeline(config)
1301
-
1302
-
1303
1423
  @register_cli(
1304
1424
  name="create_slice_mean",
1305
1425
  description="Create slice mean from multiple h5ad files",
@@ -5,8 +5,9 @@ import anndata
5
5
  import numpy as np
6
6
  import pandas as pd
7
7
  import scanpy as sc
8
+ import scipy
8
9
  import zarr
9
- from scipy.stats import rankdata
10
+ from scipy.stats import gmean, rankdata
10
11
  from tqdm import tqdm
11
12
 
12
13
  from gsMap.config import CreateSliceMeanConfig
@@ -62,22 +63,27 @@ def calculate_one_slice_mean(
62
63
 
63
64
  adata = adata[:, common_genes].copy()
64
65
  n_cells = adata.shape[0]
65
- log_ranks = np.zeros((n_cells, adata.n_vars), dtype=np.float32)
66
- # Compute log of ranks to avoid overflow when computing geometric mean
67
- for i in tqdm(range(n_cells), desc=f"Computing log ranks for {sample_name}"):
68
- data = adata.X[i, :].toarray().flatten()
69
- ranks = rankdata(data, method="average")
70
- log_ranks[i, :] = np.log(ranks) # Adding small value to avoid log(0)
71
66
 
72
- # Calculate geometric mean via log trick: exp(mean(log(values)))
73
- gmean = (np.exp(np.mean(log_ranks, axis=0))).reshape(-1, 1)
67
+ if not scipy.sparse.issparse(adata.X):
68
+ adata_X = scipy.sparse.csr_matrix(adata.X)
69
+ elif isinstance(adata.X, scipy.sparse.csr_matrix):
70
+ adata_X = adata.X # Avoid copying if already CSR
71
+ else:
72
+ adata_X = adata.X.tocsr()
73
+
74
+ ranks = np.zeros((n_cells, adata.n_vars), dtype=np.float16)
75
+ for i in tqdm(range(n_cells), desc="Computing ranks per cell"):
76
+ data = adata_X[i, :].toarray().flatten()
77
+ ranks[i, :] = rankdata(data, method="average")
78
+
79
+ gM = gmean(ranks, axis=0).reshape(-1, 1)
74
80
 
75
81
  # Calculate the expression fractio
76
82
  adata_X_bool = adata.X.astype(bool)
77
83
  frac = (np.asarray(adata_X_bool.sum(axis=0)).flatten()).reshape(-1, 1)
78
84
 
79
85
  # Save to zarr group
80
- gmean_frac = np.concatenate([gmean, frac], axis=1)
86
+ gmean_frac = np.concatenate([gM, frac], axis=1)
81
87
  s1_zarr = gmean_zarr_group.array(sample_name, data=gmean_frac, chunks=None, dtype="f4")
82
88
  s1_zarr.attrs["spot_number"] = adata.shape[0]
83
89
 
@@ -85,34 +91,42 @@ def calculate_one_slice_mean(
85
91
  def merge_zarr_means(zarr_group_path, output_file, common_genes):
86
92
  """
87
93
  Merge all Zarr arrays into a weighted geometric mean and save to a Parquet file.
88
- Instead of calculating the mean, it sums the logs and applies the exponential.
89
94
  """
90
95
  gmean_zarr_group = zarr.open(zarr_group_path, mode="a")
91
- log_sum = None
96
+
97
+ sample_gmeans = []
98
+ sample_weights = []
92
99
  frac_sum = None
93
100
  total_spot_number = 0
101
+
102
+ # Collect all geometric means and their weights (spot numbers)
94
103
  for key in tqdm(gmean_zarr_group.array_keys(), desc="Merging Zarr arrays"):
95
104
  s1 = gmean_zarr_group[key]
96
105
  s1_array_gmean = s1[:][:, 0]
97
106
  s1_array_frac = s1[:][:, 1]
98
107
  n = s1.attrs["spot_number"]
99
108
 
100
- if log_sum is None:
101
- log_sum = np.log(s1_array_gmean) * n
109
+ sample_gmeans.append(s1_array_gmean)
110
+ sample_weights.append(n)
111
+
112
+ if frac_sum is None:
102
113
  frac_sum = s1_array_frac
103
114
  else:
104
- log_sum += np.log(s1_array_gmean) * n
105
115
  frac_sum += s1_array_frac
106
116
 
107
117
  total_spot_number += n
108
118
 
109
- # Apply the geometric mean via exponentiation of the averaged logs
110
- final_mean = np.exp(log_sum / total_spot_number)
119
+ # Convert to arrays
120
+ sample_gmeans = np.array(sample_gmeans)
121
+ sample_weights = np.array(sample_weights)
122
+
123
+ final_gmean = gmean(sample_gmeans, axis=0, weights=sample_weights[:, np.newaxis])
124
+
111
125
  final_frac = frac_sum / total_spot_number
112
126
 
113
127
  # Save the final mean to a Parquet file
114
128
  gene_names = common_genes
115
- final_df = pd.DataFrame({"gene": gene_names, "G_Mean": final_mean, "frac": final_frac})
129
+ final_df = pd.DataFrame({"gene": gene_names, "G_Mean": final_gmean, "frac": final_frac})
116
130
  final_df.set_index("gene", inplace=True)
117
131
  final_df.to_parquet(output_file)
118
132
  return final_df
@@ -38,7 +38,7 @@ def preprocess_data(adata, params):
38
38
 
39
39
  if params.data_layer in adata.layers.keys():
40
40
  logger.info(f"Using data layer: {params.data_layer}...")
41
- adata.X = adata.layers[params.data_layer]
41
+ adata.X = adata.layers[params.data_layer].copy()
42
42
  elif params.data_layer == "X":
43
43
  logger.info(f"Using data layer: {params.data_layer}...")
44
44
  if adata.X.dtype == "float32" or adata.X.dtype == "float64":
gsMap/format_sumstats.py CHANGED
@@ -409,6 +409,12 @@ def gwas_format(config: FormatSumstatsConfig):
409
409
  compression=compression_type,
410
410
  na_values=[".", "NA"],
411
411
  )
412
+
413
+ if isinstance(config.n, int | float):
414
+ logger.info(f"Set the sample size of gwas data as {config.n}.")
415
+ gwas["N"] = config.n
416
+ config.n = "N"
417
+
412
418
  logger.info(f"Read {len(gwas)} SNPs from {config.sumstats}.")
413
419
 
414
420
  # Check name and format
gsMap/generate_ldscore.py CHANGED
@@ -10,7 +10,7 @@ from scipy.sparse import csr_matrix
10
10
  from tqdm import trange
11
11
 
12
12
  from gsMap.config import GenerateLDScoreConfig
13
- from gsMap.utils.generate_r2_matrix import ID_List_Factory, PlinkBEDFileWithR2Cache, getBlockLefts
13
+ from gsMap.utils.generate_r2_matrix import getBlockLefts, load_bfile
14
14
 
15
15
  warnings.filterwarnings("ignore", category=FutureWarning)
16
16
  logger = logging.getLogger(__name__)
@@ -71,9 +71,6 @@ def load_marker_score(mk_score_file):
71
71
  return mk_score
72
72
 
73
73
 
74
- # %%
75
- # load mkscore get common gene
76
- # %%
77
74
  # load bim
78
75
  def load_bim(bfile_root, chrom):
79
76
  """
@@ -146,52 +143,23 @@ def get_snp_pass_maf(bfile_root, chrom, maf_min=0.05):
146
143
  """
147
144
  Get the dummy matrix of SNP-gene pairs.
148
145
  """
149
- # Load the bim file
150
- PlinkBIMFile = ID_List_Factory(
151
- ["CHR", "SNP", "CM", "BP", "A1", "A2"], 1, ".bim", usecols=[0, 1, 2, 3, 4, 5]
152
- )
153
- PlinkFAMFile = ID_List_Factory(["IID"], 0, ".fam", usecols=[1])
146
+ array_snps, array_indivs, geno_array = load_bfile(bfile_chr_prefix=f"{bfile_root}.{chrom}")
154
147
 
155
- bfile = f"{bfile_root}.{chrom}"
156
- snp_file, snp_obj = bfile + ".bim", PlinkBIMFile
157
- array_snps = snp_obj(snp_file)
158
- # m = len(array_snps.IDList)
159
-
160
- # Load fam
161
- ind_file, ind_obj = bfile + ".fam", PlinkFAMFile
162
- array_indivs = ind_obj(ind_file)
148
+ m = len(array_snps.IDList)
163
149
  n = len(array_indivs.IDList)
164
- array_file, array_obj = bfile + ".bed", PlinkBEDFileWithR2Cache
165
- geno_array = array_obj(
166
- array_file, n, array_snps, keep_snps=None, keep_indivs=None, mafMin=None
150
+ logger.info(
151
+ f"Loading genotype data for {m} SNPs and {n} individuals from {bfile_root}.{chrom}"
167
152
  )
153
+
168
154
  ii = geno_array.maf > maf_min
169
155
  snp_pass_maf = array_snps.IDList[ii]
170
- print(f"After filtering SNPs with MAF < {maf_min}, {len(snp_pass_maf)} SNPs remain.")
156
+ logger.info(f"After filtering SNPs with MAF < {maf_min}, {len(snp_pass_maf)} SNPs remain.")
171
157
  return snp_pass_maf.SNP.to_list()
172
158
 
173
159
 
174
160
  def get_ldscore(bfile_root, chrom, annot_matrix, ld_wind, ld_unit="CM"):
175
- PlinkBIMFile = ID_List_Factory(
176
- ["CHR", "SNP", "CM", "BP", "A1", "A2"], 1, ".bim", usecols=[0, 1, 2, 3, 4, 5]
177
- )
178
- PlinkFAMFile = ID_List_Factory(["IID"], 0, ".fam", usecols=[1])
179
-
180
- bfile = f"{bfile_root}.{chrom}"
181
- snp_file, snp_obj = bfile + ".bim", PlinkBIMFile
182
- array_snps = snp_obj(snp_file)
183
- m = len(array_snps.IDList)
184
- print(f"Read list of {m} SNPs from {snp_file}")
161
+ array_snps, array_indivs, geno_array = load_bfile(bfile_chr_prefix=f"{bfile_root}.{chrom}")
185
162
 
186
- # Load fam
187
- ind_file, ind_obj = bfile + ".fam", PlinkFAMFile
188
- array_indivs = ind_obj(ind_file)
189
- n = len(array_indivs.IDList)
190
- print(f"Read list of {n} individuals from {ind_file}")
191
- array_file, array_obj = bfile + ".bed", PlinkBEDFileWithR2Cache
192
- geno_array = array_obj(
193
- array_file, n, array_snps, keep_snps=None, keep_indivs=None, mafMin=None
194
- )
195
163
  # Load the annotations of the baseline
196
164
  if ld_unit == "SNP":
197
165
  max_dist = ld_wind
gsMap/latent_to_gene.py CHANGED
@@ -147,22 +147,50 @@ def run_latent_to_gene(config: LatentToGeneConfig):
147
147
  )
148
148
 
149
149
  # Homologs transformation
150
- if config.homolog_file is not None:
151
- logger.info(f"------Transforming the {config.species} to HUMAN_GENE_SYM...")
152
- homologs = pd.read_csv(config.homolog_file, sep="\t")
153
- if homologs.shape[1] != 2:
154
- raise ValueError(
155
- "Homologs file must have two columns: one for the species and one for the human gene symbol."
150
+ if config.homolog_file is not None and config.species is not None:
151
+ species_col_name = f"{config.species}_homolog"
152
+
153
+ # Check if homolog conversion has already been performed
154
+ if species_col_name in adata.var.columns:
155
+ logger.warning(
156
+ f"Column '{species_col_name}' already exists in adata.var. "
157
+ f"It appears gene names have already been converted to human gene symbols. "
158
+ f"Skipping homolog transformation."
156
159
  )
160
+ else:
161
+ logger.info(f"------Transforming the {config.species} to HUMAN_GENE_SYM...")
162
+ homologs = pd.read_csv(config.homolog_file, sep="\t")
163
+ if homologs.shape[1] != 2:
164
+ raise ValueError(
165
+ "Homologs file must have two columns: one for the species and one for the human gene symbol."
166
+ )
167
+
168
+ homologs.columns = [config.species, "HUMAN_GENE_SYM"]
169
+ homologs.set_index(config.species, inplace=True)
170
+
171
+ # original_gene_names = adata.var_names.copy()
172
+
173
+ # Filter genes present in homolog file
174
+ adata = adata[:, adata.var_names.isin(homologs.index)]
175
+ logger.info(f"{adata.shape[1]} genes retained after homolog transformation.")
176
+ if adata.shape[1] < 100:
177
+ raise ValueError("Too few genes retained in ST data (<100).")
178
+
179
+ # Create mapping table of original to human gene names
180
+ gene_mapping = pd.Series(
181
+ homologs.loc[adata.var_names, "HUMAN_GENE_SYM"].values, index=adata.var_names
182
+ )
183
+
184
+ # Store original species gene names in var dataframe with the suffixed column name
185
+ adata.var[species_col_name] = adata.var_names.values
186
+
187
+ # Convert var_names to human gene symbols
188
+ adata.var_names = gene_mapping.values
189
+ adata.var.index.name = "HUMAN_GENE_SYM"
157
190
 
158
- homologs.columns = [config.species, "HUMAN_GENE_SYM"]
159
- homologs.set_index(config.species, inplace=True)
160
- adata = adata[:, adata.var_names.isin(homologs.index)]
161
- logger.info(f"{adata.shape[1]} genes retained after homolog transformation.")
162
- if adata.shape[1] < 100:
163
- raise ValueError("Too few genes retained in ST data (<100).")
164
- adata.var_names = homologs.loc[adata.var_names, "HUMAN_GENE_SYM"].values
165
- adata = adata[:, ~adata.var_names.duplicated()]
191
+ # Remove duplicated genes after conversion
192
+ adata = adata[:, ~adata.var_names.duplicated()]
193
+ logger.info(f"{adata.shape[1]} genes retained after removing duplicates.")
166
194
 
167
195
  if config.annotation is not None:
168
196
  cell_annotations = adata.obs[config.annotation].values
gsMap/run_all_mode.py CHANGED
@@ -67,6 +67,7 @@ def run_pipeline(config: RunAllModeConfig):
67
67
  sample_name=config.sample_name,
68
68
  annotation=config.annotation,
69
69
  data_layer=config.data_layer,
70
+ n_comps=config.n_comps,
70
71
  )
71
72
 
72
73
  # Step 1: Find latent representations