gsMap 1.62__py3-none-any.whl → 1.63__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gsMap/diagnosis.py ADDED
@@ -0,0 +1,273 @@
1
+ import logging
2
+ import warnings
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import scanpy as sc
8
+ from scipy.stats import norm
9
+
10
+ from gsMap.config import DiagnosisConfig
11
+ from gsMap.utils.manhattan_plot import ManhattanPlot
12
+ from gsMap.visualize import draw_scatter, load_st_coord, estimate_point_size_for_plot
13
+
14
+
15
+ warnings.filterwarnings("ignore", category=FutureWarning)
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def convert_z_to_p(gwas_data):
20
+ """Convert Z-scores to P-values."""
21
+ gwas_data['P'] = norm.sf(abs(gwas_data['Z'])) * 2
22
+ min_p_value = 1e-300
23
+ gwas_data['P'] = gwas_data['P'].clip(lower=min_p_value)
24
+ return gwas_data
25
+
26
+
27
+ def load_ldsc(ldsc_input_file):
28
+ """Load LDSC data and calculate logp."""
29
+ ldsc = pd.read_csv(ldsc_input_file, compression='gzip')
30
+ ldsc['spot'] = ldsc['spot'].astype(str).replace('\.0', '', regex=True)
31
+ ldsc.set_index('spot', inplace=True)
32
+ ldsc['logp'] = -np.log10(ldsc['p'])
33
+ return ldsc
34
+
35
+
36
+ def load_gene_diagnostic_info(config:DiagnosisConfig):
37
+ """Load or compute gene diagnostic info."""
38
+ gene_diagnostic_info_save_path = config.get_gene_diagnostic_info_save_path(config.trait_name)
39
+ if gene_diagnostic_info_save_path.exists():
40
+ logger.info(f'Loading gene diagnostic information from {gene_diagnostic_info_save_path}...')
41
+ return pd.read_csv(gene_diagnostic_info_save_path)
42
+ else:
43
+ logger.info('Gene diagnostic information not found. Calculating gene diagnostic information...')
44
+ return compute_gene_diagnostic_info(config)
45
+
46
+
47
+ def compute_gene_diagnostic_info(config: DiagnosisConfig):
48
+ """Calculate gene diagnostic info and save it to adata."""
49
+ logger.info('Loading ST data and LDSC results...')
50
+ # adata = sc.read_h5ad(config.hdf5_with_latent_path, backed='r')
51
+ mk_score = pd.read_feather(config.mkscore_feather_path)
52
+ mk_score.set_index('HUMAN_GENE_SYM', inplace=True)
53
+ mk_score = mk_score.T
54
+ trait_ldsc_result = load_ldsc(config.get_ldsc_result_file(config.trait_name))
55
+
56
+ # Align marker scores with trait LDSC results
57
+ mk_score = mk_score.loc[trait_ldsc_result.index]
58
+ mk_score = mk_score.loc[:, mk_score.sum(axis=0) != 0]
59
+
60
+ logger.info('Calculating correlation between gene marker scores and trait logp-values...')
61
+ corr = mk_score.corrwith(trait_ldsc_result['logp'])
62
+ corr.name = 'PCC'
63
+
64
+ grouped_mk_score = mk_score.groupby(adata.obs[config.annotation]).median()
65
+ max_annotations = grouped_mk_score.idxmax()
66
+
67
+ high_GSS_Gene_annotation_pair = pd.DataFrame({
68
+ 'Gene': max_annotations.index,
69
+ 'Annotation': max_annotations.values,
70
+ 'Median_GSS': grouped_mk_score.max().values
71
+ })
72
+
73
+ # Filter based on median GSS score
74
+ high_GSS_Gene_annotation_pair = high_GSS_Gene_annotation_pair[high_GSS_Gene_annotation_pair['Median_GSS'] >= 1.0]
75
+ high_GSS_Gene_annotation_pair = high_GSS_Gene_annotation_pair.merge(corr, left_on='Gene', right_index=True)
76
+
77
+ # Prepare the final gene diagnostic info dataframe
78
+ gene_diagnostic_info_cols = ['Gene', 'Annotation', 'Median_GSS', 'PCC']
79
+ gene_diagnostic_info = high_GSS_Gene_annotation_pair[gene_diagnostic_info_cols].drop_duplicates().dropna(
80
+ subset=['Gene'])
81
+ gene_diagnostic_info.sort_values('PCC', ascending=False, inplace=True)
82
+
83
+ # Save gene diagnostic info to a file
84
+ gene_diagnostic_info_save_path = config.get_gene_diagnostic_info_save_path(config.trait_name)
85
+ gene_diagnostic_info.to_csv(gene_diagnostic_info_save_path, index=False)
86
+ logger.info(f'Gene diagnostic information saved to {gene_diagnostic_info_save_path}.')
87
+
88
+ # Save to adata.var with the trait_name prefix
89
+ logger.info('Saving gene diagnostic info to adata.var...')
90
+ gene_diagnostic_info.set_index('Gene', inplace=True) # Use 'Gene' as the index to align with adata.var
91
+ adata.var[f'{config.trait_name}_Annotation'] = gene_diagnostic_info['Annotation']
92
+ adata.var[f'{config.trait_name}_Median_GSS'] = gene_diagnostic_info['Median_GSS']
93
+ adata.var[f'{config.trait_name}_PCC'] = gene_diagnostic_info['PCC']
94
+
95
+ # Save trait_ldsc_result to adata.obs
96
+ logger.info(f'Saving trait LDSC results to adata.obs as gsMap_{config.trait_name}_p_value...')
97
+ adata.obs[f'gsMap_{config.trait_name}_p_value'] = trait_ldsc_result['p']
98
+ adata.write(config.hdf5_with_latent_path, )
99
+
100
+ return gene_diagnostic_info.reset_index()
101
+
102
+
103
+ def load_gwas_data(config:DiagnosisConfig):
104
+ """Load and process GWAS data."""
105
+ logger.info('Loading and processing GWAS data...')
106
+ gwas_data = pd.read_csv(config.sumstats_file, compression='gzip', sep='\t')
107
+ return convert_z_to_p(gwas_data)
108
+
109
+
110
+ def load_snp_gene_pairs(config:DiagnosisConfig):
111
+ """Load SNP-gene pairs from multiple chromosomes."""
112
+ ldscore_save_dir = Path(config.ldscore_save_dir)
113
+ return pd.concat([
114
+ pd.read_feather(ldscore_save_dir / f'SNP_gene_pair/SNP_gene_pair_chr{chrom}.feather')
115
+ for chrom in range(1, 23)
116
+ ])
117
+
118
+
119
+ def filter_snps(gwas_data_with_gene_annotation_sort, SUBSAMPLE_SNP_NUMBER):
120
+ """Filter the SNPs based on significance levels."""
121
+ pass_suggestive_line_mask = gwas_data_with_gene_annotation_sort['P'] < 1e-5
122
+ pass_suggestive_line_number = pass_suggestive_line_mask.sum()
123
+
124
+ if pass_suggestive_line_number > SUBSAMPLE_SNP_NUMBER:
125
+ snps2plot = gwas_data_with_gene_annotation_sort[pass_suggestive_line_mask].SNP
126
+ logger.info(f'To reduce the number of SNPs to plot, only {snps2plot.shape[0]} SNPs with P < 1e-5 are plotted.')
127
+ else:
128
+ snps2plot = gwas_data_with_gene_annotation_sort.head(SUBSAMPLE_SNP_NUMBER).SNP
129
+ logger.info(
130
+ f'To reduce the number of SNPs to plot, only {SUBSAMPLE_SNP_NUMBER} SNPs with the smallest P-values are plotted.')
131
+
132
+ return snps2plot
133
+
134
+
135
+ def generate_manhattan_plot(config: DiagnosisConfig):
136
+ """Generate Manhattan plot."""
137
+ report_save_dir = config.get_report_dir(config.trait_name)
138
+ gwas_data = load_gwas_data(config)
139
+ snp_gene_pair = load_snp_gene_pairs(config)
140
+ gwas_data_with_gene = snp_gene_pair.merge(gwas_data, on='SNP', how='inner').rename(columns={'gene_name': 'GENE'})
141
+ gene_diagnostic_info = load_gene_diagnostic_info(config)
142
+ gwas_data_with_gene_annotation = gwas_data_with_gene.merge(gene_diagnostic_info, left_on='GENE', right_on='Gene',
143
+ how='left')
144
+
145
+ gwas_data_with_gene_annotation = gwas_data_with_gene_annotation[
146
+ ~gwas_data_with_gene_annotation['Annotation'].isna()]
147
+ gwas_data_with_gene_annotation_sort = gwas_data_with_gene_annotation.sort_values('P')
148
+
149
+ snps2plot = filter_snps(gwas_data_with_gene_annotation_sort, SUBSAMPLE_SNP_NUMBER=100_000)
150
+ gwas_data_to_plot = gwas_data_with_gene_annotation[
151
+ gwas_data_with_gene_annotation['SNP'].isin(snps2plot)].reset_index(drop=True)
152
+ gwas_data_to_plot['Annotation_text'] = 'PCC: ' + gwas_data_to_plot['PCC'].round(2).astype(
153
+ str) + '<br>' + 'Annotation: ' + gwas_data_to_plot['Annotation'].astype(str)
154
+
155
+ fig = ManhattanPlot(
156
+ dataframe=gwas_data_to_plot,
157
+ title='gsMap Diagnosis Manhattan Plot',
158
+ point_size=3,
159
+ highlight_gene_list=config.selected_genes or gene_diagnostic_info.Gene.iloc[:config.top_corr_genes].tolist(),
160
+ suggestiveline_value=-np.log10(1e-5),
161
+ annotation='Annotation_text',
162
+ )
163
+
164
+ save_manhattan_plot_path = config.get_manhattan_html_plot_path(config.trait_name)
165
+ fig.write_html(save_manhattan_plot_path)
166
+ logger.info(f'Diagnostic Manhattan Plot saved to {save_manhattan_plot_path}.')
167
+
168
+
169
+ def generate_GSS_distribution(config: DiagnosisConfig):
170
+ """Generate GSS distribution plots."""
171
+ # logger.info('Loading ST data...')
172
+ # adata = sc.read_h5ad(config.hdf5_with_latent_path)
173
+ mk_score = pd.read_feather(config.mkscore_feather_path).set_index('HUMAN_GENE_SYM').T
174
+
175
+ plot_genes = config.selected_genes or load_gene_diagnostic_info(config).Gene.iloc[:config.top_corr_genes].tolist()
176
+ if config.selected_genes is not None:
177
+ logger.info(f'Generating GSS & Expression distribution plot for selected genes: {plot_genes}...')
178
+ else:
179
+ logger.info(f'Generating GSS & Expression distribution plot for top {config.top_corr_genes} correlated genes...')
180
+
181
+ if config.customize_fig:
182
+ pixel_width, pixel_height, point_size = config.fig_width, config.fig_height, config.point_size
183
+ else:
184
+ (pixel_width, pixel_height), point_size = estimate_point_size_for_plot(adata.obsm['spatial'])
185
+ sub_fig_save_dir = config.get_GSS_plot_dir(config.trait_name)
186
+
187
+ # save plot gene list
188
+ config.get_GSS_plot_select_gene_file(config.trait_name).write_text('\n'.join(plot_genes))
189
+
190
+ for selected_gene in plot_genes:
191
+ expression_series = pd.Series(adata[:, selected_gene].X.toarray().flatten(), index=adata.obs.index,name='Expression')
192
+ threshold = np.quantile(expression_series,0.9999)
193
+ expression_series[expression_series > threshold] = threshold
194
+ generate_and_save_plots(adata, mk_score, expression_series, selected_gene, point_size, pixel_width,
195
+ pixel_height, sub_fig_save_dir, config.sample_name, config.annotation)
196
+
197
+
198
+ def generate_and_save_plots(adata, mk_score, expression_series, selected_gene, point_size, pixel_width, pixel_height,
199
+ sub_fig_save_dir, sample_name, annotation):
200
+ """Generate and save the plots."""
201
+ select_gene_expression_with_space_coord = load_st_coord(adata, expression_series, annotation)
202
+ sub_fig_1 = draw_scatter(select_gene_expression_with_space_coord, title=f'{selected_gene} (Expression)',
203
+ annotation='annotation', color_by='Expression', point_size=point_size, width=pixel_width,
204
+ height=pixel_height)
205
+ save_plot(sub_fig_1, sub_fig_save_dir, sample_name, selected_gene, 'Expression')
206
+
207
+ select_gene_GSS_with_space_coord = load_st_coord(adata, mk_score[selected_gene].rename('GSS'), annotation)
208
+ sub_fig_2 = draw_scatter(select_gene_GSS_with_space_coord, title=f'{selected_gene} (GSS)', annotation='annotation',
209
+ color_by='GSS', point_size=point_size, width=pixel_width, height=pixel_height)
210
+ save_plot(sub_fig_2, sub_fig_save_dir, sample_name, selected_gene, 'GSS')
211
+
212
+ # combined_fig = make_subplots(rows=1, cols=2,
213
+ # subplot_titles=(f'{selected_gene} (Expression)', f'{selected_gene} (GSS)'))
214
+ # for trace in sub_fig_1.data:
215
+ # combined_fig.add_trace(trace, row=1, col=1)
216
+ # for trace in sub_fig_2.data:
217
+ # combined_fig.add_trace(trace, row=1, col=2)
218
+ #
219
+
220
+ def save_plot(sub_fig, sub_fig_save_dir, sample_name, selected_gene, plot_type):
221
+ """Save the plot to HTML and PNG."""
222
+ save_sub_fig_path = sub_fig_save_dir / f'{sample_name}_{selected_gene}_{plot_type}_Distribution.html'
223
+ # sub_fig.write_html(str(save_sub_fig_path))
224
+ sub_fig.update_layout(showlegend=False)
225
+ sub_fig.write_image(str(save_sub_fig_path).replace('.html', '.png'))
226
+
227
+
228
+ def generate_gsMap_plot(config: DiagnosisConfig):
229
+ """Generate gsMap plot."""
230
+ logger.info('Creating gsMap plot...')
231
+
232
+ trait_ldsc_result = load_ldsc(config.get_ldsc_result_file(config.trait_name))
233
+ space_coord_concat = load_st_coord(adata, trait_ldsc_result, annotation=config.annotation)
234
+
235
+ if config.customize_fig:
236
+ pixel_width, pixel_height, point_size = config.fig_width, config.fig_height, config.point_size
237
+ else:
238
+ (pixel_width, pixel_height), point_size = estimate_point_size_for_plot(adata.obsm['spatial'])
239
+ fig = draw_scatter(space_coord_concat,
240
+ title=f'{config.trait_name} (gsMap)',
241
+ point_size=point_size,
242
+ width=pixel_width,
243
+ height=pixel_height,
244
+ annotation=config.annotation
245
+ )
246
+
247
+ output_dir = config.get_gsMap_plot_save_dir(config.trait_name)
248
+ output_file_html = config.get_gsMap_html_plot_save_path(config.trait_name)
249
+ output_file_png = output_file_html.with_suffix('.png')
250
+ output_file_csv = output_file_html.with_suffix('.csv')
251
+
252
+ fig.write_html(output_file_html)
253
+ fig.write_image(output_file_png)
254
+ space_coord_concat.to_csv(output_file_csv)
255
+
256
+ logger.info(f'gsMap plot created and saved in {output_dir}.')
257
+
258
+
259
+ def run_Diagnosis(config: DiagnosisConfig):
260
+ """Main function to run the diagnostic plot generation."""
261
+ global adata
262
+ adata = sc.read_h5ad(config.hdf5_with_latent_path)
263
+ if 'log1p' not in adata.uns.keys() and adata.X.max() > 14:
264
+ sc.pp.normalize_total(adata, target_sum=1e4)
265
+ sc.pp.log1p(adata)
266
+
267
+ if config.plot_type in ['manhattan', 'all']:
268
+ generate_manhattan_plot(config)
269
+ if config.plot_type in ['GSS', 'all']:
270
+ generate_GSS_distribution(config)
271
+ if config.plot_type in ['gsMap', 'all']:
272
+ generate_gsMap_plot(config)
273
+
@@ -1,9 +1,5 @@
1
- import argparse
2
1
  import logging
