gsMap3D 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/__init__.py +13 -0
- gsMap/__main__.py +4 -0
- gsMap/cauchy_combination_test.py +342 -0
- gsMap/cli.py +355 -0
- gsMap/config/__init__.py +72 -0
- gsMap/config/base.py +296 -0
- gsMap/config/cauchy_config.py +79 -0
- gsMap/config/dataclasses.py +235 -0
- gsMap/config/decorators.py +302 -0
- gsMap/config/find_latent_config.py +276 -0
- gsMap/config/format_sumstats_config.py +54 -0
- gsMap/config/latent2gene_config.py +461 -0
- gsMap/config/ldscore_config.py +261 -0
- gsMap/config/quick_mode_config.py +242 -0
- gsMap/config/report_config.py +81 -0
- gsMap/config/spatial_ldsc_config.py +334 -0
- gsMap/config/utils.py +286 -0
- gsMap/find_latent/__init__.py +3 -0
- gsMap/find_latent/find_latent_representation.py +312 -0
- gsMap/find_latent/gnn/distribution.py +498 -0
- gsMap/find_latent/gnn/encoder_decoder.py +186 -0
- gsMap/find_latent/gnn/gcn.py +85 -0
- gsMap/find_latent/gnn/gene_former.py +164 -0
- gsMap/find_latent/gnn/loss.py +18 -0
- gsMap/find_latent/gnn/st_model.py +125 -0
- gsMap/find_latent/gnn/train_step.py +177 -0
- gsMap/find_latent/st_process.py +781 -0
- gsMap/format_sumstats.py +446 -0
- gsMap/generate_ldscore.py +1018 -0
- gsMap/latent2gene/__init__.py +18 -0
- gsMap/latent2gene/connectivity.py +781 -0
- gsMap/latent2gene/entry_point.py +141 -0
- gsMap/latent2gene/marker_scores.py +1265 -0
- gsMap/latent2gene/memmap_io.py +766 -0
- gsMap/latent2gene/rank_calculator.py +590 -0
- gsMap/latent2gene/row_ordering.py +182 -0
- gsMap/latent2gene/row_ordering_jax.py +159 -0
- gsMap/ldscore/__init__.py +1 -0
- gsMap/ldscore/batch_construction.py +163 -0
- gsMap/ldscore/compute.py +126 -0
- gsMap/ldscore/constants.py +70 -0
- gsMap/ldscore/io.py +262 -0
- gsMap/ldscore/mapping.py +262 -0
- gsMap/ldscore/pipeline.py +615 -0
- gsMap/pipeline/quick_mode.py +134 -0
- gsMap/report/__init__.py +2 -0
- gsMap/report/diagnosis.py +375 -0
- gsMap/report/report.py +100 -0
- gsMap/report/report_data.py +1832 -0
- gsMap/report/static/js_lib/alpine.min.js +5 -0
- gsMap/report/static/js_lib/tailwindcss.js +83 -0
- gsMap/report/static/template.html +2242 -0
- gsMap/report/three_d_combine.py +312 -0
- gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
- gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
- gsMap/report/three_d_plot/three_d_plots.py +425 -0
- gsMap/report/visualize.py +1409 -0
- gsMap/setup.py +5 -0
- gsMap/spatial_ldsc/__init__.py +0 -0
- gsMap/spatial_ldsc/io.py +656 -0
- gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
- gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
- gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +610 -0
- gsMap/utils/jackknife.py +518 -0
- gsMap/utils/manhattan_plot.py +643 -0
- gsMap/utils/regression_read.py +177 -0
- gsMap/utils/torch_utils.py +23 -0
- gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
- gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
- gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
- gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
- gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import anndata as ad
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import pyvista as pv
|
|
9
|
+
import statsmodels.api as sm
|
|
10
|
+
import statsmodels.stats.multitest as smm
|
|
11
|
+
from scipy.stats import fisher_exact
|
|
12
|
+
|
|
13
|
+
from gsMap.cauchy_combination_test import _acat_test
|
|
14
|
+
from gsMap.config import ThreeDCombineConfig
|
|
15
|
+
|
|
16
|
+
from .three_d_plot.three_d_plots import three_d_plot, three_d_plot_save
|
|
17
|
+
|
|
18
|
+
pv.start_xvfb()
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def combine_ldsc(args):
|
|
23
|
+
|
|
24
|
+
# Set the output path
|
|
25
|
+
ldsc_root = Path(args.project_dir) / "3D_combine" / "spatial_ldsc"
|
|
26
|
+
ldsc_root.mkdir(parents=True, exist_ok=True)
|
|
27
|
+
name = ldsc_root / f"{args.trait_name}.csv.gz"
|
|
28
|
+
|
|
29
|
+
# Merge all the ldsc results
|
|
30
|
+
pth = Path(args.project_dir) / "spatial_ldsc"
|
|
31
|
+
sldsc_pth = []
|
|
32
|
+
for slice in os.listdir(pth):
|
|
33
|
+
filtemp = pth / slice / f"{slice}_{args.trait_name}.csv.gz"
|
|
34
|
+
if filtemp.exists():
|
|
35
|
+
sldsc_pth.append(filtemp)
|
|
36
|
+
|
|
37
|
+
if not os.path.exists(name):
|
|
38
|
+
logger.info(f"Find {len(sldsc_pth)} ST sections for {args.trait_name}, start to merge the results...")
|
|
39
|
+
# Load the results
|
|
40
|
+
ldsc_merge = pd.DataFrame()
|
|
41
|
+
for idx, file in enumerate(sldsc_pth):
|
|
42
|
+
ldsc_temp = pd.read_csv(file, compression="gzip")
|
|
43
|
+
ldsc_temp["ST_id"] = file.name.split(f"_{args.trait_name}")[0]
|
|
44
|
+
# print(ldsc_temp.head())
|
|
45
|
+
ldsc_merge = pd.concat([ldsc_merge, ldsc_temp], axis=0)
|
|
46
|
+
|
|
47
|
+
# Check the cell name duplication
|
|
48
|
+
if (ldsc_merge.spot.value_counts() > 1).any():
|
|
49
|
+
logger.info('There are duplicated spot names, using the st_id + spot_id as the spot index.')
|
|
50
|
+
ldsc_merge['spot_index'] = ldsc_merge['ST_id'] + '_' + ldsc_merge['spot'].astype(str)
|
|
51
|
+
else:
|
|
52
|
+
ldsc_merge['spot_index'] = ldsc_merge['spot']
|
|
53
|
+
|
|
54
|
+
# save the merged results
|
|
55
|
+
ldsc_merge.to_csv(name, compression="gzip", index=False)
|
|
56
|
+
logger.info(f"Saving the 3D merged results to {name}")
|
|
57
|
+
else:
|
|
58
|
+
logger.info(f"The merged gsMap results already exist, loading the merged results from {name}...")
|
|
59
|
+
ldsc_merge = pd.read_csv(name, compression="gzip")
|
|
60
|
+
|
|
61
|
+
return ldsc_merge
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def cauchy_combination_3d(ldsc):
|
|
65
|
+
p_cauchy = []
|
|
66
|
+
p_median = []
|
|
67
|
+
gc_median = []
|
|
68
|
+
for ct in np.unique(ldsc.annotation):
|
|
69
|
+
p_temp = ldsc.loc[ldsc["annotation"] == ct, "p"]
|
|
70
|
+
z_temp = ldsc.loc[ldsc["annotation"] == ct, "z"]
|
|
71
|
+
p_temp = p_temp.dropna()
|
|
72
|
+
|
|
73
|
+
# The Cauchy test is sensitive to very small p-values, so extreme outliers should be considered for removal...
|
|
74
|
+
p_temp_log = -np.log10(p_temp)
|
|
75
|
+
median_log = np.median(p_temp_log)
|
|
76
|
+
IQR_log = np.percentile(p_temp_log, 75) - np.percentile(p_temp_log, 25)
|
|
77
|
+
|
|
78
|
+
p_use = p_temp[p_temp_log < median_log + 3 * IQR_log]
|
|
79
|
+
z_use = z_temp[p_temp_log < median_log + 3 * IQR_log]
|
|
80
|
+
n_remove = len(p_temp) - len(p_use)
|
|
81
|
+
|
|
82
|
+
# Outlier: -log10(p) < median + 3IQR && len(outlier set) < 20
|
|
83
|
+
# if 0 < n_remove < max(len(p_temp) * 0.001,100):
|
|
84
|
+
if 0 < n_remove < len(p_temp) * 0.05:
|
|
85
|
+
print(
|
|
86
|
+
f"Remove {
|
|
87
|
+
n_remove}/{len(p_temp)} outliers (median + 3*IQR) for {ct}."
|
|
88
|
+
)
|
|
89
|
+
p_cauchy_temp = _acat_test(p_use)
|
|
90
|
+
else:
|
|
91
|
+
p_cauchy_temp = _acat_test(p_temp)
|
|
92
|
+
|
|
93
|
+
p_median_temp = np.median(p_use)
|
|
94
|
+
gc_median_temp = np.median(z_use**2) / 0.4549
|
|
95
|
+
|
|
96
|
+
p_cauchy.append(p_cauchy_temp)
|
|
97
|
+
p_median.append(p_median_temp)
|
|
98
|
+
gc_median.append(gc_median_temp)
|
|
99
|
+
|
|
100
|
+
data = {
|
|
101
|
+
"p_cauchy": p_cauchy,
|
|
102
|
+
"p_median": p_median,
|
|
103
|
+
"inflation_factor": gc_median,
|
|
104
|
+
"annotation": np.unique(ldsc.annotation),
|
|
105
|
+
}
|
|
106
|
+
p_tissue = pd.DataFrame(data)
|
|
107
|
+
p_tissue.columns = ["p_cauchy", "p_median", "inflation_factor", "annotation"]
|
|
108
|
+
p_tissue.sort_values("p_cauchy", inplace=True)
|
|
109
|
+
return p_tissue
|
|
110
|
+
|
|
111
|
+
# def cauchy_combination_3d(args):
|
|
112
|
+
|
|
113
|
+
# # Load the cauchy combination results of each ST slices
|
|
114
|
+
# pth = Path(args.project_dir) / "cauchy_combination"
|
|
115
|
+
# st_file = os.listdir(pth)
|
|
116
|
+
# logger.info(f"Find {len(st_file)} sections of cauchy combination results for {args.trait_name}...")
|
|
117
|
+
|
|
118
|
+
# cauchy_all = pd.DataFrame()
|
|
119
|
+
# for slice in st_file:
|
|
120
|
+
# filtemp = pth / slice / f"{slice}_{args.trait_name}.Cauchy.csv.gz"
|
|
121
|
+
# if filtemp.exists():
|
|
122
|
+
# cauchy = pd.read_csv(filtemp, compression="gzip")
|
|
123
|
+
# cauchy_all = pd.concat([cauchy_all, cauchy], axis=0)
|
|
124
|
+
|
|
125
|
+
# cauchy_all = cauchy_all[~cauchy_all.annotation.isna()]
|
|
126
|
+
|
|
127
|
+
# # Do the cauchy combination test across slices
|
|
128
|
+
# p_cauchy = []
|
|
129
|
+
# p_median = []
|
|
130
|
+
# for ct in cauchy_all.annotation.unique():
|
|
131
|
+
# cauchy_temp = cauchy_all.loc[cauchy_all.annotation == ct]
|
|
132
|
+
# p_cauchy_temp = cauchy_temp.p_cauchy
|
|
133
|
+
# p_median_temp = cauchy_temp.p_median
|
|
134
|
+
# n_cell = cauchy_temp.n_cell
|
|
135
|
+
|
|
136
|
+
# p_cauchy_temp_log = -np.log10(p_cauchy_temp)
|
|
137
|
+
# median_log = np.median(p_cauchy_temp_log)
|
|
138
|
+
# IQR_log = np.percentile(p_cauchy_temp_log, 75) - np.percentile(p_cauchy_temp_log, 25)
|
|
139
|
+
|
|
140
|
+
# w_use = n_cell
|
|
141
|
+
# p_use = p_cauchy_temp
|
|
142
|
+
# if len(p_cauchy_temp) > 15:
|
|
143
|
+
# index = p_cauchy_temp_log < median_log + 2*IQR_log
|
|
144
|
+
# w_use = n_cell[index]
|
|
145
|
+
# p_use = p_cauchy_temp[index]
|
|
146
|
+
# n_remove = len(p_cauchy_temp) - len(p_use)
|
|
147
|
+
# if n_remove > 0:
|
|
148
|
+
# logger.info(f"Remove {n_remove} outlier (median + 2*IQR) sections for {ct}")
|
|
149
|
+
|
|
150
|
+
# p_cauchy_new = acat_test(pvalues=p_use.to_list(),weights=w_use.to_list())
|
|
151
|
+
# p_median_new = (p_median_temp * n_cell / n_cell.sum()).sum()
|
|
152
|
+
|
|
153
|
+
# p_cauchy.append(p_cauchy_new)
|
|
154
|
+
# p_median.append(p_median_new)
|
|
155
|
+
|
|
156
|
+
# data = {
|
|
157
|
+
# "p_cauchy": p_cauchy,
|
|
158
|
+
# "p_median": p_median,
|
|
159
|
+
# "annotation": cauchy_all.annotation.unique(),
|
|
160
|
+
# }
|
|
161
|
+
# p_tissue = pd.DataFrame(data)
|
|
162
|
+
# p_tissue.columns = ["p_cauchy", "p_median", "annotation"]
|
|
163
|
+
# p_tissue.sort_values("p_cauchy", inplace=True)
|
|
164
|
+
# return p_tissue
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def odds_test_3d(ldsc_merge):
|
|
168
|
+
_, corrected_p_values, _, _ = smm.multipletests(ldsc_merge.p, alpha=0.05)
|
|
169
|
+
ldsc_merge['p_fdr'] = corrected_p_values.tolist()
|
|
170
|
+
|
|
171
|
+
Odds = []
|
|
172
|
+
for focal_annotation in ldsc_merge.annotation.unique():
|
|
173
|
+
try:
|
|
174
|
+
focal_no,focal_yes = (ldsc_merge.loc[ldsc_merge.annotation==focal_annotation,'p_fdr'] < 0.05).value_counts()
|
|
175
|
+
other_no,other_yes = (ldsc_merge.loc[ldsc_merge.annotation!=focal_annotation,'p_fdr'] < 0.05).value_counts()
|
|
176
|
+
contingency_table = [[focal_yes, focal_no], [other_yes, other_no]]
|
|
177
|
+
odds_ratio, p_value = fisher_exact(contingency_table)
|
|
178
|
+
table = sm.stats.Table2x2(contingency_table)
|
|
179
|
+
conf_int = table.oddsratio_confint()
|
|
180
|
+
except Exception:
|
|
181
|
+
odds_ratio = 0
|
|
182
|
+
p_value = 1
|
|
183
|
+
conf_int = (0, 0)
|
|
184
|
+
Odds.append({
|
|
185
|
+
'annotation': focal_annotation,
|
|
186
|
+
'odds_ratio': f"{odds_ratio:.3f}",
|
|
187
|
+
'95%_ci_low': f"{conf_int[0]:.3f}",
|
|
188
|
+
'95%_ci_high': f"{conf_int[1]:.3f}",
|
|
189
|
+
'p_odds_ratio': p_value
|
|
190
|
+
})
|
|
191
|
+
Odds = pd.DataFrame(Odds)
|
|
192
|
+
return Odds
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def three_d_combine(args: ThreeDCombineConfig):
|
|
196
|
+
|
|
197
|
+
# Load the ldsc results
|
|
198
|
+
ldsc_merge = combine_ldsc(args)
|
|
199
|
+
ldsc_merge.spot_index = ldsc_merge.spot_index.astype(str).replace(r"\.0", "", regex=True)
|
|
200
|
+
ldsc_merge.index = ldsc_merge.spot_index
|
|
201
|
+
|
|
202
|
+
# Load the spatial data
|
|
203
|
+
logger.info(f"Loading {args.adata_3d}.")
|
|
204
|
+
adata_3d_path = str(args.adata_3d)
|
|
205
|
+
if adata_3d_path.endswith('.parquet'):
|
|
206
|
+
logger.info("The input data is the metadata file of adata.")
|
|
207
|
+
meta_merged = pd.read_parquet(adata_3d_path)
|
|
208
|
+
elif adata_3d_path.endswith('.h5ad'):
|
|
209
|
+
logger.info("The input data is the h5ad.")
|
|
210
|
+
adata_merge = ad.read_h5ad(adata_3d_path, backed='r')
|
|
211
|
+
adata_merge.obs.index.name = 'index'
|
|
212
|
+
spatial = pd.DataFrame(adata_merge.obsm[args.spatial_key], columns=['sx', 'sy', 'sz'], index=adata_merge.obs_names).copy()
|
|
213
|
+
spatial = spatial.reset_index()
|
|
214
|
+
meta = adata_merge.obs.copy()
|
|
215
|
+
meta_merged = spatial.merge(meta, left_on='index', right_index=True, how='left')
|
|
216
|
+
meta_merged.index = adata_merge.obs_names
|
|
217
|
+
|
|
218
|
+
# Handle DataFrame or AnnData
|
|
219
|
+
if args.st_id is not None and (meta_merged.index.value_counts() > 1).any():
|
|
220
|
+
# Check if the index has duplicates and if st_id is provided
|
|
221
|
+
if len(np.intersect1d(ldsc_merge.index, meta_merged.index)) == 0:
|
|
222
|
+
# If no common cells, create a new index using st_id
|
|
223
|
+
logger.info(f"Using {args.st_id} + adata.obs_names as the new cell index.")
|
|
224
|
+
meta_merged.index = meta_merged[args.st_id].astype(str) + '_' + meta_merged.index.astype(str)
|
|
225
|
+
|
|
226
|
+
# Find common cells
|
|
227
|
+
common_cell = np.intersect1d(ldsc_merge.index, meta_merged.index)
|
|
228
|
+
if len(common_cell) == 0:
|
|
229
|
+
raise ValueError("No common cells between the spatial data and the ldsc results.")
|
|
230
|
+
|
|
231
|
+
logger.info(f"Found {len(common_cell)} common cells between the 3D spatial data and the mapping results.")
|
|
232
|
+
|
|
233
|
+
# Subset the data to common cells
|
|
234
|
+
meta_merged = meta_merged.loc[common_cell].copy()
|
|
235
|
+
ldsc_merge = ldsc_merge.loc[common_cell]
|
|
236
|
+
|
|
237
|
+
# Do cauchy combination test and odds ratio test
|
|
238
|
+
if args.annotation is not None:
|
|
239
|
+
annotation_use = meta_merged[args.annotation]
|
|
240
|
+
ldsc_merge['annotation'] = annotation_use
|
|
241
|
+
ldsc_merge = ldsc_merge[~ldsc_merge.annotation.isna()]
|
|
242
|
+
|
|
243
|
+
cauchy = cauchy_combination_3d(ldsc_merge)
|
|
244
|
+
odds = odds_test_3d(ldsc_merge)
|
|
245
|
+
cauchy_odds = pd.merge(odds,cauchy,left_on='annotation',right_on='annotation')
|
|
246
|
+
|
|
247
|
+
# Save the results
|
|
248
|
+
cauchy_root = Path(args.project_dir) / "3D_combine" / "cauchy_combination"
|
|
249
|
+
cauchy_root.mkdir(parents=True, exist_ok=True, mode=0o755)
|
|
250
|
+
cauchy_name = cauchy_root / f"{args.trait_name}.{args.annotation}.Cauchy.csv.gz"
|
|
251
|
+
cauchy_odds = cauchy_odds.sort_values('odds_ratio',ascending=False)
|
|
252
|
+
cauchy_odds.to_csv(cauchy_name, compression="gzip", index=False)
|
|
253
|
+
logger.info(f"Saving the 3D combination combination results to {cauchy_name}")
|
|
254
|
+
else:
|
|
255
|
+
logger.info("No annotation provided for the cauchy combination test.")
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
# Plot the 3D results
|
|
259
|
+
p_color = ['#313695', '#4575b4', '#74add1','#fee090', '#fdae61', '#f46d43', '#d73027', '#a50026']
|
|
260
|
+
meta_merged["logp"] = -np.log10(ldsc_merge.p)
|
|
261
|
+
|
|
262
|
+
required_columns = {'sx', 'sy', 'sz'}
|
|
263
|
+
if required_columns.issubset(meta_merged.columns):
|
|
264
|
+
logger.info("Generating 3D plot...")
|
|
265
|
+
|
|
266
|
+
# Set the legend and text
|
|
267
|
+
legend_kwargs = dict(scalar_bar_title_size=30, scalar_bar_label_size=30, fmt="%.1e")
|
|
268
|
+
text_kwargs = dict(text_font_size=15, text_loc="upper_edge")
|
|
269
|
+
|
|
270
|
+
# Set the opacity for each point
|
|
271
|
+
meta_merged['logp'].fillna(0, inplace=True)
|
|
272
|
+
bins = np.linspace(meta_merged['logp'].min(), meta_merged['logp'].max(), 5)
|
|
273
|
+
alpha = np.exp(np.linspace(0.1, 1.0, num=(len(bins)-1)))-1
|
|
274
|
+
alpha = alpha / max(alpha)
|
|
275
|
+
opacity_show = pd.cut(meta_merged['logp'], bins=bins, labels=alpha, include_lowest=True).values.tolist()
|
|
276
|
+
|
|
277
|
+
# Set the clim
|
|
278
|
+
max_v = np.round(np.median(np.sort(meta_merged['logp'])[::-1][0:20]))
|
|
279
|
+
|
|
280
|
+
# Plot the 3D results
|
|
281
|
+
plotter = three_d_plot(
|
|
282
|
+
clim = [0,max_v],
|
|
283
|
+
point_size=args.point_size,
|
|
284
|
+
opacity=opacity_show,
|
|
285
|
+
window_size=(1200, 1008),
|
|
286
|
+
adata=meta_merged,
|
|
287
|
+
spatial_key=args.spatial_key,
|
|
288
|
+
keys=["logp"],
|
|
289
|
+
cmaps=[args.cmap] if args.cmap is not None else [p_color],
|
|
290
|
+
scalar_bar_titles=["-log10(p)"],
|
|
291
|
+
texts=[args.trait_name],
|
|
292
|
+
jupyter=False,
|
|
293
|
+
background=args.background_color,
|
|
294
|
+
show_outline=args.show_outline,
|
|
295
|
+
legend_kwargs=legend_kwargs,
|
|
296
|
+
text_kwargs=text_kwargs,
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
# Save the results
|
|
300
|
+
plot_root = Path(args.project_dir) / "3D_combine" / "3D_plot"
|
|
301
|
+
plot_root.mkdir(parents=True, exist_ok=True, mode=0o755)
|
|
302
|
+
plot_name = plot_root / args.trait_name
|
|
303
|
+
|
|
304
|
+
three_d_plot_save(
|
|
305
|
+
plotter,
|
|
306
|
+
save_mp4=args.save_mp4,
|
|
307
|
+
save_gif=args.save_gif,
|
|
308
|
+
n_points=args.n_snapshot if args.n_snapshot is not None else 200,
|
|
309
|
+
filename=plot_name,
|
|
310
|
+
)
|
|
311
|
+
else:
|
|
312
|
+
logger.info("The spatial data does not contain 3D spatial coordinates for 3D plotting.")
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
import matplotlib as mpl
|
|
2
|
+
import numpy as np
|
|
3
|
+
from pyvista import MultiBlock
|
|
4
|
+
|
|
5
|
+
categorical_legend_loc_legal = ["upper right",
|
|
6
|
+
"upper left",
|
|
7
|
+
"lower left",
|
|
8
|
+
"lower right",
|
|
9
|
+
"center left",
|
|
10
|
+
"center right",
|
|
11
|
+
"lower center",
|
|
12
|
+
"upper center",
|
|
13
|
+
"center"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def add_model(
|
|
17
|
+
plotter,
|
|
18
|
+
model,
|
|
19
|
+
key=None,
|
|
20
|
+
colormap=None,
|
|
21
|
+
clim=None,
|
|
22
|
+
ambient=0.2,
|
|
23
|
+
opacity=1.0,
|
|
24
|
+
model_style="surface",
|
|
25
|
+
point_size=3.0,
|
|
26
|
+
):
|
|
27
|
+
|
|
28
|
+
def _add_model(_p, _model, _key, _colormap, _style, _ambient, _opacity, _point_size,_clim):
|
|
29
|
+
"""Add any PyVista/VTK model to the scene."""
|
|
30
|
+
if _style == "points":
|
|
31
|
+
_render_spheres, render_tubes, _smooth_shading = True, False, True
|
|
32
|
+
elif _style == "wireframe":
|
|
33
|
+
_render_spheres, render_tubes, _smooth_shading = False, True, False
|
|
34
|
+
else:
|
|
35
|
+
_render_spheres, render_tubes, _smooth_shading = False, False, True
|
|
36
|
+
mesh_kwargs = dict(
|
|
37
|
+
style=_style,
|
|
38
|
+
render_points_as_spheres=True,
|
|
39
|
+
render_lines_as_tubes=render_tubes,
|
|
40
|
+
point_size=_point_size,
|
|
41
|
+
line_width=_point_size,
|
|
42
|
+
ambient=_ambient,
|
|
43
|
+
opacity=_opacity,
|
|
44
|
+
smooth_shading=True,
|
|
45
|
+
clim=_clim,
|
|
46
|
+
show_scalar_bar=False,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
if _colormap is None:
|
|
50
|
+
added_kwargs = dict(
|
|
51
|
+
scalars=f"{
|
|
52
|
+
_key}_rgba" if _key in _model.array_names else _model.active_scalars_name,
|
|
53
|
+
rgba=True
|
|
54
|
+
)
|
|
55
|
+
else:
|
|
56
|
+
added_kwargs = dict(
|
|
57
|
+
scalars=_key if _key in _model.array_names else _model.active_scalars_name,
|
|
58
|
+
cmap=_colormap
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
mesh_kwargs.update(added_kwargs)
|
|
62
|
+
_p.add_mesh(_model, **mesh_kwargs)
|
|
63
|
+
|
|
64
|
+
# Add model(s) to the plotter.
|
|
65
|
+
_add_model(
|
|
66
|
+
_p=plotter,
|
|
67
|
+
_model=model,
|
|
68
|
+
_key=key,
|
|
69
|
+
_colormap=colormap,
|
|
70
|
+
_style=model_style,
|
|
71
|
+
_point_size=point_size,
|
|
72
|
+
_ambient=ambient,
|
|
73
|
+
_opacity=opacity,
|
|
74
|
+
_clim=clim,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def add_str_legend(
|
|
79
|
+
plotter,
|
|
80
|
+
labels,
|
|
81
|
+
colors,
|
|
82
|
+
font_family='arial',
|
|
83
|
+
legend_size=None,
|
|
84
|
+
legend_loc="center right"
|
|
85
|
+
):
|
|
86
|
+
|
|
87
|
+
legend_data = np.concatenate(
|
|
88
|
+
[labels.reshape(-1, 1).astype(object), colors.reshape(-1, 1).astype(object)], axis=1)
|
|
89
|
+
legend_data = legend_data[legend_data[:, 0] != "mask", :]
|
|
90
|
+
assert len(
|
|
91
|
+
legend_data) != 0, "No legend can be added, please set `show_legend=False`."
|
|
92
|
+
|
|
93
|
+
legend_entries = legend_data[np.lexsort(legend_data[:, ::-1].T)]
|
|
94
|
+
if legend_size is None:
|
|
95
|
+
legend_num = 10 if len(legend_entries) >= 10 else len(legend_entries)
|
|
96
|
+
legend_size = (0.1 + 0.01 * legend_num, 0.1 + 0.012 * legend_num)
|
|
97
|
+
|
|
98
|
+
plotter.add_legend(
|
|
99
|
+
legend_entries.tolist(),
|
|
100
|
+
face="none",
|
|
101
|
+
font_family=font_family,
|
|
102
|
+
bcolor=None,
|
|
103
|
+
loc=legend_loc,
|
|
104
|
+
size=legend_size
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def add_num_legend(
|
|
109
|
+
plotter,
|
|
110
|
+
title="",
|
|
111
|
+
n_labels=5,
|
|
112
|
+
title_font_size=None,
|
|
113
|
+
label_font_size=None,
|
|
114
|
+
font_color="black",
|
|
115
|
+
font_family="arial",
|
|
116
|
+
legend_size=(0.1, 0.4),
|
|
117
|
+
legend_loc=(0.85, 0.3),
|
|
118
|
+
vertical=True,
|
|
119
|
+
fmt="%.2e",
|
|
120
|
+
):
|
|
121
|
+
|
|
122
|
+
plotter.add_scalar_bar(
|
|
123
|
+
title=title,
|
|
124
|
+
n_labels=n_labels,
|
|
125
|
+
title_font_size=title_font_size,
|
|
126
|
+
label_font_size=label_font_size,
|
|
127
|
+
color=font_color,
|
|
128
|
+
font_family=font_family,
|
|
129
|
+
use_opacity=True,
|
|
130
|
+
width=legend_size[0],
|
|
131
|
+
height=legend_size[1],
|
|
132
|
+
position_x=legend_loc[0],
|
|
133
|
+
position_y=legend_loc[1],
|
|
134
|
+
vertical=vertical,
|
|
135
|
+
fmt=fmt,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def add_legend(
|
|
140
|
+
plotter,
|
|
141
|
+
model,
|
|
142
|
+
key=None,
|
|
143
|
+
colormap=None,
|
|
144
|
+
categorical_legend_size=None,
|
|
145
|
+
categorical_legend_loc=None,
|
|
146
|
+
scalar_bar_title="",
|
|
147
|
+
scalar_bar_size=None,
|
|
148
|
+
scalar_bar_loc=None,
|
|
149
|
+
scalar_bar_title_size=None,
|
|
150
|
+
scalar_bar_label_size=None,
|
|
151
|
+
scalar_bar_font_color="black",
|
|
152
|
+
font_family="arial",
|
|
153
|
+
fmt="%.2e",
|
|
154
|
+
scalar_bar_n_labels=5,
|
|
155
|
+
vertical=True,
|
|
156
|
+
):
|
|
157
|
+
|
|
158
|
+
# if colormap is None: categorical
|
|
159
|
+
# if colormap is not None: continuous
|
|
160
|
+
|
|
161
|
+
if colormap is None:
|
|
162
|
+
assert key is not None, "When colormap is None, key cannot be None at the same time."
|
|
163
|
+
|
|
164
|
+
if categorical_legend_loc not in categorical_legend_loc_legal and categorical_legend_loc is None:
|
|
165
|
+
categorical_legend_loc = 'center right'
|
|
166
|
+
|
|
167
|
+
if isinstance(model, MultiBlock):
|
|
168
|
+
keys = key if isinstance(key, list) else [key] * len(model)
|
|
169
|
+
|
|
170
|
+
legend_label_data, legend_color_data = [], []
|
|
171
|
+
for m, k in zip(model, keys, strict=False):
|
|
172
|
+
legend_label_data.append(np.asarray(m[k]).flatten())
|
|
173
|
+
legend_color_data.append(np.asarray(
|
|
174
|
+
[mpl.colors.to_hex(i) for i in m[f"{k}_rgba"]]).flatten())
|
|
175
|
+
legend_label_data = np.concatenate(legend_label_data, axis=0)
|
|
176
|
+
legend_color_data = np.concatenate(legend_color_data, axis=0)
|
|
177
|
+
print(legend_color_data)
|
|
178
|
+
else:
|
|
179
|
+
legend_label_data = np.asarray(model[key]).flatten()
|
|
180
|
+
legend_color_data = np.asarray(
|
|
181
|
+
[mpl.colors.to_hex(i) for i in model[f"{key}_rgba"]]).flatten()
|
|
182
|
+
|
|
183
|
+
legend_data = np.concatenate(
|
|
184
|
+
[legend_label_data.reshape(-1, 1), legend_color_data.reshape(-1, 1)], axis=1)
|
|
185
|
+
unique_legend_data = np.unique(legend_data, axis=0)
|
|
186
|
+
|
|
187
|
+
add_str_legend(
|
|
188
|
+
plotter=plotter,
|
|
189
|
+
labels=unique_legend_data[:, 0],
|
|
190
|
+
colors=unique_legend_data[:, 1],
|
|
191
|
+
font_family=font_family,
|
|
192
|
+
legend_size=categorical_legend_size,
|
|
193
|
+
legend_loc=categorical_legend_loc
|
|
194
|
+
)
|
|
195
|
+
else:
|
|
196
|
+
if not isinstance(scalar_bar_size, tuple) and scalar_bar_size is None:
|
|
197
|
+
scalar_bar_size = (0.1, 0.4)
|
|
198
|
+
if not isinstance(scalar_bar_loc, tuple) and scalar_bar_loc is None:
|
|
199
|
+
scalar_bar_loc = (0.85, 0.3)
|
|
200
|
+
|
|
201
|
+
add_num_legend(
|
|
202
|
+
plotter=plotter,
|
|
203
|
+
legend_size=scalar_bar_size,
|
|
204
|
+
legend_loc=scalar_bar_loc,
|
|
205
|
+
title=scalar_bar_title,
|
|
206
|
+
n_labels=scalar_bar_n_labels,
|
|
207
|
+
title_font_size=scalar_bar_title_size,
|
|
208
|
+
label_font_size=scalar_bar_label_size,
|
|
209
|
+
font_color=scalar_bar_font_color,
|
|
210
|
+
font_family=font_family,
|
|
211
|
+
fmt=fmt,
|
|
212
|
+
vertical=vertical
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def add_outline(
|
|
217
|
+
plotter,
|
|
218
|
+
model,
|
|
219
|
+
outline_width=1.0,
|
|
220
|
+
outline_color="black",
|
|
221
|
+
):
|
|
222
|
+
|
|
223
|
+
model.outline()
|
|
224
|
+
plotter.add_bounding_box(
|
|
225
|
+
color=outline_color,
|
|
226
|
+
line_width=outline_width
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def add_text(
|
|
232
|
+
plotter,
|
|
233
|
+
text,
|
|
234
|
+
font_family="arial",
|
|
235
|
+
text_font_size=15,
|
|
236
|
+
text_font_color="black",
|
|
237
|
+
text_loc="upper_edge"
|
|
238
|
+
):
|
|
239
|
+
|
|
240
|
+
plotter.add_text(
|
|
241
|
+
text=text,
|
|
242
|
+
font=font_family,
|
|
243
|
+
color=text_font_color,
|
|
244
|
+
font_size=text_font_size,
|
|
245
|
+
position=text_loc if text_loc is not None else "upper_edge"
|
|
246
|
+
)
|