gsMap 1.65__py3-none-any.whl → 1.67__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,145 +1,145 @@
1
1
  import logging
2
2
  import random
3
-
4
3
  import numpy as np
5
- import pandas as pd
6
4
  import scanpy as sc
7
5
  import torch
8
- from sklearn import preprocessing
9
-
10
- from gsMap.GNN_VAE.adjacency_matrix import Construct_Adjacency_Matrix
11
- from gsMap.GNN_VAE.train import Model_Train
6
+ from sklearn.decomposition import PCA
7
+ from sklearn.preprocessing import LabelEncoder
8
+ from gsMap.GNN_VAE.adjacency_matrix import construct_adjacency_matrix
9
+ from gsMap.GNN_VAE.train import ModelTrainer
12
10
  from gsMap.config import FindLatentRepresentationsConfig
13
11
 
14
12
  logger = logging.getLogger(__name__)
15
13
 
14
+
16
15
  def set_seed(seed_value):
17
16
  """
18
- Set seed for reproducibility in PyTorch.
17
+ Set seed for reproducibility in PyTorch and other libraries.
19
18
  """
20
- torch.manual_seed(seed_value) # Set the seed for PyTorch
21
- np.random.seed(seed_value) # Set the seed for NumPy
22
- random.seed(seed_value) # Set the seed for Python random module
19
+ torch.manual_seed(seed_value)
20
+ np.random.seed(seed_value)
21
+ random.seed(seed_value)
23
22
  if torch.cuda.is_available():
24
- logger.info('Running use GPU')
25
- torch.cuda.manual_seed(seed_value) # Set seed for all CUDA devices
26
- torch.cuda.manual_seed_all(seed_value) # Set seed for all CUDA devices
23
+ logger.info('Using GPU for computations.')
24
+ torch.cuda.manual_seed(seed_value)
25
+ torch.cuda.manual_seed_all(seed_value)
27
26
  else:
28
- logger.info('Running use CPU')
27
+ logger.info('Using CPU for computations.')
29
28
 
29
+ def preprocess_data(adata, params):
30
+ """
31
+ Preprocess the AnnData
32
+ """
33
+ logger.info('Preprocessing data...')
34
+ adata.var_names_make_unique()
35
+
36
+ sc.pp.filter_genes(adata, min_cells=30)
37
+ if params.data_layer in adata.layers.keys():
38
+ adata.X = adata.layers[params.data_layer]
39
+ else:
40
+ raise ValueError(f'Invalid data layer: {params.data_layer}, please check the input data.')
30
41
 
31
- # The class for finding latent representations
32
- class Latent_Representation_Finder:
42
+ if params.data_layer in ['count', 'counts']:
33
43
 
34
- def __init__(self, adata, args:FindLatentRepresentationsConfig):
35
- self.adata = adata.copy()
36
- self.Params = args
44
+ sc.pp.normalize_total(adata, target_sum=1e4)
45
+ sc.pp.log1p(adata)
37
46
 
38
- # Standard process
39
- if self.Params.data_layer == 'count' or self.Params.data_layer == 'counts':
40
- self.adata.X = self.adata.layers[self.Params.data_layer]
41
- sc.pp.highly_variable_genes(self.adata, flavor="seurat_v3", n_top_genes=self.Params.feat_cell)
42
- sc.pp.normalize_total(self.adata, target_sum=1e4)
43
- sc.pp.log1p(self.adata)
44
- sc.pp.scale(self.adata)
45
- else:
46
- if self.Params.data_layer != 'X':
47
- self.adata.X = self.adata.layers[self.Params.data_layer]
48
- sc.pp.highly_variable_genes(self.adata, n_top_genes=self.Params.feat_cell)
47
+ # Identify highly variable genes
48
+ sc.pp.highly_variable_genes(
49
+ adata,
50
+ flavor="seurat_v3",
51
+ n_top_genes=params.feat_cell,
52
+ )
49
53
 
50
- def Run_GNN_VAE(self, label, verbose='whole ST data'):
54
+ elif params.data_layer in adata.layers.keys():
55
+ logger.info(f'Using {params.data_layer} data...')
56
+ sc.pp.highly_variable_genes(
57
+ adata,
58
+ flavor="seurat",
59
+ n_top_genes=params.feat_cell,
60
+ )
51
61
 
