gsMap 1.66__py3-none-any.whl → 1.70__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,145 +1,133 @@
1
- import logging
2
- import random
3
-
4
- import numpy as np
5
- import pandas as pd
6
- import scanpy as sc
7
- import torch
8
- from sklearn import preprocessing
9
-
10
- from gsMap.GNN_VAE.adjacency_matrix import Construct_Adjacency_Matrix
11
- from gsMap.GNN_VAE.train import Model_Train
12
- from gsMap.config import FindLatentRepresentationsConfig
13
-
14
- logger = logging.getLogger(__name__)
15
-
16
- def set_seed(seed_value):
17
- """
18
- Set seed for reproducibility in PyTorch.
19
- """
20
- torch.manual_seed(seed_value) # Set the seed for PyTorch
21
- np.random.seed(seed_value) # Set the seed for NumPy
22
- random.seed(seed_value) # Set the seed for Python random module
23
- if torch.cuda.is_available():
24
- logger.info('Running use GPU')
25
- torch.cuda.manual_seed(seed_value) # Set seed for all CUDA devices
26
- torch.cuda.manual_seed_all(seed_value) # Set seed for all CUDA devices
27
- else:
28
- logger.info('Running use CPU')
29
-
30
-
31
- # The class for finding latent representations
32
- class Latent_Representation_Finder:
33
-
34
- def __init__(self, adata, args:FindLatentRepresentationsConfig):
35
- self.adata = adata.copy()
36
- self.Params = args
37
-
38
- # Standard process
39
- if self.Params.data_layer == 'count' or self.Params.data_layer == 'counts':
40
- self.adata.X = self.adata.layers[self.Params.data_layer]
41
- sc.pp.highly_variable_genes(self.adata, flavor="seurat_v3", n_top_genes=self.Params.feat_cell)
42
- sc.pp.normalize_total(self.adata, target_sum=1e4)
43
- sc.pp.log1p(self.adata)
44
- sc.pp.scale(self.adata)
45
- else:
46
- if self.Params.data_layer != 'X':
47
- self.adata.X = self.adata.layers[self.Params.data_layer]
48
- sc.pp.highly_variable_genes(self.adata, n_top_genes=self.Params.feat_cell)
49
-
50
- def Run_GNN_VAE(self, label, verbose='whole ST data'):
51
-
52
- # Construct the neighbouring graph
53
- graph_dict = Construct_Adjacency_Matrix(self.adata, self.Params)
54
-
55
- # Process the feature matrix
56
- node_X = self.adata[:, self.adata.var.highly_variable].X
57
- logger.info(f'The shape of feature matrix is {node_X.shape}.')
58
- if self.Params.input_pca:
59
- node_X = sc.pp.pca(node_X, n_comps=self.Params.n_comps)
60
-
61
- # Update the input shape
62
- self.Params.n_nodes = node_X.shape[0]
63
- self.Params.feat_cell = node_X.shape[1]
64
-
65
- # Run GNN-VAE
66
- logger.info(f'------Finding latent representations for {verbose}...')
67
- gvae = Model_Train(node_X, graph_dict, self.Params, label)
68
- gvae.run_train()
69
-
70
- return gvae.get_latent()
71
-
72
- def Run_PCA(self):
73
- sc.tl.pca(self.adata)
74
- return self.adata.obsm['X_pca'][:, 0:self.Params.n_comps]
75
-
76
-
77
- def run_find_latent_representation(args:FindLatentRepresentationsConfig):
78
- set_seed(2024)
79
- num_features = args.feat_cell
80
- args.hdf5_with_latent_path.parent.mkdir(parents=True, exist_ok=True,mode=0o755)
81
- # Load the ST data
82
- logger.info(f'------Loading ST data of {args.sample_name}...')
83
- adata = sc.read_h5ad(f'{args.input_hdf5_path}')
84
- adata.var_names_make_unique()
85
- adata.X = adata.layers[args.data_layer] if args.data_layer in adata.layers.keys() else adata.X
86
- logger.info('The ST data contains %d cells, %d genes.' % (adata.shape[0], adata.shape[1]))
87
- # Load the cell type annotation
88
- if not args.annotation is None:
89
- # remove cells without enough annotations
90
- adata = adata[~pd.isnull(adata.obs[args.annotation]), :]
91
- num = adata.obs[args.annotation].value_counts()
92
- adata = adata[adata.obs[args.annotation].isin(num[num >= 30].index.to_list())]
93
-
94
- le = preprocessing.LabelEncoder()
95
- le.fit(adata.obs[args.annotation])
96
- adata.obs['categorical_label'] = le.transform(adata.obs[args.annotation])
97
- label = adata.obs['categorical_label'].to_list()
98
- else:
99
- label = None
100
- # Find latent representations
101
- latent_rep = Latent_Representation_Finder(adata, args)
102
- latent_GVAE = latent_rep.Run_GNN_VAE(label)
103
- latent_PCA = latent_rep.Run_PCA()
104
- # Add latent representations to the spe data
105
- logger.info(f'------Adding latent representations...')
106
- adata.obsm["latent_GVAE"] = latent_GVAE
107
- adata.obsm["latent_PCA"] = latent_PCA
108
- # Run umap based on latent representations
109
- for name in ['latent_GVAE', 'latent_PCA']:
110
- sc.pp.neighbors(adata, n_neighbors=10, use_rep=name)
111
- sc.tl.umap(adata)
112
- adata.obsm['X_umap_' + name] = adata.obsm['X_umap']
113
-
114
- # Find the latent representations hierarchically (optionally)
115
- if not args.annotation is None and args.hierarchically:
116
- logger.info(f'------Finding latent representations hierarchically...')
117
- PCA_all = pd.DataFrame()
118
- GVAE_all = pd.DataFrame()
119
-
120
- for ct in adata.obs[args.annotation].unique():
121
- adata_part = adata[adata.obs[args.annotation] == ct, :]
122
- logger.info(adata_part.shape)
123
-
124
- # Find latent representations for the selected ct
125
- latent_rep = Latent_Representation_Finder(adata_part, args)
126
-
127
- latent_PCA_part = pd.DataFrame(latent_rep.Run_PCA())
128
- if adata_part.shape[0] <= args.n_comps:
129
- latent_GVAE_part = latent_PCA_part
130
- else:
131
- latent_GVAE_part = pd.DataFrame(latent_rep.Run_GNN_VAE(label=None, verbose=ct))
132
-
133
- latent_GVAE_part.index = adata_part.obs_names
134
- latent_PCA_part.index = adata_part.obs_names
135
-
136
- GVAE_all = pd.concat((GVAE_all, latent_GVAE_part), axis=0)
137
- PCA_all = pd.concat((PCA_all, latent_PCA_part), axis=0)
138
-
139
- args.feat_cell = num_features
140
-
141
- adata.obsm["latent_GVAE_hierarchy"] = np.array(GVAE_all.loc[adata.obs_names,])
142
- adata.obsm["latent_PCA_hierarchy"] = np.array(PCA_all.loc[adata.obs_names,])
143
- logger.info(f'------Saving ST data...')
144
- adata.write(args.hdf5_with_latent_path)
145
-
1
+ import logging
2
+ import random
3
+ import numpy as np
4
+ import scanpy as sc
5
+ import torch
6
+ from sklearn.decomposition import PCA
7
+ from sklearn.preprocessing import LabelEncoder
8
+ from gsMap.GNN.adjacency_matrix import construct_adjacency_matrix
9
+ from gsMap.GNN.train import ModelTrainer
10
+ from gsMap.config import FindLatentRepresentationsConfig
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def set_seed(seed_value):
16
+ """
17
+ Set seed for reproducibility in PyTorch and other libraries.
18
+ """
19
+ torch.manual_seed(seed_value)
20
+ np.random.seed(seed_value)
21
+ random.seed(seed_value)
22
+ if torch.cuda.is_available():
23
+ logger.info('Using GPU for computations.')
24
+ torch.cuda.manual_seed(seed_value)
25
+ torch.cuda.manual_seed_all(seed_value)
26
+ else:
27
+ logger.info('Using CPU for computations.')
28
+
29
+ def preprocess_data(adata, params):
30
+ """
31
+ Preprocess the AnnData
32
+ """
33
+ logger.info('Preprocessing data...')
34
+ adata.var_names_make_unique()
35
+
36
+ sc.pp.filter_genes(adata, min_cells=30)
37
+ if params.data_layer in adata.layers.keys():
38
+ logger.info(f'Using data layer: {params.data_layer}...')
39
+ adata.X = adata.layers[params.data_layer]
40
+ else:
41
+ raise ValueError(f'Invalid data layer: {params.data_layer}, please check the input data.')
42
+
43
+ if params.data_layer in ['count', 'counts']:
44
+ # HVGs based on count
45
+ sc.pp.highly_variable_genes(adata,flavor="seurat_v3",n_top_genes=params.feat_cell)
46
+ # Normalize the data
47
+ sc.pp.normalize_total(adata, target_sum=1e4)
48
+ sc.pp.log1p(adata)
49
+
50
+ elif params.data_layer in adata.layers.keys():
51
+ sc.pp.highly_variable_genes(adata,flavor="seurat",n_top_genes=params.feat_cell)
52
+
53
+ return adata
54
+
55
+
56
+ class LatentRepresentationFinder:
57
+ def __init__(self, adata, args: FindLatentRepresentationsConfig):
58
+ self.params = args
59
+
60
+ self.expression_array = adata[:, adata.var.highly_variable].X.copy()
61
+ self.expression_array = sc.pp.scale(self.expression_array, max_value=10)
62
+
63
+ # Construct the neighboring graph
64
+ self.graph_dict = construct_adjacency_matrix(adata, self.params)
65
+
66
+ def compute_pca(self):
67
+ self.latent_pca = PCA(n_components=self.params.feat_cell).fit_transform(self.expression_array)
68
+ return self.latent_pca
69
+
70
+ def run_gnn_vae(self, label, verbose='whole ST data'):
71
+
72
+ # Use PCA if specified
73
+ if self.params.input_pca:
74
+ node_X = self.compute_pca()
75
+ else:
76
+ node_X = self.expression_array
77
+
78
+ # Update the input shape
79
+ self.params.n_nodes = node_X.shape[0]
80
+ self.params.feat_cell = node_X.shape[1]
81
+
82
+ # Run GNN
83
+ logger.info(f'Finding latent representations for {verbose}...')
84
+ gvae = ModelTrainer(node_X, self.graph_dict, self.params, label)
85
+ gvae.run_train()
86
+
87
+ del self.graph_dict
88
+
89
+ return gvae.get_latent()
90
+
91
+
92
+ def run_find_latent_representation(args: FindLatentRepresentationsConfig):
93
+ set_seed(2024)
94
+
95
+ # Load the ST data
96
+ logger.info(f'Loading ST data of {args.sample_name}...')
97
+ adata = sc.read_h5ad(args.input_hdf5_path)
98
+ logger.info(f'The ST data contains {adata.shape[0]} cells, {adata.shape[1]} genes.')
99
+
100
+ # Load the cell type annotation
101
+ if args.annotation is not None:
102
+ # Remove cells without enough annotations
103
+ adata = adata[~adata.obs[args.annotation].isnull()]
104
+ num = adata.obs[args.annotation].value_counts()
105
+ valid_annotations = num[num >= 30].index.to_list()
106
+ adata = adata[adata.obs[args.annotation].isin(valid_annotations)]
107
+
108
+ le = LabelEncoder()
109
+ label = le.fit_transform(adata.obs[args.annotation])
110
+ else:
111
+ label = None
112
+
113
+ # Preprocess data
114
+ adata = preprocess_data(adata, args)
115
+
116
+ latent_rep = LatentRepresentationFinder(adata, args)
117
+ latent_gvae = latent_rep.run_gnn_vae(label)
118
+ latent_pca = latent_rep.latent_pca
119
+
120
+ # Add latent representations to the AnnData object
121
+ logger.info('Adding latent representations...')
122
+ adata.obsm["latent_GVAE"] = latent_gvae
123
+ adata.obsm["latent_PCA"] = latent_pca
124
+
125
+ # Run UMAP based on latent representations
126
+ #for name in ['latent_GVAE', 'latent_PCA']:
127
+ # sc.pp.neighbors(adata, n_neighbors=10, use_rep=name)
128
+ # sc.tl.umap(adata)
129
+ # adata.obsm['X_umap_' + name] = adata.obsm['X_umap']
130
+
131
+ # Save the AnnData object
132
+ logger.info('Saving ST data...')
133
+ adata.write(args.hdf5_with_latent_path)