gsMap 1.71.2__py3-none-any.whl → 1.72.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,15 @@
1
1
  import logging
2
2
  import random
3
+
3
4
  import numpy as np
4
5
  import scanpy as sc
5
6
  import torch
6
7
  from sklearn.decomposition import PCA
7
8
  from sklearn.preprocessing import LabelEncoder
9
+
10
+ from gsMap.config import FindLatentRepresentationsConfig
8
11
  from gsMap.GNN.adjacency_matrix import construct_adjacency_matrix
9
12
  from gsMap.GNN.train import ModelTrainer
10
- from gsMap.config import FindLatentRepresentationsConfig
11
13
 
12
14
  logger = logging.getLogger(__name__)
13
15
 
@@ -20,41 +22,40 @@ def set_seed(seed_value):
20
22
  np.random.seed(seed_value)
21
23
  random.seed(seed_value)
22
24
  if torch.cuda.is_available():
23
- logger.info('Using GPU for computations.')
25
+ logger.info("Using GPU for computations.")
24
26
  torch.cuda.manual_seed(seed_value)
25
27
  torch.cuda.manual_seed_all(seed_value)
26
28
  else:
27
- logger.info('Using CPU for computations.')
29
+ logger.info("Using CPU for computations.")
30
+
28
31
 
29
32
  def preprocess_data(adata, params):
30
33
  """
31
34
  Preprocess the AnnData
32
35
  """
33
- logger.info('Preprocessing data...')
36
+ logger.info("Preprocessing data...")
34
37
  adata.var_names_make_unique()
35
38
 
36
39
  if params.data_layer in adata.layers.keys():
37
- logger.info(f'Using data layer: {params.data_layer}...')
40
+ logger.info(f"Using data layer: {params.data_layer}...")
38
41
  adata.X = adata.layers[params.data_layer]
39
- sc.pp.filter_genes(adata, min_cells=30)
40
- elif params.data_layer == 'X':
41
- logger.info(f'Using data layer: {params.data_layer}...')
42
- if adata.X.dtype == 'float32' or adata.X.dtype == 'float64':
43
- logger.warning(f'The data layer should be raw count data')
44
- sc.pp.filter_genes(adata, min_cells=30)
42
+ elif params.data_layer == "X":
43
+ logger.info(f"Using data layer: {params.data_layer}...")
44
+ if adata.X.dtype == "float32" or adata.X.dtype == "float64":
45
+ logger.warning("The data layer should be raw count data")
45
46
  else:
46
- raise ValueError(f'Invalid data layer: {params.data_layer}, please check the input data.')
47
+ raise ValueError(f"Invalid data layer: {params.data_layer}, please check the input data.")
47
48
 
48
- if params.data_layer in ['count', 'counts', 'X']:
49
+ if params.data_layer in ["count", "counts", "X"]:
49
50
  # HVGs based on count
50
- logger.info('Dealing with count data...')
51
- sc.pp.highly_variable_genes(adata,flavor="seurat_v3",n_top_genes=params.feat_cell)
51
+ logger.info("Dealing with count data...")
52
+ sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=params.feat_cell)
52
53
  # Normalize the data
53
54
  sc.pp.normalize_total(adata, target_sum=1e4)
54
55
  sc.pp.log1p(adata)
55
56
 
56
57
  elif params.data_layer in adata.layers.keys():
57
- sc.pp.highly_variable_genes(adata,flavor="seurat",n_top_genes=params.feat_cell)
58
+ sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=params.feat_cell)
58
59
 
59
60
  return adata
60
61
 
@@ -70,11 +71,12 @@ class LatentRepresentationFinder:
70
71
  self.graph_dict = construct_adjacency_matrix(adata, self.params)
71
72
 
72
73
  def compute_pca(self):
73
- self.latent_pca = PCA(n_components=self.params.n_comps).fit_transform(self.expression_array)
74
+ self.latent_pca = PCA(n_components=self.params.n_comps).fit_transform(
75
+ self.expression_array
76
+ )
74
77
  return self.latent_pca
75
78
 
76
- def run_gnn_vae(self, label, verbose='whole ST data'):
77
-
79
+ def run_gnn_vae(self, label, verbose="whole ST data"):
78
80
  # Use PCA if specified
79
81
  if self.params.input_pca:
80
82
  node_X = self.compute_pca()
@@ -86,7 +88,7 @@ class LatentRepresentationFinder:
86
88
  self.params.feat_cell = node_X.shape[1]
87
89
 
88
90
  # Run GNN
89
- logger.info(f'Finding latent representations for {verbose}...')
91
+ logger.info(f"Finding latent representations for {verbose}...")
90
92
  gvae = ModelTrainer(node_X, self.graph_dict, self.params, label)
91
93
  gvae.run_train()
92
94
 
@@ -99,9 +101,9 @@ def run_find_latent_representation(args: FindLatentRepresentationsConfig):
99
101
  set_seed(2024)
100
102
 
101
103
  # Load the ST data
102
- logger.info(f'Loading ST data of {args.sample_name}...')
104
+ logger.info(f"Loading ST data of {args.sample_name}...")
103
105
  adata = sc.read_h5ad(args.input_hdf5_path)
104
- logger.info(f'The ST data contains {adata.shape[0]} cells, {adata.shape[1]} genes.')
106
+ logger.info(f"The ST data contains {adata.shape[0]} cells, {adata.shape[1]} genes.")
105
107
 
106
108
  # Load the cell type annotation
107
109
  if args.annotation is not None:
@@ -112,7 +114,7 @@ def run_find_latent_representation(args: FindLatentRepresentationsConfig):
112
114
  adata = adata[adata.obs[args.annotation].isin(valid_annotations)]
113
115
 
114
116
  le = LabelEncoder()
115
- label = le.fit_transform(adata.obs[args.annotation])
117
+ label = le.fit_transform(adata.obs[args.annotation])
116
118
  else:
117
119
  label = None
118
120
 
@@ -124,16 +126,16 @@ def run_find_latent_representation(args: FindLatentRepresentationsConfig):
124
126
  latent_pca = latent_rep.latent_pca
125
127
 
126
128
  # Add latent representations to the AnnData object
127
- logger.info('Adding latent representations...')
129
+ logger.info("Adding latent representations...")
128
130
  adata.obsm["latent_GVAE"] = latent_gvae
129
131
  adata.obsm["latent_PCA"] = latent_pca
130
132
 
131
133
  # Run UMAP based on latent representations
132
- #for name in ['latent_GVAE', 'latent_PCA']:
134
+ # for name in ['latent_GVAE', 'latent_PCA']:
133
135
  # sc.pp.neighbors(adata, n_neighbors=10, use_rep=name)
134
136
  # sc.tl.umap(adata)
135
137
  # adata.obsm['X_umap_' + name] = adata.obsm['X_umap']
136
138
 
137
139
  # Save the AnnData object
138
- logger.info('Saving ST data...')
140
+ logger.info("Saving ST data...")
139
141
  adata.write(args.hdf5_with_latent_path)