52
- # Construct the neighbouring graph
53
- graph_dict = Construct_Adjacency_Matrix(self.adata, self.Params)
62
+ return adata
54
63
 
55
- # Process the feature matrix
56
- node_X = self.adata[:, self.adata.var.highly_variable].X
57
- logger.info(f'The shape of feature matrix is {node_X.shape}.')
58
- if self.Params.input_pca:
59
- node_X = sc.pp.pca(node_X, n_comps=self.Params.n_comps)
64
+
65
+ class LatentRepresentationFinder:
66
+ def __init__(self, adata, args: FindLatentRepresentationsConfig):
67
+ self.params = args
68
+
69
+ self.expression_array = adata[:, adata.var.highly_variable].X.copy()
70
+
71
+ if self.params.data_layer in ['count', 'counts']:
72
+ self.expression_array = sc.pp.scale(self.expression_array, max_value=10)
73
+
74
+ # Construct the neighboring graph
75
+ self.graph_dict = construct_adjacency_matrix(adata, self.params)
76
+
77
+ def compute_pca(self):
78
+ self.latent_pca = PCA(n_components=self.params.feat_cell).fit_transform(self.expression_array)
79
+ return self.latent_pca
80
+
81
+ def run_gnn_vae(self, label, verbose='whole ST data'):
82
+
83
+ # Use PCA if specified
84
+ if self.params.input_pca:
85
+ node_X = self.compute_pca()
86
+ else:
87
+ node_X = self.expression_array
60
88
 
61
89
  # Update the input shape
62
- self.Params.n_nodes = node_X.shape[0]
63
- self.Params.feat_cell = node_X.shape[1]
90
+ self.params.n_nodes = node_X.shape[0]
91
+ self.params.feat_cell = node_X.shape[1]
64
92
 
65
- # Run GNN-VAE
66
- logger.info(f'------Finding latent representations for {verbose}...')
67
- gvae = Model_Train(node_X, graph_dict, self.Params, label)
93
+ # Run GNN
94
+ logger.info(f'Finding latent representations for {verbose}...')
95
+ gvae = ModelTrainer(node_X, self.graph_dict, self.params, label)
68
96
  gvae.run_train()
69
97
 
70
- return gvae.get_latent()
98
+ del self.graph_dict
71
99
 
72
- def Run_PCA(self):
73
- sc.tl.pca(self.adata)
74
- return self.adata.obsm['X_pca'][:, 0:self.Params.n_comps]
100
+ return gvae.get_latent()
75
101
 
76
102
 
77
- def run_find_latent_representation(args:FindLatentRepresentationsConfig):
103
+ def run_find_latent_representation(args: FindLatentRepresentationsConfig):
78
104
  set_seed(2024)
79
- num_features = args.feat_cell
80
- args.hdf5_with_latent_path.parent.mkdir(parents=True, exist_ok=True,mode=0o755)
105
+
81
106
  # Load the ST data
82
- logger.info(f'------Loading ST data of {args.sample_name}...')
83
- adata = sc.read_h5ad(f'{args.input_hdf5_path}')
84
- adata.var_names_make_unique()
85
- adata.X = adata.layers[args.data_layer] if args.data_layer in adata.layers.keys() else adata.X
86
- logger.info('The ST data contains %d cells, %d genes.' % (adata.shape[0], adata.shape[1]))
107
+ logger.info(f'Loading ST data of {args.sample_name}...')
108
+ adata = sc.read_h5ad(args.input_hdf5_path)
109
+ logger.info(f'The ST data contains {adata.shape[0]} cells, {adata.shape[1]} genes.')
110
+
87
111
  # Load the cell type annotation
88
- if not args.annotation is None:
89
- # remove cells without enough annotations
90
- adata = adata[~pd.isnull(adata.obs[args.annotation]), :]
112
+ if args.annotation is not None:
113
+ # Remove cells without enough annotations
114
+ adata = adata[~adata.obs[args.annotation].isnull()]
91
115
  num = adata.obs[args.annotation].value_counts()
