ipyvasp 1.1.0__py2.py3-none-any.whl → 1.1.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ipyvasp/_lattice.py CHANGED
@@ -909,6 +909,7 @@ def splot_bz(
909
909
  shade=True,
910
910
  alpha=0.4,
911
911
  zoffset=0,
912
+ center=(0,0,0),
912
913
  **kwargs,
913
914
  ):
914
915
  """Plots matplotlib's static figure of BZ/Cell. You can also plot in 2D on a 3D axes.
@@ -937,8 +938,9 @@ def splot_bz(
937
938
  alpha : float
938
939
  Opacity of filling in range [0,1]. Increase for clear viewpoint.
939
940
  zoffset : float
940
- Only used if plotting in 2D over a 3D axis. Default is 0. Any plane 'xy','yz' etc.
941
-
941
+ Only used if plotting in 2D over a 3D axis. Default is 0. Any plane 'xy','yz' etc can be offset to it's own normal.
942
+ center : (3,) array_like
943
+ Translation of origin in *basis coordinates* (fractional along the plotted basis). Use this to tile BZ with help of ``BrZoneData.tile`` fuction.
942
944
 
943
945
  kwargs are passed to `plt.plot` or `Poly3DCollection` if `fill=True`.
944
946
 
@@ -962,6 +964,17 @@ def splot_bz(
962
964
  if v not in [0, 1, 2]:
963
965
  raise ValueError(f"`vectors` expects values in [0,1,2], got {vectors!r}")
964
966
 
967
+ if not isinstance(center, (tuple, list, np.ndarray)) or len(center) != 3:
968
+ raise ValueError("`center` must be a 3-sequence like (0,0,0) in basis coordinates.")
969
+ try:
970
+ center = np.array(center, dtype=float).reshape(3)
971
+ except Exception as e:
972
+ raise ValueError(f"`center` must be numeric, got {center!r}") from e
973
+
974
+ origin = to_R3(bz_data.basis, [center])[0] # (3,) cartesian shift
975
+ bz_data = bz_data.copy()
976
+ bz_data.vertices[:,:] += origin # apply on view, assignment is restricted
977
+
965
978
  name = kwargs.pop("label", None) # will set only on single line
966
979
  kwargs.pop("zdir", None) # remove , no need
967
980
  is_subzone = hasattr(bz_data, "_specials") # For subzone
@@ -1031,13 +1044,14 @@ def splot_bz(
1031
1044
 
1032
1045
  if vectors and not is_subzone:
1033
1046
  s_basis = to_plane(normals[plane], bz_data.basis[(vectors,)])
1047
+ s_origin = to_plane(normals[plane], [origin]*len(vectors))
1034
1048
 
1035
1049
  for k, b in zip(vectors, s_basis):
1036
1050
  x, y = b[idxs[plane]]
1037
1051
  l = r" ${}_{} $".format(_label, k + 1)
1038
1052
  l = l + "\n" if y < 0 else "\n" + l
1039
1053
  ha = "right" if x < 0 else "left"
1040
- xyz = 0.8 * b + z0 if is3d else np.array([0.8 * x, 0.8 * y])
1054
+ xyz = 0.8 * b + z0 + s_origin[0] if is3d else np.array([0.8 * x, 0.8 * y]) + s_origin[0, idxs[plane]]
1041
1055
  ax.text(
1042
1056
  *xyz, l, va="center", ha=ha, clip_on=True
1043
1057
  ) # must clip to have limits of axes working.
@@ -1045,7 +1059,7 @@ def splot_bz(
1045
1059
  *(xyz / 0.8), color="w", s=0.0005
1046
1060
  ) # Must be to scale below arrow.
1047
1061
  if is3d:
1048
- XYZ, UVW = (np.ones_like(s_basis) * z0).T, s_basis.T
1062
+ XYZ, UVW = (np.ones_like(s_basis) * z0 + s_origin).T, s_basis.T
1049
1063
  quiver3d(
1050
1064
  *XYZ,
1051
1065
  *UVW,
@@ -1056,10 +1070,8 @@ def splot_bz(
1056
1070
  mutation_scale=7,
1057
1071
  )
1058
1072
  else:
1059
- s_zero = [0 for _ in s_basis] # either 3 or 2.
1060
1073
  ax.quiver(
1061
- s_zero,
1062
- s_zero,
1074
+ *s_origin[:, idxs[plane]].T,
1063
1075
  *s_basis[:, idxs[plane]].T,
1064
1076
  lw=0.7,
1065
1077
  color=color,
@@ -1126,9 +1138,9 @@ def splot_bz(
1126
1138
 
1127
1139
  if vectors and not is_subzone:
1128
1140
  for k, v in enumerate(0.35 * bz_data.basis):
1129
- ax.text(*v, r"${}_{}$".format(_label, k + 1), va="center", ha="center")
1141
+ ax.text(*(v + origin), r"${}_{}$".format(_label, k + 1), va="center", ha="center")
1130
1142
 
1131
- XYZ, UVW = [[0, 0, 0], [0, 0, 0], [0, 0, 0]], 0.3 * bz_data.basis.T
1143
+ XYZ, UVW = np.array([origin] * 3).T, 0.3 * bz_data.basis.T
1132
1144
  quiver3d(
1133
1145
  *XYZ, *UVW, C="k", L=0.7, ax=ax, arrowstyle="-|>", mutation_scale=7
1134
1146
  )
@@ -1147,7 +1159,7 @@ def splot_bz(
1147
1159
  ax.set_zlabel(label.format("z"))
1148
1160
 
1149
1161
  if vname == "b": # These needed for splot_kpath internally
1150
- type(bz_data)._splot_kws = dict(plane=plane, zoffset=zoffset, ax=ax)
1162
+ type(bz_data)._splot_kws = dict(plane=plane, zoffset=zoffset, ax=ax, shift=origin)
1151
1163
 
1152
1164
  return ax
1153
1165
 
@@ -1155,7 +1167,7 @@ def splot_bz(
1155
1167
  def splot_kpath(
1156
1168
  bz_data, kpoints, labels=None, fmt_label=lambda x: (x, {"color": "blue"}), **kwargs
1157
1169
  ):
1158
- """Plot k-path over existing BZ. It will take ``ax``, ``plane`` and ``zoffset`` internally from most recent call to ``splot_bz``/``bz.splot``.
1170
+ """Plot k-path over last plotted BZ. It will take ``ax``, ``plane`` and ``zoffset`` internally from most recent call to ``splot_bz``/``bz.splot``.
1159
1171
 
1160
1172
  Parameters
1161
1173
  ----------
@@ -1181,9 +1193,9 @@ def splot_kpath(
1181
1193
  if not np.ndim(kpoints) == 2 and np.shape(kpoints)[-1] == 3:
1182
1194
  raise ValueError("kpoints must be 2D array of shape (N,3)")
1183
1195
 
1184
- plane, ax, zoffset = [
1196
+ plane, ax, zoffset, shift = [
1185
1197
  bz_data._splot_kws.get(attr, default) # class level attributes
1186
- for attr, default in zip(["plane", "ax", "zoffset"], [None, None, 0])
1198
+ for attr, default in zip(["plane", "ax", "zoffset", "shift"], [None, None, 0,np.array([0.0, 0.0, 0.0])])
1187
1199
  ]
1188
1200
 
1189
1201
  ijk = [0, 1, 2]
@@ -1219,9 +1231,9 @@ def splot_kpath(
1219
1231
  if fmt_label is None:
1220
1232
  fmt_label = lambda x: (x, {"color": "blue"})
1221
1233
 
1222
- _validate_label_func(fmt_label,labels[0])
1234
+ _validate_label_func(fmt_label,labels[0])
1223
1235
 
1224
- coords = bz_data.to_cartesian(kpoints)
1236
+ coords = bz_data.to_cartesian(kpoints) + shift
1225
1237
  if _zoffset and plane:
1226
1238
  normal = (
1227
1239
  [0, 0, 1] if plane in "xyx" else [0, 1, 0] if plane in "xzx" else [1, 0, 0]
ipyvasp/_version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "1.1.0"
1
+ __version__ = "1.1.2"
@@ -1,6 +1,7 @@
1
1
  from uuid import uuid1
2
2
  from io import BytesIO
3
3
 
4
+ import types
4
5
  import numpy as np
5
6
  import matplotlib as mpl
6
7
  import matplotlib.pyplot as plt
@@ -124,112 +125,193 @@ def quiver3d(X, Y, Z, U, V, W, ax=None, C="r", L=0.7, mutation_scale=10, **kwarg
124
125
 
125
126
 
126
127
  def get_axes(
127
- figsize=(3.4, 2.6),
128
- nrows=1,
129
- ncols=1,
130
- widths=[],
131
- heights=[],
132
- axes_off=[],
133
- axes_3d=[],
134
- sharex=False,
135
- sharey=False,
136
- azim=45,
137
- elev=15,
138
- ortho3d=True,
139
- **subplots_adjust_kwargs,
140
- ):
141
- """Returns axes of initialized figure, based on plt.subplots().
142
- If you want to access parent figure, use ax.get_figure() or current figure as plt.gcf().
143
-
128
+ shape=None, figsize=(3.4, 2.6),
129
+ widths=None, heights=None,
130
+ sharex=False, sharey=False,
131
+ layout='constrained',
132
+ axes_3d=None, axes_off=None,
133
+ axes_kws=None, fig_kws=None,
134
+ ortho3d=True, azim=45, elev=15,
135
+ **grid_spec_kws):
136
+ r"""Returns axes of initialized figure, based on given shape.
137
+
144
138
  Parameters
145
139
  ----------
146
- figsize : tuple
147
- Tuple (width, height). Default is (3.4,2.6).
148
- nrows : int
149
- Default 1.
150
- ncols : int
151
- Default 1.
152
- widths : list
153
- List with len(widths)==nrows, to set width ratios of subplots.
154
- heights : list
155
- List with len(heights)==ncols, to set height ratios of subplots.
156
- sharex : bool
157
- Share x-axis between plots, this removes shared ticks automatically.
158
- sharey : bool
159
- Share y-axis between plots, this removes shared ticks automatically.
160
- axes_off : bool or list
161
- Turn off axes visibility, If `nrows = ncols = 1, set True/False`.
162
- If anyone of `nrows or ncols > 1`, provide list of axes indices to turn off.
163
- If both `nrows and ncols > 1`, provide list of tuples (x_index,y_index) of axes.
164
- axes_3d : bool or list
165
- Change axes to 3D. If `nrows = ncols = 1, set True/False`.
166
- If anyone of `nrows or ncols > 1`, provide list of axes indices to turn off.
167
- If both `nrows and ncols > 1`, provide list of tuples (x_index,y_index) of axes.
168
- ortho3d : bool
169
- Only works for 3D axes. If True, x,y,z are orthogonal, otherwise perspective.
170
-
171
-
172
- `azim, elev` are passed to `ax.view_init`. Defualt values are 45,15 respectively.
173
-
174
- `subplots_adjust_kwargs` are passed to `plt.subplots_adjust`.
175
-
140
+ shape : int, tuple, list, or str, optional
141
+ - int: Number of columns in a 1xN grid.
142
+ - tuple (R, C): Number of rows and columns.
143
+ - str: Mosaic layout (e.g., "AA;BC" or "A\\nB"). Supports ';' and '\\n' as row separators.
144
+ - list: Mosaic layout as a list of lists.
145
+ figsize : tuple, optional
146
+ Width and height of the figure in inches. Defaults to (3.4, 2.6).
147
+ widths, heights : list, optional
148
+ Width and height ratios for the grid or mosaic.
149
+ sharex, sharey : bool, optional
150
+ If True, axes will share the x or y axis. Defaults to False.
151
+ layout : str, optional
152
+ The layout engine to use ('constrained', 'tight', or None).
153
+ Set to None if using manual hspace/wspace inside grid_spec_kws.
154
+ axes_3d : bool or list/tuple of keys, optional
155
+ - True: All created axes will be 3D projections.
156
+ - list/tuple: Only specified keys/indices will be 3D.
157
+ axes_off : int, str, or list of keys, optional
158
+ Keys or indices of axes to turn off (hide spines, ticks, and labels).
159
+ axes_kws : dict, optional
160
+ A dictionary of keywords for specific axes.
161
+ - Key -1: Applied to all axes (global style).
162
+ - Key [int/str]: Applied to specific axis index or mosaic label.
163
+ fig_kws : dict, optional
164
+ Additional keywords passed to plt.figure().
165
+ ortho3d : bool, optional
166
+ If True, uses orthographic projection for 3D axes. Defaults to True.
167
+ azim, elev : int, optional
168
+ Initial viewing angles for 3D axes.
169
+ **grid_spec_kws : dict
170
+ Keywords passed to GridSpec (e.g., hspace, wspace, left, right).
171
+
172
+ Returns
173
+ -------
174
+ matplotlib.axes.Axes, numpy.ndarray, or dict
175
+ - A single Axes object if 1x1.
176
+ - An array of Axes if grid mode.
177
+ - A dictionary of Axes mapping labels to objects if mosaic mode.
178
+
179
+
176
180
  .. note::
177
181
  There are extra methods added to each axes (only 2D) object.
178
182
  ``add_text``, ``add_legend``, ``add_colorbar``, ``color_wheel``,
179
- ``break_spines``, ``adjust_axes``, ``append_axes``, ``join_axes``.
183
+ ``break_spines``, ``adjust_axes``, ``append_axes``, ``join_axes`` and ``stitch_axes``.
180
184
  """
