gsMap 1.71.1__py3-none-any.whl → 1.71.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,133 +1,139 @@
1
- import logging
2
- import random
3
- import numpy as np
4
- import scanpy as sc
5
- import torch
6
- from sklearn.decomposition import PCA
7
- from sklearn.preprocessing import LabelEncoder
8
- from gsMap.GNN.adjacency_matrix import construct_adjacency_matrix
9
- from gsMap.GNN.train import ModelTrainer
10
- from gsMap.config import FindLatentRepresentationsConfig
11
-
12
- logger = logging.getLogger(__name__)
13
-
14
-
15
- def set_seed(seed_value):
16
- """
17
- Set seed for reproducibility in PyTorch and other libraries.
18
- """
19
- torch.manual_seed(seed_value)
20
- np.random.seed(seed_value)
21
- random.seed(seed_value)
22
- if torch.cuda.is_available():
23
- logger.info('Using GPU for computations.')
24
- torch.cuda.manual_seed(seed_value)
25
- torch.cuda.manual_seed_all(seed_value)
26
- else:
27
- logger.info('Using CPU for computations.')
28
-
29
- def preprocess_data(adata, params):
30
- """
31
- Preprocess the AnnData
32
- """
33
- logger.info('Preprocessing data...')
34
- adata.var_names_make_unique()
35
-
36
- sc.pp.filter_genes(adata, min_cells=30)
37
- if params.data_layer in adata.layers.keys():
38
- logger.info(f'Using data layer: {params.data_layer}...')
39
- adata.X = adata.layers[params.data_layer]
40
- else:
41
- raise ValueError(f'Invalid data layer: {params.data_layer}, please check the input data.')
42
-
43
- if params.data_layer in ['count', 'counts']:
44
- # HVGs based on count
45
- sc.pp.highly_variable_genes(adata,flavor="seurat_v3",n_top_genes=params.feat_cell)
46
- # Normalize the data
47
- sc.pp.normalize_total(adata, target_sum=1e4)
48
- sc.pp.log1p(adata)
49
-
50
- elif params.data_layer in adata.layers.keys():
51
- sc.pp.highly_variable_genes(adata,flavor="seurat",n_top_genes=params.feat_cell)
52
-
53
- return adata
54
-
55
-
56
- class LatentRepresentationFinder:
57
- def __init__(self, adata, args: FindLatentRepresentationsConfig):
58
- self.params = args
59
-
60
- self.expression_array = adata[:, adata.var.highly_variable].X.copy()
61
- self.expression_array = sc.pp.scale(self.expression_array, max_value=10)
62
-
63
- # Construct the neighboring graph
64
- self.graph_dict = construct_adjacency_matrix(adata, self.params)
65
-
66
- def compute_pca(self):
67
- self.latent_pca = PCA(n_components=self.params.n_comps).fit_transform(self.expression_array)
68
- return self.latent_pca
69
-
70
- def run_gnn_vae(self, label, verbose='whole ST data'):
71
-
72
- # Use PCA if specified
73
- if self.params.input_pca:
74
- node_X = self.compute_pca()
75
- else:
76
- node_X = self.expression_array
77
-
78
- # Update the input shape
79
- self.params.n_nodes = node_X.shape[0]
80
- self.params.feat_cell = node_X.shape[1]
81
-
82
- # Run GNN
83
- logger.info(f'Finding latent representations for {verbose}...')
84
- gvae = ModelTrainer(node_X, self.graph_dict, self.params, label)
85
- gvae.run_train()
86
-
87
- del self.graph_dict
88
-
89
- return gvae.get_latent()
90
-
91
-
92
- def run_find_latent_representation(args: FindLatentRepresentationsConfig):
93
- set_seed(2024)
94
-
95
- # Load the ST data
96
- logger.info(f'Loading ST data of {args.sample_name}...')
97
- adata = sc.read_h5ad(args.input_hdf5_path)
98
- logger.info(f'The ST data contains {adata.shape[0]} cells, {adata.shape[1]} genes.')
99
-
100
- # Load the cell type annotation
101
- if args.annotation is not None:
102
- # Remove cells without enough annotations
103
- adata = adata[~adata.obs[args.annotation].isnull()]
104
- num = adata.obs[args.annotation].value_counts()
105
- valid_annotations = num[num >= 30].index.to_list()
106
- adata = adata[adata.obs[args.annotation].isin(valid_annotations)]
107
-
108
- le = LabelEncoder()
109
- label = le.fit_transform(adata.obs[args.annotation])
110
- else:
111
- label = None
112
-
113
- # Preprocess data
114
- adata = preprocess_data(adata, args)
115
-
116
- latent_rep = LatentRepresentationFinder(adata, args)
117
- latent_gvae = latent_rep.run_gnn_vae(label)
118
- latent_pca = latent_rep.latent_pca
119
-
120
- # Add latent representations to the AnnData object
121
- logger.info('Adding latent representations...')
122
- adata.obsm["latent_GVAE"] = latent_gvae
123
- adata.obsm["latent_PCA"] = latent_pca
124
-
125
- # Run UMAP based on latent representations
126
- #for name in ['latent_GVAE', 'latent_PCA']:
127
- # sc.pp.neighbors(adata, n_neighbors=10, use_rep=name)
128
- # sc.tl.umap(adata)
129
- # adata.obsm['X_umap_' + name] = adata.obsm['X_umap']
130
-
131
- # Save the AnnData object
132
- logger.info('Saving ST data...')
133
- adata.write(args.hdf5_with_latent_path)
1
+ import logging
2
+ import random
3
+ import numpy as np
4
+ import scanpy as sc
5
+ import torch
6
+ from sklearn.decomposition import PCA
7
+ from sklearn.preprocessing import LabelEncoder
8
+ from gsMap.GNN.adjacency_matrix import construct_adjacency_matrix
9
+ from gsMap.GNN.train import ModelTrainer
10
+ from gsMap.config import FindLatentRepresentationsConfig
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def set_seed(seed_value):
16
+ """
17
+ Set seed for reproducibility in PyTorch and other libraries.
18
+ """
19
+ torch.manual_seed(seed_value)
20
+ np.random.seed(seed_value)
21
+ random.seed(seed_value)
22
+ if torch.cuda.is_available():
23
+ logger.info('Using GPU for computations.')
24
+ torch.cuda.manual_seed(seed_value)
25
+ torch.cuda.manual_seed_all(seed_value)
26
+ else:
27
+ logger.info('Using CPU for computations.')
28
+
29
+ def preprocess_data(adata, params):
30
+ """
31
+ Preprocess the AnnData
32
+ """
33
+ logger.info('Preprocessing data...')
34
+ adata.var_names_make_unique()
35
+
36
+ if params.data_layer in adata.layers.keys():
37
+ logger.info(f'Using data layer: {params.data_layer}...')
38
+ adata.X = adata.layers[params.data_layer]
39
+ sc.pp.filter_genes(adata, min_cells=30)
40
+ elif params.data_layer == 'X':
41
+ logger.info(f'Using data layer: {params.data_layer}...')
42
+ if adata.X.dtype == 'float32' or adata.X.dtype == 'float64':
43
+ logger.warning(f'The data layer should be raw count data')
44
+ sc.pp.filter_genes(adata, min_cells=30)
45
+ else:
46
+ raise ValueError(f'Invalid data layer: {params.data_layer}, please check the input data.')
47
+
48
+ if params.data_layer in ['count', 'counts', 'X']:
49
+ # HVGs based on count
50
+ logger.info('Dealing with count data...')
51
+ sc.pp.highly_variable_genes(adata,flavor="seurat_v3",n_top_genes=params.feat_cell)
52
+ # Normalize the data
53
+ sc.pp.normalize_total(adata, target_sum=1e4)
54
+ sc.pp.log1p(adata)
55
+
56
+ elif params.data_layer in adata.layers.keys():
57
+ sc.pp.highly_variable_genes(adata,flavor="seurat",n_top_genes=params.feat_cell)
58
+
59
+ return adata
60
+
61
+
62
+ class LatentRepresentationFinder:
63
+ def __init__(self, adata, args: FindLatentRepresentationsConfig):
64
+ self.params = args
65
+
66
+ self.expression_array = adata[:, adata.var.highly_variable].X.copy()
67
+ self.expression_array = sc.pp.scale(self.expression_array, max_value=10)
68
+
69
+ # Construct the neighboring graph
70
+ self.graph_dict = construct_adjacency_matrix(adata, self.params)
71
+
72
+ def compute_pca(self):
73
+ self.latent_pca = PCA(n_components=self.params.n_comps).fit_transform(self.expression_array)
74
+ return self.latent_pca
75
+
76
+ def run_gnn_vae(self, label, verbose='whole ST data'):
77
+
78
+ # Use PCA if specified
79
+ if self.params.input_pca:
80
+ node_X = self.compute_pca()
81
+ else:
82
+ node_X = self.expression_array
83
+
84
+ # Update the input shape
85
+ self.params.n_nodes = node_X.shape[0]
86
+ self.params.feat_cell = node_X.shape[1]
87
+
88
+ # Run GNN
89
+ logger.info(f'Finding latent representations for {verbose}...')
90
+ gvae = ModelTrainer(node_X, self.graph_dict, self.params, label)
91
+ gvae.run_train()
92
+
93
+ del self.graph_dict
94
+
95
+ return gvae.get_latent()
96
+
97
+
98
+ def run_find_latent_representation(args: FindLatentRepresentationsConfig):
99
+ set_seed(2024)
100
+
101
+ # Load the ST data
102
+ logger.info(f'Loading ST data of {args.sample_name}...')
103
+ adata = sc.read_h5ad(args.input_hdf5_path)
104
+ logger.info(f'The ST data contains {adata.shape[0]} cells, {adata.shape[1]} genes.')
105
+
106
+ # Load the cell type annotation
107
+ if args.annotation is not None:
108
+ # Remove cells without enough annotations
109
+ adata = adata[~adata.obs[args.annotation].isnull()]
110
+ num = adata.obs[args.annotation].value_counts()
111
+ valid_annotations = num[num >= 30].index.to_list()
112
+ adata = adata[adata.obs[args.annotation].isin(valid_annotations)]
113
+
114
+ le = LabelEncoder()
115
+ label = le.fit_transform(adata.obs[args.annotation])
116
+ else:
117
+ label = None
118
+
119
+ # Preprocess data
120
+ adata = preprocess_data(adata, args)
121
+
122
+ latent_rep = LatentRepresentationFinder(adata, args)
123
+ latent_gvae = latent_rep.run_gnn_vae(label)
124
+ latent_pca = latent_rep.latent_pca
125
+
126
+ # Add latent representations to the AnnData object
127
+ logger.info('Adding latent representations...')
128
+ adata.obsm["latent_GVAE"] = latent_gvae
129
+ adata.obsm["latent_PCA"] = latent_pca
130
+
131
+ # Run UMAP based on latent representations
132
+ #for name in ['latent_GVAE', 'latent_PCA']:
133
+ # sc.pp.neighbors(adata, n_neighbors=10, use_rep=name)
134
+ # sc.tl.umap(adata)
135
+ # adata.obsm['X_umap_' + name] = adata.obsm['X_umap']
136
+
137
+ # Save the AnnData object
138
+ logger.info('Saving ST data...')
139
+ adata.write(args.hdf5_with_latent_path)