gsMap 1.67__py3-none-any.whl → 1.71__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,145 +1,133 @@
1
- import logging
2
- import random
3
- import numpy as np
4
- import scanpy as sc
5
- import torch
6
- from sklearn.decomposition import PCA
7
- from sklearn.preprocessing import LabelEncoder
8
- from gsMap.GNN_VAE.adjacency_matrix import construct_adjacency_matrix
9
- from gsMap.GNN_VAE.train import ModelTrainer
10
- from gsMap.config import FindLatentRepresentationsConfig
11
-
12
- logger = logging.getLogger(__name__)
13
-
14
-
15
- def set_seed(seed_value):
16
- """
17
- Set seed for reproducibility in PyTorch and other libraries.
18
- """
19
- torch.manual_seed(seed_value)
20
- np.random.seed(seed_value)
21
- random.seed(seed_value)
22
- if torch.cuda.is_available():
23
- logger.info('Using GPU for computations.')
24
- torch.cuda.manual_seed(seed_value)
25
- torch.cuda.manual_seed_all(seed_value)
26
- else:
27
- logger.info('Using CPU for computations.')
28
-
29
- def preprocess_data(adata, params):
30
- """
31
- Preprocess the AnnData
32
- """
33
- logger.info('Preprocessing data...')
34
- adata.var_names_make_unique()
35
-
36
- sc.pp.filter_genes(adata, min_cells=30)
37
- if params.data_layer in adata.layers.keys():
38
- adata.X = adata.layers[params.data_layer]
39
- else:
40
- raise ValueError(f'Invalid data layer: {params.data_layer}, please check the input data.')
41
-
42
- if params.data_layer in ['count', 'counts']:
43
-
44
- sc.pp.normalize_total(adata, target_sum=1e4)
45
- sc.pp.log1p(adata)
46
-
47
- # Identify highly variable genes
48
- sc.pp.highly_variable_genes(
49
- adata,
50
- flavor="seurat_v3",
51
- n_top_genes=params.feat_cell,
52
- )
53
-
54
- elif params.data_layer in adata.layers.keys():
55
- logger.info(f'Using {params.data_layer} data...')
56
- sc.pp.highly_variable_genes(
57
- adata,
58
- flavor="seurat",
59
- n_top_genes=params.feat_cell,
60
- )
61
-
62
- return adata
63
-
64
-
65
- class LatentRepresentationFinder:
66
- def __init__(self, adata, args: FindLatentRepresentationsConfig):
67
- self.params = args
68
-
69
- self.expression_array = adata[:, adata.var.highly_variable].X.copy()
70
-
71
- if self.params.data_layer in ['count', 'counts']:
72
- self.expression_array = sc.pp.scale(self.expression_array, max_value=10)
73
-
74
- # Construct the neighboring graph
75
- self.graph_dict = construct_adjacency_matrix(adata, self.params)
76
-
77
- def compute_pca(self):
78
- self.latent_pca = PCA(n_components=self.params.feat_cell).fit_transform(self.expression_array)
79
- return self.latent_pca
80
-
81
- def run_gnn_vae(self, label, verbose='whole ST data'):
82
-
83
- # Use PCA if specified
84
- if self.params.input_pca:
85
- node_X = self.compute_pca()
86
- else:
87
- node_X = self.expression_array
88
-
89
- # Update the input shape
90
- self.params.n_nodes = node_X.shape[0]
91
- self.params.feat_cell = node_X.shape[1]
92
-
93
- # Run GNN
94
- logger.info(f'Finding latent representations for {verbose}...')
95
- gvae = ModelTrainer(node_X, self.graph_dict, self.params, label)
96
- gvae.run_train()
97
-
98
- del self.graph_dict
99
-
100
- return gvae.get_latent()
101
-
102
-
103
- def run_find_latent_representation(args: FindLatentRepresentationsConfig):
104
- set_seed(2024)
105
-
106
- # Load the ST data
107
- logger.info(f'Loading ST data of {args.sample_name}...')
108
- adata = sc.read_h5ad(args.input_hdf5_path)
109
- logger.info(f'The ST data contains {adata.shape[0]} cells, {adata.shape[1]} genes.')
110
-
111
- # Load the cell type annotation
112
- if args.annotation is not None:
113
- # Remove cells without enough annotations
114
- adata = adata[~adata.obs[args.annotation].isnull()]
115
- num = adata.obs[args.annotation].value_counts()
116
- valid_annotations = num[num >= 30].index.to_list()
117
- adata = adata[adata.obs[args.annotation].isin(valid_annotations)]
118
-
119
- le = LabelEncoder()
120
- adata.obs['categorical_label'] = le.fit_transform(adata.obs[args.annotation])
121
- label = adata.obs['categorical_label'].to_numpy()
122
- else:
123
- label = None
124
-
125
- # Preprocess data
126
- adata = preprocess_data(adata, args)
127
-
128
- latent_rep = LatentRepresentationFinder(adata, args)
129
- latent_gvae = latent_rep.run_gnn_vae(label)
130
- latent_pca = latent_rep.compute_pca()
131
-
132
- # Add latent representations to the AnnData object
133
- logger.info('Adding latent representations...')
134
- adata.obsm["latent_GVAE"] = latent_gvae
135
- adata.obsm["latent_PCA"] = latent_pca
136
-
137
- # Run UMAP based on latent representations
138
- for name in ['latent_GVAE', 'latent_PCA']:
139
- sc.pp.neighbors(adata, n_neighbors=10, use_rep=name)
140
- sc.tl.umap(adata)
141
- adata.obsm['X_umap_' + name] = adata.obsm['X_umap']
142
-
143
- # Save the AnnData object
144
- logger.info('Saving ST data...')
145
- adata.write(args.hdf5_with_latent_path)
1
+ import logging
2
+ import random
3
+ import numpy as np
4
+ import scanpy as sc
5
+ import torch
6
+ from sklearn.decomposition import PCA
7
+ from sklearn.preprocessing import LabelEncoder
8
+ from gsMap.GNN.adjacency_matrix import construct_adjacency_matrix
9
+ from gsMap.GNN.train import ModelTrainer
10
+ from gsMap.config import FindLatentRepresentationsConfig
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def set_seed(seed_value):
16
+ """
17
+ Set seed for reproducibility in PyTorch and other libraries.
18
+ """
19
+ torch.manual_seed(seed_value)
20
+ np.random.seed(seed_value)
21
+ random.seed(seed_value)
22
+ if torch.cuda.is_available():
23
+ logger.info('Using GPU for computations.')
24
+ torch.cuda.manual_seed(seed_value)
25
+ torch.cuda.manual_seed_all(seed_value)
26
+ else:
27
+ logger.info('Using CPU for computations.')
28
+
29
+ def preprocess_data(adata, params):
30
+ """
31
+ Preprocess the AnnData
32
+ """
33
+ logger.info('Preprocessing data...')
34
+ adata.var_names_make_unique()
35
+
36
+ sc.pp.filter_genes(adata, min_cells=30)
37
+ if params.data_layer in adata.layers.keys():
38
+ logger.info(f'Using data layer: {params.data_layer}...')
39
+ adata.X = adata.layers[params.data_layer]
40
+ else:
41
+ raise ValueError(f'Invalid data layer: {params.data_layer}, please check the input data.')
42
+
43
+ if params.data_layer in ['count', 'counts']:
44
+ # HVGs based on count
45
+ sc.pp.highly_variable_genes(adata,flavor="seurat_v3",n_top_genes=params.feat_cell)
46
+ # Normalize the data
47
+ sc.pp.normalize_total(adata, target_sum=1e4)
48
+ sc.pp.log1p(adata)
49
+
50
+ elif params.data_layer in adata.layers.keys():
51
+ sc.pp.highly_variable_genes(adata,flavor="seurat",n_top_genes=params.feat_cell)
52
+
53
+ return adata
54
+
55
+
56
+ class LatentRepresentationFinder:
57
+ def __init__(self, adata, args: FindLatentRepresentationsConfig):
58
+ self.params = args
59
+
60
+ self.expression_array = adata[:, adata.var.highly_variable].X.copy()
61
+ self.expression_array = sc.pp.scale(self.expression_array, max_value=10)
62
+
63
+ # Construct the neighboring graph
64
+ self.graph_dict = construct_adjacency_matrix(adata, self.params)
65
+
66
+ def compute_pca(self):
67
+ self.latent_pca = PCA(n_components=self.params.n_comps).fit_transform(self.expression_array)
68
+ return self.latent_pca
69
+
70
+ def run_gnn_vae(self, label, verbose='whole ST data'):
71
+
72
+ # Use PCA if specified
73
+ if self.params.input_pca:
74
+ node_X = self.compute_pca()
75
+ else:
76
+ node_X = self.expression_array
77
+
78
+ # Update the input shape
79
+ self.params.n_nodes = node_X.shape[0]
80
+ self.params.feat_cell = node_X.shape[1]
81
+
82
+ # Run GNN
83
+ logger.info(f'Finding latent representations for {verbose}...')
84
+ gvae = ModelTrainer(node_X, self.graph_dict, self.params, label)
85
+ gvae.run_train()
86
+
87
+ del self.graph_dict
88
+
89
+ return gvae.get_latent()
90
+
91
+
92
+ def run_find_latent_representation(args: FindLatentRepresentationsConfig):
93
+ set_seed(2024)
94
+
95
+ # Load the ST data
96
+ logger.info(f'Loading ST data of {args.sample_name}...')
97
+ adata = sc.read_h5ad(args.input_hdf5_path)
98
+ logger.info(f'The ST data contains {adata.shape[0]} cells, {adata.shape[1]} genes.')
99
+
100
+ # Load the cell type annotation
101
+ if args.annotation is not None:
102
+ # Remove cells without enough annotations
103
+ adata = adata[~adata.obs[args.annotation].isnull()]
104
+ num = adata.obs[args.annotation].value_counts()
105
+ valid_annotations = num[num >= 30].index.to_list()
106
+ adata = adata[adata.obs[args.annotation].isin(valid_annotations)]
107
+
108
+ le = LabelEncoder()
109
+ label = le.fit_transform(adata.obs[args.annotation])
110
+ else:
111
+ label = None
112
+
113
+ # Preprocess data
114
+ adata = preprocess_data(adata, args)
115
+
116
+ latent_rep = LatentRepresentationFinder(adata, args)
117
+ latent_gvae = latent_rep.run_gnn_vae(label)
118
+ latent_pca = latent_rep.latent_pca
119
+
120
+ # Add latent representations to the AnnData object
121
+ logger.info('Adding latent representations...')
122
+ adata.obsm["latent_GVAE"] = latent_gvae
123
+ adata.obsm["latent_PCA"] = latent_pca
124
+
125
+ # Run UMAP based on latent representations
126
+ #for name in ['latent_GVAE', 'latent_PCA']:
127
+ # sc.pp.neighbors(adata, n_neighbors=10, use_rep=name)
128
+ # sc.tl.umap(adata)
129
+ # adata.obsm['X_umap_' + name] = adata.obsm['X_umap']
130
+
131
+ # Save the AnnData object
132
+ logger.info('Saving ST data...')
133
+ adata.write(args.hdf5_with_latent_path)