tensorcircuit-nightly 1.2.1.dev20250721__py3-none-any.whl → 1.2.1.dev20250722__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

@@ -0,0 +1,1448 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ The lattice module for defining and manipulating lattice geometries.
4
+ """
5
+ import logging
6
+ import abc
7
+ from typing import (
8
+ Any,
9
+ Dict,
10
+ Hashable,
11
+ Iterator,
12
+ List,
13
+ Optional,
14
+ Tuple,
15
+ Union,
16
+ TYPE_CHECKING,
17
+ cast,
18
+ )
19
+
20
+ logger = logging.getLogger(__name__)
21
+ import numpy as np
22
+
23
+ from scipy.spatial import KDTree
24
+ from scipy.spatial.distance import pdist, squareform
25
+
26
+
27
+ # This block resolves a name resolution issue for the static type checker (mypy).
28
+ # GOAL:
29
+ # Keep `matplotlib` as an optional dependency, so it is only imported
30
+ # inside the `show()` method, not at the module level.
31
+ # PROBLEM:
32
+ # The type hint for the `ax` parameter in `show()`'s signature
33
+ # (`ax: Optional["matplotlib.axes.Axes"]`) needs to know what `matplotlib` is.
34
+ # Without this block, mypy would raise a "Name 'matplotlib' is not defined" error.
35
+ # SOLUTION:
36
+ # The `if TYPE_CHECKING:` block is ignored at runtime but processed by mypy.
37
+ # This makes the name `matplotlib` available to the type checker without
38
+ # creating a hard dependency for the user.
39
+ if TYPE_CHECKING:
40
+ import matplotlib.axes
41
+ from mpl_toolkits.mplot3d import Axes3D
42
+
43
+ SiteIndex = int
44
+ SiteIdentifier = Hashable
45
+ Coordinates = np.ndarray[Any, Any]
46
+ NeighborMap = Dict[SiteIndex, List[SiteIndex]]
47
+
48
+
49
+ class AbstractLattice(abc.ABC):
50
+ """Abstract base class for describing lattice systems.
51
+
52
+ This class defines the common interface for all lattice structures,
53
+ providing access to fundamental properties like site information
54
+ (count, coordinates, identifiers) and neighbor relationships.
55
+ Subclasses are responsible for implementing the specific logic for
56
+ generating the lattice points and calculating neighbor connections.
57
+
58
+ :param dimensionality: The spatial dimension of the lattice (e.g., 1, 2, 3).
59
+ :type dimensionality: int
60
+ """
61
+
62
+ def __init__(self, dimensionality: int):
63
+ """Initializes the base lattice class."""
64
+ self._dimensionality = dimensionality
65
+
66
+ # --- Internal Data Structures (to be populated by subclasses) ---
67
+ self._indices: List[SiteIndex] = []
68
+ self._identifiers: List[SiteIdentifier] = []
69
+ self._coordinates: List[Coordinates] = []
70
+ self._ident_to_idx: Dict[SiteIdentifier, SiteIndex] = {}
71
+ self._neighbor_maps: Dict[int, NeighborMap] = {}
72
+ self._distance_matrix: Optional[Coordinates] = None
73
+
74
+ @property
75
+ def num_sites(self) -> int:
76
+ """Returns the total number of sites (N) in the lattice."""
77
+ return len(self._indices)
78
+
79
+ @property
80
+ def dimensionality(self) -> int:
81
+ """Returns the spatial dimension of the lattice."""
82
+ return self._dimensionality
83
+
84
+ def __len__(self) -> int:
85
+ """Returns the total number of sites, enabling `len(lattice)`."""
86
+ return self.num_sites
87
+
88
+ # --- Public API for Accessing Lattice Information ---
89
+ @property
90
+ def distance_matrix(self) -> Coordinates:
91
+ """
92
+ Returns the full N x N distance matrix.
93
+ The matrix is computed on the first access and then cached for
94
+ subsequent calls. This computation can be expensive for large lattices.
95
+ """
96
+ if self._distance_matrix is None:
97
+ logger.info("Distance matrix not cached. Computing now...")
98
+ self._distance_matrix = self._compute_distance_matrix()
99
+ return self._distance_matrix
100
+
101
+ def _validate_index(self, index: SiteIndex) -> None:
102
+ """A private helper to check if a site index is within the valid range."""
103
+ if not (0 <= index < self.num_sites):
104
+ raise IndexError(
105
+ f"Site index {index} out of range (0-{self.num_sites - 1})"
106
+ )
107
+
108
+ def get_coordinates(self, index: SiteIndex) -> Coordinates:
109
+ """Gets the spatial coordinates of a site by its integer index.
110
+
111
+ :param index: The integer index of the site.
112
+ :type index: SiteIndex
113
+ :raises IndexError: If the site index is out of range.
114
+ :return: The spatial coordinates as a NumPy array.
115
+ :rtype: Coordinates
116
+ """
117
+ self._validate_index(index)
118
+ return self._coordinates[index]
119
+
120
+ def get_identifier(self, index: SiteIndex) -> SiteIdentifier:
121
+ """Gets the abstract identifier of a site by its integer index.
122
+
123
+ :param index: The integer index of the site.
124
+ :type index: SiteIndex
125
+ :raises IndexError: If the site index is out of range.
126
+ :return: The unique, hashable identifier of the site.
127
+ :rtype: SiteIdentifier
128
+ """
129
+ self._validate_index(index)
130
+ return self._identifiers[index]
131
+
132
+ def get_index(self, identifier: SiteIdentifier) -> SiteIndex:
133
+ """Gets the integer index of a site by its unique identifier.
134
+
135
+ :param identifier: The unique identifier of the site.
136
+ :type identifier: SiteIdentifier
137
+ :raises ValueError: If the identifier is not found in the lattice.
138
+ :return: The corresponding integer index of the site.
139
+ :rtype: SiteIndex
140
+ """
141
+ try:
142
+ return self._ident_to_idx[identifier]
143
+ except KeyError as e:
144
+ raise ValueError(
145
+ f"Identifier {identifier} not found in the lattice."
146
+ ) from e
147
+
148
+ def get_site_info(
149
+ self, index_or_identifier: Union[SiteIndex, SiteIdentifier]
150
+ ) -> Tuple[SiteIndex, SiteIdentifier, Coordinates]:
151
+ """Gets all information for a single site.
152
+
153
+ This method provides a convenient way to retrieve all relevant data for a
154
+ site (its index, identifier, and coordinates) by using either its
155
+ integer index or its unique identifier.
156
+
157
+ :param index_or_identifier: The integer
158
+ index or the unique identifier of the site to look up.
159
+ :type index_or_identifier: Union[SiteIndex, SiteIdentifier]
160
+ :raises IndexError: If the given index is out of bounds.
161
+ :raises ValueError: If the given identifier is not found in the lattice.
162
+ :return: A tuple containing:
163
+ - The site's integer index.
164
+ - The site's unique identifier.
165
+ - The site's coordinates as a NumPy array.
166
+ :rtype: Tuple[SiteIndex, SiteIdentifier, Coordinates]
167
+ """
168
+ if isinstance(index_or_identifier, int): # SiteIndex is an int
169
+ idx = index_or_identifier
170
+ self._validate_index(idx)
171
+ return idx, self._identifiers[idx], self._coordinates[idx]
172
+ else: # Identifier
173
+ ident = index_or_identifier
174
+ idx = self.get_index(ident)
175
+ return idx, ident, self._coordinates[idx]
176
+
177
+ def sites(self) -> Iterator[Tuple[SiteIndex, SiteIdentifier, Coordinates]]:
178
+ """Returns an iterator over all sites in the lattice.
179
+
180
+ This provides a convenient way to loop through all sites, for example:
181
+ `for idx, ident, coords in my_lattice.sites(): ...`
182
+
183
+ :return: An iterator where each item is a tuple containing the site's
184
+ index, identifier, and coordinates.
185
+ :rtype: Iterator[Tuple[SiteIndex, SiteIdentifier, Coordinates]]
186
+ """
187
+ for i in range(self.num_sites):
188
+ yield i, self._identifiers[i], self._coordinates[i]
189
+
190
+ def get_neighbors(self, index: SiteIndex, k: int = 1) -> List[SiteIndex]:
191
+ """Gets the list of k-th nearest neighbor indices for a given site.
192
+
193
+ :param index: The integer index of the center site.
194
+ :type index: SiteIndex
195
+ :param k: The order of the neighbors, where k=1 corresponds
196
+ to nearest neighbors (NN), k=2 to next-nearest neighbors (NNN),
197
+ and so on. Defaults to 1.
198
+ :type k: int, optional
199
+ :return: A list of integer indices for the neighboring sites.
200
+ Returns an empty list if neighbors for the given `k` have not been
201
+ pre-calculated or if the site has no such neighbors.
202
+ :rtype: List[SiteIndex]
203
+ """
204
+ if k not in self._neighbor_maps:
205
+ logger.info(
206
+ f"Neighbors for k={k} not pre-computed. Building now up to max_k={k}."
207
+ )
208
+ self._build_neighbors(max_k=k)
209
+
210
+ if k not in self._neighbor_maps:
211
+ return []
212
+
213
+ return self._neighbor_maps[k].get(index, [])
214
+
215
+ def get_neighbor_pairs(
216
+ self, k: int = 1, unique: bool = True
217
+ ) -> List[Tuple[SiteIndex, SiteIndex]]:
218
+ """Gets all pairs of k-th nearest neighbors, representing bonds.
219
+
220
+ :param k: The order of the neighbors to consider.
221
+ Defaults to 1.
222
+ :type k: int, optional
223
+ :param unique: If True, returns only one representation
224
+ for each pair (i, j) such that i < j, avoiding duplicates
225
+ like (j, i). If False, returns all directed pairs.
226
+ Defaults to True.
227
+ :type unique: bool, optional
228
+ :return: A list of tuples, where each
229
+ tuple is a pair of neighbor indices.
230
+ :rtype: List[Tuple[SiteIndex, SiteIndex]]
231
+ """
232
+
233
+ if k not in self._neighbor_maps:
234
+ logger.info(
235
+ f"Neighbor pairs for k={k} not pre-computed. Building now up to max_k={k}."
236
+ )
237
+ self._build_neighbors(max_k=k)
238
+
239
+ # After attempting to build, check again. If still not found, return empty.
240
+ if k not in self._neighbor_maps:
241
+ return []
242
+
243
+ pairs = []
244
+ for i, neighbors in self._neighbor_maps[k].items():
245
+ for j in neighbors:
246
+ if unique:
247
+ if i < j:
248
+ pairs.append((i, j))
249
+ else:
250
+ pairs.append((i, j))
251
+ return sorted(pairs)
252
+
253
+ # Sorting provides a deterministic output order
254
+ # --- Abstract Methods for Subclass Implementation ---
255
+
256
+ @abc.abstractmethod
257
+ def _build_lattice(self, *args: Any, **kwargs: Any) -> None:
258
+ """
259
+ Abstract method for subclasses to generate the lattice data.
260
+
261
+ A concrete implementation of this method in a subclass is responsible
262
+ for populating the following internal attributes:
263
+ - self._indices
264
+ - self._identifiers
265
+ - self._coordinates
266
+ - self._ident_to_idx
267
+ """
268
+ pass
269
+
270
+ @abc.abstractmethod
271
+ def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None:
272
+ """
273
+ Abstract method for subclasses to calculate neighbor relationships.
274
+
275
+ A concrete implementation of this method should calculate the neighbor
276
+ relationships up to `max_k` and populate the `self._neighbor_maps`
277
+ dictionary. The keys of the dictionary should be the neighbor order (k),
278
+ and the values should be a dictionary mapping site indices to their
279
+ list of k-th neighbors.
280
+ """
281
+ pass
282
+
283
+ @abc.abstractmethod
284
+ def _compute_distance_matrix(self) -> Coordinates:
285
+ """
286
+ Abstract method for subclasses to implement the actual matrix calculation.
287
+ This method is called by the `distance_matrix` property when the matrix
288
+ needs to be computed for the first time.
289
+ """
290
+ pass
291
+
292
+ def show(
293
+ self,
294
+ show_indices: bool = False,
295
+ show_identifiers: bool = False,
296
+ show_bonds_k: Optional[int] = None,
297
+ ax: Optional["matplotlib.axes.Axes"] = None,
298
+ bond_kwargs: Optional[Dict[str, Any]] = None,
299
+ **kwargs: Any,
300
+ ) -> None:
301
+ """Visualizes the lattice structure using Matplotlib.
302
+
303
+ This method supports 1D, 2D, and 3D plotting. For 1D lattices, sites
304
+ are plotted along the x-axis.
305
+
306
+ :param show_indices: If True, displays the integer index
307
+ next to each site. Defaults to False.
308
+ :type show_indices: bool, optional
309
+ :param show_identifiers: If True, displays the unique
310
+ identifier next to each site. Defaults to False.
311
+ :type show_identifiers: bool, optional
312
+ :param show_bonds_k: Specifies which order of
313
+ neighbor bonds to draw (e.g., 1 for NN, 2 for NNN). If None,
314
+ no bonds are drawn. If the specified neighbors have not been
315
+ calculated, a warning is printed. Defaults to None.
316
+ :type show_bonds_k: Optional[int], optional
317
+ :param ax: An existing Matplotlib Axes object to plot on.
318
+ If None, a new Figure and Axes are created automatically. Defaults to None.
319
+ :type ax: Optional["matplotlib.axes.Axes"], optional
320
+ :param bond_kwargs: A dictionary of keyword arguments for customizing bond appearance,
321
+ passed directly to the Matplotlib plot function. Defaults to None.
322
+ :type bond_kwargs: Optional[Dict[str, Any]], optional
323
+
324
+ :param kwargs: Additional keyword arguments to be passed directly to the
325
+ `matplotlib.pyplot.scatter` function for customizing site appearance.
326
+ """
327
+ try:
328
+ import matplotlib.pyplot as plt
329
+ except ImportError:
330
+ logger.error(
331
+ "Matplotlib is required for visualization. "
332
+ "Please install it using 'pip install matplotlib'."
333
+ )
334
+ return
335
+
336
+ # creat "fig_created_internally" as flag
337
+ fig_created_internally = False
338
+
339
+ if self.num_sites == 0:
340
+ logger.info("Lattice is empty, nothing to show.")
341
+ return
342
+ if self.dimensionality not in [1, 2, 3]:
343
+ logger.warning(
344
+ f"show() is not implemented for {self.dimensionality}D lattices."
345
+ )
346
+ return
347
+
348
+ if ax is None:
349
+ # when ax is none, make fig_created_internally true
350
+ fig_created_internally = True
351
+ if self.dimensionality == 3:
352
+ fig = plt.figure(figsize=(8, 8))
353
+ ax = fig.add_subplot(111, projection="3d")
354
+ else:
355
+ fig, ax = plt.subplots(figsize=(8, 8))
356
+ else:
357
+ fig = ax.figure # type: ignore
358
+
359
+ coords = np.array(self._coordinates)
360
+ scatter_args = {"s": 100, "zorder": 2}
361
+ scatter_args.update(kwargs)
362
+ if self.dimensionality == 1:
363
+ ax.scatter(coords[:, 0], np.zeros_like(coords[:, 0]), **scatter_args) # type: ignore
364
+ elif self.dimensionality == 2:
365
+ ax.scatter(coords[:, 0], coords[:, 1], **scatter_args) # type: ignore
366
+ elif self.dimensionality > 2: # Safely handle 3D and future higher dimensions
367
+ scatter_args["s"] = scatter_args.get("s", 100) // 2
368
+ ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2], **scatter_args) # type: ignore
369
+
370
+ if show_indices or show_identifiers:
371
+ for i in range(self.num_sites):
372
+ label = str(self._identifiers[i]) if show_identifiers else str(i)
373
+ offset = (
374
+ 0.02 * np.max(np.ptp(coords, axis=0)) if coords.size > 0 else 0.1
375
+ )
376
+
377
+ # Robust Logic: Decide plotting strategy based on known dimensionality.
378
+
379
+ if self.dimensionality == 1:
380
+ ax.text(coords[i, 0], offset, label, fontsize=9, ha="center")
381
+ elif self.dimensionality == 2:
382
+ ax.text(
383
+ coords[i, 0] + offset,
384
+ coords[i, 1] + offset,
385
+ label,
386
+ fontsize=9,
387
+ zorder=3,
388
+ )
389
+ elif self.dimensionality == 3:
390
+ ax_3d = cast("Axes3D", ax)
391
+ ax_3d.text(
392
+ coords[i, 0],
393
+ coords[i, 1],
394
+ coords[i, 2] + offset,
395
+ label,
396
+ fontsize=9,
397
+ zorder=3,
398
+ )
399
+
400
+ # Note: No 'else' needed as we already check dimensionality at the start.
401
+
402
+ if show_bonds_k is not None:
403
+ if show_bonds_k not in self._neighbor_maps:
404
+ logger.warning(
405
+ f"Cannot draw bonds. k={show_bonds_k} neighbors have not been calculated."
406
+ )
407
+ else:
408
+ try:
409
+ bonds = self.get_neighbor_pairs(k=show_bonds_k, unique=True)
410
+ plot_bond_kwargs = {
411
+ "color": "k",
412
+ "linestyle": "-",
413
+ "alpha": 0.6,
414
+ "zorder": 1,
415
+ }
416
+ if bond_kwargs:
417
+ plot_bond_kwargs.update(bond_kwargs)
418
+
419
+ if self.dimensionality > 2:
420
+ ax_3d = cast("Axes3D", ax)
421
+ for i, j in bonds:
422
+ p1, p2 = self._coordinates[i], self._coordinates[j]
423
+ ax_3d.plot(
424
+ [p1[0], p2[0]],
425
+ [p1[1], p2[1]],
426
+ [p1[2], p2[2]],
427
+ **plot_bond_kwargs,
428
+ )
429
+ else:
430
+ for i, j in bonds:
431
+ p1, p2 = self._coordinates[i], self._coordinates[j]
432
+ if self.dimensionality == 1: # type: ignore
433
+
434
+ ax.plot([p1[0], p2[0]], [0, 0], **plot_bond_kwargs) # type: ignore
435
+ else: # dimensionality == 2
436
+ ax.plot([p1[0], p2[0]], [p1[1], p2[1]], **plot_bond_kwargs) # type: ignore
437
+
438
+ except ValueError as e:
439
+ logger.info(f"Could not draw bonds: {e}")
440
+
441
+ ax.set_title(f"{self.__class__.__name__} ({self.num_sites} sites)")
442
+ if self.dimensionality == 2:
443
+ ax.set_aspect("equal", adjustable="box")
444
+ ax.set_xlabel("x")
445
+ if self.dimensionality > 1:
446
+ ax.set_ylabel("y")
447
+ if self.dimensionality > 2 and hasattr(ax, "set_zlabel"):
448
+ ax.set_zlabel("z")
449
+ ax.grid(True)
450
+
451
+ # 3. whether plt.show()
452
+ if fig_created_internally:
453
+ plt.show()
454
+
455
+ def _identify_distance_shells(
456
+ self,
457
+ all_distances_sq: Union[Coordinates, List[float]],
458
+ max_k: int,
459
+ tol: float = 1e-6,
460
+ ) -> List[float]:
461
+ """Identifies unique distance shells from a list of squared distances.
462
+
463
+ This helper function takes a flat list of squared distances, sorts them,
464
+ and identifies the first `max_k` unique distance shells based on a
465
+ numerical tolerance.
466
+
467
+ :param all_distances_sq: A list or array
468
+ of all squared distances between pairs of sites.
469
+ :type all_distances_sq: Union[np.ndarray, List[float]]
470
+ :param max_k: The maximum number of neighbor shells to identify.
471
+ :type max_k: int
472
+ :param tol: The numerical tolerance to consider two distances equal.
473
+ :type tol: float
474
+ :return: A sorted list of squared distances representing the shells.
475
+ :rtype: List[float]
476
+ """
477
+ ZERO_THRESHOLD_SQ = 1e-12
478
+
479
+ all_distances_sq = np.asarray(all_distances_sq)
480
+ # Now, the .size call below is guaranteed to be safe.
481
+ if all_distances_sq.size == 0:
482
+ return []
483
+
484
+ sorted_dist = np.sort(all_distances_sq[all_distances_sq > ZERO_THRESHOLD_SQ])
485
+
486
+ if sorted_dist.size == 0:
487
+ return []
488
+
489
+ # Identify shells using the user-provided tolerance.
490
+ dist_shells = [sorted_dist[0]]
491
+
492
+ for d_sq in sorted_dist[1:]:
493
+ if len(dist_shells) >= max_k:
494
+ break
495
+ # If the current distance is notably larger than the last shell's distance
496
+ if d_sq > dist_shells[-1] + tol**2:
497
+ dist_shells.append(d_sq)
498
+
499
+ return dist_shells
500
+
501
+ def _build_neighbors_by_distance_matrix(
502
+ self, max_k: int = 2, tol: float = 1e-6
503
+ ) -> None:
504
+ """A generic, distance-based neighbor finding method.
505
+
506
+ This method calculates the full N x N distance matrix to find neighbor
507
+ shells. It is computationally expensive for large N (O(N^2)) and is
508
+ best suited for non-periodic or custom-defined lattices.
509
+
510
+ :param max_k: The maximum number of neighbor shells to
511
+ calculate. Defaults to 2.
512
+ :type max_k: int, optional
513
+ :param tol: The numerical tolerance for distance
514
+ comparisons. Defaults to 1e-6.
515
+ :type tol: float, optional
516
+ """
517
+ if self.num_sites < 2:
518
+ return
519
+
520
+ all_coords = np.array(self._coordinates)
521
+ dist_matrix_sq = np.sum(
522
+ (all_coords[:, np.newaxis, :] - all_coords[np.newaxis, :, :]) ** 2, axis=-1
523
+ )
524
+
525
+ all_distances_sq = dist_matrix_sq.flatten()
526
+ dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
527
+
528
+ self._neighbor_maps = {k: {} for k in range(1, len(dist_shells_sq) + 1)}
529
+ for k_idx, target_d_sq in enumerate(dist_shells_sq):
530
+ k = k_idx + 1
531
+ current_k_map: Dict[int, List[int]] = {}
532
+ for i in range(self.num_sites):
533
+ neighbor_indices = np.where(
534
+ np.isclose(dist_matrix_sq[i], target_d_sq, rtol=0, atol=tol**2)
535
+ )[0]
536
+ if len(neighbor_indices) > 0:
537
+ current_k_map[i] = sorted(neighbor_indices.tolist())
538
+ self._neighbor_maps[k] = current_k_map
539
+ self._distance_matrix = np.sqrt(dist_matrix_sq)
540
+
541
+
542
+ class TILattice(AbstractLattice):
543
+ """Describes a periodic lattice with translational invariance.
544
+
545
+ This class serves as a base for any lattice defined by a repeating unit
546
+ cell. The geometry is specified by lattice vectors, the coordinates of
547
+ basis sites within a unit cell, and the total size of the lattice in
548
+ terms of unit cells.
549
+
550
+ The site identifier for this class is a tuple in the format of
551
+ `(uc_coord_1, ..., uc_coord_d, basis_index)`, where `uc_coord` represents
552
+ the integer coordinate of the unit cell and `basis_index` is the index
553
+ of the site within that unit cell's basis.
554
+
555
+ :param dimensionality: The spatial dimension of the lattice.
556
+ :type dimensionality: int
557
+ :param lattice_vectors: The lattice vectors defining the unit
558
+ cell, given as row vectors. Shape: (dimensionality, dimensionality).
559
+ For example, in 2D: `np.array([[ax, ay], [bx, by]])`.
560
+ :type lattice_vectors: np.ndarray
561
+ :param basis_coords: The Cartesian coordinates of the basis sites
562
+ within the unit cell. Shape: (num_basis_sites, dimensionality).
563
+ For a simple Bravais lattice, this would be `np.array([[0, 0]])`.
564
+ :type basis_coords: np.ndarray
565
+ :param size: A tuple specifying the number of unit cells
566
+ to generate in each lattice vector direction (e.g., (Nx, Ny)).
567
+ :type size: Tuple[int, ...]
568
+ :param pbc: Specifies whether
569
+ periodic boundary conditions are applied. Can be a single boolean
570
+ for all dimensions or a tuple of booleans for each dimension
571
+ individually. Defaults to True.
572
+ :type pbc: Union[bool, Tuple[bool, ...]], optional
573
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
574
+ up to the given order `k` upon initialization. Defaults to None.
575
+ :type precompute_neighbors: Optional[int], optional
576
+
577
+ """
578
+
579
+ def __init__(
580
+ self,
581
+ dimensionality: int,
582
+ lattice_vectors: Coordinates,
583
+ basis_coords: Coordinates,
584
+ size: Tuple[int, ...],
585
+ pbc: Union[bool, Tuple[bool, ...]] = True,
586
+ precompute_neighbors: Optional[int] = None,
587
+ ):
588
+ """Initializes the Translationally Invariant Lattice."""
589
+ super().__init__(dimensionality)
590
+ assert lattice_vectors.shape == (
591
+ dimensionality,
592
+ dimensionality,
593
+ ), "Lattice vectors shape mismatch"
594
+ assert (
595
+ basis_coords.shape[1] == dimensionality
596
+ ), "Basis coordinates dimension mismatch"
597
+ assert len(size) == dimensionality, "Size tuple length mismatch"
598
+
599
+ self.lattice_vectors = lattice_vectors
600
+ self.basis_coords = basis_coords
601
+ self.num_basis = basis_coords.shape[0]
602
+ self.size = size
603
+ if isinstance(pbc, bool):
604
+ self.pbc = tuple([pbc] * dimensionality)
605
+ else:
606
+ assert len(pbc) == dimensionality, "PBC tuple length mismatch"
607
+ self.pbc = tuple(pbc)
608
+
609
+ # Build the lattice sites and their neighbor relationships
610
+ self._build_lattice()
611
+ if precompute_neighbors is not None and precompute_neighbors > 0:
612
+ logger.info(f"Pre-computing neighbors up to k={precompute_neighbors}...")
613
+ self._build_neighbors(max_k=precompute_neighbors)
614
+
615
+ def _build_lattice(self) -> None:
616
+ """Generates all site information for the periodic lattice.
617
+
618
+ This method iterates through each unit cell defined by `self.size`,
619
+ and for each unit cell, it iterates through all basis sites. It then
620
+ calculates the real-space coordinates and creates a unique identifier
621
+ for each site, populating the internal lattice data structures.
622
+ """
623
+ current_index = 0
624
+
625
+ # Iterate over all unit cell coordinates elegantly using np.ndindex
626
+ for cell_coord in np.ndindex(self.size):
627
+ cell_coord_arr = np.array(cell_coord)
628
+ # R = n1*a1 + n2*a2 + ...
629
+ cell_vector = np.dot(cell_coord_arr, self.lattice_vectors)
630
+
631
+ # Iterate over the basis sites within the unit cell
632
+ for basis_index in range(self.num_basis):
633
+ basis_vec = self.basis_coords[basis_index]
634
+
635
+ # Calculate the real-space coordinate
636
+ coord = cell_vector + basis_vec
637
+ # Create a structured identifier
638
+ identifier = cell_coord + (basis_index,)
639
+
640
+ # Store site information
641
+ self._indices.append(current_index)
642
+ self._identifiers.append(identifier)
643
+ self._coordinates.append(coord)
644
+ self._ident_to_idx[identifier] = current_index
645
+ current_index += 1
646
+
647
+ def _get_distance_matrix_with_mic(self) -> Coordinates:
648
+ """
649
+ Computes the full N x N distance matrix, correctly applying the
650
+ Minimum Image Convention (MIC) for all periodic dimensions.
651
+ """
652
+ all_coords = np.array(self._coordinates)
653
+ size_arr = np.array(self.size)
654
+ system_vectors = self.lattice_vectors * size_arr[:, np.newaxis]
655
+
656
+ # Generate translation vectors ONLY for periodic dimensions
657
+ pbc_dims = [d for d in range(self.dimensionality) if self.pbc[d]]
658
+ translations = [np.zeros(self.dimensionality)]
659
+ if pbc_dims:
660
+ num_pbc_dims = len(pbc_dims)
661
+ pbc_system_vectors = system_vectors[pbc_dims, :]
662
+
663
+ # Create all 3^k - 1 non-zero shifts for k periodic dimensions
664
+ shift_options = [np.array([-1, 0, 1])] * num_pbc_dims
665
+ shifts_grid = np.meshgrid(*shift_options, indexing="ij")
666
+ all_shifts = np.stack(shifts_grid, axis=-1).reshape(-1, num_pbc_dims)
667
+ all_shifts = all_shifts[np.any(all_shifts != 0, axis=1)]
668
+
669
+ pbc_translations = all_shifts @ pbc_system_vectors
670
+ translations.extend(pbc_translations)
671
+
672
+ translations_arr = np.array(translations, dtype=float)
673
+
674
+ # Calculate the distance matrix applying MIC
675
+ dist_matrix_sq = np.full((self.num_sites, self.num_sites), np.inf, dtype=float)
676
+ for i in range(self.num_sites):
677
+ displacements = all_coords - all_coords[i]
678
+ image_displacements = (
679
+ displacements[:, np.newaxis, :] - translations_arr[np.newaxis, :, :]
680
+ )
681
+ image_d_sq = np.sum(image_displacements**2, axis=2)
682
+ dist_matrix_sq[i, :] = np.min(image_d_sq, axis=1)
683
+
684
+ return cast(Coordinates, np.sqrt(dist_matrix_sq))
685
+
686
+ def _build_neighbors(self, max_k: int = 2, **kwargs: Any) -> None:
687
+ """Calculates neighbor relationships for the periodic lattice.
688
+
689
+ This method calculates neighbor relationships by computing the full N x N
690
+ distance matrix. It robustly handles all boundary conditions (fully
691
+ periodic, open, or mixed) by applying the Minimum Image Convention
692
+ (MIC) only to the periodic dimensions.
693
+
694
+ From this distance matrix, it identifies unique neighbor shells up to
695
+ the specified `max_k` and populates the neighbor maps. The computed
696
+ distance matrix is then cached for future use.
697
+
698
+ :param max_k: The maximum number of neighbor shells to
699
+ calculate. Defaults to 2.
700
+ :type max_k: int, optional
701
+ :param tol: The numerical tolerance for distance
702
+ comparisons. Defaults to 1e-6.
703
+ :type tol: float, optional
704
+ """
705
+ tol = kwargs.get("tol", 1e-6)
706
+ dist_matrix = self._get_distance_matrix_with_mic()
707
+ dist_matrix_sq = dist_matrix**2
708
+ self._distance_matrix = dist_matrix
709
+ all_distances_sq = dist_matrix_sq.flatten()
710
+ dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
711
+
712
+ self._neighbor_maps = {k: {} for k in range(1, len(dist_shells_sq) + 1)}
713
+ for k_idx, target_d_sq in enumerate(dist_shells_sq):
714
+ k = k_idx + 1
715
+ current_k_map: Dict[int, List[int]] = {}
716
+ match_indices = np.where(
717
+ np.isclose(dist_matrix_sq, target_d_sq, rtol=0, atol=tol**2)
718
+ )
719
+ for i, j in zip(*match_indices):
720
+ if i == j:
721
+ continue
722
+ if i not in current_k_map:
723
+ current_k_map[i] = []
724
+ current_k_map[i].append(j)
725
+
726
+ for i in current_k_map:
727
+ current_k_map[i].sort()
728
+
729
+ self._neighbor_maps[k] = current_k_map
730
+
731
+ def _compute_distance_matrix(self) -> Coordinates:
732
+ """Computes the distance matrix using the Minimum Image Convention."""
733
+ return self._get_distance_matrix_with_mic()
734
+
735
+
736
+ class SquareLattice(TILattice):
737
+ """A 2D square lattice.
738
+
739
+ This is a concrete implementation of a translationally invariant lattice
740
+ representing a simple square grid. It is a Bravais lattice with a
741
+ single-site basis.
742
+
743
+ :param size: A tuple (Nx, Ny) specifying the number of
744
+ unit cells (sites) in the x and y directions.
745
+ :type size: Tuple[int, int]
746
+ :param lattice_constant: The distance between two adjacent
747
+ sites. Defaults to 1.0.
748
+ :type lattice_constant: float, optional
749
+ :param pbc: Specifies periodic boundary conditions. Can be a single boolean
750
+ for all dimensions or a tuple of booleans for each dimension
751
+ individually. Defaults to True.
752
+ :type pbc: Union[bool, Tuple[bool, bool]], optional
753
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
754
+ up to the given order `k` upon initialization. Defaults to None.
755
+ :type precompute_neighbors: Optional[int], optional
756
+ """
757
+
758
+ def __init__(
759
+ self,
760
+ size: Tuple[int, int],
761
+ lattice_constant: float = 1.0,
762
+ pbc: Union[bool, Tuple[bool, bool]] = True,
763
+ precompute_neighbors: Optional[int] = None,
764
+ ):
765
+ """Initializes the SquareLattice."""
766
+ dimensionality = 2
767
+
768
+ # Define lattice vectors for a square lattice
769
+ lattice_vectors = np.array([[lattice_constant, 0.0], [0.0, lattice_constant]])
770
+
771
+ # A square lattice has a single site in its basis
772
+ basis_coords = np.array([[0.0, 0.0]])
773
+
774
+ # Call the parent TILattice constructor with these parameters
775
+ super().__init__(
776
+ dimensionality=dimensionality,
777
+ lattice_vectors=lattice_vectors,
778
+ basis_coords=basis_coords,
779
+ size=size,
780
+ pbc=pbc,
781
+ precompute_neighbors=precompute_neighbors,
782
+ )
783
+
784
+
785
+ class HoneycombLattice(TILattice):
786
+ """A 2D honeycomb lattice.
787
+
788
+ This is a classic example of a composite lattice. It consists of a
789
+ two-site basis (sublattices A and B) on an underlying triangular
790
+ Bravais lattice.
791
+
792
+ :param size: A tuple (Nx, Ny) specifying the number of unit
793
+ cells along the two lattice vector directions.
794
+ :type size: Tuple[int, int]
795
+ :param lattice_constant: The bond length, i.e., the distance
796
+ between two nearest neighbor sites. Defaults to 1.0.
797
+ :type lattice_constant: float, optional
798
+ :param pbc: Specifies periodic
799
+ boundary conditions. Defaults to True.
800
+ :type pbc: Union[bool, Tuple[bool, bool]], optional
801
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
802
+ up to the given order `k` upon initialization. Defaults to None.
803
+ :type precompute_neighbors: Optional[int], optional
804
+
805
+ """
806
+
807
+ def __init__(
808
+ self,
809
+ size: Tuple[int, int],
810
+ lattice_constant: float = 1.0,
811
+ pbc: Union[bool, Tuple[bool, bool]] = True,
812
+ precompute_neighbors: Optional[int] = None,
813
+ ):
814
+ """Initializes the HoneycombLattice."""
815
+ dimensionality = 2
816
+ a = lattice_constant
817
+
818
+ # Define the primitive lattice vectors for the underlying triangular lattice
819
+ lattice_vectors = a * np.array([[1.5, np.sqrt(3) / 2], [1.5, -np.sqrt(3) / 2]])
820
+
821
+ # Define the coordinates of the two basis sites (A and B)
822
+ basis_coords = a * np.array([[0.0, 0.0], [1.0, 0.0]]) # Site A # Site B
823
+
824
+ super().__init__(
825
+ dimensionality=dimensionality,
826
+ lattice_vectors=lattice_vectors,
827
+ basis_coords=basis_coords,
828
+ size=size,
829
+ pbc=pbc,
830
+ precompute_neighbors=precompute_neighbors,
831
+ )
832
+
833
+
834
+ class TriangularLattice(TILattice):
835
+ """A 2D triangular lattice.
836
+
837
+ This is a Bravais lattice where each site has 6 nearest neighbors.
838
+
839
+ :param size: A tuple (Nx, Ny) specifying the number of
840
+ unit cells along the two lattice vector directions.
841
+ :type size: Tuple[int, int]
842
+ :param lattice_constant: The bond length, i.e., the
843
+ distance between two nearest neighbor sites. Defaults to 1.0.
844
+ :type lattice_constant: float, optional
845
+ :param pbc: Specifies periodic
846
+ boundary conditions. Defaults to True.
847
+ :type pbc: Union[bool, Tuple[bool, bool]], optional
848
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
849
+ up to the given order `k` upon initialization. Defaults to None.
850
+ :type precompute_neighbors: Optional[int], optional
851
+
852
+ """
853
+
854
+ def __init__(
855
+ self,
856
+ size: Tuple[int, int],
857
+ lattice_constant: float = 1.0,
858
+ pbc: Union[bool, Tuple[bool, bool]] = True,
859
+ precompute_neighbors: Optional[int] = None,
860
+ ):
861
+ """Initializes the TriangularLattice."""
862
+ dimensionality = 2
863
+ a = lattice_constant
864
+
865
+ # Define the primitive lattice vectors for a triangular lattice
866
+ lattice_vectors = a * np.array([[1.0, 0.0], [0.5, np.sqrt(3) / 2]])
867
+
868
+ # A triangular lattice is a Bravais lattice, with a single site in its basis
869
+ basis_coords = np.array([[0.0, 0.0]])
870
+
871
+ super().__init__(
872
+ dimensionality=dimensionality,
873
+ lattice_vectors=lattice_vectors,
874
+ basis_coords=basis_coords,
875
+ size=size,
876
+ pbc=pbc,
877
+ precompute_neighbors=precompute_neighbors,
878
+ )
879
+
880
+
881
+ class ChainLattice(TILattice):
882
+ """A 1D chain (simple Bravais lattice).
883
+
884
+ :param size: A tuple `(N,)` specifying the number of sites in the chain.
885
+ :type size: Tuple[int]
886
+ :param lattice_constant: The distance between two adjacent sites. Defaults to 1.0.
887
+ :type lattice_constant: float, optional
888
+ :param pbc: Specifies if periodic boundary conditions are applied. Defaults to True.
889
+ :type pbc: bool, optional
890
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
891
+ up to the given order `k` upon initialization. Defaults to None.
892
+ :type precompute_neighbors: Optional[int], optional
893
+ """
894
+
895
+ def __init__(
896
+ self,
897
+ size: Tuple[int],
898
+ lattice_constant: float = 1.0,
899
+ pbc: bool = True,
900
+ precompute_neighbors: Optional[int] = None,
901
+ ):
902
+ dimensionality = 1
903
+ lattice_vectors = np.array([[lattice_constant]])
904
+ basis_coords = np.array([[0.0]])
905
+ super().__init__(
906
+ dimensionality=dimensionality,
907
+ lattice_vectors=lattice_vectors,
908
+ basis_coords=basis_coords,
909
+ size=size,
910
+ pbc=pbc,
911
+ precompute_neighbors=precompute_neighbors,
912
+ )
913
+
914
+
915
+ class DimerizedChainLattice(TILattice):
916
+ """A 1D chain with an AB sublattice (dimerized chain).
917
+
918
+ The unit cell contains two sites, A and B. The bond length is uniform.
919
+
920
+ :param size: A tuple `(N,)` specifying the number of **unit cells**.
921
+ The total number of sites in the chain will be `2 * N`, as each
922
+ unit cell contains two sites.
923
+ :type size: Tuple[int]
924
+ :param lattice_constant: The distance between two adjacent sites (bond length). Defaults to 1.0.
925
+ :type lattice_constant: float, optional
926
+ :param pbc: Specifies if periodic boundary conditions are applied. Defaults to True.
927
+ :type pbc: bool, optional
928
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
929
+ up to the given order `k` upon initialization. Defaults to None.
930
+ :type precompute_neighbors: Optional[int], optional
931
+ """
932
+
933
+ def __init__(
934
+ self,
935
+ size: Tuple[int],
936
+ lattice_constant: float = 1.0,
937
+ pbc: bool = True,
938
+ precompute_neighbors: Optional[int] = None,
939
+ ):
940
+ dimensionality = 1
941
+ # The unit cell vector connects two A sites, spanning length 2*a
942
+ lattice_vectors = np.array([[2 * lattice_constant]])
943
+ # Basis has site A at origin, site B at distance 'a'
944
+ basis_coords = np.array([[0.0], [lattice_constant]])
945
+
946
+ super().__init__(
947
+ dimensionality=dimensionality,
948
+ lattice_vectors=lattice_vectors,
949
+ basis_coords=basis_coords,
950
+ size=size,
951
+ pbc=pbc,
952
+ precompute_neighbors=precompute_neighbors,
953
+ )
954
+
955
+
956
+ class RectangularLattice(TILattice):
957
+ """A 2D rectangular lattice.
958
+
959
+ This is a generalization of the SquareLattice where the lattice constants
960
+ in the x and y directions can be different.
961
+
962
+ :param size: A tuple (Nx, Ny) specifying the number of sites in x and y.
963
+ :type size: Tuple[int, int]
964
+ :param lattice_constants: The distance between adjacent sites
965
+ in the x and y directions, e.g., (ax, ay). Defaults to (1.0, 1.0).
966
+ :type lattice_constants: Tuple[float, float], optional
967
+ :param pbc: Specifies periodic boundary conditions. Defaults to True.
968
+ :type pbc: Union[bool, Tuple[bool, bool]], optional
969
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
970
+ up to the given order `k` upon initialization. Defaults to None.
971
+ :type precompute_neighbors: Optional[int], optional
972
+ """
973
+
974
+ def __init__(
975
+ self,
976
+ size: Tuple[int, int],
977
+ lattice_constants: Tuple[float, float] = (1.0, 1.0),
978
+ pbc: Union[bool, Tuple[bool, bool]] = True,
979
+ precompute_neighbors: Optional[int] = None,
980
+ ):
981
+ dimensionality = 2
982
+ ax, ay = lattice_constants
983
+ lattice_vectors = np.array([[ax, 0.0], [0.0, ay]])
984
+ basis_coords = np.array([[0.0, 0.0]])
985
+
986
+ super().__init__(
987
+ dimensionality=dimensionality,
988
+ lattice_vectors=lattice_vectors,
989
+ basis_coords=basis_coords,
990
+ size=size,
991
+ pbc=pbc,
992
+ precompute_neighbors=precompute_neighbors,
993
+ )
994
+
995
+
996
+ class CheckerboardLattice(TILattice):
997
+ """A 2D checkerboard lattice (a square lattice with an AB sublattice).
998
+
999
+ The unit cell is a square rotated by 45 degrees, containing two sites.
1000
+
1001
+ :param size: A tuple (Nx, Ny) specifying the number of unit cells. Total sites will be 2*Nx*Ny.
1002
+ :type size: Tuple[int, int]
1003
+ :param lattice_constant: The bond length between nearest neighbors. Defaults to 1.0.
1004
+ :type lattice_constant: float, optional
1005
+ :param pbc: Specifies periodic boundary conditions. Defaults to True.
1006
+ :type pbc: Union[bool, Tuple[bool, bool]], optional
1007
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
1008
+ up to the given order `k` upon initialization. Defaults to None.
1009
+ :type precompute_neighbors: Optional[int], optional
1010
+ """
1011
+
1012
+ def __init__(
1013
+ self,
1014
+ size: Tuple[int, int],
1015
+ lattice_constant: float = 1.0,
1016
+ pbc: Union[bool, Tuple[bool, bool]] = True,
1017
+ precompute_neighbors: Optional[int] = None,
1018
+ ):
1019
+ dimensionality = 2
1020
+ a = lattice_constant
1021
+ # Primitive vectors for a square lattice rotated by 45 degrees.
1022
+ lattice_vectors = a * np.array([[1.0, 1.0], [1.0, -1.0]])
1023
+ # Two-site basis
1024
+ basis_coords = a * np.array([[0.0, 0.0], [1.0, 0.0]])
1025
+ super().__init__(
1026
+ dimensionality=dimensionality,
1027
+ lattice_vectors=lattice_vectors,
1028
+ basis_coords=basis_coords,
1029
+ size=size,
1030
+ pbc=pbc,
1031
+ precompute_neighbors=precompute_neighbors,
1032
+ )
1033
+
1034
+
1035
+ class KagomeLattice(TILattice):
1036
+ """A 2D Kagome lattice.
1037
+
1038
+ This is a lattice with a three-site basis on a triangular Bravais lattice.
1039
+
1040
+ :param size: A tuple (Nx, Ny) specifying the number of unit cells. Total sites will be 3*Nx*Ny.
1041
+ :type size: Tuple[int, int]
1042
+ :param lattice_constant: The bond length. Defaults to 1.0.
1043
+ :type lattice_constant: float, optional
1044
+ :param pbc: Specifies periodic boundary conditions. Defaults to True.
1045
+ :type pbc: Union[bool, Tuple[bool, bool]], optional
1046
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
1047
+ up to the given order `k` upon initialization. Defaults to None.
1048
+ :type precompute_neighbors: Optional[int], optional
1049
+ """
1050
+
1051
+ def __init__(
1052
+ self,
1053
+ size: Tuple[int, int],
1054
+ lattice_constant: float = 1.0,
1055
+ pbc: Union[bool, Tuple[bool, bool]] = True,
1056
+ precompute_neighbors: Optional[int] = None,
1057
+ ):
1058
+ dimensionality = 2
1059
+ a = lattice_constant
1060
+ # Using a rectangular unit cell definition for simplicity
1061
+ lattice_vectors = a * np.array([[2.0, 0.0], [1.0, np.sqrt(3)]])
1062
+ # Three-site basis
1063
+ basis_coords = a * np.array([[0.0, 0.0], [1.0, 0.0], [0.5, np.sqrt(3) / 2.0]])
1064
+ super().__init__(
1065
+ dimensionality=dimensionality,
1066
+ lattice_vectors=lattice_vectors,
1067
+ basis_coords=basis_coords,
1068
+ size=size,
1069
+ pbc=pbc,
1070
+ precompute_neighbors=precompute_neighbors,
1071
+ )
1072
+
1073
+
1074
+ class LiebLattice(TILattice):
1075
+ """A 2D Lieb lattice.
1076
+
1077
+ This is a lattice with a three-site basis on a square Bravais lattice.
1078
+ It has sites at the corners and centers of the edges of a square.
1079
+
1080
+ :param size: A tuple (Nx, Ny) specifying the number of unit cells. Total sites will be 3*Nx*Ny.
1081
+ :type size: Tuple[int, int]
1082
+ :param lattice_constant: The bond length. Defaults to 1.0.
1083
+ :type lattice_constant: float, optional
1084
+ :param pbc: Specifies periodic boundary conditions. Defaults to True.
1085
+ :type pbc: Union[bool, Tuple[bool, bool]], optional
1086
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
1087
+ up to the given order `k` upon initialization. Defaults to None.
1088
+ :type precompute_neighbors: Optional[int], optional
1089
+ """
1090
+
1091
+ def __init__(
1092
+ self,
1093
+ size: Tuple[int, int],
1094
+ lattice_constant: float = 1.0,
1095
+ pbc: Union[bool, Tuple[bool, bool]] = True,
1096
+ precompute_neighbors: Optional[int] = None,
1097
+ ):
1098
+ """Initializes the LiebLattice."""
1099
+ dimensionality = 2
1100
+ # Use a more descriptive name for clarity. In a Lieb lattice,
1101
+ # the lattice_constant is the bond length between nearest neighbors.
1102
+ bond_length = lattice_constant
1103
+
1104
+ # The unit cell of a Lieb lattice is a square with side length
1105
+ # equal to twice the bond length.
1106
+ unit_cell_side = 2 * bond_length
1107
+ lattice_vectors = np.array([[unit_cell_side, 0.0], [0.0, unit_cell_side]])
1108
+
1109
+ # The three-site basis consists of a corner site, a site on the
1110
+ # center of the horizontal edge, and a site on the center of the vertical edge.
1111
+ # Their coordinates are defined directly in terms of the physical bond length.
1112
+ basis_coords = np.array(
1113
+ [
1114
+ [0.0, 0.0], # Corner site
1115
+ [bond_length, 0.0], # Horizontal edge center
1116
+ [0.0, bond_length], # Vertical edge center
1117
+ ]
1118
+ )
1119
+
1120
+ super().__init__(
1121
+ dimensionality=dimensionality,
1122
+ lattice_vectors=lattice_vectors,
1123
+ basis_coords=basis_coords,
1124
+ size=size,
1125
+ pbc=pbc,
1126
+ precompute_neighbors=precompute_neighbors,
1127
+ )
1128
+
1129
+
1130
+ class CubicLattice(TILattice):
1131
+ """A 3D cubic lattice.
1132
+
1133
+ This is a simple Bravais lattice, the 3D generalization of SquareLattice.
1134
+
1135
+ :param size: A tuple (Nx, Ny, Nz) specifying the number of sites.
1136
+ :type size: Tuple[int, int, int]
1137
+ :param lattice_constant: The distance between adjacent sites. Defaults to 1.0.
1138
+ :type lattice_constant: float, optional
1139
+ :param pbc: Specifies periodic boundary conditions. Defaults to True.
1140
+ :type pbc: Union[bool, Tuple[bool, bool, bool]], optional
1141
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
1142
+ up to the given order `k` upon initialization. Defaults to None.
1143
+ :type precompute_neighbors: Optional[int], optional
1144
+ """
1145
+
1146
+ def __init__(
1147
+ self,
1148
+ size: Tuple[int, int, int],
1149
+ lattice_constant: float = 1.0,
1150
+ pbc: Union[bool, Tuple[bool, bool, bool]] = True,
1151
+ precompute_neighbors: Optional[int] = None,
1152
+ ):
1153
+ dimensionality = 3
1154
+ a = lattice_constant
1155
+ lattice_vectors = np.array([[a, 0, 0], [0, a, 0], [0, 0, a]])
1156
+ basis_coords = np.array([[0.0, 0.0, 0.0]])
1157
+ super().__init__(
1158
+ dimensionality=dimensionality,
1159
+ lattice_vectors=lattice_vectors,
1160
+ basis_coords=basis_coords,
1161
+ size=size,
1162
+ pbc=pbc,
1163
+ precompute_neighbors=precompute_neighbors,
1164
+ )
1165
+
1166
+
1167
+ class CustomizeLattice(AbstractLattice):
1168
+ """A general lattice built from an explicit list of sites and coordinates.
1169
+
1170
+ This class is suitable for creating lattices with arbitrary geometries,
1171
+ such as finite clusters, disordered systems, or any custom structure
1172
+ that does not have translational symmetry. The lattice is defined simply
1173
+ by providing lists of identifiers and coordinates for each site.
1174
+
1175
+ :param dimensionality: The spatial dimension of the lattice.
1176
+ :type dimensionality: int
1177
+ :param identifiers: A list of unique, hashable
1178
+ identifiers for the sites. The length must match `coordinates`.
1179
+ :type identifiers: List[SiteIdentifier]
1180
+ :param coordinates: A list of site
1181
+ coordinates. Each coordinate should be a list of floats or a
1182
+ NumPy array.
1183
+ :type coordinates: List[Union[List[float], Coordinates]]
1184
+ :raises ValueError: If the lengths of `identifiers` and `coordinates` lists
1185
+ do not match, or if a coordinate's dimension is incorrect.
1186
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
1187
+ up to the given order `k` upon initialization. Defaults to None.
1188
+ :type precompute_neighbors: Optional[int], optional
1189
+
1190
+ """
1191
+
1192
+ def __init__(
1193
+ self,
1194
+ dimensionality: int,
1195
+ identifiers: List[SiteIdentifier],
1196
+ coordinates: List[Union[List[float], Coordinates]],
1197
+ precompute_neighbors: Optional[int] = None,
1198
+ ):
1199
+ """Initializes the CustomizeLattice."""
1200
+ super().__init__(dimensionality)
1201
+ if len(identifiers) != len(coordinates):
1202
+ raise ValueError(
1203
+ "Identifiers and coordinates lists must have the same length."
1204
+ )
1205
+
1206
+ # The _build_lattice logic is simple enough to be in __init__
1207
+ self._identifiers = list(identifiers)
1208
+ self._coordinates = [np.array(c) for c in coordinates]
1209
+ self._indices = list(range(len(identifiers)))
1210
+ self._ident_to_idx = {ident: idx for idx, ident in enumerate(identifiers)}
1211
+
1212
+ # Validate coordinate dimensions
1213
+ for i, coord in enumerate(self._coordinates):
1214
+ if coord.shape != (dimensionality,):
1215
+ raise ValueError(
1216
+ f"Coordinate at index {i} has shape {coord.shape}, "
1217
+ f"expected ({dimensionality},)"
1218
+ )
1219
+
1220
+ logger.info(f"CustomizeLattice with {self.num_sites} sites created.")
1221
+
1222
+ if precompute_neighbors is not None and precompute_neighbors > 0:
1223
+ self._build_neighbors(max_k=precompute_neighbors)
1224
+
1225
+ def _build_lattice(self, *args: Any, **kwargs: Any) -> None:
1226
+ """For CustomizeLattice, lattice data is built during __init__."""
1227
+ pass
1228
+
1229
+ def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None:
1230
+ """Calculates neighbors using a KDTree for efficiency.
1231
+
1232
+ This method uses a memory-efficient approach to identify neighbors without
1233
+ initially computing the full N x N distance matrix. It leverages
1234
+ `scipy.spatial.distance.pdist` to find unique distance shells and then
1235
+ a `scipy.spatial.KDTree` for fast radius queries. This approach is
1236
+ significantly more memory-efficient during the neighbor identification phase.
1237
+
1238
+ After the neighbors are identified, the full distance matrix is computed
1239
+ from the pairwise distances and cached for potential future use.
1240
+
1241
+ :param max_k: The maximum number of neighbor shells to
1242
+ calculate. Defaults to 1.
1243
+ :type max_k: int, optional
1244
+ :param tol: The numerical tolerance for distance
1245
+ comparisons. Defaults to 1e-6.
1246
+ :type tol: float, optional
1247
+ """
1248
+ tol = kwargs.get("tol", 1e-6)
1249
+ logger.info(f"Building neighbors for CustomizeLattice up to k={max_k}...")
1250
+ if self.num_sites < 2:
1251
+ return
1252
+
1253
+ all_coords = np.array(self._coordinates)
1254
+
1255
+ # 1. Use pdist for memory-efficient calculation of pairwise distances
1256
+ # to robustly identify the distance shells.
1257
+ all_distances_sq = pdist(all_coords, metric="sqeuclidean")
1258
+ dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
1259
+
1260
+ if not dist_shells_sq:
1261
+ logger.info("No distinct neighbor shells found.")
1262
+ return
1263
+
1264
+ # 2. Build the KDTree for efficient querying.
1265
+ tree = KDTree(all_coords)
1266
+ self._neighbor_maps = {k: {} for k in range(1, len(dist_shells_sq) + 1)}
1267
+
1268
+ # 3. Find neighbors by isolating shells using inclusion-exclusion.
1269
+ # `found_indices` will store all neighbors within a given radius.
1270
+ found_indices: List[set[int]] = []
1271
+ for k_idx, target_d_sq in enumerate(dist_shells_sq):
1272
+ radius = np.sqrt(target_d_sq) + tol
1273
+ # Query for all points within the new, larger radius.
1274
+ current_shell_indices = tree.query_ball_point(
1275
+ all_coords, r=radius, return_sorted=True
1276
+ )
1277
+
1278
+ # Now, isolate the neighbors for the current shell k
1279
+ k = k_idx + 1
1280
+ current_k_map: Dict[int, List[int]] = {}
1281
+ for i in range(self.num_sites):
1282
+
1283
+ if k_idx == 0:
1284
+ co_located_indices = tree.query_ball_point(all_coords[i], r=1e-12)
1285
+ prev_found = set(co_located_indices)
1286
+ else:
1287
+ prev_found = found_indices[i]
1288
+
1289
+ # The new neighbors are those in the current radius shell,
1290
+ # excluding those already found in smaller shells.
1291
+ new_neighbors = set(current_shell_indices[i]) - prev_found
1292
+
1293
+ if new_neighbors:
1294
+ current_k_map[i] = sorted(list(new_neighbors))
1295
+
1296
+ self._neighbor_maps[k] = current_k_map
1297
+ found_indices = [
1298
+ set(l) for l in current_shell_indices
1299
+ ] # Update for next iteration
1300
+ self._distance_matrix = np.sqrt(squareform(all_distances_sq))
1301
+
1302
+ logger.info("Neighbor building complete using KDTree.")
1303
+
1304
+ def _compute_distance_matrix(self) -> Coordinates:
1305
+ """Computes the distance matrix from the stored coordinates.
1306
+
1307
+ This implementation uses scipy.pdist for a memory-efficient
1308
+ calculation of pairwise distances, which is then converted to a
1309
+ full square matrix.
1310
+ """
1311
+ if self.num_sites < 2:
1312
+ return cast(Coordinates, np.empty((self.num_sites, self.num_sites)))
1313
+
1314
+ all_coords = np.array(self._coordinates)
1315
+ # Use pdist for memory-efficiency, then build the full matrix.
1316
+ all_distances_sq = pdist(all_coords, metric="sqeuclidean")
1317
+ dist_matrix_sq = squareform(all_distances_sq)
1318
+ return cast(Coordinates, np.sqrt(dist_matrix_sq))
1319
+
1320
+ def _reset_computations(self) -> None:
1321
+ """Resets all cached data that depends on the lattice structure."""
1322
+ self._neighbor_maps = {}
1323
+ self._distance_matrix = None
1324
+
1325
+ @classmethod
1326
+ def from_lattice(cls, lattice: "AbstractLattice") -> "CustomizeLattice":
1327
+ """Creates a CustomizeLattice instance from any existing lattice object.
1328
+
1329
+ This is useful for 'detaching' a procedurally generated lattice (like
1330
+ a SquareLattice) into a customizable one for further modifications,
1331
+ such as adding defects or extra sites.
1332
+
1333
+ :param lattice: An instance of any AbstractLattice subclass.
1334
+ :type lattice: AbstractLattice
1335
+ :return: A new CustomizeLattice instance with the same sites.
1336
+ :rtype: CustomizeLattice
1337
+ """
1338
+ all_sites_info = list(lattice.sites())
1339
+
1340
+ if not all_sites_info:
1341
+ return cls(
1342
+ dimensionality=lattice.dimensionality, identifiers=[], coordinates=[]
1343
+ )
1344
+
1345
+ # Unzip the list of tuples into separate lists of identifiers and coordinates
1346
+ _, identifiers, coordinates = zip(*all_sites_info)
1347
+
1348
+ return cls(
1349
+ dimensionality=lattice.dimensionality,
1350
+ identifiers=list(identifiers),
1351
+ coordinates=list(coordinates),
1352
+ )
1353
+
1354
+ def add_sites(
1355
+ self,
1356
+ identifiers: List[SiteIdentifier],
1357
+ coordinates: List[Union[List[float], Coordinates]],
1358
+ ) -> None:
1359
+ """Adds new sites to the lattice.
1360
+
1361
+ This operation modifies the lattice in-place. After adding sites, any
1362
+ previously computed neighbor information is cleared and must be
1363
+ recalculated.
1364
+
1365
+ :param identifiers: A list of unique, hashable identifiers for the new sites.
1366
+ :type identifiers: List[SiteIdentifier]
1367
+ :param coordinates: A list of coordinates for the new sites.
1368
+ :type coordinates: List[Union[List[float], np.ndarray]]
1369
+ :raises ValueError: If input lists have mismatched lengths, or if any new
1370
+ identifier already exists in the lattice.
1371
+ """
1372
+ if len(identifiers) != len(coordinates):
1373
+ raise ValueError(
1374
+ "Identifiers and coordinates lists must have the same length."
1375
+ )
1376
+ if not identifiers:
1377
+ return # Nothing to add
1378
+
1379
+ # Check for duplicate identifiers before making any changes
1380
+ existing_ids = set(self._identifiers)
1381
+ new_ids = set(identifiers)
1382
+ if not new_ids.isdisjoint(existing_ids):
1383
+ raise ValueError(
1384
+ f"Duplicate identifiers found: {new_ids.intersection(existing_ids)}"
1385
+ )
1386
+
1387
+ for i, coord in enumerate(coordinates):
1388
+ coord_arr = np.asarray(coord)
1389
+ if coord_arr.shape != (self.dimensionality,):
1390
+ raise ValueError(
1391
+ f"New coordinate at index {i} has shape {coord_arr.shape}, "
1392
+ f"expected ({self.dimensionality},)"
1393
+ )
1394
+ self._coordinates.append(coord_arr)
1395
+ self._identifiers.append(identifiers[i])
1396
+
1397
+ # Rebuild index mappings from scratch
1398
+ self._indices = list(range(len(self._identifiers)))
1399
+ self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)}
1400
+
1401
+ # Invalidate any previously computed neighbors or distance matrices
1402
+ self._reset_computations()
1403
+ logger.info(
1404
+ f"{len(identifiers)} sites added. Lattice now has {self.num_sites} sites."
1405
+ )
1406
+
1407
+ def remove_sites(self, identifiers: List[SiteIdentifier]) -> None:
1408
+ """Removes specified sites from the lattice.
1409
+
1410
+ This operation modifies the lattice in-place. After removing sites,
1411
+ all site indices are re-calculated, and any previously computed
1412
+ neighbor information is cleared.
1413
+
1414
+ :param identifiers: A list of identifiers for the sites to be removed.
1415
+ :type identifiers: List[SiteIdentifier]
1416
+ :raises ValueError: If any of the specified identifiers do not exist.
1417
+ """
1418
+ if not identifiers:
1419
+ return # Nothing to remove
1420
+
1421
+ ids_to_remove = set(identifiers)
1422
+ current_ids = set(self._identifiers)
1423
+ if not ids_to_remove.issubset(current_ids):
1424
+ raise ValueError(
1425
+ f"Non-existent identifiers provided for removal: {ids_to_remove - current_ids}"
1426
+ )
1427
+
1428
+ # Create new lists containing only the sites to keep
1429
+ new_identifiers: List[SiteIdentifier] = []
1430
+ new_coordinates: List[Coordinates] = []
1431
+ for ident, coord in zip(self._identifiers, self._coordinates):
1432
+ if ident not in ids_to_remove:
1433
+ new_identifiers.append(ident)
1434
+ new_coordinates.append(coord)
1435
+
1436
+ # Replace old data with the new, filtered data
1437
+ self._identifiers = new_identifiers
1438
+ self._coordinates = new_coordinates
1439
+
1440
+ # Rebuild index mappings
1441
+ self._indices = list(range(len(self._identifiers)))
1442
+ self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)}
1443
+
1444
+ # Invalidate caches
1445
+ self._reset_computations()
1446
+ logger.info(
1447
+ f"{len(ids_to_remove)} sites removed. Lattice now has {self.num_sites} sites."
1448
+ )