3
- import pprint
4
2
  import random
5
- import time
6
- from pathlib import Path
7
3
 
8
4
  import numpy as np
9
5
  import pandas as pd
@@ -13,17 +9,9 @@ from sklearn import preprocessing
13
9
 
14
10
  from gsMap.GNN_VAE.adjacency_matrix import Construct_Adjacency_Matrix
15
11
  from gsMap.GNN_VAE.train import Model_Train
16
- from gsMap.config import add_find_latent_representations_args, FindLatentRepresentationsConfig
17
-
18
- # seed all
12
+ from gsMap.config import FindLatentRepresentationsConfig
19
13
 
20
14
  logger = logging.getLogger(__name__)
21
- logger.setLevel(logging.DEBUG)
22
- handler = logging.StreamHandler()
23
- handler.setFormatter(logging.Formatter(
24
- '[{asctime}] {levelname:8s} {filename} {message}', style='{'))
25
- logger.addHandler(handler)
26
-
27
15
 
28
16
  def set_seed(seed_value):
29
17
  """
@@ -33,30 +21,30 @@ def set_seed(seed_value):
33
21
  np.random.seed(seed_value) # Set the seed for NumPy
34
22
  random.seed(seed_value) # Set the seed for Python random module
35
23
  if torch.cuda.is_available():
36
- print('Running use GPU')
24
+ logger.info('Running use GPU')
37
25
  torch.cuda.manual_seed(seed_value) # Set seed for all CUDA devices
38
26
  torch.cuda.manual_seed_all(seed_value) # Set seed for all CUDA devices
39
27
  else:
40
- print('Running use CPU')
28
+ logger.info('Running use CPU')
41
29
 
42
- set_seed(2024)
43
30
 
44
31
  # The class for finding latent representations
45
32
  class Latent_Representation_Finder:
46
33
 
47
- def __init__(self, adata, Params):
34
+ def __init__(self, adata, args:FindLatentRepresentationsConfig):
48
35
  self.adata = adata.copy()
49
- self.Params = Params
36
+ self.Params = args
50
37
 
51
38
  # Standard process
52
- if self.Params.type == 'count' or self.Params.type == 'counts':
53
- self.adata.X = self.adata.layers[self.Params.type]
39
+ if self.Params.data_layer == 'count' or self.Params.data_layer == 'counts':
40
+ self.adata.X = self.adata.layers[self.Params.data_layer]
54
41
  sc.pp.highly_variable_genes(self.adata, flavor="seurat_v3", n_top_genes=self.Params.feat_cell)
55
42
  sc.pp.normalize_total(self.adata, target_sum=1e4)
56
43
  sc.pp.log1p(self.adata)
57
44
  sc.pp.scale(self.adata)
58
45
  else:
59
- self.adata.X = self.adata.layers[self.Params.type]
46
+ if self.Params.data_layer != 'X':
47
+ self.adata.X = self.adata.layers[self.Params.data_layer]
60
48
  sc.pp.highly_variable_genes(self.adata, n_top_genes=self.Params.feat_cell)
61
49
 
62
50
  def Run_GNN_VAE(self, label, verbose='whole ST data'):
@@ -66,7 +54,7 @@ class Latent_Representation_Finder:
66
54
 
67
55
  # Process the feature matrix
68
56
  node_X = self.adata[:, self.adata.var.highly_variable].X
69
- print(f'The shape of feature matrix is {node_X.shape}.')
57
+ logger.info(f'The shape of feature matrix is {node_X.shape}.')
70
58
  if self.Params.input_pca:
71
59
  node_X = sc.pp.pca(node_X, n_comps=self.Params.n_comps)
72
60
 
@@ -75,7 +63,7 @@ class Latent_Representation_Finder:
75
63
  self.Params.feat_cell = node_X.shape[1]
76
64
 
77
65
  # Run GNN-VAE
78
- print(f'------Finding latent representations for {verbose}...')
66
+ logger.info(f'------Finding latent representations for {verbose}...')
79
67
  gvae = Model_Train(node_X, graph_dict, self.Params, label)
80
68
  gvae.run_train()
81
69
 
@@ -87,20 +75,21 @@ class Latent_Representation_Finder:
87
75
 
88
76
 
89
77
  def run_find_latent_representation(args:FindLatentRepresentationsConfig):
78
+ set_seed(2024)
90
79
  num_features = args.feat_cell
91
- args.output_dir = Path(args.output_hdf5_path).parent
92
- args.output_dir.mkdir(parents=True, exist_ok=True,mode=0o755)
80
+ args.hdf5_with_latent_path.parent.mkdir(parents=True, exist_ok=True,mode=0o755)
93
81
  # Load the ST data
94
- print(f'------Loading ST data of {args.sample_name}...')
82
+ logger.info(f'------Loading ST data of {args.sample_name}...')
95
83
  adata = sc.read_h5ad(f'{args.input_hdf5_path}')
96
84
  adata.var_names_make_unique()
97
- print('The ST data contains %d cells, %d genes.' % (adata.shape[0], adata.shape[1]))
85
+ adata.X = adata.layers[args.data_layer] if args.data_layer in adata.layers.keys() else adata.X
86
+ logger.info('The ST data contains %d cells, %d genes.' % (adata.shape[0], adata.shape[1]))
98
87
  # Load the cell type annotation
99
88
  if not args.annotation is None:
100
89
  # remove cells without enough annotations
101
90
  adata = adata[~pd.isnull(adata.obs[args.annotation]), :]
102
91
  num = adata.obs[args.annotation].value_counts()
103
- adata = adata[adata.obs[args.annotation].isin(num[num >= 30].index.to_list()),]
92
+ adata = adata[adata.obs[args.annotation].isin(num[num >= 30].index.to_list())]
104
93
 
105
94
  le = preprocessing.LabelEncoder()
106
95
  le.fit(adata.obs[args.annotation])
@@ -113,7 +102,7 @@ def run_find_latent_representation(args:FindLatentRepresentationsConfig):
113
102
  latent_GVAE = latent_rep.Run_GNN_VAE(label)
114
103
  latent_PCA = latent_rep.Run_PCA()
115
104
  # Add latent representations to the spe data
116
- print(f'------Adding latent representations...')
105
+ logger.info(f'------Adding latent representations...')
117
106
  adata.obsm["latent_GVAE"] = latent_GVAE
118
107
  adata.obsm["latent_PCA"] = latent_PCA
119
108
  # Run umap based on latent representations
@@ -124,13 +113,13 @@ def run_find_latent_representation(args:FindLatentRepresentationsConfig):
124
113
 
125
114
  # Find the latent representations hierarchically (optionally)
126
115
  if not args.annotation is None and args.hierarchically:
127
- print(f'------Finding latent representations hierarchically...')
116
+ logger.info(f'------Finding latent representations hierarchically...')
128
117
  PCA_all = pd.DataFrame()
129
118
  GVAE_all = pd.DataFrame()
130
119
 
131
120
  for ct in adata.obs[args.annotation].unique():
132
121
  adata_part = adata[adata.obs[args.annotation] == ct, :]
133
- print(adata_part.shape)
122
+ logger.info(adata_part.shape)
134
123
 
135
124
  # Find latent representations for the selected ct
136
125
  latent_rep = Latent_Representation_Finder(adata_part, args)
@@ -151,59 +140,6 @@ def run_find_latent_representation(args:FindLatentRepresentationsConfig):
151
140
 
152
141
  adata.obsm["latent_GVAE_hierarchy"] = np.array(GVAE_all.loc[adata.obs_names,])
153
142
  adata.obsm["latent_PCA_hierarchy"] = np.array(PCA_all.loc[adata.obs_names,])
154
- print(f'------Saving ST data...')
155
- adata.write(args.output_hdf5_path)
156
-
157
-
158
- if __name__ == '__main__':
159
- parser = argparse.ArgumentParser(description="This script is designed to find latent representations in spatial transcriptomics data using a Graph Neural Network Variational Autoencoder (GNN-VAE). It processes input data, constructs a neighboring graph, and runs GNN-VAE to output latent representations.")
160
- add_find_latent_representations_args(parser)
161
- TEST=True
162
- if TEST:
163
- test_dir = '/storage/yangjianLab/chenwenhao/projects/202312_gsMap/data/gsMap_test/Nature_Neuroscience_2021'
164
- name = 'Cortex_151507'
143
+ logger.info(f'------Saving ST data...')
144
+ adata.write(args.hdf5_with_latent_path)
165
145
 
166
-
167
- args = parser.parse_args(
168
- [
169
- '--input_hdf5_path','/storage/yangjianLab/songliyang/SpatialData/Data/Brain/Human/Nature_Neuroscience_2021/processed/h5ad/Cortex_151507.h5ad',
170
- '--output_hdf5_path',f'{test_dir}/{name}/hdf5/{name}_add_latent.h5ad',
171
- '--sample_name', name,
172
- '--annotation','layer_guess',
173
- '--type','count',
174
- ]
175
-
176
- )
177
-
178
- else:
179
- args = parser.parse_args()
180
- config=FindLatentRepresentationsConfig(**{'annotation': 'SubClass',
181
- 'convergence_threshold': 0.0001,
182
- 'epochs': 300,
183
- 'feat_cell': 3000,
184
- 'feat_hidden1': 256,
185
- 'feat_hidden2': 128,
186
- 'gcn_decay': 0.01,
187
- 'gcn_hidden1': 64,
188
- 'gcn_hidden2': 30,
189
- 'gcn_lr': 0.001,
190
- 'hierarchically': False,
191
- 'input_hdf5_path': '/storage/yangjianLab/songliyang/SpatialData/Data/Brain/macaque/Cell/processed/h5ad/T862_macaque3.h5ad',
192
- 'label_w': 1.0,
193
- 'n_comps': 300,
194
- 'n_neighbors': 11,
195
- 'nheads': 3,
196
- 'output_hdf5_path': 'T862_macaque3/find_latent_representations/T862_macaque3_add_latent.h5ad',
197
- 'p_drop': 0.1,
198
- 'rec_w': 1.0,
199
- 'sample_name': 'T862_macaque3',
200
- 'type': 'SCT',
201
- 'var': False,
202
- 'weighted_adj': False})
203
- # config=FindLatentRepresentationsConfig(**vars(args))
204
- start_time = time.time()
205
- logger.info(f'Find latent representations for {config.sample_name}...')
206
- pprint.pprint(config)
207
- run_find_latent_representation(config)
208
- end_time = time.time()
209
- logger.info(f'Find latent representations for {config.sample_name} finished. Time spent: {(end_time - start_time) / 60:.2f} min.')