gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,202 @@
1
+ import matplotlib as mpl
2
+ import numpy as np
3
+ import pyvista as pv
4
+ from matplotlib.colors import LinearSegmentedColormap
5
+ from pandas.core.frame import DataFrame
6
+
7
+ p_color = ['#313695', '#4575b4', '#74add1','#fee090', '#fdae61', '#f46d43', '#d73027', '#a50026']
8
+
9
+
10
+
11
+ def _get_default_cmap():
12
+ if "default_cmap" not in mpl.colormaps():
13
+ colors = p_color
14
+ nodes = np.linspace(0, 1, len(p_color))
15
+
16
+ mpl.colormaps.register(LinearSegmentedColormap.from_list(
17
+ "default_cmap", list(zip(nodes, colors, strict=False))))
18
+ return "default_cmap"
19
+
20
+
21
+ def create_plotter(
22
+ jupyter=False,
23
+ off_screen=False,
24
+ window_size=(512, 512),
25
+ background="white",
26
+ shape=(1, 1),
27
+ show_camera_orientation=True,
28
+ show_axis_orientation=False
29
+ ):
30
+
31
+ # Create an initial plotting object.
32
+ _get_default_cmap()
33
+ plotter = pv.Plotter(
34
+ off_screen=off_screen,
35
+ window_size=window_size,
36
+ notebook=False if jupyter is False else True,
37
+ lighting="light_kit",
38
+ shape=shape,
39
+ )
40
+
41
+ # Set the background color of the active render window.
42
+ plotter.background_color = background
43
+
44
+ # Add a camera orientation widget to the active renderer window.
45
+ if jupyter != "trame":
46
+ if show_camera_orientation:
47
+ plotter.add_camera_orientation_widget()
48
+ elif show_axis_orientation:
49
+ plotter.add_axes(labels_off=True)
50
+ return plotter
51
+
52
+
53
+ def add_point_labels(
54
+ model,
55
+ labels,
56
+ key_added="groups",
57
+ where="point_data",
58
+ colormap="rainbow",
59
+ alphamap=1.0,
60
+ mask_color="black",
61
+ mask_alpha=0.0,
62
+ inplace=False,
63
+ ):
64
+
65
+ model = model.copy() if not inplace else model
66
+ labels = np.asarray(labels).flatten()
67
+
68
+ # Set color here if group is of string type.
69
+ if not np.issubdtype(labels.dtype, np.number):
70
+ cu_arr = np.sort(np.unique(labels), axis=0).astype(object)
71
+
72
+ raw_labels_hex = labels.copy().astype(object)
73
+ raw_labels_alpha = labels.copy().astype(object)
74
+ raw_labels_hex[raw_labels_hex =="mask"] = mpl.colors.to_hex(mask_color)
75
+ raw_labels_alpha[raw_labels_alpha == "mask"] = mask_alpha
76
+
77
+ # Set raw hex.
78
+ if isinstance(colormap, str):
79
+ if colormap in list(mpl.colormaps()):
80
+ lscmap = mpl.colormaps[colormap]
81
+ raw_hex_list = [mpl.colors.to_hex(lscmap(i)) for i in np.linspace(0, 1, len(cu_arr))]
82
+ for label, color in zip(cu_arr, raw_hex_list, strict=False):
83
+ raw_labels_hex[raw_labels_hex == label] = color
84
+ else:
85
+ raw_labels_hex[raw_labels_hex !="mask"] = mpl.colors.to_hex(colormap)
86
+ elif isinstance(colormap, dict):
87
+ for label, color in colormap.items():
88
+ raw_labels_hex[raw_labels_hex ==label] = mpl.colors.to_hex(color)
89
+ elif isinstance(colormap, list) or isinstance(colormap, np.ndarray):
90
+ raw_hex_list = np.array([mpl.colors.to_hex(color)for color in colormap]).astype(object)
91
+ for label, color in zip(cu_arr, raw_hex_list, strict=False):
92
+ raw_labels_hex[raw_labels_hex == label] = color
93
+ else:
94
+ raise ValueError(
95
+ "`colormap` value is wrong." "\nAvailable `colormap` types are: `str`, `list` and `dict`.")
96
+
97
+ # Set raw alpha.
98
+ if isinstance(alphamap, float) or isinstance(alphamap, int):
99
+ raw_labels_alpha[raw_labels_alpha != "mask"] = alphamap
100
+ elif isinstance(alphamap, dict):
101
+ for label, alpha in alphamap.items():
102
+ raw_labels_alpha[raw_labels_alpha == label] = alpha
103
+ elif isinstance(alphamap, list) or isinstance(alphamap, np.ndarray):
104
+ raw_labels_alpha = np.asarray(alphamap).astype(object)
105
+ else:
106
+ raise ValueError(
107
+ "`alphamap` value is wrong." "\nAvailable `alphamap` types are: `float`, `list` and `dict`."
108
+ )
109
+
110
+ # Set rgba.
111
+ labels_rgba = [mpl.colors.to_rgba(c, alpha=a) for c, a in zip(raw_labels_hex, raw_labels_alpha, strict=False)]
112
+ labels_rgba = np.array(labels_rgba).astype(np.float32)
113
+
114
+ # Added rgba of the labels.
115
+ if where == "point_data":
116
+ model.point_data[f"{key_added}_rgba"] = labels_rgba
117
+ else:
118
+ model.cell_data[f"{key_added}_rgba"] = labels_rgba
119
+
120
+ plot_cmap = None
121
+ else:
122
+ plot_cmap = colormap
123
+
124
+ # Added labels.
125
+ if where == "point_data":
126
+ model.point_data[key_added] = labels
127
+ else:
128
+ model.cell_data[key_added] = labels
129
+
130
+ return model if not inplace else None, plot_cmap
131
+
132
+
133
+ def construct_pc(
134
+ adata,
135
+ layer="X",
136
+ spatial_key="spatial",
137
+ groupby=None,
138
+ key_added="groups",
139
+ mask=None,
140
+ colormap="default_cmap",
141
+ alphamap=1.0
142
+ ):
143
+
144
+ # Ensure mask is a list
145
+ mask_list = mask if isinstance(mask, list) else [mask] if mask is not None else []
146
+
147
+ # Extract spatial coordinates
148
+ if isinstance(adata, DataFrame):
149
+ cell_names = np.array(adata.index.tolist())
150
+ try:
151
+ bucket_xyz = adata[['sx','sy','sz']].values
152
+ except KeyError:
153
+ raise ValueError("Spatial coordinates ('sx','sy','sz') not found in meta data.")
154
+ else:
155
+ cell_names = np.array(adata.obs_names.tolist())
156
+ if spatial_key not in adata.obsm:
157
+ raise ValueError(f"Spatial key {spatial_key} not found in adata.obsm.")
158
+ bucket_xyz = adata.obsm[spatial_key].astype(np.float64)
159
+ if isinstance(bucket_xyz, DataFrame):
160
+ bucket_xyz = bucket_xyz.values
161
+
162
+ # Handle grouping
163
+ if groupby is None:
164
+ groups = np.array(["same"] * bucket_xyz.shape[0], dtype=str)
165
+ elif isinstance(adata, DataFrame):
166
+ # If adata is a DataFrame, check if groupby is a valid column
167
+ if groupby not in adata.columns:
168
+ raise ValueError(f"`groupby` column '{groupby}' not found in DataFrame.")
169
+ groups = adata[groupby].map(lambda x: "mask" if x in mask_list else x).values
170
+ else:
171
+ # If adata is AnnData, check if groupby is in obs or var
172
+ if groupby in adata.obs_keys():
173
+ # Group by observation metadata
174
+ groups = adata.obs[groupby].map(lambda x: "mask" if x in mask_list else x).values
175
+ elif groupby in adata.var_names or set(groupby) <= set(adata.var_names):
176
+ # Group by gene expression
177
+ adata.X if layer == "X" else adata.layers[layer]
178
+ if isinstance(groupby, str):
179
+ groupby = [groupby]
180
+ groups = np.asarray(adata[:, groupby].X.sum(axis=1).flatten())
181
+ else:
182
+ raise ValueError(
183
+ f"`groupby` value '{groupby}' is invalid. "
184
+ "It must be a column in adata.obs, a gene in adata.var_names, "
185
+ "or a list of genes in adata.var_names."
186
+ )
187
+
188
+ pc = pv.PolyData(bucket_xyz)
189
+ _, plot_cmap = add_point_labels(
190
+ model=pc,
191
+ labels=groups,
192
+ key_added=key_added,
193
+ where="point_data",
194
+ colormap=colormap,
195
+ alphamap=alphamap,
196
+ inplace=True,
197
+ )
198
+
199
+ # The obs_index of each coordinate in the original adata.
200
+ pc.point_data["obs_index"] = cell_names
201
+
202
+ return pc, plot_cmap
@@ -0,0 +1,425 @@
1
+ import logging
2
+ import math
3
+ from typing import Literal
4
+
5
+ import matplotlib as mpl
6
+ import numpy as np
7
+ import pyvista as pv
8
+ from pyvista import MultiBlock, Plotter, PolyData
9
+
10
+ from .three_d_plot_decorate import add_legend, add_model, add_outline, add_text
11
+ from .three_d_plot_prepare import _get_default_cmap, construct_pc, create_plotter
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ _get_default_cmap()
16
+
17
+
18
+ def wrap_to_plotter(
19
+ plotter: Plotter,
20
+ model: PolyData | MultiBlock,
21
+ key: str | None = None,
22
+ colormap: str | list | None = None,
23
+
24
+ # parameters for model settings
25
+ ambient: float = 0.2,
26
+ opacity: [float,list] = 1,
27
+ point_size: float = 1,
28
+ model_style: Literal["points", "surface", "wireframe"] = "surface",
29
+ font_family: Literal["times", "courier", "arial"] = "arial",
30
+ background: str = "black",
31
+ cpo: str | list = "iso",
32
+ clim: list | None = None,
33
+ legend_kwargs: dict | None = None,
34
+ outline_kwargs: dict | None = None,
35
+ text: str | None = None,
36
+ scalar_bar_title: str | None = None,
37
+ text_kwargs: dict | None = None,
38
+ show_outline: bool = False,
39
+ show_text: bool = True,
40
+ show_legend: bool = True
41
+ ):
42
+ """
43
+ Wrap the model and its settings to a plotter.
44
+
45
+ Parameters:
46
+ plotter (Plotter): The plotter object to wrap the model to.
47
+ model (Union[PolyData, MultiBlock]): The model to be added to the plotter.
48
+ key (Optional[str]): The key to identify the model in the plotter.
49
+ colormap (Optional[Union[str, list]]): The colormap to use for the model.
50
+ ambient (float): The ambient lighting coefficient.
51
+ opacity (float): The opacity of the model.
52
+ point_size (float): The size of the points in the model.
53
+ model_style (Literal["points", "surface", "wireframe"]): The style of the model.
54
+ font_family (Literal["times", "courier", "arial"]): The font family to use.
55
+ background (str): The background color of the plotter.
56
+ cpo (Union[str, list]): The camera position.
57
+ legend_kwargs (Optional[dict]): Additional keyword arguments for the legend.
58
+ outline_kwargs (Optional[dict]): Additional keyword arguments for the outline.
59
+ text (Optional[str]): The text to add to the plotter.
60
+ scalar_bar_title (Optional[str]): The title of the scalar bar.
61
+ text_kwargs (Optional[dict]): Additional keyword arguments for the text.
62
+ show_outline (bool): Whether to show the outline.
63
+ show_text (bool): Whether to show the text.
64
+ show_legend (bool): Whether to show the legend.
65
+ """
66
+
67
+ # Set the bacic settings for the plotter.
68
+ # plotter.camera_position = cpo
69
+ bg_rgb = mpl.colors.to_rgb(background)
70
+ cbg_rgb = (1 - bg_rgb[0], 1 - bg_rgb[1], 1 - bg_rgb[2])
71
+
72
+ # Add model(s) basic settings to the plotter.
73
+ add_model(
74
+ plotter=plotter,
75
+ model=model,
76
+ key=key,
77
+ colormap=colormap,
78
+ ambient=ambient,
79
+ opacity=opacity,
80
+ point_size=point_size,
81
+ model_style=model_style,
82
+ clim = clim
83
+ )
84
+
85
+ # Add legends to the plotter.
86
+ if show_legend:
87
+ lg_kwargs = dict(
88
+ categorical_legend_size=None,
89
+ categorical_legend_loc=None,
90
+ scalar_bar_size=None,
91
+ scalar_bar_loc=None,
92
+ scalar_bar_title_size=None,
93
+ scalar_bar_label_size=None,
94
+ scalar_bar_font_color=cbg_rgb,
95
+ scalar_bar_n_labels=5,
96
+ font_family=font_family,
97
+ fmt="%.1e",
98
+ vertical=True,
99
+ )
100
+ if legend_kwargs is not None:
101
+ lg_kwargs.update(
102
+ (k, legend_kwargs[k]) for k in lg_kwargs.keys() & legend_kwargs.keys())
103
+
104
+ add_legend(plotter=plotter, model=model, key=key,
105
+ colormap=colormap, scalar_bar_title=scalar_bar_title, **lg_kwargs)
106
+
107
+ # Add an outline to the plotter.
108
+ if show_outline:
109
+ ol_kwargs = dict(
110
+ outline_width=1.0,
111
+ outline_color=cbg_rgb,
112
+ )
113
+
114
+ if outline_kwargs is not None:
115
+ ol_kwargs.update(
116
+ (k, outline_kwargs[k]) for k in ol_kwargs.keys() & outline_kwargs.keys())
117
+
118
+ add_outline(plotter=plotter, model=model, **ol_kwargs)
119
+
120
+ # Add text to the plotter.
121
+ if show_text:
122
+ t_kwargs = dict(
123
+ font_family=font_family,
124
+ text_font_size=12,
125
+ text_font_color=cbg_rgb,
126
+ text_loc="upper_edge",
127
+ )
128
+
129
+ if text_kwargs is not None:
130
+ t_kwargs.update((k, text_kwargs[k])
131
+ for k in t_kwargs.keys() & text_kwargs.keys())
132
+
133
+ add_text(plotter=plotter, text=text, **t_kwargs)
134
+
135
+
136
+ def three_d_plot(
137
+ # parameters for plotter
138
+ adata,
139
+ spatial_key: str,
140
+ keys: list,
141
+ cmaps: str | list | dict | None = 'default_cmap',
142
+ scalar_bar_titles: str | list | None = None,
143
+ texts: str | list | None = None,
144
+
145
+ window_size: tuple[int, int] | None = None,
146
+ off_screen: bool = False,
147
+ shape: tuple | None = None,
148
+ show_camera_orientation: bool = True,
149
+ show_axis_orientation: bool = False,
150
+ jupyter: bool = True,
151
+
152
+ # parameters for model settings
153
+ ambient: float = 0.2,
154
+ opacity: float | list[float] = 1,
155
+ point_size: float = 1,
156
+ clim: list | None = None,
157
+ model_style: Literal["points", "surface", "wireframe"] = "surface",
158
+ font_family: Literal["times", "courier", "arial"] = "arial",
159
+ background: str = "black",
160
+ cpo: str | list = "iso",
161
+
162
+ # parameters for show decoration
163
+ show_outline: bool = False,
164
+ show_text: bool = True,
165
+ show_legend: bool = True,
166
+
167
+ # parameters for legends, outline, and text
168
+ legend_kwargs: dict | None = None,
169
+ outline_kwargs: dict | None = None,
170
+ text_kwargs: dict | None = None
171
+ ):
172
+ """
173
+ Generate a 3D plot using pyvista for spatial data visualization.
174
+
175
+ Parameters:
176
+ - adata: AnnData object containing spatial data.
177
+ - spatial_key: Key in adata.obs that specifies the spatial coordinates.
178
+ - keys: List of keys in adata.obs to group the data by.
179
+ - cmaps: Colormap(s) to use for each group. Can be a string or a list of strings. Default is 'default_cmap'.
180
+ - scalar_bar_titles: Titles for the scalar bars. Can be a string or a list of strings. Default is None.
181
+ - texts: List of texts to display on the plot. Default is None.
182
+ - window_size: Size of the plot window in pixels. Default is None.
183
+ Note window_size is a tuple of (width, height), representing the width and height of the each subplot.
184
+ - off_screen: Whether to render the plot off-screen. Default is False.
185
+ - shape: Shape of the subplot grid. Default is None.
186
+ For plotting figues in a grid, shape should be a tuple of (n_rows, n_cols)
187
+ - show_camera_orientation: Whether to show the camera orientation widget. Default is True.
188
+ - jupyter: Whether to display the plot in a Jupyter notebook. Default is True.
189
+ - ambient: Ambient lighting intensity. Default is 0.2.
190
+ - opacity: Opacity of the plot objects. Default is 1.
191
+ - point_size: Size of the points in the plot. Default is 1.
192
+ - model_style: Style of the plot objects. Can be 'points', 'surface', or 'wireframe'. Default is 'surface'.
193
+ - font_family: Font family for the plot text. Can be 'times', 'courier', or 'arial'. Default is 'arial'.
194
+ - background: Background color of the plot. Default is 'black'.
195
+ - cpo: Color by point option. Can be a string or a list. Default is 'iso'.
196
+ - show_outline: Whether to show the plot outline. Default is False.
197
+ - show_text: Whether to show the plot text. Default is True.
198
+ - show_legend: Whether to show the plot legend. Default is True.
199
+ - legend_kwargs (dict, optional): Additional keyword arguments for the legend, the default values are:
200
+ categorical_legend_size: [tuple] = None, # (gap, size)
201
+ categorical_legend_loc: [Literal["upper right", "upper left", "lower left", "lower right",
202
+ "center left", "center right", "lower center", "upper center"
203
+ "center"]] = None
204
+ scalar_bar_title: Optional[str] = None
205
+ scalar_bar_size: Optional[tuple] = None
206
+ scalar_bar_loc: Optional[tuple] = None
207
+ scalar_bar_title_size: Union[int, float] = None
208
+ scalar_bar_label_size: Union[int, float] = None
209
+ scalar_bar_font_color: Optional[str] = None
210
+ scalar_bar_n_labels: int = 5
211
+ fmt="%.1e",
212
+ vertical: bool = True
213
+
214
+ - outline_kwargs (dict, optional): Additional keyword arguments for the plot outline, the default values are:
215
+ outline_width: float = 1.0
216
+ outline_color: Optional[str] = None
217
+ show_outline_labels: bool = False
218
+ outline_font_size: Optional[int] = None
219
+ outline_font_color: Optional[str] = None
220
+
221
+ - text_kwargs (dict, optional): Additional keyword arguments for the text, the default values are:
222
+ text_font_size: Optional[float] = None,
223
+ text_font_color: Optional[str] = None,
224
+ text_loc: Optional[Literal["lower_left", "lower_right", "upper_left",
225
+ "upper_right", "lower_edge", "upper_edge",
226
+ "right_edge", "left_edge",]] = None
227
+
228
+ Returns:
229
+ - plotter: The pyvista plotter object.
230
+ """
231
+ _get_default_cmap()
232
+ if isinstance(cmaps, str):
233
+ cmaps = [cmaps] * len(keys)
234
+
235
+ if scalar_bar_titles is None or isinstance(scalar_bar_titles, str):
236
+ scalar_bar_titles = [scalar_bar_titles] * len(keys)
237
+
238
+ if texts is None or isinstance(texts, str):
239
+ texts = [texts] * len(keys)
240
+
241
+ # Build the pyvista object
242
+ models = pv.MultiBlock()
243
+ plot_cmaps = []
244
+ for i, key in enumerate(keys):
245
+ _model, _plot_cmap = construct_pc(adata=adata.copy(),
246
+ spatial_key=spatial_key,
247
+ groupby=key,
248
+ key_added=key,
249
+ colormap=cmaps[i])
250
+ models[f"model_{i}"] = _model
251
+ plot_cmaps.append(_plot_cmap)
252
+
253
+ # Set the shape and window size of the plot
254
+ n_window = len(keys)
255
+ shape = (math.ceil(n_window / 3), n_window if n_window <
256
+ 3 else 3) if shape is None else shape
257
+ if isinstance(shape, tuple | list):
258
+ n_subplots = shape[0] * shape[1]
259
+ subplots = []
260
+ for i in range(n_subplots):
261
+ col = math.floor(i / shape[1])
262
+ ind = i - col * shape[1]
263
+ subplots.append([col, ind])
264
+
265
+ win_x, win_y = shape[1], shape[0]
266
+ window_size = ((1500 * win_x, 1500 * win_y)
267
+ if window_size is None else (window_size[0] * win_x, window_size[1] * win_y))
268
+
269
+ # Create the plotter
270
+ plotter = create_plotter(
271
+ background=background,
272
+ off_screen=off_screen,
273
+ shape=shape,
274
+ show_camera_orientation=show_camera_orientation,
275
+ show_axis_orientation=show_axis_orientation,
276
+ window_size=window_size,
277
+ jupyter=jupyter
278
+ )
279
+
280
+ # Set the plotter
281
+ for (model, key, plot_cmap, subplot_index, scalar_bar_title, text) in zip(models, keys, plot_cmaps, subplots, scalar_bar_titles, texts, strict=False):
282
+ plotter.subplot(subplot_index[0], subplot_index[1])
283
+
284
+ wrap_to_plotter(
285
+ # parameters for plotter
286
+ plotter=plotter,
287
+ model=model,
288
+ key=key,
289
+ colormap=plot_cmap,
290
+
291
+ # parameters for model settings
292
+ clim=clim,
293
+ ambient=ambient,
294
+ opacity=opacity,
295
+ point_size=point_size,
296
+ model_style=model_style,
297
+ font_family=font_family,
298
+ background=background,
299
+ cpo=cpo,
300
+
301
+ # parameters for legends, outline, and text
302
+ legend_kwargs=legend_kwargs,
303
+ scalar_bar_title=scalar_bar_title,
304
+ outline_kwargs=outline_kwargs,
305
+ text_kwargs=text_kwargs,
306
+ text=text,
307
+
308
+ # parameters for show decoration
309
+ show_outline=show_outline,
310
+ show_text=show_text,
311
+ show_legend=show_legend
312
+ )
313
+
314
+ plotter.link_views()
315
+ # plotter.camera_position = 'yz'
316
+
317
+ return plotter
318
+
319
+
320
+ def three_d_plot_save(
321
+ plotter: Plotter,
322
+ filename: str,
323
+ view_up_1=(0, 0, 0),
324
+ view_up_2=(0, 0, 1),
325
+ n_points: int = 150,
326
+ factor: float = 2.0,
327
+ shift: float = 0,
328
+ step: int = 1,
329
+ quality: int = 9,
330
+ framerate: int = 10,
331
+ save_mp4 : bool = False,
332
+ save_gif : bool = False
333
+ ):
334
+ """
335
+ Saves a 3D plot as an HTML, GIF, and MP4 file.
336
+
337
+ Args:
338
+ plotter (Plotter): The Plotter object used for generating the plot.
339
+ filename (str): The base filename for saving the files.
340
+ view_up_1 (tuple, optional): The initial view up direction. Defaults to (0.5, 0.5, 1).
341
+ view_up_2 (tuple, optional): The final view up direction. Defaults to (0, 0, 1).
342
+ n_points (int, optional): The number of points on the orbital path. Defaults to 150.
343
+ factor (float, optional): The factor for scaling the orbital path. Defaults to 2.0.
344
+ shift (float, optional): The shift value for the orbital path. Defaults to 0.
345
+ step (int, optional): The step size for writing frames. Defaults to 1.
346
+ quality (int, optional): The quality of the GIF file. Defaults to 9.
347
+ framerate (int, optional): The framerate of the MP4 file. Defaults to 15.
348
+ """
349
+ # save html
350
+ logger.info('saving 3d plot as html...')
351
+
352
+ # Workaround for asyncio conflict: use 'static' backend for export
353
+ pv.set_jupyter_backend('static')
354
+ try:
355
+ plotter.export_html(f'{filename}.html')
356
+ finally:
357
+ pv.set_jupyter_backend(pv.global_theme.jupyter_backend)
358
+
359
+ # save gif
360
+ if save_gif:
361
+ logger.info('saving 3d plot as gif...')
362
+ path = plotter.generate_orbital_path(factor=factor, shift=shift, viewup=view_up_1, n_points=n_points)
363
+ plotter.open_gif(filename=f'{filename}.gif')
364
+ plotter.orbit_on_path(path, write_frames=True, viewup=view_up_2, step=step)
365
+ plotter.close()
366
+
367
+ # save mp4
368
+ if save_mp4:
369
+ logger.info('saving 3d plot as mp4...')
370
+ path = plotter.generate_orbital_path(factor=factor, shift=shift, viewup=view_up_1, n_points=n_points)
371
+ plotter.open_movie(filename=f'{filename}.mp4', framerate=framerate, quality=quality)
372
+ plotter.orbit_on_path(path, write_frames=True, viewup=view_up_2, step=step)
373
+ plotter.close()
374
+
375
+
376
+
377
+ def rotate_around_xyz(
378
+ camera_coordinates,
379
+ angle_x=0,
380
+ angle_y=0,
381
+ angle_z=0
382
+ ):
383
+ """
384
+ Rotate a point around the x, y, and z axes.
385
+
386
+ Parameters:
387
+ - point: numpy array representing the camera coordinates
388
+ - angle_x: rotation angle around the x-axis in degrees (default: 0)
389
+ - angle_y: rotation angle around the y-axis in degrees (default: 0)
390
+ - angle_z: rotation angle around the z-axis in degrees (default: 0)
391
+
392
+ Returns:
393
+ - point_rotated: numpy array representing the rotated point coordinates
394
+ """
395
+ angle_x_rad = np.radians(angle_x)
396
+ angle_y_rad = np.radians(angle_y)
397
+ angle_z_rad = np.radians(angle_z)
398
+
399
+ # Rotation matrix for x-axis
400
+ rotation_matrix_x = np.array([
401
+ [1, 0, 0],
402
+ [0, np.cos(angle_x_rad), -np.sin(angle_x_rad)],
403
+ [0, np.sin(angle_x_rad), np.cos(angle_x_rad)]
404
+ ])
405
+
406
+ # Rotation matrix for y-axis
407
+ rotation_matrix_y = np.array([
408
+ [np.cos(angle_y_rad), 0, np.sin(angle_y_rad)],
409
+ [0, 1, 0],
410
+ [-np.sin(angle_y_rad), 0, np.cos(angle_y_rad)]
411
+ ])
412
+
413
+ # Rotation matrix for z-axis
414
+ rotation_matrix_z = np.array([
415
+ [np.cos(angle_z_rad), -np.sin(angle_z_rad), 0],
416
+ [np.sin(angle_z_rad), np.cos(angle_z_rad), 0],
417
+ [0, 0, 1]
418
+ ])
419
+
420
+ # Apply rotations
421
+ camera_rotated = np.dot(rotation_matrix_x, camera_coordinates)
422
+ camera_rotated = np.dot(rotation_matrix_y, camera_rotated)
423
+ camera_rotated = np.dot(rotation_matrix_z, camera_rotated)
424
+
425
+ return camera_rotated