92
- adata = adata[adata.obs[args.annotation].isin(num[num >= 30].index.to_list())]
116
+ valid_annotations = num[num >= 30].index.to_list()
117
+ adata = adata[adata.obs[args.annotation].isin(valid_annotations)]
93
118
 
94
- le = preprocessing.LabelEncoder()
95
- le.fit(adata.obs[args.annotation])
96
- adata.obs['categorical_label'] = le.transform(adata.obs[args.annotation])
97
- label = adata.obs['categorical_label'].to_list()
119
+ le = LabelEncoder()
120
+ adata.obs['categorical_label'] = le.fit_transform(adata.obs[args.annotation])
121
+ label = adata.obs['categorical_label'].to_numpy()
98
122
  else:
99
123
  label = None
100
- # Find latent representations
101
- latent_rep = Latent_Representation_Finder(adata, args)
102
- latent_GVAE = latent_rep.Run_GNN_VAE(label)
103
- latent_PCA = latent_rep.Run_PCA()
104
- # Add latent representations to the spe data
105
- logger.info(f'------Adding latent representations...')
106
- adata.obsm["latent_GVAE"] = latent_GVAE
107
- adata.obsm["latent_PCA"] = latent_PCA
108
- # Run umap based on latent representations
109
- for name in ['latent_GVAE', 'latent_PCA']:
110
- sc.pp.neighbors(adata, n_neighbors=10, use_rep=name)
111
- sc.tl.umap(adata)
112
- adata.obsm['X_umap_' + name] = adata.obsm['X_umap']
113
124
 
114
- # Find the latent representations hierarchically (optionally)
115
- if not args.annotation is None and args.hierarchically:
116
- logger.info(f'------Finding latent representations hierarchically...')
117
- PCA_all = pd.DataFrame()
118
- GVAE_all = pd.DataFrame()
125
+ # Preprocess data
126
+ adata = preprocess_data(adata, args)
119
127
 
120
- for ct in adata.obs[args.annotation].unique():
121
- adata_part = adata[adata.obs[args.annotation] == ct, :]
122
- logger.info(adata_part.shape)
128
+ latent_rep = LatentRepresentationFinder(adata, args)
129
+ latent_gvae = latent_rep.run_gnn_vae(label)
130
+ latent_pca = latent_rep.compute_pca()
123
131
 
124
- # Find latent representations for the selected ct
125
- latent_rep = Latent_Representation_Finder(adata_part, args)
132
+ # Add latent representations to the AnnData object
133
+ logger.info('Adding latent representations...')
134
+ adata.obsm["latent_GVAE"] = latent_gvae
135
+ adata.obsm["latent_PCA"] = latent_pca
126
136
 
127
- latent_PCA_part = pd.DataFrame(latent_rep.Run_PCA())
128
- if adata_part.shape[0] <= args.n_comps:
129
- latent_GVAE_part = latent_PCA_part
130
- else:
131
- latent_GVAE_part = pd.DataFrame(latent_rep.Run_GNN_VAE(label=None, verbose=ct))
132
-
133
- latent_GVAE_part.index = adata_part.obs_names
134
- latent_PCA_part.index = adata_part.obs_names
135
-
136
- GVAE_all = pd.concat((GVAE_all, latent_GVAE_part), axis=0)
137
- PCA_all = pd.concat((PCA_all, latent_PCA_part), axis=0)
138
-
139
- args.feat_cell = num_features
137
+ # Run UMAP based on latent representations
138
+ for name in ['latent_GVAE', 'latent_PCA']:
139
+ sc.pp.neighbors(adata, n_neighbors=10, use_rep=name)
140
+ sc.tl.umap(adata)
141
+ adata.obsm['X_umap_' + name] = adata.obsm['X_umap']
140
142
 
141
- adata.obsm["latent_GVAE_hierarchy"] = np.array(GVAE_all.loc[adata.obs_names,])
142
- adata.obsm["latent_PCA_hierarchy"] = np.array(PCA_all.loc[adata.obs_names,])
143
- logger.info(f'------Saving ST data...')
143
+ # Save the AnnData object
144
+ logger.info('Saving ST data...')
144
145
  adata.write(args.hdf5_with_latent_path)
145
-
gsMap/format_sumstats.py CHANGED
@@ -150,10 +150,10 @@ def gwas_checkname(gwas, config):
150
150
  'Pos': 'SNP positions.'
