ipyvasp 1.1.0__tar.gz → 1.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/PKG-INFO +1 -1
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp/_lattice.py +27 -15
- ipyvasp-1.1.2/ipyvasp/_version.py +1 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp/core/plot_toolkit.py +207 -131
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp/core/serializer.py +41 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp.egg-info/PKG-INFO +1 -1
- ipyvasp-1.1.0/ipyvasp/_version.py +0 -1
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/LICENSE +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/README.md +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp/__init__.py +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp/__main__.py +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp/_enplots.py +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp/bsdos.py +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp/cli.py +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp/core/__init__.py +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp/core/parser.py +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp/core/spatial_toolkit.py +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp/evals_dataframe.py +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp/lattice.py +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp/misc.py +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp/potential.py +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp/utils.py +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp/widgets.py +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp.egg-info/SOURCES.txt +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp.egg-info/dependency_links.txt +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp.egg-info/entry_points.txt +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp.egg-info/requires.txt +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/ipyvasp.egg-info/top_level.txt +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/setup.cfg +0 -0
- {ipyvasp-1.1.0 → ipyvasp-1.1.2}/setup.py +0 -0
|
@@ -909,6 +909,7 @@ def splot_bz(
|
|
|
909
909
|
shade=True,
|
|
910
910
|
alpha=0.4,
|
|
911
911
|
zoffset=0,
|
|
912
|
+
center=(0,0,0),
|
|
912
913
|
**kwargs,
|
|
913
914
|
):
|
|
914
915
|
"""Plots matplotlib's static figure of BZ/Cell. You can also plot in 2D on a 3D axes.
|
|
@@ -937,8 +938,9 @@ def splot_bz(
|
|
|
937
938
|
alpha : float
|
|
938
939
|
Opacity of filling in range [0,1]. Increase for clear viewpoint.
|
|
939
940
|
zoffset : float
|
|
940
|
-
Only used if plotting in 2D over a 3D axis. Default is 0. Any plane 'xy','yz' etc.
|
|
941
|
-
|
|
941
|
+
Only used if plotting in 2D over a 3D axis. Default is 0. Any plane 'xy','yz' etc can be offset to it's own normal.
|
|
942
|
+
center : (3,) array_like
|
|
943
|
+
Translation of origin in *basis coordinates* (fractional along the plotted basis). Use this to tile BZ with help of ``BrZoneData.tile`` fuction.
|
|
942
944
|
|
|
943
945
|
kwargs are passed to `plt.plot` or `Poly3DCollection` if `fill=True`.
|
|
944
946
|
|
|
@@ -962,6 +964,17 @@ def splot_bz(
|
|
|
962
964
|
if v not in [0, 1, 2]:
|
|
963
965
|
raise ValueError(f"`vectors` expects values in [0,1,2], got {vectors!r}")
|
|
964
966
|
|
|
967
|
+
if not isinstance(center, (tuple, list, np.ndarray)) or len(center) != 3:
|
|
968
|
+
raise ValueError("`center` must be a 3-sequence like (0,0,0) in basis coordinates.")
|
|
969
|
+
try:
|
|
970
|
+
center = np.array(center, dtype=float).reshape(3)
|
|
971
|
+
except Exception as e:
|
|
972
|
+
raise ValueError(f"`center` must be numeric, got {center!r}") from e
|
|
973
|
+
|
|
974
|
+
origin = to_R3(bz_data.basis, [center])[0] # (3,) cartesian shift
|
|
975
|
+
bz_data = bz_data.copy()
|
|
976
|
+
bz_data.vertices[:,:] += origin # apply on view, assignment is restricted
|
|
977
|
+
|
|
965
978
|
name = kwargs.pop("label", None) # will set only on single line
|
|
966
979
|
kwargs.pop("zdir", None) # remove , no need
|
|
967
980
|
is_subzone = hasattr(bz_data, "_specials") # For subzone
|
|
@@ -1031,13 +1044,14 @@ def splot_bz(
|
|
|
1031
1044
|
|
|
1032
1045
|
if vectors and not is_subzone:
|
|
1033
1046
|
s_basis = to_plane(normals[plane], bz_data.basis[(vectors,)])
|
|
1047
|
+
s_origin = to_plane(normals[plane], [origin]*len(vectors))
|
|
1034
1048
|
|
|
1035
1049
|
for k, b in zip(vectors, s_basis):
|
|
1036
1050
|
x, y = b[idxs[plane]]
|
|
1037
1051
|
l = r" ${}_{} $".format(_label, k + 1)
|
|
1038
1052
|
l = l + "\n" if y < 0 else "\n" + l
|
|
1039
1053
|
ha = "right" if x < 0 else "left"
|
|
1040
|
-
xyz = 0.8 * b + z0 if is3d else np.array([0.8 * x, 0.8 * y])
|
|
1054
|
+
xyz = 0.8 * b + z0 + s_origin[0] if is3d else np.array([0.8 * x, 0.8 * y]) + s_origin[0, idxs[plane]]
|
|
1041
1055
|
ax.text(
|
|
1042
1056
|
*xyz, l, va="center", ha=ha, clip_on=True
|
|
1043
1057
|
) # must clip to have limits of axes working.
|
|
@@ -1045,7 +1059,7 @@ def splot_bz(
|
|
|
1045
1059
|
*(xyz / 0.8), color="w", s=0.0005
|
|
1046
1060
|
) # Must be to scale below arrow.
|
|
1047
1061
|
if is3d:
|
|
1048
|
-
XYZ, UVW = (np.ones_like(s_basis) * z0).T, s_basis.T
|
|
1062
|
+
XYZ, UVW = (np.ones_like(s_basis) * z0 + s_origin).T, s_basis.T
|
|
1049
1063
|
quiver3d(
|
|
1050
1064
|
*XYZ,
|
|
1051
1065
|
*UVW,
|
|
@@ -1056,10 +1070,8 @@ def splot_bz(
|
|
|
1056
1070
|
mutation_scale=7,
|
|
1057
1071
|
)
|
|
1058
1072
|
else:
|
|
1059
|
-
s_zero = [0 for _ in s_basis] # either 3 or 2.
|
|
1060
1073
|
ax.quiver(
|
|
1061
|
-
|
|
1062
|
-
s_zero,
|
|
1074
|
+
*s_origin[:, idxs[plane]].T,
|
|
1063
1075
|
*s_basis[:, idxs[plane]].T,
|
|
1064
1076
|
lw=0.7,
|
|
1065
1077
|
color=color,
|
|
@@ -1126,9 +1138,9 @@ def splot_bz(
|
|
|
1126
1138
|
|
|
1127
1139
|
if vectors and not is_subzone:
|
|
1128
1140
|
for k, v in enumerate(0.35 * bz_data.basis):
|
|
1129
|
-
ax.text(*v, r"${}_{}$".format(_label, k + 1), va="center", ha="center")
|
|
1141
|
+
ax.text(*(v + origin), r"${}_{}$".format(_label, k + 1), va="center", ha="center")
|
|
1130
1142
|
|
|
1131
|
-
XYZ, UVW = [
|
|
1143
|
+
XYZ, UVW = np.array([origin] * 3).T, 0.3 * bz_data.basis.T
|
|
1132
1144
|
quiver3d(
|
|
1133
1145
|
*XYZ, *UVW, C="k", L=0.7, ax=ax, arrowstyle="-|>", mutation_scale=7
|
|
1134
1146
|
)
|
|
@@ -1147,7 +1159,7 @@ def splot_bz(
|
|
|
1147
1159
|
ax.set_zlabel(label.format("z"))
|
|
1148
1160
|
|
|
1149
1161
|
if vname == "b": # These needed for splot_kpath internally
|
|
1150
|
-
type(bz_data)._splot_kws = dict(plane=plane, zoffset=zoffset, ax=ax)
|
|
1162
|
+
type(bz_data)._splot_kws = dict(plane=plane, zoffset=zoffset, ax=ax, shift=origin)
|
|
1151
1163
|
|
|
1152
1164
|
return ax
|
|
1153
1165
|
|
|
@@ -1155,7 +1167,7 @@ def splot_bz(
|
|
|
1155
1167
|
def splot_kpath(
|
|
1156
1168
|
bz_data, kpoints, labels=None, fmt_label=lambda x: (x, {"color": "blue"}), **kwargs
|
|
1157
1169
|
):
|
|
1158
|
-
"""Plot k-path over
|
|
1170
|
+
"""Plot k-path over last plotted BZ. It will take ``ax``, ``plane`` and ``zoffset`` internally from most recent call to ``splot_bz``/``bz.splot``.
|
|
1159
1171
|
|
|
1160
1172
|
Parameters
|
|
1161
1173
|
----------
|
|
@@ -1181,9 +1193,9 @@ def splot_kpath(
|
|
|
1181
1193
|
if not np.ndim(kpoints) == 2 and np.shape(kpoints)[-1] == 3:
|
|
1182
1194
|
raise ValueError("kpoints must be 2D array of shape (N,3)")
|
|
1183
1195
|
|
|
1184
|
-
plane, ax, zoffset = [
|
|
1196
|
+
plane, ax, zoffset, shift = [
|
|
1185
1197
|
bz_data._splot_kws.get(attr, default) # class level attributes
|
|
1186
|
-
for attr, default in zip(["plane", "ax", "zoffset"], [None, None, 0])
|
|
1198
|
+
for attr, default in zip(["plane", "ax", "zoffset", "shift"], [None, None, 0,np.array([0.0, 0.0, 0.0])])
|
|
1187
1199
|
]
|
|
1188
1200
|
|
|
1189
1201
|
ijk = [0, 1, 2]
|
|
@@ -1219,9 +1231,9 @@ def splot_kpath(
|
|
|
1219
1231
|
if fmt_label is None:
|
|
1220
1232
|
fmt_label = lambda x: (x, {"color": "blue"})
|
|
1221
1233
|
|
|
1222
|
-
_validate_label_func(fmt_label,labels[0])
|
|
1234
|
+
_validate_label_func(fmt_label,labels[0])
|
|
1223
1235
|
|
|
1224
|
-
coords = bz_data.to_cartesian(kpoints)
|
|
1236
|
+
coords = bz_data.to_cartesian(kpoints) + shift
|
|
1225
1237
|
if _zoffset and plane:
|
|
1226
1238
|
normal = (
|
|
1227
1239
|
[0, 0, 1] if plane in "xyx" else [0, 1, 0] if plane in "xzx" else [1, 0, 0]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "1.1.2"
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from uuid import uuid1
|
|
2
2
|
from io import BytesIO
|
|
3
3
|
|
|
4
|
+
import types
|
|
4
5
|
import numpy as np
|
|
5
6
|
import matplotlib as mpl
|
|
6
7
|
import matplotlib.pyplot as plt
|
|
@@ -124,112 +125,193 @@ def quiver3d(X, Y, Z, U, V, W, ax=None, C="r", L=0.7, mutation_scale=10, **kwarg
|
|
|
124
125
|
|
|
125
126
|
|
|
126
127
|
def get_axes(
|
|
127
|
-
figsize=(3.4, 2.6),
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
elev=15,
|
|
138
|
-
ortho3d=True,
|
|
139
|
-
**subplots_adjust_kwargs,
|
|
140
|
-
):
|
|
141
|
-
"""Returns axes of initialized figure, based on plt.subplots().
|
|
142
|
-
If you want to access parent figure, use ax.get_figure() or current figure as plt.gcf().
|
|
143
|
-
|
|
128
|
+
shape=None, figsize=(3.4, 2.6),
|
|
129
|
+
widths=None, heights=None,
|
|
130
|
+
sharex=False, sharey=False,
|
|
131
|
+
layout='constrained',
|
|
132
|
+
axes_3d=None, axes_off=None,
|
|
133
|
+
axes_kws=None, fig_kws=None,
|
|
134
|
+
ortho3d=True, azim=45, elev=15,
|
|
135
|
+
**grid_spec_kws):
|
|
136
|
+
r"""Returns axes of initialized figure, based on given shape.
|
|
137
|
+
|
|
144
138
|
Parameters
|
|
145
139
|
----------
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
140
|
+
shape : int, tuple, list, or str, optional
|
|
141
|
+
- int: Number of columns in a 1xN grid.
|
|
142
|
+
- tuple (R, C): Number of rows and columns.
|
|
143
|
+
- str: Mosaic layout (e.g., "AA;BC" or "A\\nB"). Supports ';' and '\\n' as row separators.
|
|
144
|
+
- list: Mosaic layout as a list of lists.
|
|
145
|
+
figsize : tuple, optional
|
|
146
|
+
Width and height of the figure in inches. Defaults to (3.4, 2.6).
|
|
147
|
+
widths, heights : list, optional
|
|
148
|
+
Width and height ratios for the grid or mosaic.
|
|
149
|
+
sharex, sharey : bool, optional
|
|
150
|
+
If True, axes will share the x or y axis. Defaults to False.
|
|
151
|
+
layout : str, optional
|
|
152
|
+
The layout engine to use ('constrained', 'tight', or None).
|
|
153
|
+
Set to None if using manual hspace/wspace inside grid_spec_kws.
|
|
154
|
+
axes_3d : bool or list/tuple of keys, optional
|
|
155
|
+
- True: All created axes will be 3D projections.
|
|
156
|
+
- list/tuple: Only specified keys/indices will be 3D.
|
|
157
|
+
axes_off : int, str, or list of keys, optional
|
|
158
|
+
Keys or indices of axes to turn off (hide spines, ticks, and labels).
|
|
159
|
+
axes_kws : dict, optional
|
|
160
|
+
A dictionary of keywords for specific axes.
|
|
161
|
+
- Key -1: Applied to all axes (global style).
|
|
162
|
+
- Key [int/str]: Applied to specific axis index or mosaic label.
|
|
163
|
+
fig_kws : dict, optional
|
|
164
|
+
Additional keywords passed to plt.figure().
|
|
165
|
+
ortho3d : bool, optional
|
|
166
|
+
If True, uses orthographic projection for 3D axes. Defaults to True.
|
|
167
|
+
azim, elev : int, optional
|
|
168
|
+
Initial viewing angles for 3D axes.
|
|
169
|
+
**grid_spec_kws : dict
|
|
170
|
+
Keywords passed to GridSpec (e.g., hspace, wspace, left, right).
|
|
171
|
+
|
|
172
|
+
Returns
|
|
173
|
+
-------
|
|
174
|
+
matplotlib.axes.Axes, numpy.ndarray, or dict
|
|
175
|
+
- A single Axes object if 1x1.
|
|
176
|
+
- An array of Axes if grid mode.
|
|
177
|
+
- A dictionary of Axes mapping labels to objects if mosaic mode.
|
|
178
|
+
|
|
179
|
+
|
|
176
180
|
.. note::
|
|
177
181
|
There are extra methods added to each axes (only 2D) object.
|
|
178
182
|
``add_text``, ``add_legend``, ``add_colorbar``, ``color_wheel``,
|
|
179
|
-
``break_spines``, ``adjust_axes``, ``append_axes``, ``join_axes``.
|
|
183
|
+
``break_spines``, ``adjust_axes``, ``append_axes``, ``join_axes`` and ``stitch_axes``.
|
|
180
184
|
"""
|
|
181
|
-
if
|
|
182
|
-
|
|
183
|
-
gs_kw = dict({}) # Define Empty Dictionary.
|
|
184
|
-
if widths != [] and len(widths) == ncols:
|
|
185
|
-
gs_kw = dict({**gs_kw, "width_ratios": widths})
|
|
186
|
-
if heights != [] and len(heights) == nrows:
|
|
187
|
-
gs_kw = dict({**gs_kw, "height_ratios": heights})
|
|
188
|
-
fig, axs = plt.subplots(
|
|
189
|
-
nrows, ncols, figsize=figsize, gridspec_kw=gs_kw, sharex=sharex, sharey=sharey
|
|
190
|
-
)
|
|
191
|
-
proj = {"proj_type": "ortho"} if ortho3d else {} # For 3D only
|
|
192
|
-
if nrows * ncols == 1:
|
|
193
|
-
adjust_axes(ax=axs)
|
|
194
|
-
if axes_off == True:
|
|
195
|
-
axs.set_axis_off()
|
|
196
|
-
if axes_3d == True:
|
|
197
|
-
pos = axs.get_position()
|
|
198
|
-
axs.remove()
|
|
199
|
-
axs = fig.add_axes(pos, projection="3d", azim=azim, elev=elev, **proj)
|
|
200
|
-
setattr(axs, add_legend.__name__, add_legend.__get__(axs, type(axs)))
|
|
185
|
+
if axes_kws is not None and not isinstance(axes_kws, dict):
|
|
186
|
+
raise TypeError("axes_kws must be None or a dictionary.")
|
|
201
187
|
|
|
188
|
+
if fig_kws is not None and not isinstance(fig_kws, dict):
|
|
189
|
+
raise TypeError("fig_kws must be None or a dictionary.")
|
|
190
|
+
|
|
191
|
+
if axes_3d is not None and not isinstance(axes_3d, (bool, list, tuple)):
|
|
192
|
+
raise TypeError("axes_3d must be None, True, or a list/tuple of int/str.")
|
|
193
|
+
|
|
194
|
+
# Validate keys and values
|
|
195
|
+
axes_kws = axes_kws or {}
|
|
196
|
+
for key, value in axes_kws.items():
|
|
197
|
+
if not isinstance(key, (str, int)):
|
|
198
|
+
raise TypeError(f"axes_kws keys must be str or int, got {type(key).__name__} at key {key}")
|
|
199
|
+
if not isinstance(value, dict):
|
|
200
|
+
raise TypeError(f"axes_kws values must be dicts, got {type(value).__name__} for key '{key}'")
|
|
201
|
+
|
|
202
|
+
global_kw = axes_kws.pop(-1, {}) # The 'all-axes' style
|
|
203
|
+
p3d = {"projection": "3d", "proj_type": "ortho" if ortho3d else "persp"}
|
|
204
|
+
if axes_3d is True:
|
|
205
|
+
global_kw.update(p3d) # if user wants it
|
|
206
|
+
|
|
207
|
+
if axes_3d in (None, True, False):
|
|
208
|
+
axes_3d = () # done for it above
|
|
209
|
+
elif not isinstance(axes_3d, (list, tuple)): # If it's not None/True/False/list/tuple, it's an invalid type
|
|
210
|
+
raise TypeError(f"axes_3d must be None, True, or a list/tuple of keys. Got {type(axes_3d).__name__}")
|
|
211
|
+
|
|
212
|
+
if shape is None:
|
|
213
|
+
nr, nc, mode = 1, 1, 'grid'
|
|
214
|
+
elif isinstance(shape, int):
|
|
215
|
+
nr, nc = 1, shape
|
|
216
|
+
mode = 'grid'
|
|
217
|
+
elif isinstance(shape, (tuple, list)):
|
|
218
|
+
if len(shape) == 2 and all(isinstance(i, int) for i in shape):
|
|
219
|
+
nr, nc = shape; mode = 'grid'
|
|
220
|
+
elif all(isinstance(s, (str, list)) for s in shape):
|
|
221
|
+
mode = 'mosaic'
|
|
222
|
+
else: raise ValueError("shape must be (rows, cols) or a mosaic design.")
|
|
223
|
+
else: mode = 'mosaic'
|
|
224
|
+
|
|
225
|
+
if figsize[0] <= 2.38: mpl.rc("font", size=8)
|
|
226
|
+
|
|
227
|
+
if widths: grid_spec_kws['width_ratios'] = widths
|
|
228
|
+
if heights: grid_spec_kws['height_ratios'] = heights
|
|
229
|
+
|
|
230
|
+
f_kws = fig_kws or {}
|
|
231
|
+
f_kws.update({'figsize': figsize, 'layout': layout})
|
|
232
|
+
if any(k in grid_spec_kws for k in ['hspace', 'wspace']):
|
|
233
|
+
f_kws['layout'] = None
|
|
234
|
+
|
|
235
|
+
fig = plt.figure(**f_kws)
|
|
236
|
+
|
|
237
|
+
def is_match(target, key):
|
|
238
|
+
if target is True and (key in [0, "ax"]): return True
|
|
239
|
+
search = target if isinstance(target, (list, tuple)) else [target]
|
|
240
|
+
return key in search if target is not None else False
|
|
241
|
+
|
|
242
|
+
axs_dict = {}
|
|
243
|
+
if mode == 'mosaic':
|
|
244
|
+
pkws = axes_kws.copy()
|
|
245
|
+
for a3d in axes_3d:
|
|
246
|
+
pkws[a3d] = {**p3d, **pkws.get(a3d, {})}
|
|
247
|
+
axs_dict = fig.subplot_mosaic(shape, sharex=sharex, sharey=sharey,
|
|
248
|
+
subplot_kw=global_kw, gridspec_kw=grid_spec_kws, per_subplot_kw=pkws)
|
|
202
249
|
else:
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
250
|
+
gs = fig.add_gridspec(nr, nc, **grid_spec_kws)
|
|
251
|
+
main_ax = None
|
|
252
|
+
for i in range(nr * nc):
|
|
253
|
+
kw = {**global_kw, **axes_kws.get(i, {})}
|
|
254
|
+
if is_match(axes_3d, i): kw.update(p3d)
|
|
255
|
+
if main_ax and not kw.get("projection") == "3d":
|
|
256
|
+
if sharex: kw.setdefault('sharex', main_ax)
|
|
257
|
+
if sharey: kw.setdefault('sharey', main_ax)
|
|
258
|
+
ax = fig.add_subplot(gs[i // nc, i % nc], **kw)
|
|
259
|
+
if i == 0: main_ax = ax
|
|
260
|
+
axs_dict[i] = ax
|
|
261
|
+
|
|
262
|
+
for key, ax in axs_dict.items():
|
|
263
|
+
if is_match(axes_off, key): ax.set_axis_off()
|
|
264
|
+
if ax.name == "3d": ax.view_init(elev=elev, azim=azim)
|
|
265
|
+
_monkey_patch(ax)
|
|
266
|
+
|
|
267
|
+
if len(axs_dict) == 1: return list(axs_dict.values())[0]
|
|
268
|
+
return axs_dict if mode == 'mosaic' else np.array([axs_dict[k] for k in sorted(axs_dict.keys())], dtype=object)
|
|
269
|
+
|
|
270
|
+
def _monkey_patch(ax):
|
|
271
|
+
if ax.name != "3d":
|
|
272
|
+
for f in (
|
|
273
|
+
add_text, add_legend, add_colorbar,
|
|
274
|
+
color_wheel, color_cube, break_spines,
|
|
275
|
+
adjust_axes, append_axes, join_axes, stitch_axes
|
|
276
|
+
):
|
|
277
|
+
adjust_axes(ax)
|
|
278
|
+
if not hasattr(ax, f.__name__): # avoid resetting
|
|
279
|
+
setattr(ax, f.__name__, types.MethodType(f, ax))
|
|
280
|
+
|
|
281
|
+
def stitch_axes(ax1, ax2, symbol="\u2571", **kwargs):
|
|
282
|
+
"""Simulates broken axes by stitching ax1 and ax2 together. Need to fix heights/widths according
|
|
283
|
+
to given data for real aspect. Also plot the same data on each axes and set axes limits.
|
|
284
|
+
|
|
285
|
+
Parameters
|
|
286
|
+
----------
|
|
287
|
+
symbol: str
|
|
288
|
+
Defult is u'\u2571'. Its at 60 degrees. so you can apply rotation to make it any angle.
|
|
231
289
|
|
|
232
290
|
|
|
291
|
+
kwargs are passed to plt.text.
|
|
292
|
+
"""
|
|
293
|
+
p1 = ax1.get_position().get_points().mean(axis=0)
|
|
294
|
+
p2 = ax2.get_position().get_points().mean(axis=0)
|
|
295
|
+
is_vertical = abs(p1[1] - p2[1]) > abs(p1[0] - p2[0])
|
|
296
|
+
if is_vertical:
|
|
297
|
+
top, bot = (ax1, ax2) if p1[1] > p2[1] else (ax2, ax1)
|
|
298
|
+
top.spines['bottom'].set_visible(False)
|
|
299
|
+
top.tick_params(bottom=False, labelbottom=False)
|
|
300
|
+
bot.spines['top'].set_visible(False)
|
|
301
|
+
bot.tick_params(top=False, labeltop=False)
|
|
302
|
+
for ax, y in [(top, 0), (bot, 1)]:
|
|
303
|
+
kw = {**kwargs, 'transform': ax.transAxes, 'ha': 'center', 'va': 'center', 'clip_on': False}
|
|
304
|
+
[ax.text(x, y, symbol, **kw) for x in [0, 1]]
|
|
305
|
+
else:
|
|
306
|
+
left, right = (ax1, ax2) if p1[0] < p2[0] else (ax2, ax1)
|
|
307
|
+
left.spines['right'].set_visible(False)
|
|
308
|
+
left.tick_params(right=False, labelright=False)
|
|
309
|
+
right.spines['left'].set_visible(False)
|
|
310
|
+
right.tick_params(left=False, labelleft=False)
|
|
311
|
+
for ax, x in [(left, 1), (right, 0)]:
|
|
312
|
+
kw = {**kwargs, 'transform': ax.transAxes, 'ha': 'center', 'va': 'center', 'clip_on': False}
|
|
313
|
+
[ax.text(x, y, symbol, **kw) for y in [0, 1]]
|
|
314
|
+
|
|
233
315
|
def adjust_axes(
|
|
234
316
|
ax=None,
|
|
235
317
|
xticks=[],
|
|
@@ -332,18 +414,7 @@ def join_axes(ax1, ax2, **kwargs):
|
|
|
332
414
|
ax1.remove()
|
|
333
415
|
ax2.remove()
|
|
334
416
|
new_ax = fig.add_axes(new_bbox, **kwargs)
|
|
335
|
-
|
|
336
|
-
for f in [
|
|
337
|
-
add_text,
|
|
338
|
-
add_legend,
|
|
339
|
-
add_colorbar,
|
|
340
|
-
color_wheel,
|
|
341
|
-
break_spines,
|
|
342
|
-
adjust_axes,
|
|
343
|
-
append_axes,
|
|
344
|
-
]:
|
|
345
|
-
if new_ax.name != "3d":
|
|
346
|
-
setattr(new_ax, f.__name__, f.__get__(new_ax, type(new_ax)))
|
|
417
|
+
_monkey_patch(new_ax)
|
|
347
418
|
return new_ax
|
|
348
419
|
|
|
349
420
|
|
|
@@ -362,6 +433,9 @@ def break_spines(ax, spines, symbol="\u2571", **kwargs):
|
|
|
362
433
|
|
|
363
434
|
|
|
364
435
|
kwargs are passed to plt.text.
|
|
436
|
+
|
|
437
|
+
.. note::
|
|
438
|
+
Use ``stitch_axes` for better and automatically hiding spines.
|
|
365
439
|
"""
|
|
366
440
|
kwargs.update(transform=ax.transAxes, ha="center", va="center")
|
|
367
441
|
_spines = [spines] if isinstance(spines, str) else spines
|
|
@@ -523,18 +597,12 @@ def add_colorbar(
|
|
|
523
597
|
if cax is None:
|
|
524
598
|
position = "right" if vertical == True else "top"
|
|
525
599
|
cax = append_axes(ax, position=position, size="5%", pad=0.05)
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
[0, 1, 0],
|
|
533
|
-
[0, 1, 1],
|
|
534
|
-
[0, 0, 1],
|
|
535
|
-
[1, 0, 1],
|
|
536
|
-
]
|
|
537
|
-
)
|
|
600
|
+
|
|
601
|
+
mappable = ax.images[-1] if ax.images else ax.collections[-1] if ax.collections else None
|
|
602
|
+
if cmap_or_clist is None and mappable is not None:
|
|
603
|
+
_hsv_ = mappable.get_cmap()
|
|
604
|
+
elif cmap_or_clist is None:
|
|
605
|
+
colors = np.array([[1,0,1], [1,0,0], [1,1,0], [0,1,0], [0,1,1], [0,0,1], [1,0,1]])
|
|
538
606
|
_hsv_ = LSC.from_list("_hsv_", colors, N=N)
|
|
539
607
|
elif isinstance(cmap_or_clist, (list, np.ndarray)):
|
|
540
608
|
try:
|
|
@@ -542,16 +610,26 @@ def add_colorbar(
|
|
|
542
610
|
except Exception as e:
|
|
543
611
|
print(e, "\nFalling back to default color map!")
|
|
544
612
|
_hsv_ = None # fallback
|
|
545
|
-
elif isinstance(cmap_or_clist, str):
|
|
546
|
-
_hsv_ = cmap_or_clist # colormap name
|
|
547
613
|
else:
|
|
548
|
-
_hsv_ =
|
|
614
|
+
_hsv_ = cmap_or_clist
|
|
615
|
+
|
|
616
|
+
# Extract vmin/vmax if not provided
|
|
617
|
+
if mappable is not None:
|
|
618
|
+
d_min, d_max = mappable.get_clim()
|
|
619
|
+
vmin = vmin if vmin is not None else d_min
|
|
620
|
+
vmax = vmax if vmax is not None else d_max
|
|
621
|
+
|
|
622
|
+
# Fallback if no mappable exists and user didn't provide limits
|
|
623
|
+
vmin = 0 if vmin is None else vmin
|
|
624
|
+
vmax = 1 if vmax is None else vmax
|
|
549
625
|
|
|
550
626
|
if ticks != []:
|
|
551
627
|
if ticks is None: # should be before labels
|
|
552
|
-
ticks = np.linspace(
|
|
628
|
+
ticks = np.linspace(vmin, vmax, 3, endpoint=True)
|
|
553
629
|
if ticklabels is None:
|
|
554
630
|
ticklabels = ticks.round(digits).astype(str)
|
|
631
|
+
# Renormalize ticks after assigning ticklabels
|
|
632
|
+
ticks = (ticks - vmin) / (vmax - vmin)
|
|
555
633
|
|
|
556
634
|
elif isinstance(ticks, (list, tuple, np.ndarray)):
|
|
557
635
|
ticks = np.array(ticks)
|
|
@@ -562,6 +640,7 @@ def add_colorbar(
|
|
|
562
640
|
|
|
563
641
|
if ticklabels is None:
|
|
564
642
|
ticklabels = ticks.round(digits).astype(str)
|
|
643
|
+
|
|
565
644
|
# Renormalize ticks after assigning ticklabels
|
|
566
645
|
ticks = (ticks - _vmin) / (_vmax - _vmin)
|
|
567
646
|
else:
|
|
@@ -580,7 +659,7 @@ def add_colorbar(
|
|
|
580
659
|
grid_color=(1, 1, 1, 0),
|
|
581
660
|
grid_alpha=0,
|
|
582
661
|
)
|
|
583
|
-
ticks_param.update({tickloc: True}) # Only show on given side
|
|
662
|
+
ticks_param.update({tickloc: True,"labelsize":fontsize}) # Only show on given side
|
|
584
663
|
cax.tick_params(**ticks_param)
|
|
585
664
|
if vertical == False:
|
|
586
665
|
cax.imshow(
|
|
@@ -607,8 +686,6 @@ def add_colorbar(
|
|
|
607
686
|
cax.set_yticklabels(ticklabels, rotation=90, va="center")
|
|
608
687
|
cax.set_ylim([0, 1]) # enforce limit
|
|
609
688
|
|
|
610
|
-
for tick in cax.xaxis.get_major_ticks():
|
|
611
|
-
tick.label.set_fontsize(fontsize)
|
|
612
689
|
for child in cax.get_children():
|
|
613
690
|
if isinstance(child, mpl.spines.Spine):
|
|
614
691
|
child.set_color((1, 1, 1, 0.4))
|
|
@@ -644,10 +721,9 @@ def color_wheel(
|
|
|
644
721
|
if ax is None:
|
|
645
722
|
ax = get_axes()
|
|
646
723
|
if colormap is None:
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
colormap = "viridis"
|
|
724
|
+
mappable = ax.images[-1] if ax.images else ax.collections[-1] if ax.collections else None
|
|
725
|
+
colormap = mappable.get_cmap() if mappable is not None else "hsv"
|
|
726
|
+
|
|
651
727
|
pos = ax.get_position()
|
|
652
728
|
ratio = pos.height / pos.width
|
|
653
729
|
cpos = [
|
|
@@ -765,10 +765,51 @@ class BrZoneData(Dict2Data):
|
|
|
765
765
|
d = self.copy().to_dict()
|
|
766
766
|
d.update({"faces": faces, "vertices": vertices, "_specials": specials})
|
|
767
767
|
return self.__class__(d)
|
|
768
|
+
|
|
769
|
+
def tile(self, nxyz, filter=None):
|
|
770
|
+
"""Create a tiled array of BZ centers for visualization.
|
|
771
|
+
|
|
772
|
+
Parameters
|
|
773
|
+
----------
|
|
774
|
+
nxyz : list or tuple of 3 ints
|
|
775
|
+
Number of tiles along each cartesian direction [nx, ny, nz].
|
|
776
|
+
Must be 3 positive integers.
|
|
777
|
+
filter : callable, optional
|
|
778
|
+
Function filter(x,y,z) that takes cartesian coordinates and returns bool.
|
|
779
|
+
|
|
780
|
+
Returns
|
|
781
|
+
-------
|
|
782
|
+
numpy.ndarray
|
|
783
|
+
Array of shape (N, 3) containing BZ center positions in fractional coordinates.
|
|
784
|
+
|
|
785
|
+
Examples
|
|
786
|
+
--------
|
|
787
|
+
>>> centers = bz_data.tile([3, 3, 3])
|
|
788
|
+
>>> centers = bz_data.tile([5, 5, 1], filter=lambda x,y,z: x**2 + y**2 <= 2**2)
|
|
789
|
+
"""
|
|
790
|
+
if not isinstance(nxyz, (list, tuple, np.ndarray)) or len(nxyz) != 3:
|
|
791
|
+
raise ValueError("nxyz must be a list/tuple/array of 3 integers")
|
|
792
|
+
|
|
793
|
+
for i, n in enumerate(nxyz):
|
|
794
|
+
if not isinstance(n, int) or n < 1:
|
|
795
|
+
raise ValueError(f"nxyz[{i}] must be a positive integer")
|
|
796
|
+
|
|
797
|
+
xyz = self.to_cartesian(np.indices(np.ceil(nxyz).astype(int)).reshape((3,-1)).T)
|
|
798
|
+
# Apply filter if provided
|
|
799
|
+
if filter is not None:
|
|
800
|
+
if not callable(filter):
|
|
801
|
+
raise TypeError("filter must be callable")
|
|
802
|
+
|
|
803
|
+
mask = np.array([filter(x, y, z) for x, y, z in xyz])
|
|
804
|
+
xyz = xyz[mask]
|
|
805
|
+
|
|
806
|
+
# Convert to fractional coordinates and return
|
|
807
|
+
return self.to_fractional(xyz)
|
|
768
808
|
|
|
769
809
|
|
|
770
810
|
class CellData(Dict2Data):
|
|
771
811
|
splot, iplot = _methods_imported()
|
|
812
|
+
tile = BrZoneData.tile
|
|
772
813
|
_req_keys = ("basis", "faces", "vertices")
|
|
773
814
|
|
|
774
815
|
def __init__(self, d):
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
__version__ = "1.1.0"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|