181
- if figsize[0] <= 2.38:
182
- mpl.rc("font", size=8)
183
- gs_kw = dict({}) # Define Empty Dictionary.
184
- if widths != [] and len(widths) == ncols:
185
- gs_kw = dict({**gs_kw, "width_ratios": widths})
186
- if heights != [] and len(heights) == nrows:
187
- gs_kw = dict({**gs_kw, "height_ratios": heights})
188
- fig, axs = plt.subplots(
189
- nrows, ncols, figsize=figsize, gridspec_kw=gs_kw, sharex=sharex, sharey=sharey
190
- )
191
- proj = {"proj_type": "ortho"} if ortho3d else {} # For 3D only
192
- if nrows * ncols == 1:
193
- adjust_axes(ax=axs)
194
- if axes_off == True:
195
- axs.set_axis_off()
196
- if axes_3d == True:
197
- pos = axs.get_position()
198
- axs.remove()
199
- axs = fig.add_axes(pos, projection="3d", azim=azim, elev=elev, **proj)
200
- setattr(axs, add_legend.__name__, add_legend.__get__(axs, type(axs)))
185
+ if axes_kws is not None and not isinstance(axes_kws, dict):
186
+ raise TypeError("axes_kws must be None or a dictionary.")
201
187
 
188
+ if fig_kws is not None and not isinstance(fig_kws, dict):
189
+ raise TypeError("fig_kws must be None or a dictionary.")
190
+
191
+ if axes_3d is not None and not isinstance(axes_3d, (bool, list, tuple)):
192
+ raise TypeError("axes_3d must be None, True, or a list/tuple of int/str.")
193
+
194
+ # Validate keys and values
195
+ axes_kws = axes_kws or {}
196
+ for key, value in axes_kws.items():
197
+ if not isinstance(key, (str, int)):
198
+ raise TypeError(f"axes_kws keys must be str or int, got {type(key).__name__} at key {key}")
199
+ if not isinstance(value, dict):
200
+ raise TypeError(f"axes_kws values must be dicts, got {type(value).__name__} for key '{key}'")
201
+
202
+ global_kw = axes_kws.pop(-1, {}) # The 'all-axes' style
203
+ p3d = {"projection": "3d", "proj_type": "ortho" if ortho3d else "persp"}
204
+ if axes_3d is True:
205
+ global_kw.update(p3d) # if user wants it
206
+
207
+ if axes_3d in (None, True, False):
208
+ axes_3d = () # done for it above
209
+ elif not isinstance(axes_3d, (list, tuple)): # If it's not None/True/False/list/tuple, it's an invalid type
210
+ raise TypeError(f"axes_3d must be None, True, or a list/tuple of keys. Got {type(axes_3d).__name__}")
211
+
212
+ if shape is None:
213
+ nr, nc, mode = 1, 1, 'grid'
214
+ elif isinstance(shape, int):
215
+ nr, nc = 1, shape
216
+ mode = 'grid'
217
+ elif isinstance(shape, (tuple, list)):
218
+ if len(shape) == 2 and all(isinstance(i, int) for i in shape):
219
+ nr, nc = shape; mode = 'grid'
220
+ elif all(isinstance(s, (str, list)) for s in shape):
221
+ mode = 'mosaic'
222
+ else: raise ValueError("shape must be (rows, cols) or a mosaic design.")
223
+ else: mode = 'mosaic'
224
+
225
+ if figsize[0] <= 2.38: mpl.rc("font", size=8)
226
+
227
+ if widths: grid_spec_kws['width_ratios'] = widths
228
+ if heights: grid_spec_kws['height_ratios'] = heights
229
+
230
+ f_kws = fig_kws or {}
231
+ f_kws.update({'figsize': figsize, 'layout': layout})
232
+ if any(k in grid_spec_kws for k in ['hspace', 'wspace']):
233
+ f_kws['layout'] = None
234
+
235
+ fig = plt.figure(**f_kws)
236
+
237
+ def is_match(target, key):
238
+ if target is True and (key in [0, "ax"]): return True
239
+ search = target if isinstance(target, (list, tuple)) else [target]
240
+ return key in search if target is not None else False
241
+
242
+ axs_dict = {}
243
+ if mode == 'mosaic':
244
+ pkws = axes_kws.copy()
245
+ for a3d in axes_3d:
246
+ pkws[a3d] = {**p3d, **pkws.get(a3d, {})}
247
+ axs_dict = fig.subplot_mosaic(shape, sharex=sharex, sharey=sharey,
248
+ subplot_kw=global_kw, gridspec_kw=grid_spec_kws, per_subplot_kw=pkws)
202
249
  else:
203
- _ = [adjust_axes(ax=ax) for ax in axs.ravel()]
204
- _ = [axs[inds].set_axis_off() for inds in axes_off if axes_off != []]
205
- if axes_3d != []:
206
- for inds in axes_3d:
207
- pos = axs[inds].get_position()
208
- axs[inds].remove()
209
- axs[inds] = fig.add_axes(
210
- pos, projection="3d", azim=azim, elev=elev, **proj
211
- )
212
- try:
213
- for ax in np.array([axs]).flatten():
214
- for f in [
215
- add_text,
216
- add_legend,
217
- add_colorbar,
218
- color_wheel,
219
- color_cube,
220
- break_spines,
221
- adjust_axes,
222
- append_axes,
223
- ]:
224
- if ax.name != "3d":
225
- setattr(ax, f.__name__, f.__get__(ax, type(ax)))
226
- except:
227
- pass
228
-
229
- plt.subplots_adjust(**subplots_adjust_kwargs)
230
- return axs
250
+ gs = fig.add_gridspec(nr, nc, **grid_spec_kws)
251
+ main_ax = None
252
+ for i in range(nr * nc):
253
+ kw = {**global_kw, **axes_kws.get(i, {})}
254
+ if is_match(axes_3d, i): kw.update(p3d)
255
+ if main_ax and not kw.get("projection") == "3d":
256
+ if sharex: kw.setdefault('sharex', main_ax)
257
+ if sharey: kw.setdefault('sharey', main_ax)
258
+ ax = fig.add_subplot(gs[i // nc, i % nc], **kw)
259
+ if i == 0: main_ax = ax
260
+ axs_dict[i] = ax
261
+
262
+ for key, ax in axs_dict.items():
263
+ if is_match(axes_off, key): ax.set_axis_off()
264
+ if ax.name == "3d": ax.view_init(elev=elev, azim=azim)
265
+ _monkey_patch(ax)
266
+
267
+ if len(axs_dict) == 1: return list(axs_dict.values())[0]
268
+ return axs_dict if mode == 'mosaic' else np.array([axs_dict[k] for k in sorted(axs_dict.keys())], dtype=object)
269
+
270
+ def _monkey_patch(ax):
271
+ if ax.name != "3d":
272
+ for f in (
273
+ add_text, add_legend, add_colorbar,
274
+ color_wheel, color_cube, break_spines,
275
+ adjust_axes, append_axes, join_axes, stitch_axes
276
+ ):
277
+ adjust_axes(ax)
278
+ if not hasattr(ax, f.__name__): # avoid resetting
279
+ setattr(ax, f.__name__, types.MethodType(f, ax))
280
+
281
+ def stitch_axes(ax1, ax2, symbol="\u2571", **kwargs):
282
+ """Simulates broken axes by stitching ax1 and ax2 together. Need to fix heights/widths according
283
+ to given data for real aspect. Also plot the same data on each axes and set axes limits.
284
+
285
+ Parameters
286
+ ----------
287
+ symbol: str
288
+ Defult is u'\u2571'. Its at 60 degrees. so you can apply rotation to make it any angle.
231
289
 
232
290
 
291
+ kwargs are passed to plt.text.
292
+ """
293
+ p1 = ax1.get_position().get_points().mean(axis=0)
294
+ p2 = ax2.get_position().get_points().mean(axis=0)
295
+ is_vertical = abs(p1[1] - p2[1]) > abs(p1[0] - p2[0])
296
+ if is_vertical:
297
+ top, bot = (ax1, ax2) if p1[1] > p2[1] else (ax2, ax1)
298
+ top.spines['bottom'].set_visible(False)
299
+ top.tick_params(bottom=False, labelbottom=False)
300
+ bot.spines['top'].set_visible(False)
301
+ bot.tick_params(top=False, labeltop=False)
302
+ for ax, y in [(top, 0), (bot, 1)]:
303
+ kw = {**kwargs, 'transform': ax.transAxes, 'ha': 'center', 'va': 'center', 'clip_on': False}
304
+ [ax.text(x, y, symbol, **kw) for x in [0, 1]]
305
+ else:
306
+ left, right = (ax1, ax2) if p1[0] < p2[0] else (ax2, ax1)
307
+ left.spines['right'].set_visible(False)
308
+ left.tick_params(right=False, labelright=False)
309
+ right.spines['left'].set_visible(False)
310
+ right.tick_params(left=False, labelleft=False)
311
+ for ax, x in [(left, 1), (right, 0)]:
312
+ kw = {**kwargs, 'transform': ax.transAxes, 'ha': 'center', 'va': 'center', 'clip_on': False}
313
+ [ax.text(x, y, symbol, **kw) for y in [0, 1]]
314
+
233
315
  def adjust_axes(
234
316
  ax=None,
235
317
  xticks=[],
@@ -332,18 +414,7 @@ def join_axes(ax1, ax2, **kwargs):
332
414
  ax1.remove()
333
415
  ax2.remove()
334
416
  new_ax = fig.add_axes(new_bbox, **kwargs)
335
- _ = adjust_axes(new_ax)
336
- for f in [
337
- add_text,
338
- add_legend,
339
- add_colorbar,
340
- color_wheel,
341
- break_spines,
342
- adjust_axes,
343
- append_axes,
344
- ]:
345
- if new_ax.name != "3d":
346
- setattr(new_ax, f.__name__, f.__get__(new_ax, type(new_ax)))
417
+ _monkey_patch(new_ax)
347
418
  return new_ax
348
419
 
349
420
 
@@ -362,6 +433,9 @@ def break_spines(ax, spines, symbol="\u2571", **kwargs):
362
433
 
363
434
 
364
435
  kwargs are passed to plt.text.
436
+
437
+ .. note::
438
+ Use ``stitch_axes` for better and automatically hiding spines.
365
439
  """
366
440
  kwargs.update(transform=ax.transAxes, ha="center", va="center")
367
441
  _spines = [spines] if isinstance(spines, str) else spines
@@ -523,18 +597,12 @@ def add_colorbar(
523
597
  if cax is None:
524
598
  position = "right" if vertical == True else "top"
525
599
  cax = append_axes(ax, position=position, size="5%", pad=0.05)
526
- if cmap_or_clist is None:
527
- colors = np.array(
528
- [
529
- [1, 0, 1],
530
- [1, 0, 0],
531
- [1, 1, 0],
532
- [0, 1, 0],
533
- [0, 1, 1],
534
- [0, 0, 1],
535
- [1, 0, 1],
536
- ]
537
- )
600
+
601
+ mappable = ax.images[-1] if ax.images else ax.collections[-1] if ax.collections else None
602
+ if cmap_or_clist is None and mappable is not None:
603
+ _hsv_ = mappable.get_cmap()
604
+ elif cmap_or_clist is None:
605
+ colors = np.array([[1,0,1], [1,0,0], [1,1,0], [0,1,0], [0,1,1], [0,0,1], [1,0,1]])
538
606
  _hsv_ = LSC.from_list("_hsv_", colors, N=N)
539
607
  elif isinstance(cmap_or_clist, (list, np.ndarray)):
540
608
  try:
@@ -542,16 +610,26 @@ def add_colorbar(
542
610
  except Exception as e:
543
611
  print(e, "\nFalling back to default color map!")
544
612
  _hsv_ = None # fallback
545
- elif isinstance(cmap_or_clist, str):
546
- _hsv_ = cmap_or_clist # colormap name
547
613
  else:
548
- _hsv_ = None # default fallback
614
+ _hsv_ = cmap_or_clist
615
+
616
+ # Extract vmin/vmax if not provided
617
+ if mappable is not None:
618
+ d_min, d_max = mappable.get_clim()
619
+ vmin = vmin if vmin is not None else d_min
620
+ vmax = vmax if vmax is not None else d_max
621
+
622
+ # Fallback if no mappable exists and user didn't provide limits
623
+ vmin = 0 if vmin is None else vmin
624
+ vmax = 1 if vmax is None else vmax
549
625
 
550
626
  if ticks != []:
551
627
  if ticks is None: # should be before labels
552
- ticks = np.linspace(1 / 6, 5 / 6, 3, endpoint=True)
628
+ ticks = np.linspace(vmin, vmax, 3, endpoint=True)
553
629
  if ticklabels is None:
554
630
  ticklabels = ticks.round(digits).astype(str)
631
+ # Renormalize ticks after assigning ticklabels
632
+ ticks = (ticks - vmin) / (vmax - vmin)
555
633
 
556
634
  elif isinstance(ticks, (list, tuple, np.ndarray)):
557
635
  ticks = np.array(ticks)
@@ -562,6 +640,7 @@ def add_colorbar(
562
640
 
563
641
  if ticklabels is None:
564
642
  ticklabels = ticks.round(digits).astype(str)
643
+
565
644
  # Renormalize ticks after assigning ticklabels
566
645
  ticks = (ticks - _vmin) / (_vmax - _vmin)
567
646
  else:
@@ -580,7 +659,7 @@ def add_colorbar(
580
659
  grid_color=(1, 1, 1, 0),
581
660
  grid_alpha=0,
582
661
  )
583
- ticks_param.update({tickloc: True}) # Only show on given side
662
+ ticks_param.update({tickloc: True,"labelsize":fontsize}) # Only show on given side
584
663
  cax.tick_params(**ticks_param)
585
664
  if vertical == False:
586
665
  cax.imshow(
@@ -607,8 +686,6 @@ def add_colorbar(
607
686
  cax.set_yticklabels(ticklabels, rotation=90, va="center")
608
687
  cax.set_ylim([0, 1]) # enforce limit
609
688
 
610
- for tick in cax.xaxis.get_major_ticks():
611
- tick.label.set_fontsize(fontsize)
612
689
  for child in cax.get_children():
613
690
  if isinstance(child, mpl.spines.Spine):
614
691
  child.set_color((1, 1, 1, 0.4))
@@ -644,10 +721,9 @@ def color_wheel(
644
721
  if ax is None:
645
722
  ax = get_axes()
646
723
  if colormap is None:
647
- try:
648
- colormap = plt.cm.get_cmap("hsv")
649
- except:
650
- colormap = "viridis"
724
+ mappable = ax.images[-1] if ax.images else ax.collections[-1] if ax.collections else None
725
+ colormap = mappable.get_cmap() if mappable is not None else "hsv"
726
+
651
727
  pos = ax.get_position()
652
728
  ratio = pos.height / pos.width
653
729
  cpos = [
@@ -765,10 +765,51 @@ class BrZoneData(Dict2Data):
765
765
  d = self.copy().to_dict()
766
766
  d.update({"faces": faces, "vertices": vertices, "_specials": specials})
767
767
  return self.__class__(d)
768
+
769
+ def tile(self, nxyz, filter=None):
770
+ """Create a tiled array of BZ centers for visualization.
771
+
772
+ Parameters
773
+ ----------
774
+ nxyz : list or tuple of 3 ints
775
+ Number of tiles along each cartesian direction [nx, ny, nz].
776
+ Must be 3 positive integers.
777
+ filter : callable, optional
778
+ Function filter(x,y,z) that takes cartesian coordinates and returns bool.
779
+
780
+ Returns
781
+ -------
782
+ numpy.ndarray
783
+ Array of shape (N, 3) containing BZ center positions in fractional coordinates.
784
+
785
+ Examples
786
+ --------
787
+ >>> centers = bz_data.tile([3, 3, 3])
788
+ >>> centers = bz_data.tile([5, 5, 1], filter=lambda x,y,z: x**2 + y**2 <= 2**2)
789
+ """
790
+ if not isinstance(nxyz, (list, tuple, np.ndarray)) or len(nxyz) != 3:
791
+ raise ValueError("nxyz must be a list/tuple/array of 3 integers")
792
+
793
+ for i, n in enumerate(nxyz):
794
+ if not isinstance(n, int) or n < 1:
795
+ raise ValueError(f"nxyz[{i}] must be a positive integer")
796
+
797
+ xyz = self.to_cartesian(np.indices(np.ceil(nxyz).astype(int)).reshape((3,-1)).T)
798
+ # Apply filter if provided
799
+ if filter is not None:
800
+ if not callable(filter):
801
+ raise TypeError("filter must be callable")
802
+
803
+ mask = np.array([filter(x, y, z) for x, y, z in xyz])
804
+ xyz = xyz[mask]
805
+
806
+ # Convert to fractional coordinates and return
807
+ return self.to_fractional(xyz)
768
808
 
769
809
 
770
810
  class CellData(Dict2Data):
771
811
  splot, iplot = _methods_imported()
812
+ tile = BrZoneData.tile
772
813
  _req_keys = ("basis", "faces", "vertices")
773
814
 
774
815
  def __init__(self, d):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ipyvasp
3
- Version: 1.1.0
3
+ Version: 1.1.2
4
4
  Summary: A processing tool for VASP DFT input/output processing in Jupyter Notebook.
5
5
  Home-page: https://github.com/massgh/ipyvasp
6
6
  Author: Abdul Saboor
@@ -1,8 +1,8 @@
1
1
  ipyvasp/__init__.py,sha256=pzTqeKuf6sN2GQmaexmMgG677ggT3sxIFyQDXq_2whU,1422
2
2
  ipyvasp/__main__.py,sha256=eJV1TZSiT8mC_VqAeksNnBI2I8mKMiPkEIlwikbtOjI,216
3
3
  ipyvasp/_enplots.py,sha256=gJ7S9WBmrxvDEbmoccDRaJG01kpx9oNlRf7mozigbgY,37872
4
- ipyvasp/_lattice.py,sha256=-JQaIdB2e6yFh5j_U2343FH2aM_c3cSdqSC19cJINq4,107218
5
- ipyvasp/_version.py,sha256=Zrv57EzpjdsuSPqsYvFkVsQKKRUOHFG7yURCf7qN-Tk,23
4
+ ipyvasp/_lattice.py,sha256=Lh-ip60M6APUckR4I0bQ6GMN2aaY7gpkA1Z0uuFt1A4,108117
5
+ ipyvasp/_version.py,sha256=JhNfc49cF1z8HnPYAyTem2s5Sn4s2XfgPp_HRwRQSfQ,23
6
6
  ipyvasp/bsdos.py,sha256=hVHpxkdT2ImRsxwFvMSMHxRSo4LqDM90DnUhwTP8vcs,32192
7
7
  ipyvasp/cli.py,sha256=-Lf-qdTvs7WyrA4ALNLaoqxMjLsZkXdPviyQps3ezqg,6880
8
8
  ipyvasp/evals_dataframe.py,sha256=n2iSH4D4ZbrxlAV4yDTVHbcl3ycfD0zfQYmTBcxjfkE,20789
@@ -14,12 +14,12 @@ ipyvasp/utils.py,sha256=1eVDhYzK3dr0AC_CouWrU3xIhbVJu7AABscV-qi_vAA,18000
14
14
  ipyvasp/widgets.py,sha256=Bpa4Y3Eopk_ZPFsVetfysClZP2q_2ONvmOwUol9vVGI,53154
15
15
  ipyvasp/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  ipyvasp/core/parser.py,sha256=o-uHyL_w9W0pmxoSt3JmLwwmmT3WRHlHSs_NoiHO-hs,39401
17
- ipyvasp/core/plot_toolkit.py,sha256=aURiPAu0tWTNctuA9LMEL3CfraRgAOyTJURzDrrO0r0,36232
18
- ipyvasp/core/serializer.py,sha256=aEc7K5jVga8gxm9Tt2OgBw8wnsmWZGtODBnwRJ_5sf0,38423
17
+ ipyvasp/core/plot_toolkit.py,sha256=ru8-FLJp8-X2p_Ft0F3K68qNSUJhy3jQD_S2zW50FWg,40930
18
+ ipyvasp/core/serializer.py,sha256=9xuLfl9LtUXPxTX2_zRs7nvfZsZOZN00bnxErAxDA5w,40065
19
19
  ipyvasp/core/spatial_toolkit.py,sha256=dXowREhiFzBvvr5f_bApzFhf8IzjH2E2Ix90oCBUetY,14885
20
- ipyvasp-1.1.0.dist-info/LICENSE,sha256=F3SO5RiAZOMfmMGf1KOuk2g_c4ObvuBJhd9iBLDgXoQ,1263
21
- ipyvasp-1.1.0.dist-info/METADATA,sha256=DCZLpgRjMRdVVJqT23OoP4AVFf_HeLUICtdBdSYa4_I,3218
22
- ipyvasp-1.1.0.dist-info/WHEEL,sha256=Kh9pAotZVRFj97E15yTA4iADqXdQfIVTHcNaZTjxeGM,110
23
- ipyvasp-1.1.0.dist-info/entry_points.txt,sha256=aU-gGjQG2Q8XfxDlNc_8__cwfp8WG2K5ZgFPInTm2jg,45
24
- ipyvasp-1.1.0.dist-info/top_level.txt,sha256=ftziWlMWu_1VpDP1sRTFrkfBnWxAi393HYDVu4wRhUk,8
25
- ipyvasp-1.1.0.dist-info/RECORD,,
20
+ ipyvasp-1.1.2.dist-info/LICENSE,sha256=F3SO5RiAZOMfmMGf1KOuk2g_c4ObvuBJhd9iBLDgXoQ,1263
21
+ ipyvasp-1.1.2.dist-info/METADATA,sha256=2wWh4djtZU6uiYMOCpsCuRDFiCAIQHToqTAf3Clkp3g,3218
22
+ ipyvasp-1.1.2.dist-info/WHEEL,sha256=Kh9pAotZVRFj97E15yTA4iADqXdQfIVTHcNaZTjxeGM,110
23
+ ipyvasp-1.1.2.dist-info/entry_points.txt,sha256=aU-gGjQG2Q8XfxDlNc_8__cwfp8WG2K5ZgFPInTm2jg,45
24
+ ipyvasp-1.1.2.dist-info/top_level.txt,sha256=ftziWlMWu_1VpDP1sRTFrkfBnWxAi393HYDVu4wRhUk,8
25
+ ipyvasp-1.1.2.dist-info/RECORD,,