gsMap 1.71.1__py3-none-any.whl → 1.72.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,133 +1,141 @@
1
- import logging
2
- import random
3
- import numpy as np
4
- import scanpy as sc
5
- import torch
6
- from sklearn.decomposition import PCA
7
- from sklearn.preprocessing import LabelEncoder
8
- from gsMap.GNN.adjacency_matrix import construct_adjacency_matrix
9
- from gsMap.GNN.train import ModelTrainer
10
- from gsMap.config import FindLatentRepresentationsConfig
11
-
12
- logger = logging.getLogger(__name__)
13
-
14
-
15
- def set_seed(seed_value):
16
- """
17
- Set seed for reproducibility in PyTorch and other libraries.
18
- """
19
- torch.manual_seed(seed_value)
20
- np.random.seed(seed_value)
21
- random.seed(seed_value)
22
- if torch.cuda.is_available():
23
- logger.info('Using GPU for computations.')
24
- torch.cuda.manual_seed(seed_value)
25
- torch.cuda.manual_seed_all(seed_value)
26
- else:
27
- logger.info('Using CPU for computations.')
28
-
29
- def preprocess_data(adata, params):
30
- """
31
- Preprocess the AnnData
32
- """
33
- logger.info('Preprocessing data...')
34
- adata.var_names_make_unique()
35
-
36
- sc.pp.filter_genes(adata, min_cells=30)
37
- if params.data_layer in adata.layers.keys():
38
- logger.info(f'Using data layer: {params.data_layer}...')
39
- adata.X = adata.layers[params.data_layer]
40
- else:
41
- raise ValueError(f'Invalid data layer: {params.data_layer}, please check the input data.')
42
-
43
- if params.data_layer in ['count', 'counts']:
44
- # HVGs based on count
45
- sc.pp.highly_variable_genes(adata,flavor="seurat_v3",n_top_genes=params.feat_cell)
46
- # Normalize the data
47
- sc.pp.normalize_total(adata, target_sum=1e4)
48
- sc.pp.log1p(adata)
49
-
50
- elif params.data_layer in adata.layers.keys():
51
- sc.pp.highly_variable_genes(adata,flavor="seurat",n_top_genes=params.feat_cell)
52
-
53
- return adata
54
-
55
-
56
- class LatentRepresentationFinder:
57
- def __init__(self, adata, args: FindLatentRepresentationsConfig):
58
- self.params = args
59
-
60
- self.expression_array = adata[:, adata.var.highly_variable].X.copy()
61
- self.expression_array = sc.pp.scale(self.expression_array, max_value=10)
62
-
63
- # Construct the neighboring graph
64
- self.graph_dict = construct_adjacency_matrix(adata, self.params)
65
-
66
- def compute_pca(self):
67
- self.latent_pca = PCA(n_components=self.params.n_comps).fit_transform(self.expression_array)
68
- return self.latent_pca
69
-
70
- def run_gnn_vae(self, label, verbose='whole ST data'):
71
-
72
- # Use PCA if specified
73
- if self.params.input_pca:
74
- node_X = self.compute_pca()
75
- else:
76
- node_X = self.expression_array
77
-
78
- # Update the input shape
79
- self.params.n_nodes = node_X.shape[0]
80
- self.params.feat_cell = node_X.shape[1]
81
-
82
- # Run GNN
83
- logger.info(f'Finding latent representations for {verbose}...')
84
- gvae = ModelTrainer(node_X, self.graph_dict, self.params, label)
85
- gvae.run_train()
86
-
87
- del self.graph_dict
88
-
89
- return gvae.get_latent()
90
-
91
-
92
- def run_find_latent_representation(args: FindLatentRepresentationsConfig):
93
- set_seed(2024)
94
-
95
- # Load the ST data
96
- logger.info(f'Loading ST data of {args.sample_name}...')
97
- adata = sc.read_h5ad(args.input_hdf5_path)
98
- logger.info(f'The ST data contains {adata.shape[0]} cells, {adata.shape[1]} genes.')
99
-
100
- # Load the cell type annotation
101
- if args.annotation is not None:
102
- # Remove cells without enough annotations
103
- adata = adata[~adata.obs[args.annotation].isnull()]
104
- num = adata.obs[args.annotation].value_counts()
105
- valid_annotations = num[num >= 30].index.to_list()
106
- adata = adata[adata.obs[args.annotation].isin(valid_annotations)]
107
-
108
- le = LabelEncoder()
109
- label = le.fit_transform(adata.obs[args.annotation])
110
- else:
111
- label = None
112
-
113
- # Preprocess data
114
- adata = preprocess_data(adata, args)
115
-
116
- latent_rep = LatentRepresentationFinder(adata, args)
117
- latent_gvae = latent_rep.run_gnn_vae(label)
118
- latent_pca = latent_rep.latent_pca
119
-
120
- # Add latent representations to the AnnData object
121
- logger.info('Adding latent representations...')
122
- adata.obsm["latent_GVAE"] = latent_gvae
123
- adata.obsm["latent_PCA"] = latent_pca
124
-
125
- # Run UMAP based on latent representations
126
- #for name in ['latent_GVAE', 'latent_PCA']:
127
- # sc.pp.neighbors(adata, n_neighbors=10, use_rep=name)
128
- # sc.tl.umap(adata)
129
- # adata.obsm['X_umap_' + name] = adata.obsm['X_umap']
130
-
131
- # Save the AnnData object
132
- logger.info('Saving ST data...')
133
- adata.write(args.hdf5_with_latent_path)
1
+ import logging
2
+ import random
3
+
4
+ import numpy as np
5
+ import scanpy as sc
6
+ import torch
7
+ from sklearn.decomposition import PCA
8
+ from sklearn.preprocessing import LabelEncoder
9
+
10
+ from gsMap.config import FindLatentRepresentationsConfig
11
+ from gsMap.GNN.adjacency_matrix import construct_adjacency_matrix
12
+ from gsMap.GNN.train import ModelTrainer
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def set_seed(seed_value):
18
+ """
19
+ Set seed for reproducibility in PyTorch and other libraries.
20
+ """
21
+ torch.manual_seed(seed_value)
22
+ np.random.seed(seed_value)
23
+ random.seed(seed_value)
24
+ if torch.cuda.is_available():
25
+ logger.info("Using GPU for computations.")
26
+ torch.cuda.manual_seed(seed_value)
27
+ torch.cuda.manual_seed_all(seed_value)
28
+ else:
29
+ logger.info("Using CPU for computations.")
30
+
31
+
32
+ def preprocess_data(adata, params):
33
+ """
34
+ Preprocess the AnnData
35
+ """
36
+ logger.info("Preprocessing data...")
37
+ adata.var_names_make_unique()
38
+
39
+ if params.data_layer in adata.layers.keys():
40
+ logger.info(f"Using data layer: {params.data_layer}...")
41
+ adata.X = adata.layers[params.data_layer]
42
+ elif params.data_layer == "X":
43
+ logger.info(f"Using data layer: {params.data_layer}...")
44
+ if adata.X.dtype == "float32" or adata.X.dtype == "float64":
45
+ logger.warning("The data layer should be raw count data")
46
+ else:
47
+ raise ValueError(f"Invalid data layer: {params.data_layer}, please check the input data.")
48
+
49
+ if params.data_layer in ["count", "counts", "X"]:
50
+ # HVGs based on count
51
+ logger.info("Dealing with count data...")
52
+ sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=params.feat_cell)
53
+ # Normalize the data
54
+ sc.pp.normalize_total(adata, target_sum=1e4)
55
+ sc.pp.log1p(adata)
56
+
57
+ elif params.data_layer in adata.layers.keys():
58
+ sc.pp.highly_variable_genes(adata, flavor="seurat", n_top_genes=params.feat_cell)
59
+
60
+ return adata
61
+
62
+
63
+ class LatentRepresentationFinder:
64
+ def __init__(self, adata, args: FindLatentRepresentationsConfig):
65
+ self.params = args
66
+
67
+ self.expression_array = adata[:, adata.var.highly_variable].X.copy()
68
+ self.expression_array = sc.pp.scale(self.expression_array, max_value=10)
69
+
70
+ # Construct the neighboring graph
71
+ self.graph_dict = construct_adjacency_matrix(adata, self.params)
72
+
73
+ def compute_pca(self):
74
+ self.latent_pca = PCA(n_components=self.params.n_comps).fit_transform(
75
+ self.expression_array
76
+ )
77
+ return self.latent_pca
78
+
79
+ def run_gnn_vae(self, label, verbose="whole ST data"):
80
+ # Use PCA if specified
81
+ if self.params.input_pca:
82
+ node_X = self.compute_pca()
83
+ else:
84
+ node_X = self.expression_array
85
+
86
+ # Update the input shape
87
+ self.params.n_nodes = node_X.shape[0]
88
+ self.params.feat_cell = node_X.shape[1]
89
+
90
+ # Run GNN
91
+ logger.info(f"Finding latent representations for {verbose}...")
92
+ gvae = ModelTrainer(node_X, self.graph_dict, self.params, label)
93
+ gvae.run_train()
94
+
95
+ del self.graph_dict
96
+
97
+ return gvae.get_latent()
98
+
99
+
100
+ def run_find_latent_representation(args: FindLatentRepresentationsConfig):
101
+ set_seed(2024)
102
+
103
+ # Load the ST data
104
+ logger.info(f"Loading ST data of {args.sample_name}...")
105
+ adata = sc.read_h5ad(args.input_hdf5_path)
106
+ logger.info(f"The ST data contains {adata.shape[0]} cells, {adata.shape[1]} genes.")
107
+
108
+ # Load the cell type annotation
109
+ if args.annotation is not None:
110
+ # Remove cells without enough annotations
111
+ adata = adata[~adata.obs[args.annotation].isnull()]
112
+ num = adata.obs[args.annotation].value_counts()
113
+ valid_annotations = num[num >= 30].index.to_list()
114
+ adata = adata[adata.obs[args.annotation].isin(valid_annotations)]
115
+
116
+ le = LabelEncoder()
117
+ label = le.fit_transform(adata.obs[args.annotation])
118
+ else:
119
+ label = None
120
+
121
+ # Preprocess data
122
+ adata = preprocess_data(adata, args)
123
+
124
+ latent_rep = LatentRepresentationFinder(adata, args)
125
+ latent_gvae = latent_rep.run_gnn_vae(label)
126
+ latent_pca = latent_rep.latent_pca
127
+
128
+ # Add latent representations to the AnnData object
129
+ logger.info("Adding latent representations...")
130
+ adata.obsm["latent_GVAE"] = latent_gvae
131
+ adata.obsm["latent_PCA"] = latent_pca
132
+
133
+ # Run UMAP based on latent representations
134
+ # for name in ['latent_GVAE', 'latent_PCA']:
135
+ # sc.pp.neighbors(adata, n_neighbors=10, use_rep=name)
136
+ # sc.tl.umap(adata)
137
+ # adata.obsm['X_umap_' + name] = adata.obsm['X_umap']
138
+
139
+ # Save the AnnData object
140
+ logger.info("Saving ST data...")
141
+ adata.write(args.hdf5_with_latent_path)