gsMap3D 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/__init__.py +13 -0
- gsMap/__main__.py +4 -0
- gsMap/cauchy_combination_test.py +342 -0
- gsMap/cli.py +355 -0
- gsMap/config/__init__.py +72 -0
- gsMap/config/base.py +296 -0
- gsMap/config/cauchy_config.py +79 -0
- gsMap/config/dataclasses.py +235 -0
- gsMap/config/decorators.py +302 -0
- gsMap/config/find_latent_config.py +276 -0
- gsMap/config/format_sumstats_config.py +54 -0
- gsMap/config/latent2gene_config.py +461 -0
- gsMap/config/ldscore_config.py +261 -0
- gsMap/config/quick_mode_config.py +242 -0
- gsMap/config/report_config.py +81 -0
- gsMap/config/spatial_ldsc_config.py +334 -0
- gsMap/config/utils.py +286 -0
- gsMap/find_latent/__init__.py +3 -0
- gsMap/find_latent/find_latent_representation.py +312 -0
- gsMap/find_latent/gnn/distribution.py +498 -0
- gsMap/find_latent/gnn/encoder_decoder.py +186 -0
- gsMap/find_latent/gnn/gcn.py +85 -0
- gsMap/find_latent/gnn/gene_former.py +164 -0
- gsMap/find_latent/gnn/loss.py +18 -0
- gsMap/find_latent/gnn/st_model.py +125 -0
- gsMap/find_latent/gnn/train_step.py +177 -0
- gsMap/find_latent/st_process.py +781 -0
- gsMap/format_sumstats.py +446 -0
- gsMap/generate_ldscore.py +1018 -0
- gsMap/latent2gene/__init__.py +18 -0
- gsMap/latent2gene/connectivity.py +781 -0
- gsMap/latent2gene/entry_point.py +141 -0
- gsMap/latent2gene/marker_scores.py +1265 -0
- gsMap/latent2gene/memmap_io.py +766 -0
- gsMap/latent2gene/rank_calculator.py +590 -0
- gsMap/latent2gene/row_ordering.py +182 -0
- gsMap/latent2gene/row_ordering_jax.py +159 -0
- gsMap/ldscore/__init__.py +1 -0
- gsMap/ldscore/batch_construction.py +163 -0
- gsMap/ldscore/compute.py +126 -0
- gsMap/ldscore/constants.py +70 -0
- gsMap/ldscore/io.py +262 -0
- gsMap/ldscore/mapping.py +262 -0
- gsMap/ldscore/pipeline.py +615 -0
- gsMap/pipeline/quick_mode.py +134 -0
- gsMap/report/__init__.py +2 -0
- gsMap/report/diagnosis.py +375 -0
- gsMap/report/report.py +100 -0
- gsMap/report/report_data.py +1832 -0
- gsMap/report/static/js_lib/alpine.min.js +5 -0
- gsMap/report/static/js_lib/tailwindcss.js +83 -0
- gsMap/report/static/template.html +2242 -0
- gsMap/report/three_d_combine.py +312 -0
- gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
- gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
- gsMap/report/three_d_plot/three_d_plots.py +425 -0
- gsMap/report/visualize.py +1409 -0
- gsMap/setup.py +5 -0
- gsMap/spatial_ldsc/__init__.py +0 -0
- gsMap/spatial_ldsc/io.py +656 -0
- gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
- gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
- gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +610 -0
- gsMap/utils/jackknife.py +518 -0
- gsMap/utils/manhattan_plot.py +643 -0
- gsMap/utils/regression_read.py +177 -0
- gsMap/utils/torch_utils.py +23 -0
- gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
- gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
- gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
- gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
- gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
import matplotlib as mpl
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pyvista as pv
|
|
4
|
+
from matplotlib.colors import LinearSegmentedColormap
|
|
5
|
+
from pandas.core.frame import DataFrame
|
|
6
|
+
|
|
7
|
+
p_color = ['#313695', '#4575b4', '#74add1','#fee090', '#fdae61', '#f46d43', '#d73027', '#a50026']
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _get_default_cmap():
|
|
12
|
+
if "default_cmap" not in mpl.colormaps():
|
|
13
|
+
colors = p_color
|
|
14
|
+
nodes = np.linspace(0, 1, len(p_color))
|
|
15
|
+
|
|
16
|
+
mpl.colormaps.register(LinearSegmentedColormap.from_list(
|
|
17
|
+
"default_cmap", list(zip(nodes, colors, strict=False))))
|
|
18
|
+
return "default_cmap"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def create_plotter(
|
|
22
|
+
jupyter=False,
|
|
23
|
+
off_screen=False,
|
|
24
|
+
window_size=(512, 512),
|
|
25
|
+
background="white",
|
|
26
|
+
shape=(1, 1),
|
|
27
|
+
show_camera_orientation=True,
|
|
28
|
+
show_axis_orientation=False
|
|
29
|
+
):
|
|
30
|
+
|
|
31
|
+
# Create an initial plotting object.
|
|
32
|
+
_get_default_cmap()
|
|
33
|
+
plotter = pv.Plotter(
|
|
34
|
+
off_screen=off_screen,
|
|
35
|
+
window_size=window_size,
|
|
36
|
+
notebook=False if jupyter is False else True,
|
|
37
|
+
lighting="light_kit",
|
|
38
|
+
shape=shape,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# Set the background color of the active render window.
|
|
42
|
+
plotter.background_color = background
|
|
43
|
+
|
|
44
|
+
# Add a camera orientation widget to the active renderer window.
|
|
45
|
+
if jupyter != "trame":
|
|
46
|
+
if show_camera_orientation:
|
|
47
|
+
plotter.add_camera_orientation_widget()
|
|
48
|
+
elif show_axis_orientation:
|
|
49
|
+
plotter.add_axes(labels_off=True)
|
|
50
|
+
return plotter
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def add_point_labels(
|
|
54
|
+
model,
|
|
55
|
+
labels,
|
|
56
|
+
key_added="groups",
|
|
57
|
+
where="point_data",
|
|
58
|
+
colormap="rainbow",
|
|
59
|
+
alphamap=1.0,
|
|
60
|
+
mask_color="black",
|
|
61
|
+
mask_alpha=0.0,
|
|
62
|
+
inplace=False,
|
|
63
|
+
):
|
|
64
|
+
|
|
65
|
+
model = model.copy() if not inplace else model
|
|
66
|
+
labels = np.asarray(labels).flatten()
|
|
67
|
+
|
|
68
|
+
# Set color here if group is of string type.
|
|
69
|
+
if not np.issubdtype(labels.dtype, np.number):
|
|
70
|
+
cu_arr = np.sort(np.unique(labels), axis=0).astype(object)
|
|
71
|
+
|
|
72
|
+
raw_labels_hex = labels.copy().astype(object)
|
|
73
|
+
raw_labels_alpha = labels.copy().astype(object)
|
|
74
|
+
raw_labels_hex[raw_labels_hex =="mask"] = mpl.colors.to_hex(mask_color)
|
|
75
|
+
raw_labels_alpha[raw_labels_alpha == "mask"] = mask_alpha
|
|
76
|
+
|
|
77
|
+
# Set raw hex.
|
|
78
|
+
if isinstance(colormap, str):
|
|
79
|
+
if colormap in list(mpl.colormaps()):
|
|
80
|
+
lscmap = mpl.colormaps[colormap]
|
|
81
|
+
raw_hex_list = [mpl.colors.to_hex(lscmap(i)) for i in np.linspace(0, 1, len(cu_arr))]
|
|
82
|
+
for label, color in zip(cu_arr, raw_hex_list, strict=False):
|
|
83
|
+
raw_labels_hex[raw_labels_hex == label] = color
|
|
84
|
+
else:
|
|
85
|
+
raw_labels_hex[raw_labels_hex !="mask"] = mpl.colors.to_hex(colormap)
|
|
86
|
+
elif isinstance(colormap, dict):
|
|
87
|
+
for label, color in colormap.items():
|
|
88
|
+
raw_labels_hex[raw_labels_hex ==label] = mpl.colors.to_hex(color)
|
|
89
|
+
elif isinstance(colormap, list) or isinstance(colormap, np.ndarray):
|
|
90
|
+
raw_hex_list = np.array([mpl.colors.to_hex(color)for color in colormap]).astype(object)
|
|
91
|
+
for label, color in zip(cu_arr, raw_hex_list, strict=False):
|
|
92
|
+
raw_labels_hex[raw_labels_hex == label] = color
|
|
93
|
+
else:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
"`colormap` value is wrong." "\nAvailable `colormap` types are: `str`, `list` and `dict`.")
|
|
96
|
+
|
|
97
|
+
# Set raw alpha.
|
|
98
|
+
if isinstance(alphamap, float) or isinstance(alphamap, int):
|
|
99
|
+
raw_labels_alpha[raw_labels_alpha != "mask"] = alphamap
|
|
100
|
+
elif isinstance(alphamap, dict):
|
|
101
|
+
for label, alpha in alphamap.items():
|
|
102
|
+
raw_labels_alpha[raw_labels_alpha == label] = alpha
|
|
103
|
+
elif isinstance(alphamap, list) or isinstance(alphamap, np.ndarray):
|
|
104
|
+
raw_labels_alpha = np.asarray(alphamap).astype(object)
|
|
105
|
+
else:
|
|
106
|
+
raise ValueError(
|
|
107
|
+
"`alphamap` value is wrong." "\nAvailable `alphamap` types are: `float`, `list` and `dict`."
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Set rgba.
|
|
111
|
+
labels_rgba = [mpl.colors.to_rgba(c, alpha=a) for c, a in zip(raw_labels_hex, raw_labels_alpha, strict=False)]
|
|
112
|
+
labels_rgba = np.array(labels_rgba).astype(np.float32)
|
|
113
|
+
|
|
114
|
+
# Added rgba of the labels.
|
|
115
|
+
if where == "point_data":
|
|
116
|
+
model.point_data[f"{key_added}_rgba"] = labels_rgba
|
|
117
|
+
else:
|
|
118
|
+
model.cell_data[f"{key_added}_rgba"] = labels_rgba
|
|
119
|
+
|
|
120
|
+
plot_cmap = None
|
|
121
|
+
else:
|
|
122
|
+
plot_cmap = colormap
|
|
123
|
+
|
|
124
|
+
# Added labels.
|
|
125
|
+
if where == "point_data":
|
|
126
|
+
model.point_data[key_added] = labels
|
|
127
|
+
else:
|
|
128
|
+
model.cell_data[key_added] = labels
|
|
129
|
+
|
|
130
|
+
return model if not inplace else None, plot_cmap
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def construct_pc(
|
|
134
|
+
adata,
|
|
135
|
+
layer="X",
|
|
136
|
+
spatial_key="spatial",
|
|
137
|
+
groupby=None,
|
|
138
|
+
key_added="groups",
|
|
139
|
+
mask=None,
|
|
140
|
+
colormap="default_cmap",
|
|
141
|
+
alphamap=1.0
|
|
142
|
+
):
|
|
143
|
+
|
|
144
|
+
# Ensure mask is a list
|
|
145
|
+
mask_list = mask if isinstance(mask, list) else [mask] if mask is not None else []
|
|
146
|
+
|
|
147
|
+
# Extract spatial coordinates
|
|
148
|
+
if isinstance(adata, DataFrame):
|
|
149
|
+
cell_names = np.array(adata.index.tolist())
|
|
150
|
+
try:
|
|
151
|
+
bucket_xyz = adata[['sx','sy','sz']].values
|
|
152
|
+
except KeyError:
|
|
153
|
+
raise ValueError("Spatial coordinates ('sx','sy','sz') not found in meta data.")
|
|
154
|
+
else:
|
|
155
|
+
cell_names = np.array(adata.obs_names.tolist())
|
|
156
|
+
if spatial_key not in adata.obsm:
|
|
157
|
+
raise ValueError(f"Spatial key {spatial_key} not found in adata.obsm.")
|
|
158
|
+
bucket_xyz = adata.obsm[spatial_key].astype(np.float64)
|
|
159
|
+
if isinstance(bucket_xyz, DataFrame):
|
|
160
|
+
bucket_xyz = bucket_xyz.values
|
|
161
|
+
|
|
162
|
+
# Handle grouping
|
|
163
|
+
if groupby is None:
|
|
164
|
+
groups = np.array(["same"] * bucket_xyz.shape[0], dtype=str)
|
|
165
|
+
elif isinstance(adata, DataFrame):
|
|
166
|
+
# If adata is a DataFrame, check if groupby is a valid column
|
|
167
|
+
if groupby not in adata.columns:
|
|
168
|
+
raise ValueError(f"`groupby` column '{groupby}' not found in DataFrame.")
|
|
169
|
+
groups = adata[groupby].map(lambda x: "mask" if x in mask_list else x).values
|
|
170
|
+
else:
|
|
171
|
+
# If adata is AnnData, check if groupby is in obs or var
|
|
172
|
+
if groupby in adata.obs_keys():
|
|
173
|
+
# Group by observation metadata
|
|
174
|
+
groups = adata.obs[groupby].map(lambda x: "mask" if x in mask_list else x).values
|
|
175
|
+
elif groupby in adata.var_names or set(groupby) <= set(adata.var_names):
|
|
176
|
+
# Group by gene expression
|
|
177
|
+
adata.X if layer == "X" else adata.layers[layer]
|
|
178
|
+
if isinstance(groupby, str):
|
|
179
|
+
groupby = [groupby]
|
|
180
|
+
groups = np.asarray(adata[:, groupby].X.sum(axis=1).flatten())
|
|
181
|
+
else:
|
|
182
|
+
raise ValueError(
|
|
183
|
+
f"`groupby` value '{groupby}' is invalid. "
|
|
184
|
+
"It must be a column in adata.obs, a gene in adata.var_names, "
|
|
185
|
+
"or a list of genes in adata.var_names."
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
pc = pv.PolyData(bucket_xyz)
|
|
189
|
+
_, plot_cmap = add_point_labels(
|
|
190
|
+
model=pc,
|
|
191
|
+
labels=groups,
|
|
192
|
+
key_added=key_added,
|
|
193
|
+
where="point_data",
|
|
194
|
+
colormap=colormap,
|
|
195
|
+
alphamap=alphamap,
|
|
196
|
+
inplace=True,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# The obs_index of each coordinate in the original adata.
|
|
200
|
+
pc.point_data["obs_index"] = cell_names
|
|
201
|
+
|
|
202
|
+
return pc, plot_cmap
|
|
@@ -0,0 +1,425 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import math
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import matplotlib as mpl
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pyvista as pv
|
|
8
|
+
from pyvista import MultiBlock, Plotter, PolyData
|
|
9
|
+
|
|
10
|
+
from .three_d_plot_decorate import add_legend, add_model, add_outline, add_text
|
|
11
|
+
from .three_d_plot_prepare import _get_default_cmap, construct_pc, create_plotter
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
_get_default_cmap()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def wrap_to_plotter(
|
|
19
|
+
plotter: Plotter,
|
|
20
|
+
model: PolyData | MultiBlock,
|
|
21
|
+
key: str | None = None,
|
|
22
|
+
colormap: str | list | None = None,
|
|
23
|
+
|
|
24
|
+
# parameters for model settings
|
|
25
|
+
ambient: float = 0.2,
|
|
26
|
+
opacity: [float,list] = 1,
|
|
27
|
+
point_size: float = 1,
|
|
28
|
+
model_style: Literal["points", "surface", "wireframe"] = "surface",
|
|
29
|
+
font_family: Literal["times", "courier", "arial"] = "arial",
|
|
30
|
+
background: str = "black",
|
|
31
|
+
cpo: str | list = "iso",
|
|
32
|
+
clim: list | None = None,
|
|
33
|
+
legend_kwargs: dict | None = None,
|
|
34
|
+
outline_kwargs: dict | None = None,
|
|
35
|
+
text: str | None = None,
|
|
36
|
+
scalar_bar_title: str | None = None,
|
|
37
|
+
text_kwargs: dict | None = None,
|
|
38
|
+
show_outline: bool = False,
|
|
39
|
+
show_text: bool = True,
|
|
40
|
+
show_legend: bool = True
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
Wrap the model and its settings to a plotter.
|
|
44
|
+
|
|
45
|
+
Parameters:
|
|
46
|
+
plotter (Plotter): The plotter object to wrap the model to.
|
|
47
|
+
model (Union[PolyData, MultiBlock]): The model to be added to the plotter.
|
|
48
|
+
key (Optional[str]): The key to identify the model in the plotter.
|
|
49
|
+
colormap (Optional[Union[str, list]]): The colormap to use for the model.
|
|
50
|
+
ambient (float): The ambient lighting coefficient.
|
|
51
|
+
opacity (float): The opacity of the model.
|
|
52
|
+
point_size (float): The size of the points in the model.
|
|
53
|
+
model_style (Literal["points", "surface", "wireframe"]): The style of the model.
|
|
54
|
+
font_family (Literal["times", "courier", "arial"]): The font family to use.
|
|
55
|
+
background (str): The background color of the plotter.
|
|
56
|
+
cpo (Union[str, list]): The camera position.
|
|
57
|
+
legend_kwargs (Optional[dict]): Additional keyword arguments for the legend.
|
|
58
|
+
outline_kwargs (Optional[dict]): Additional keyword arguments for the outline.
|
|
59
|
+
text (Optional[str]): The text to add to the plotter.
|
|
60
|
+
scalar_bar_title (Optional[str]): The title of the scalar bar.
|
|
61
|
+
text_kwargs (Optional[dict]): Additional keyword arguments for the text.
|
|
62
|
+
show_outline (bool): Whether to show the outline.
|
|
63
|
+
show_text (bool): Whether to show the text.
|
|
64
|
+
show_legend (bool): Whether to show the legend.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
# Set the bacic settings for the plotter.
|
|
68
|
+
# plotter.camera_position = cpo
|
|
69
|
+
bg_rgb = mpl.colors.to_rgb(background)
|
|
70
|
+
cbg_rgb = (1 - bg_rgb[0], 1 - bg_rgb[1], 1 - bg_rgb[2])
|
|
71
|
+
|
|
72
|
+
# Add model(s) basic settings to the plotter.
|
|
73
|
+
add_model(
|
|
74
|
+
plotter=plotter,
|
|
75
|
+
model=model,
|
|
76
|
+
key=key,
|
|
77
|
+
colormap=colormap,
|
|
78
|
+
ambient=ambient,
|
|
79
|
+
opacity=opacity,
|
|
80
|
+
point_size=point_size,
|
|
81
|
+
model_style=model_style,
|
|
82
|
+
clim = clim
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Add legends to the plotter.
|
|
86
|
+
if show_legend:
|
|
87
|
+
lg_kwargs = dict(
|
|
88
|
+
categorical_legend_size=None,
|
|
89
|
+
categorical_legend_loc=None,
|
|
90
|
+
scalar_bar_size=None,
|
|
91
|
+
scalar_bar_loc=None,
|
|
92
|
+
scalar_bar_title_size=None,
|
|
93
|
+
scalar_bar_label_size=None,
|
|
94
|
+
scalar_bar_font_color=cbg_rgb,
|
|
95
|
+
scalar_bar_n_labels=5,
|
|
96
|
+
font_family=font_family,
|
|
97
|
+
fmt="%.1e",
|
|
98
|
+
vertical=True,
|
|
99
|
+
)
|
|
100
|
+
if legend_kwargs is not None:
|
|
101
|
+
lg_kwargs.update(
|
|
102
|
+
(k, legend_kwargs[k]) for k in lg_kwargs.keys() & legend_kwargs.keys())
|
|
103
|
+
|
|
104
|
+
add_legend(plotter=plotter, model=model, key=key,
|
|
105
|
+
colormap=colormap, scalar_bar_title=scalar_bar_title, **lg_kwargs)
|
|
106
|
+
|
|
107
|
+
# Add an outline to the plotter.
|
|
108
|
+
if show_outline:
|
|
109
|
+
ol_kwargs = dict(
|
|
110
|
+
outline_width=1.0,
|
|
111
|
+
outline_color=cbg_rgb,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if outline_kwargs is not None:
|
|
115
|
+
ol_kwargs.update(
|
|
116
|
+
(k, outline_kwargs[k]) for k in ol_kwargs.keys() & outline_kwargs.keys())
|
|
117
|
+
|
|
118
|
+
add_outline(plotter=plotter, model=model, **ol_kwargs)
|
|
119
|
+
|
|
120
|
+
# Add text to the plotter.
|
|
121
|
+
if show_text:
|
|
122
|
+
t_kwargs = dict(
|
|
123
|
+
font_family=font_family,
|
|
124
|
+
text_font_size=12,
|
|
125
|
+
text_font_color=cbg_rgb,
|
|
126
|
+
text_loc="upper_edge",
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
if text_kwargs is not None:
|
|
130
|
+
t_kwargs.update((k, text_kwargs[k])
|
|
131
|
+
for k in t_kwargs.keys() & text_kwargs.keys())
|
|
132
|
+
|
|
133
|
+
add_text(plotter=plotter, text=text, **t_kwargs)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def three_d_plot(
|
|
137
|
+
# parameters for plotter
|
|
138
|
+
adata,
|
|
139
|
+
spatial_key: str,
|
|
140
|
+
keys: list,
|
|
141
|
+
cmaps: str | list | dict | None = 'default_cmap',
|
|
142
|
+
scalar_bar_titles: str | list | None = None,
|
|
143
|
+
texts: str | list | None = None,
|
|
144
|
+
|
|
145
|
+
window_size: tuple[int, int] | None = None,
|
|
146
|
+
off_screen: bool = False,
|
|
147
|
+
shape: tuple | None = None,
|
|
148
|
+
show_camera_orientation: bool = True,
|
|
149
|
+
show_axis_orientation: bool = False,
|
|
150
|
+
jupyter: bool = True,
|
|
151
|
+
|
|
152
|
+
# parameters for model settings
|
|
153
|
+
ambient: float = 0.2,
|
|
154
|
+
opacity: float | list[float] = 1,
|
|
155
|
+
point_size: float = 1,
|
|
156
|
+
clim: list | None = None,
|
|
157
|
+
model_style: Literal["points", "surface", "wireframe"] = "surface",
|
|
158
|
+
font_family: Literal["times", "courier", "arial"] = "arial",
|
|
159
|
+
background: str = "black",
|
|
160
|
+
cpo: str | list = "iso",
|
|
161
|
+
|
|
162
|
+
# parameters for show decoration
|
|
163
|
+
show_outline: bool = False,
|
|
164
|
+
show_text: bool = True,
|
|
165
|
+
show_legend: bool = True,
|
|
166
|
+
|
|
167
|
+
# parameters for legends, outline, and text
|
|
168
|
+
legend_kwargs: dict | None = None,
|
|
169
|
+
outline_kwargs: dict | None = None,
|
|
170
|
+
text_kwargs: dict | None = None
|
|
171
|
+
):
|
|
172
|
+
"""
|
|
173
|
+
Generate a 3D plot using pyvista for spatial data visualization.
|
|
174
|
+
|
|
175
|
+
Parameters:
|
|
176
|
+
- adata: AnnData object containing spatial data.
|
|
177
|
+
- spatial_key: Key in adata.obs that specifies the spatial coordinates.
|
|
178
|
+
- keys: List of keys in adata.obs to group the data by.
|
|
179
|
+
- cmaps: Colormap(s) to use for each group. Can be a string or a list of strings. Default is 'default_cmap'.
|
|
180
|
+
- scalar_bar_titles: Titles for the scalar bars. Can be a string or a list of strings. Default is None.
|
|
181
|
+
- texts: List of texts to display on the plot. Default is None.
|
|
182
|
+
- window_size: Size of the plot window in pixels. Default is None.
|
|
183
|
+
Note window_size is a tuple of (width, height), representing the width and height of the each subplot.
|
|
184
|
+
- off_screen: Whether to render the plot off-screen. Default is False.
|
|
185
|
+
- shape: Shape of the subplot grid. Default is None.
|
|
186
|
+
For plotting figues in a grid, shape should be a tuple of (n_rows, n_cols)
|
|
187
|
+
- show_camera_orientation: Whether to show the camera orientation widget. Default is True.
|
|
188
|
+
- jupyter: Whether to display the plot in a Jupyter notebook. Default is True.
|
|
189
|
+
- ambient: Ambient lighting intensity. Default is 0.2.
|
|
190
|
+
- opacity: Opacity of the plot objects. Default is 1.
|
|
191
|
+
- point_size: Size of the points in the plot. Default is 1.
|
|
192
|
+
- model_style: Style of the plot objects. Can be 'points', 'surface', or 'wireframe'. Default is 'surface'.
|
|
193
|
+
- font_family: Font family for the plot text. Can be 'times', 'courier', or 'arial'. Default is 'arial'.
|
|
194
|
+
- background: Background color of the plot. Default is 'black'.
|
|
195
|
+
- cpo: Color by point option. Can be a string or a list. Default is 'iso'.
|
|
196
|
+
- show_outline: Whether to show the plot outline. Default is False.
|
|
197
|
+
- show_text: Whether to show the plot text. Default is True.
|
|
198
|
+
- show_legend: Whether to show the plot legend. Default is True.
|
|
199
|
+
- legend_kwargs (dict, optional): Additional keyword arguments for the legend, the default values are:
|
|
200
|
+
categorical_legend_size: [tuple] = None, # (gap, size)
|
|
201
|
+
categorical_legend_loc: [Literal["upper right", "upper left", "lower left", "lower right",
|
|
202
|
+
"center left", "center right", "lower center", "upper center"
|
|
203
|
+
"center"]] = None
|
|
204
|
+
scalar_bar_title: Optional[str] = None
|
|
205
|
+
scalar_bar_size: Optional[tuple] = None
|
|
206
|
+
scalar_bar_loc: Optional[tuple] = None
|
|
207
|
+
scalar_bar_title_size: Union[int, float] = None
|
|
208
|
+
scalar_bar_label_size: Union[int, float] = None
|
|
209
|
+
scalar_bar_font_color: Optional[str] = None
|
|
210
|
+
scalar_bar_n_labels: int = 5
|
|
211
|
+
fmt="%.1e",
|
|
212
|
+
vertical: bool = True
|
|
213
|
+
|
|
214
|
+
- outline_kwargs (dict, optional): Additional keyword arguments for the plot outline, the default values are:
|
|
215
|
+
outline_width: float = 1.0
|
|
216
|
+
outline_color: Optional[str] = None
|
|
217
|
+
show_outline_labels: bool = False
|
|
218
|
+
outline_font_size: Optional[int] = None
|
|
219
|
+
outline_font_color: Optional[str] = None
|
|
220
|
+
|
|
221
|
+
- text_kwargs (dict, optional): Additional keyword arguments for the text, the default values are:
|
|
222
|
+
text_font_size: Optional[float] = None,
|
|
223
|
+
text_font_color: Optional[str] = None,
|
|
224
|
+
text_loc: Optional[Literal["lower_left", "lower_right", "upper_left",
|
|
225
|
+
"upper_right", "lower_edge", "upper_edge",
|
|
226
|
+
"right_edge", "left_edge",]] = None
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
- plotter: The pyvista plotter object.
|
|
230
|
+
"""
|
|
231
|
+
_get_default_cmap()
|
|
232
|
+
if isinstance(cmaps, str):
|
|
233
|
+
cmaps = [cmaps] * len(keys)
|
|
234
|
+
|
|
235
|
+
if scalar_bar_titles is None or isinstance(scalar_bar_titles, str):
|
|
236
|
+
scalar_bar_titles = [scalar_bar_titles] * len(keys)
|
|
237
|
+
|
|
238
|
+
if texts is None or isinstance(texts, str):
|
|
239
|
+
texts = [texts] * len(keys)
|
|
240
|
+
|
|
241
|
+
# Build the pyvista object
|
|
242
|
+
models = pv.MultiBlock()
|
|
243
|
+
plot_cmaps = []
|
|
244
|
+
for i, key in enumerate(keys):
|
|
245
|
+
_model, _plot_cmap = construct_pc(adata=adata.copy(),
|
|
246
|
+
spatial_key=spatial_key,
|
|
247
|
+
groupby=key,
|
|
248
|
+
key_added=key,
|
|
249
|
+
colormap=cmaps[i])
|
|
250
|
+
models[f"model_{i}"] = _model
|
|
251
|
+
plot_cmaps.append(_plot_cmap)
|
|
252
|
+
|
|
253
|
+
# Set the shape and window size of the plot
|
|
254
|
+
n_window = len(keys)
|
|
255
|
+
shape = (math.ceil(n_window / 3), n_window if n_window <
|
|
256
|
+
3 else 3) if shape is None else shape
|
|
257
|
+
if isinstance(shape, tuple | list):
|
|
258
|
+
n_subplots = shape[0] * shape[1]
|
|
259
|
+
subplots = []
|
|
260
|
+
for i in range(n_subplots):
|
|
261
|
+
col = math.floor(i / shape[1])
|
|
262
|
+
ind = i - col * shape[1]
|
|
263
|
+
subplots.append([col, ind])
|
|
264
|
+
|
|
265
|
+
win_x, win_y = shape[1], shape[0]
|
|
266
|
+
window_size = ((1500 * win_x, 1500 * win_y)
|
|
267
|
+
if window_size is None else (window_size[0] * win_x, window_size[1] * win_y))
|
|
268
|
+
|
|
269
|
+
# Create the plotter
|
|
270
|
+
plotter = create_plotter(
|
|
271
|
+
background=background,
|
|
272
|
+
off_screen=off_screen,
|
|
273
|
+
shape=shape,
|
|
274
|
+
show_camera_orientation=show_camera_orientation,
|
|
275
|
+
show_axis_orientation=show_axis_orientation,
|
|
276
|
+
window_size=window_size,
|
|
277
|
+
jupyter=jupyter
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# Set the plotter
|
|
281
|
+
for (model, key, plot_cmap, subplot_index, scalar_bar_title, text) in zip(models, keys, plot_cmaps, subplots, scalar_bar_titles, texts, strict=False):
|
|
282
|
+
plotter.subplot(subplot_index[0], subplot_index[1])
|
|
283
|
+
|
|
284
|
+
wrap_to_plotter(
|
|
285
|
+
# parameters for plotter
|
|
286
|
+
plotter=plotter,
|
|
287
|
+
model=model,
|
|
288
|
+
key=key,
|
|
289
|
+
colormap=plot_cmap,
|
|
290
|
+
|
|
291
|
+
# parameters for model settings
|
|
292
|
+
clim=clim,
|
|
293
|
+
ambient=ambient,
|
|
294
|
+
opacity=opacity,
|
|
295
|
+
point_size=point_size,
|
|
296
|
+
model_style=model_style,
|
|
297
|
+
font_family=font_family,
|
|
298
|
+
background=background,
|
|
299
|
+
cpo=cpo,
|
|
300
|
+
|
|
301
|
+
# parameters for legends, outline, and text
|
|
302
|
+
legend_kwargs=legend_kwargs,
|
|
303
|
+
scalar_bar_title=scalar_bar_title,
|
|
304
|
+
outline_kwargs=outline_kwargs,
|
|
305
|
+
text_kwargs=text_kwargs,
|
|
306
|
+
text=text,
|
|
307
|
+
|
|
308
|
+
# parameters for show decoration
|
|
309
|
+
show_outline=show_outline,
|
|
310
|
+
show_text=show_text,
|
|
311
|
+
show_legend=show_legend
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
plotter.link_views()
|
|
315
|
+
# plotter.camera_position = 'yz'
|
|
316
|
+
|
|
317
|
+
return plotter
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def three_d_plot_save(
|
|
321
|
+
plotter: Plotter,
|
|
322
|
+
filename: str,
|
|
323
|
+
view_up_1=(0, 0, 0),
|
|
324
|
+
view_up_2=(0, 0, 1),
|
|
325
|
+
n_points: int = 150,
|
|
326
|
+
factor: float = 2.0,
|
|
327
|
+
shift: float = 0,
|
|
328
|
+
step: int = 1,
|
|
329
|
+
quality: int = 9,
|
|
330
|
+
framerate: int = 10,
|
|
331
|
+
save_mp4 : bool = False,
|
|
332
|
+
save_gif : bool = False
|
|
333
|
+
):
|
|
334
|
+
"""
|
|
335
|
+
Saves a 3D plot as an HTML, GIF, and MP4 file.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
plotter (Plotter): The Plotter object used for generating the plot.
|
|
339
|
+
filename (str): The base filename for saving the files.
|
|
340
|
+
view_up_1 (tuple, optional): The initial view up direction. Defaults to (0.5, 0.5, 1).
|
|
341
|
+
view_up_2 (tuple, optional): The final view up direction. Defaults to (0, 0, 1).
|
|
342
|
+
n_points (int, optional): The number of points on the orbital path. Defaults to 150.
|
|
343
|
+
factor (float, optional): The factor for scaling the orbital path. Defaults to 2.0.
|
|
344
|
+
shift (float, optional): The shift value for the orbital path. Defaults to 0.
|
|
345
|
+
step (int, optional): The step size for writing frames. Defaults to 1.
|
|
346
|
+
quality (int, optional): The quality of the GIF file. Defaults to 9.
|
|
347
|
+
framerate (int, optional): The framerate of the MP4 file. Defaults to 15.
|
|
348
|
+
"""
|
|
349
|
+
# save html
|
|
350
|
+
logger.info('saving 3d plot as html...')
|
|
351
|
+
|
|
352
|
+
# Workaround for asyncio conflict: use 'static' backend for export
|
|
353
|
+
pv.set_jupyter_backend('static')
|
|
354
|
+
try:
|
|
355
|
+
plotter.export_html(f'{filename}.html')
|
|
356
|
+
finally:
|
|
357
|
+
pv.set_jupyter_backend(pv.global_theme.jupyter_backend)
|
|
358
|
+
|
|
359
|
+
# save gif
|
|
360
|
+
if save_gif:
|
|
361
|
+
logger.info('saving 3d plot as gif...')
|
|
362
|
+
path = plotter.generate_orbital_path(factor=factor, shift=shift, viewup=view_up_1, n_points=n_points)
|
|
363
|
+
plotter.open_gif(filename=f'{filename}.gif')
|
|
364
|
+
plotter.orbit_on_path(path, write_frames=True, viewup=view_up_2, step=step)
|
|
365
|
+
plotter.close()
|
|
366
|
+
|
|
367
|
+
# save mp4
|
|
368
|
+
if save_mp4:
|
|
369
|
+
logger.info('saving 3d plot as mp4...')
|
|
370
|
+
path = plotter.generate_orbital_path(factor=factor, shift=shift, viewup=view_up_1, n_points=n_points)
|
|
371
|
+
plotter.open_movie(filename=f'{filename}.mp4', framerate=framerate, quality=quality)
|
|
372
|
+
plotter.orbit_on_path(path, write_frames=True, viewup=view_up_2, step=step)
|
|
373
|
+
plotter.close()
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
def rotate_around_xyz(
|
|
378
|
+
camera_coordinates,
|
|
379
|
+
angle_x=0,
|
|
380
|
+
angle_y=0,
|
|
381
|
+
angle_z=0
|
|
382
|
+
):
|
|
383
|
+
"""
|
|
384
|
+
Rotate a point around the x, y, and z axes.
|
|
385
|
+
|
|
386
|
+
Parameters:
|
|
387
|
+
- point: numpy array representing the camera coordinates
|
|
388
|
+
- angle_x: rotation angle around the x-axis in degrees (default: 0)
|
|
389
|
+
- angle_y: rotation angle around the y-axis in degrees (default: 0)
|
|
390
|
+
- angle_z: rotation angle around the z-axis in degrees (default: 0)
|
|
391
|
+
|
|
392
|
+
Returns:
|
|
393
|
+
- point_rotated: numpy array representing the rotated point coordinates
|
|
394
|
+
"""
|
|
395
|
+
angle_x_rad = np.radians(angle_x)
|
|
396
|
+
angle_y_rad = np.radians(angle_y)
|
|
397
|
+
angle_z_rad = np.radians(angle_z)
|
|
398
|
+
|
|
399
|
+
# Rotation matrix for x-axis
|
|
400
|
+
rotation_matrix_x = np.array([
|
|
401
|
+
[1, 0, 0],
|
|
402
|
+
[0, np.cos(angle_x_rad), -np.sin(angle_x_rad)],
|
|
403
|
+
[0, np.sin(angle_x_rad), np.cos(angle_x_rad)]
|
|
404
|
+
])
|
|
405
|
+
|
|
406
|
+
# Rotation matrix for y-axis
|
|
407
|
+
rotation_matrix_y = np.array([
|
|
408
|
+
[np.cos(angle_y_rad), 0, np.sin(angle_y_rad)],
|
|
409
|
+
[0, 1, 0],
|
|
410
|
+
[-np.sin(angle_y_rad), 0, np.cos(angle_y_rad)]
|
|
411
|
+
])
|
|
412
|
+
|
|
413
|
+
# Rotation matrix for z-axis
|
|
414
|
+
rotation_matrix_z = np.array([
|
|
415
|
+
[np.cos(angle_z_rad), -np.sin(angle_z_rad), 0],
|
|
416
|
+
[np.sin(angle_z_rad), np.cos(angle_z_rad), 0],
|
|
417
|
+
[0, 0, 1]
|
|
418
|
+
])
|
|
419
|
+
|
|
420
|
+
# Apply rotations
|
|
421
|
+
camera_rotated = np.dot(rotation_matrix_x, camera_coordinates)
|
|
422
|
+
camera_rotated = np.dot(rotation_matrix_y, camera_rotated)
|
|
423
|
+
camera_rotated = np.dot(rotation_matrix_z, camera_rotated)
|
|
424
|
+
|
|
425
|
+
return camera_rotated
|