gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,312 @@
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import anndata as ad
6
+ import numpy as np
7
+ import pandas as pd
8
+ import pyvista as pv
9
+ import statsmodels.api as sm
10
+ import statsmodels.stats.multitest as smm
11
+ from scipy.stats import fisher_exact
12
+
13
+ from gsMap.cauchy_combination_test import _acat_test
14
+ from gsMap.config import ThreeDCombineConfig
15
+
16
+ from .three_d_plot.three_d_plots import three_d_plot, three_d_plot_save
17
+
18
+ pv.start_xvfb()
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def combine_ldsc(args):
23
+
24
+ # Set the output path
25
+ ldsc_root = Path(args.project_dir) / "3D_combine" / "spatial_ldsc"
26
+ ldsc_root.mkdir(parents=True, exist_ok=True)
27
+ name = ldsc_root / f"{args.trait_name}.csv.gz"
28
+
29
+ # Merge all the ldsc results
30
+ pth = Path(args.project_dir) / "spatial_ldsc"
31
+ sldsc_pth = []
32
+ for slice in os.listdir(pth):
33
+ filtemp = pth / slice / f"{slice}_{args.trait_name}.csv.gz"
34
+ if filtemp.exists():
35
+ sldsc_pth.append(filtemp)
36
+
37
+ if not os.path.exists(name):
38
+ logger.info(f"Find {len(sldsc_pth)} ST sections for {args.trait_name}, start to merge the results...")
39
+ # Load the results
40
+ ldsc_merge = pd.DataFrame()
41
+ for idx, file in enumerate(sldsc_pth):
42
+ ldsc_temp = pd.read_csv(file, compression="gzip")
43
+ ldsc_temp["ST_id"] = file.name.split(f"_{args.trait_name}")[0]
44
+ # print(ldsc_temp.head())
45
+ ldsc_merge = pd.concat([ldsc_merge, ldsc_temp], axis=0)
46
+
47
+ # Check the cell name duplication
48
+ if (ldsc_merge.spot.value_counts() > 1).any():
49
+ logger.info('There are duplicated spot names, using the st_id + spot_id as the spot index.')
50
+ ldsc_merge['spot_index'] = ldsc_merge['ST_id'] + '_' + ldsc_merge['spot'].astype(str)
51
+ else:
52
+ ldsc_merge['spot_index'] = ldsc_merge['spot']
53
+
54
+ # save the merged results
55
+ ldsc_merge.to_csv(name, compression="gzip", index=False)
56
+ logger.info(f"Saving the 3D merged results to {name}")
57
+ else:
58
+ logger.info(f"The merged gsMap results already exist, loading the merged results from {name}...")
59
+ ldsc_merge = pd.read_csv(name, compression="gzip")
60
+
61
+ return ldsc_merge
62
+
63
+
64
+ def cauchy_combination_3d(ldsc):
65
+ p_cauchy = []
66
+ p_median = []
67
+ gc_median = []
68
+ for ct in np.unique(ldsc.annotation):
69
+ p_temp = ldsc.loc[ldsc["annotation"] == ct, "p"]
70
+ z_temp = ldsc.loc[ldsc["annotation"] == ct, "z"]
71
+ p_temp = p_temp.dropna()
72
+
73
+ # The Cauchy test is sensitive to very small p-values, so extreme outliers should be considered for removal...
74
+ p_temp_log = -np.log10(p_temp)
75
+ median_log = np.median(p_temp_log)
76
+ IQR_log = np.percentile(p_temp_log, 75) - np.percentile(p_temp_log, 25)
77
+
78
+ p_use = p_temp[p_temp_log < median_log + 3 * IQR_log]
79
+ z_use = z_temp[p_temp_log < median_log + 3 * IQR_log]
80
+ n_remove = len(p_temp) - len(p_use)
81
+
82
+ # Outlier: -log10(p) < median + 3IQR && len(outlier set) < 20
83
+ # if 0 < n_remove < max(len(p_temp) * 0.001,100):
84
+ if 0 < n_remove < len(p_temp) * 0.05:
85
+ print(
86
+ f"Remove {
87
+ n_remove}/{len(p_temp)} outliers (median + 3*IQR) for {ct}."
88
+ )
89
+ p_cauchy_temp = _acat_test(p_use)
90
+ else:
91
+ p_cauchy_temp = _acat_test(p_temp)
92
+
93
+ p_median_temp = np.median(p_use)
94
+ gc_median_temp = np.median(z_use**2) / 0.4549
95
+
96
+ p_cauchy.append(p_cauchy_temp)
97
+ p_median.append(p_median_temp)
98
+ gc_median.append(gc_median_temp)
99
+
100
+ data = {
101
+ "p_cauchy": p_cauchy,
102
+ "p_median": p_median,
103
+ "inflation_factor": gc_median,
104
+ "annotation": np.unique(ldsc.annotation),
105
+ }
106
+ p_tissue = pd.DataFrame(data)
107
+ p_tissue.columns = ["p_cauchy", "p_median", "inflation_factor", "annotation"]
108
+ p_tissue.sort_values("p_cauchy", inplace=True)
109
+ return p_tissue
110
+
111
+ # def cauchy_combination_3d(args):
112
+
113
+ # # Load the cauchy combination results of each ST slices
114
+ # pth = Path(args.project_dir) / "cauchy_combination"
115
+ # st_file = os.listdir(pth)
116
+ # logger.info(f"Find {len(st_file)} sections of cauchy combination results for {args.trait_name}...")
117
+
118
+ # cauchy_all = pd.DataFrame()
119
+ # for slice in st_file:
120
+ # filtemp = pth / slice / f"{slice}_{args.trait_name}.Cauchy.csv.gz"
121
+ # if filtemp.exists():
122
+ # cauchy = pd.read_csv(filtemp, compression="gzip")
123
+ # cauchy_all = pd.concat([cauchy_all, cauchy], axis=0)
124
+
125
+ # cauchy_all = cauchy_all[~cauchy_all.annotation.isna()]
126
+
127
+ # # Do the cauchy combination test across slices
128
+ # p_cauchy = []
129
+ # p_median = []
130
+ # for ct in cauchy_all.annotation.unique():
131
+ # cauchy_temp = cauchy_all.loc[cauchy_all.annotation == ct]
132
+ # p_cauchy_temp = cauchy_temp.p_cauchy
133
+ # p_median_temp = cauchy_temp.p_median
134
+ # n_cell = cauchy_temp.n_cell
135
+
136
+ # p_cauchy_temp_log = -np.log10(p_cauchy_temp)
137
+ # median_log = np.median(p_cauchy_temp_log)
138
+ # IQR_log = np.percentile(p_cauchy_temp_log, 75) - np.percentile(p_cauchy_temp_log, 25)
139
+
140
+ # w_use = n_cell
141
+ # p_use = p_cauchy_temp
142
+ # if len(p_cauchy_temp) > 15:
143
+ # index = p_cauchy_temp_log < median_log + 2*IQR_log
144
+ # w_use = n_cell[index]
145
+ # p_use = p_cauchy_temp[index]
146
+ # n_remove = len(p_cauchy_temp) - len(p_use)
147
+ # if n_remove > 0:
148
+ # logger.info(f"Remove {n_remove} outlier (median + 2*IQR) sections for {ct}")
149
+
150
+ # p_cauchy_new = acat_test(pvalues=p_use.to_list(),weights=w_use.to_list())
151
+ # p_median_new = (p_median_temp * n_cell / n_cell.sum()).sum()
152
+
153
+ # p_cauchy.append(p_cauchy_new)
154
+ # p_median.append(p_median_new)
155
+
156
+ # data = {
157
+ # "p_cauchy": p_cauchy,
158
+ # "p_median": p_median,
159
+ # "annotation": cauchy_all.annotation.unique(),
160
+ # }
161
+ # p_tissue = pd.DataFrame(data)
162
+ # p_tissue.columns = ["p_cauchy", "p_median", "annotation"]
163
+ # p_tissue.sort_values("p_cauchy", inplace=True)
164
+ # return p_tissue
165
+
166
+
167
+ def odds_test_3d(ldsc_merge):
168
+ _, corrected_p_values, _, _ = smm.multipletests(ldsc_merge.p, alpha=0.05)
169
+ ldsc_merge['p_fdr'] = corrected_p_values.tolist()
170
+
171
+ Odds = []
172
+ for focal_annotation in ldsc_merge.annotation.unique():
173
+ try:
174
+ focal_no,focal_yes = (ldsc_merge.loc[ldsc_merge.annotation==focal_annotation,'p_fdr'] < 0.05).value_counts()
175
+ other_no,other_yes = (ldsc_merge.loc[ldsc_merge.annotation!=focal_annotation,'p_fdr'] < 0.05).value_counts()
176
+ contingency_table = [[focal_yes, focal_no], [other_yes, other_no]]
177
+ odds_ratio, p_value = fisher_exact(contingency_table)
178
+ table = sm.stats.Table2x2(contingency_table)
179
+ conf_int = table.oddsratio_confint()
180
+ except Exception:
181
+ odds_ratio = 0
182
+ p_value = 1
183
+ conf_int = (0, 0)
184
+ Odds.append({
185
+ 'annotation': focal_annotation,
186
+ 'odds_ratio': f"{odds_ratio:.3f}",
187
+ '95%_ci_low': f"{conf_int[0]:.3f}",
188
+ '95%_ci_high': f"{conf_int[1]:.3f}",
189
+ 'p_odds_ratio': p_value
190
+ })
191
+ Odds = pd.DataFrame(Odds)
192
+ return Odds
193
+
194
+
195
+ def three_d_combine(args: ThreeDCombineConfig):
196
+
197
+ # Load the ldsc results
198
+ ldsc_merge = combine_ldsc(args)
199
+ ldsc_merge.spot_index = ldsc_merge.spot_index.astype(str).replace(r"\.0", "", regex=True)
200
+ ldsc_merge.index = ldsc_merge.spot_index
201
+
202
+ # Load the spatial data
203
+ logger.info(f"Loading {args.adata_3d}.")
204
+ adata_3d_path = str(args.adata_3d)
205
+ if adata_3d_path.endswith('.parquet'):
206
+ logger.info("The input data is the metadata file of adata.")
207
+ meta_merged = pd.read_parquet(adata_3d_path)
208
+ elif adata_3d_path.endswith('.h5ad'):
209
+ logger.info("The input data is the h5ad.")
210
+ adata_merge = ad.read_h5ad(adata_3d_path, backed='r')
211
+ adata_merge.obs.index.name = 'index'
212
+ spatial = pd.DataFrame(adata_merge.obsm[args.spatial_key], columns=['sx', 'sy', 'sz'], index=adata_merge.obs_names).copy()
213
+ spatial = spatial.reset_index()
214
+ meta = adata_merge.obs.copy()
215
+ meta_merged = spatial.merge(meta, left_on='index', right_index=True, how='left')
216
+ meta_merged.index = adata_merge.obs_names
217
+
218
+ # Handle DataFrame or AnnData
219
+ if args.st_id is not None and (meta_merged.index.value_counts() > 1).any():
220
+ # Check if the index has duplicates and if st_id is provided
221
+ if len(np.intersect1d(ldsc_merge.index, meta_merged.index)) == 0:
222
+ # If no common cells, create a new index using st_id
223
+ logger.info(f"Using {args.st_id} + adata.obs_names as the new cell index.")
224
+ meta_merged.index = meta_merged[args.st_id].astype(str) + '_' + meta_merged.index.astype(str)
225
+
226
+ # Find common cells
227
+ common_cell = np.intersect1d(ldsc_merge.index, meta_merged.index)
228
+ if len(common_cell) == 0:
229
+ raise ValueError("No common cells between the spatial data and the ldsc results.")
230
+
231
+ logger.info(f"Found {len(common_cell)} common cells between the 3D spatial data and the mapping results.")
232
+
233
+ # Subset the data to common cells
234
+ meta_merged = meta_merged.loc[common_cell].copy()
235
+ ldsc_merge = ldsc_merge.loc[common_cell]
236
+
237
+ # Do cauchy combination test and odds ratio test
238
+ if args.annotation is not None:
239
+ annotation_use = meta_merged[args.annotation]
240
+ ldsc_merge['annotation'] = annotation_use
241
+ ldsc_merge = ldsc_merge[~ldsc_merge.annotation.isna()]
242
+
243
+ cauchy = cauchy_combination_3d(ldsc_merge)
244
+ odds = odds_test_3d(ldsc_merge)
245
+ cauchy_odds = pd.merge(odds,cauchy,left_on='annotation',right_on='annotation')
246
+
247
+ # Save the results
248
+ cauchy_root = Path(args.project_dir) / "3D_combine" / "cauchy_combination"
249
+ cauchy_root.mkdir(parents=True, exist_ok=True, mode=0o755)
250
+ cauchy_name = cauchy_root / f"{args.trait_name}.{args.annotation}.Cauchy.csv.gz"
251
+ cauchy_odds = cauchy_odds.sort_values('odds_ratio',ascending=False)
252
+ cauchy_odds.to_csv(cauchy_name, compression="gzip", index=False)
253
+ logger.info(f"Saving the 3D combination combination results to {cauchy_name}")
254
+ else:
255
+ logger.info("No annotation provided for the cauchy combination test.")
256
+
257
+
258
+ # Plot the 3D results
259
+ p_color = ['#313695', '#4575b4', '#74add1','#fee090', '#fdae61', '#f46d43', '#d73027', '#a50026']
260
+ meta_merged["logp"] = -np.log10(ldsc_merge.p)
261
+
262
+ required_columns = {'sx', 'sy', 'sz'}
263
+ if required_columns.issubset(meta_merged.columns):
264
+ logger.info("Generating 3D plot...")
265
+
266
+ # Set the legend and text
267
+ legend_kwargs = dict(scalar_bar_title_size=30, scalar_bar_label_size=30, fmt="%.1e")
268
+ text_kwargs = dict(text_font_size=15, text_loc="upper_edge")
269
+
270
+ # Set the opacity for each point
271
+ meta_merged['logp'].fillna(0, inplace=True)
272
+ bins = np.linspace(meta_merged['logp'].min(), meta_merged['logp'].max(), 5)
273
+ alpha = np.exp(np.linspace(0.1, 1.0, num=(len(bins)-1)))-1
274
+ alpha = alpha / max(alpha)
275
+ opacity_show = pd.cut(meta_merged['logp'], bins=bins, labels=alpha, include_lowest=True).values.tolist()
276
+
277
+ # Set the clim
278
+ max_v = np.round(np.median(np.sort(meta_merged['logp'])[::-1][0:20]))
279
+
280
+ # Plot the 3D results
281
+ plotter = three_d_plot(
282
+ clim = [0,max_v],
283
+ point_size=args.point_size,
284
+ opacity=opacity_show,
285
+ window_size=(1200, 1008),
286
+ adata=meta_merged,
287
+ spatial_key=args.spatial_key,
288
+ keys=["logp"],
289
+ cmaps=[args.cmap] if args.cmap is not None else [p_color],
290
+ scalar_bar_titles=["-log10(p)"],
291
+ texts=[args.trait_name],
292
+ jupyter=False,
293
+ background=args.background_color,
294
+ show_outline=args.show_outline,
295
+ legend_kwargs=legend_kwargs,
296
+ text_kwargs=text_kwargs,
297
+ )
298
+
299
+ # Save the results
300
+ plot_root = Path(args.project_dir) / "3D_combine" / "3D_plot"
301
+ plot_root.mkdir(parents=True, exist_ok=True, mode=0o755)
302
+ plot_name = plot_root / args.trait_name
303
+
304
+ three_d_plot_save(
305
+ plotter,
306
+ save_mp4=args.save_mp4,
307
+ save_gif=args.save_gif,
308
+ n_points=args.n_snapshot if args.n_snapshot is not None else 200,
309
+ filename=plot_name,
310
+ )
311
+ else:
312
+ logger.info("The spatial data does not contain 3D spatial coordinates for 3D plotting.")
@@ -0,0 +1,246 @@
1
+ import matplotlib as mpl
2
+ import numpy as np
3
+ from pyvista import MultiBlock
4
+
5
+ categorical_legend_loc_legal = ["upper right",
6
+ "upper left",
7
+ "lower left",
8
+ "lower right",
9
+ "center left",
10
+ "center right",
11
+ "lower center",
12
+ "upper center",
13
+ "center"]
14
+
15
+
16
+ def add_model(
17
+ plotter,
18
+ model,
19
+ key=None,
20
+ colormap=None,
21
+ clim=None,
22
+ ambient=0.2,
23
+ opacity=1.0,
24
+ model_style="surface",
25
+ point_size=3.0,
26
+ ):
27
+
28
+ def _add_model(_p, _model, _key, _colormap, _style, _ambient, _opacity, _point_size,_clim):
29
+ """Add any PyVista/VTK model to the scene."""
30
+ if _style == "points":
31
+ _render_spheres, render_tubes, _smooth_shading = True, False, True
32
+ elif _style == "wireframe":
33
+ _render_spheres, render_tubes, _smooth_shading = False, True, False
34
+ else:
35
+ _render_spheres, render_tubes, _smooth_shading = False, False, True
36
+ mesh_kwargs = dict(
37
+ style=_style,
38
+ render_points_as_spheres=True,
39
+ render_lines_as_tubes=render_tubes,
40
+ point_size=_point_size,
41
+ line_width=_point_size,
42
+ ambient=_ambient,
43
+ opacity=_opacity,
44
+ smooth_shading=True,
45
+ clim=_clim,
46
+ show_scalar_bar=False,
47
+ )
48
+
49
+ if _colormap is None:
50
+ added_kwargs = dict(
51
+ scalars=f"{
52
+ _key}_rgba" if _key in _model.array_names else _model.active_scalars_name,
53
+ rgba=True
54
+ )
55
+ else:
56
+ added_kwargs = dict(
57
+ scalars=_key if _key in _model.array_names else _model.active_scalars_name,
58
+ cmap=_colormap
59
+ )
60
+
61
+ mesh_kwargs.update(added_kwargs)
62
+ _p.add_mesh(_model, **mesh_kwargs)
63
+
64
+ # Add model(s) to the plotter.
65
+ _add_model(
66
+ _p=plotter,
67
+ _model=model,
68
+ _key=key,
69
+ _colormap=colormap,
70
+ _style=model_style,
71
+ _point_size=point_size,
72
+ _ambient=ambient,
73
+ _opacity=opacity,
74
+ _clim=clim,
75
+ )
76
+
77
+
78
+ def add_str_legend(
79
+ plotter,
80
+ labels,
81
+ colors,
82
+ font_family='arial',
83
+ legend_size=None,
84
+ legend_loc="center right"
85
+ ):
86
+
87
+ legend_data = np.concatenate(
88
+ [labels.reshape(-1, 1).astype(object), colors.reshape(-1, 1).astype(object)], axis=1)
89
+ legend_data = legend_data[legend_data[:, 0] != "mask", :]
90
+ assert len(
91
+ legend_data) != 0, "No legend can be added, please set `show_legend=False`."
92
+
93
+ legend_entries = legend_data[np.lexsort(legend_data[:, ::-1].T)]
94
+ if legend_size is None:
95
+ legend_num = 10 if len(legend_entries) >= 10 else len(legend_entries)
96
+ legend_size = (0.1 + 0.01 * legend_num, 0.1 + 0.012 * legend_num)
97
+
98
+ plotter.add_legend(
99
+ legend_entries.tolist(),
100
+ face="none",
101
+ font_family=font_family,
102
+ bcolor=None,
103
+ loc=legend_loc,
104
+ size=legend_size
105
+ )
106
+
107
+
108
+ def add_num_legend(
109
+ plotter,
110
+ title="",
111
+ n_labels=5,
112
+ title_font_size=None,
113
+ label_font_size=None,
114
+ font_color="black",
115
+ font_family="arial",
116
+ legend_size=(0.1, 0.4),
117
+ legend_loc=(0.85, 0.3),
118
+ vertical=True,
119
+ fmt="%.2e",
120
+ ):
121
+
122
+ plotter.add_scalar_bar(
123
+ title=title,
124
+ n_labels=n_labels,
125
+ title_font_size=title_font_size,
126
+ label_font_size=label_font_size,
127
+ color=font_color,
128
+ font_family=font_family,
129
+ use_opacity=True,
130
+ width=legend_size[0],
131
+ height=legend_size[1],
132
+ position_x=legend_loc[0],
133
+ position_y=legend_loc[1],
134
+ vertical=vertical,
135
+ fmt=fmt,
136
+ )
137
+
138
+
139
+ def add_legend(
140
+ plotter,
141
+ model,
142
+ key=None,
143
+ colormap=None,
144
+ categorical_legend_size=None,
145
+ categorical_legend_loc=None,
146
+ scalar_bar_title="",
147
+ scalar_bar_size=None,
148
+ scalar_bar_loc=None,
149
+ scalar_bar_title_size=None,
150
+ scalar_bar_label_size=None,
151
+ scalar_bar_font_color="black",
152
+ font_family="arial",
153
+ fmt="%.2e",
154
+ scalar_bar_n_labels=5,
155
+ vertical=True,
156
+ ):
157
+
158
+ # if colormap is None: categorical
159
+ # if colormap is not None: continuous
160
+
161
+ if colormap is None:
162
+ assert key is not None, "When colormap is None, key cannot be None at the same time."
163
+
164
+ if categorical_legend_loc not in categorical_legend_loc_legal and categorical_legend_loc is None:
165
+ categorical_legend_loc = 'center right'
166
+
167
+ if isinstance(model, MultiBlock):
168
+ keys = key if isinstance(key, list) else [key] * len(model)
169
+
170
+ legend_label_data, legend_color_data = [], []
171
+ for m, k in zip(model, keys, strict=False):
172
+ legend_label_data.append(np.asarray(m[k]).flatten())
173
+ legend_color_data.append(np.asarray(
174
+ [mpl.colors.to_hex(i) for i in m[f"{k}_rgba"]]).flatten())
175
+ legend_label_data = np.concatenate(legend_label_data, axis=0)
176
+ legend_color_data = np.concatenate(legend_color_data, axis=0)
177
+ print(legend_color_data)
178
+ else:
179
+ legend_label_data = np.asarray(model[key]).flatten()
180
+ legend_color_data = np.asarray(
181
+ [mpl.colors.to_hex(i) for i in model[f"{key}_rgba"]]).flatten()
182
+
183
+ legend_data = np.concatenate(
184
+ [legend_label_data.reshape(-1, 1), legend_color_data.reshape(-1, 1)], axis=1)
185
+ unique_legend_data = np.unique(legend_data, axis=0)
186
+
187
+ add_str_legend(
188
+ plotter=plotter,
189
+ labels=unique_legend_data[:, 0],
190
+ colors=unique_legend_data[:, 1],
191
+ font_family=font_family,
192
+ legend_size=categorical_legend_size,
193
+ legend_loc=categorical_legend_loc
194
+ )
195
+ else:
196
+ if not isinstance(scalar_bar_size, tuple) and scalar_bar_size is None:
197
+ scalar_bar_size = (0.1, 0.4)
198
+ if not isinstance(scalar_bar_loc, tuple) and scalar_bar_loc is None:
199
+ scalar_bar_loc = (0.85, 0.3)
200
+
201
+ add_num_legend(
202
+ plotter=plotter,
203
+ legend_size=scalar_bar_size,
204
+ legend_loc=scalar_bar_loc,
205
+ title=scalar_bar_title,
206
+ n_labels=scalar_bar_n_labels,
207
+ title_font_size=scalar_bar_title_size,
208
+ label_font_size=scalar_bar_label_size,
209
+ font_color=scalar_bar_font_color,
210
+ font_family=font_family,
211
+ fmt=fmt,
212
+ vertical=vertical
213
+ )
214
+
215
+
216
+ def add_outline(
217
+ plotter,
218
+ model,
219
+ outline_width=1.0,
220
+ outline_color="black",
221
+ ):
222
+
223
+ model.outline()
224
+ plotter.add_bounding_box(
225
+ color=outline_color,
226
+ line_width=outline_width
227
+ )
228
+
229
+
230
+
231
+ def add_text(
232
+ plotter,
233
+ text,
234
+ font_family="arial",
235
+ text_font_size=15,
236
+ text_font_color="black",
237
+ text_loc="upper_edge"
238
+ ):
239
+
240
+ plotter.add_text(
241
+ text=text,
242
+ font=font_family,
243
+ color=text_font_color,
244
+ font_size=text_font_size,
245
+ position=text_loc if text_loc is not None else "upper_edge"
246
+ )