151
151
  }
152
152
 
153
- print(f'\nIterpreting column names as follows:')
153
+ logger.info(f'\nIterpreting column names as follows:')
154
154
  for key, value in interpreting.items():
155
155
  if key in new_name:
156
- print(f'{name_dict[key]}: {interpreting[key]}')
156
+ logger.info(f'{name_dict[key]}: {interpreting[key]}')
157
157
 
158
158
  return gwas
159
159
 
@@ -242,7 +242,7 @@ def gwas_qc(gwas, config):
242
242
  Filter out SNPs based on INFO, FRQ, MAF, N, and Genotypes.
243
243
  '''
244
244
  old = len(gwas)
245
- print(f'\nFiltering SNPs as follows:')
245
+ logger.info(f'\nFiltering SNPs as follows:')
246
246
  # filter: SNPs with missing values
247
247
  drops = {'NA': 0, 'P': 0, 'INFO': 0, 'FRQ': 0, 'A': 0, 'SNP': 0, 'Dup': 0, 'N': 0}
248
248
 
@@ -250,28 +250,28 @@ def gwas_qc(gwas, config):
250
250
  lambda x: x != 'INFO', gwas.columns)).reset_index(drop=True)
251
251
 
252
252
  drops['NA'] = old - len(gwas)
253
- print(f'Removed {drops["NA"]} SNPs with missing values.')
253
+ logger.info(f'Removed {drops["NA"]} SNPs with missing values.')
254
254
 
255
255
  # filter: SNPs with Info < 0.9
256
256
  if 'INFO' in gwas.columns:
257
257
  old = len(gwas)
258
258
  gwas = gwas.loc[filter_info(gwas['INFO'], config)]
259
259
  drops['INFO'] = old - len(gwas)
260
- print(f'Removed {drops["INFO"]} SNPs with INFO <= 0.9.')
260
+ logger.info(f'Removed {drops["INFO"]} SNPs with INFO <= 0.9.')
261
261
 
262
262
  # filter: SNPs with MAF <= 0.01
263
263
  if 'FRQ' in gwas.columns:
264
264
  old = len(gwas)
265
265
  gwas = gwas.loc[filter_frq(gwas['FRQ'], config)]
266
266
  drops['FRQ'] += old - len(gwas)
267
- print(f'Removed {drops["FRQ"]} SNPs with MAF <= 0.01.')
267
+ logger.info(f'Removed {drops["FRQ"]} SNPs with MAF <= 0.01.')
268
268
 
269
269
  # filter: P-value that out-of-bounds [0,1]
270
270
  if 'P' in gwas.columns:
271
271
  old = len(gwas)
272
272
  gwas = gwas.loc[filter_pvals(gwas['P'], config)]
273
273
  drops['P'] += old - len(gwas)
274
- print(f'Removed {drops["P"]} SNPs with out-of-bounds p-values.')
274
+ logger.info(f'Removed {drops["P"]} SNPs with out-of-bounds p-values.')
275
275
 
276
276
  # filter: Variants that are strand-ambiguous
277
277
  if 'A1' in gwas.columns and 'A2' in gwas.columns:
@@ -279,21 +279,21 @@ def gwas_qc(gwas, config):
279
279
  gwas.A2 = gwas.A2.str.upper()
280
280
  gwas = gwas.loc[filter_alleles(gwas.A1 + gwas.A2)]
281
281
  drops['A'] += old - len(gwas)
282
- print(f'Removed {drops["A"]} variants that were not SNPs or were strand-ambiguous.')
282
+ logger.info(f'Removed {drops["A"]} variants that were not SNPs or were strand-ambiguous.')
283
283
 
284
284
  # filter: Duplicated rs numbers
285
285
  if 'SNP' in gwas.columns:
286
286
  old = len(gwas)
287
287
  gwas = gwas.drop_duplicates(subset='SNP').reset_index(drop=True)
288
288
  drops['Dup'] += old - len(gwas)
289
- print(f'Removed {drops["Dup"]} SNPs with duplicated rs numbers.')
289
+ logger.info(f'Removed {drops["Dup"]} SNPs with duplicated rs numbers.')
290
290
 
291
291
  # filter:Sample size
292
292
  n_min = gwas.N.quantile(0.9) / 1.5
293
293
  old = len(gwas)
294
294
  gwas = gwas[gwas.N >= n_min].reset_index(drop=True)
295
295
  drops['N'] += old - len(gwas)
296
- print(f'Removed {drops["N"]} SNPs with N < {n_min}.')
296
+ logger.info(f'Removed {drops["N"]} SNPs with N < {n_min}.')
297
297
 
298
298
  return gwas
299
299
 
@@ -302,7 +302,7 @@ def variant_to_rsid(gwas, config):
302
302
  '''
303
303
  Convert variant id (Chr, Pos) to rsid
304
304
  '''
305
- print("\nConverting the SNP position to rsid. This process may take some time.")
305
+ logger.info("\nConverting the SNP position to rsid. This process may take some time.")
306
306
  unique_ids = set(gwas['id'])
307
307
  chr_format = gwas['Chr'].unique().astype(str)
308
308
  chr_format = [re.sub(r'\d+', '', value) for value in chr_format][1]
@@ -347,7 +347,7 @@ def clean_SNP_id(gwas, config):
347
347
  gwas = gwas.loc[matching_id.id]
348
348
  gwas['SNP'] = matching_id.dbsnp
349
349
  num_fail = old - len(gwas)
350
- print(f'Removed {num_fail} SNPs that did not convert to rsid.')
350
+ logger.info(f'Removed {num_fail} SNPs that did not convert to rsid.')
351
351
 
352
352
  return gwas
353
353
 
@@ -356,27 +356,27 @@ def gwas_metadata(gwas, config):
356
356
  '''
357
357
  Report key features of GWAS data
358
358
  '''
359
- print('\nMetadata:')
359
+ logger.info('\nSummary of GWAS data:')
360
360
  CHISQ = (gwas.Z ** 2)
361
361
  mean_chisq = CHISQ.mean()
362
- print('Mean chi^2 = ' + str(round(mean_chisq, 3)))
362
+ logger.info('Mean chi^2 = ' + str(round(mean_chisq, 3)))
363
363
  if mean_chisq < 1.02:
364
364
  logger.warning("Mean chi^2 may be too small.")
365
365
 
366
- print('Lambda GC = ' + str(round(CHISQ.median() / 0.4549, 3)))
367
- print('Max chi^2 = ' + str(round(CHISQ.max(), 3)))
368
- print('{N} Genome-wide significant SNPs (some may have been removed by filtering).'.format(N=(CHISQ > 29).sum()))
366
+ logger.info('Lambda GC = ' + str(round(CHISQ.median() / 0.4549, 3)))
367
+ logger.info('Max chi^2 = ' + str(round(CHISQ.max(), 3)))
368
+ logger.info('{N} Genome-wide significant SNPs (some may have been removed by filtering).'.format(N=(CHISQ > 29).sum()))
369
369
 
370
370
 
371
371
  def gwas_format(config: FormatSumstatsConfig):
372
372
  '''
373
373
  Format GWAS data
374
374
  '''
375
- print(f'------Formating gwas data for {config.sumstats}...')
375
+ logger.info(f'------Formating gwas data for {config.sumstats}...')
376
376
  compression_type = get_compression(config.sumstats)
377
377
  gwas = pd.read_csv(config.sumstats, delim_whitespace=True, header=0, compression=compression_type,
378
378
  na_values=['.', 'NA'])
379
- print(f'Read {len(gwas)} SNPs from {config.sumstats}.')
379
+ logger.info(f'Read {len(gwas)} SNPs from {config.sumstats}.')
380
380
 
381
381
  # Check name and format
382
382
  gwas = gwas_checkname(gwas, config)
@@ -402,6 +402,6 @@ def gwas_format(config: FormatSumstatsConfig):
402
402
  gwas = gwas[keep]
403
403
  out_name = config.out + appendix + '.gz'
404
404
 
405
- print(f'\nWriting summary statistics for {len(gwas)} SNPs to {out_name}.')
405
+ logger.info(f'\nWriting summary statistics for {len(gwas)} SNPs to {out_name}.')
406
406
  gwas.to_csv(out_name, sep="\t", index=False,
407
407
  float_format='%.3f', compression='gzip')