tensorcircuit-nightly 1.0.2.dev20250108__py3-none-any.whl → 1.4.0.dev20251103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

Files changed (76) hide show
  1. tensorcircuit/__init__.py +18 -2
  2. tensorcircuit/about.py +46 -0
  3. tensorcircuit/abstractcircuit.py +4 -0
  4. tensorcircuit/analogcircuit.py +413 -0
  5. tensorcircuit/applications/layers.py +1 -1
  6. tensorcircuit/applications/van.py +1 -1
  7. tensorcircuit/backends/abstract_backend.py +320 -7
  8. tensorcircuit/backends/cupy_backend.py +3 -1
  9. tensorcircuit/backends/jax_backend.py +102 -4
  10. tensorcircuit/backends/jax_ops.py +110 -1
  11. tensorcircuit/backends/numpy_backend.py +49 -3
  12. tensorcircuit/backends/pytorch_backend.py +92 -3
  13. tensorcircuit/backends/tensorflow_backend.py +102 -3
  14. tensorcircuit/basecircuit.py +157 -98
  15. tensorcircuit/circuit.py +115 -57
  16. tensorcircuit/cloud/local.py +1 -1
  17. tensorcircuit/cloud/quafu_provider.py +1 -1
  18. tensorcircuit/cloud/tencent.py +1 -1
  19. tensorcircuit/compiler/simple_compiler.py +2 -2
  20. tensorcircuit/cons.py +142 -21
  21. tensorcircuit/densitymatrix.py +43 -14
  22. tensorcircuit/experimental.py +387 -129
  23. tensorcircuit/fgs.py +282 -81
  24. tensorcircuit/gates.py +66 -22
  25. tensorcircuit/interfaces/__init__.py +1 -3
  26. tensorcircuit/interfaces/jax.py +189 -0
  27. tensorcircuit/keras.py +3 -3
  28. tensorcircuit/mpscircuit.py +154 -65
  29. tensorcircuit/quantum.py +868 -152
  30. tensorcircuit/quditcircuit.py +733 -0
  31. tensorcircuit/quditgates.py +618 -0
  32. tensorcircuit/results/counts.py +147 -20
  33. tensorcircuit/results/readout_mitigation.py +4 -1
  34. tensorcircuit/shadows.py +1 -1
  35. tensorcircuit/simplify.py +3 -1
  36. tensorcircuit/stabilizercircuit.py +479 -0
  37. tensorcircuit/templates/__init__.py +2 -0
  38. tensorcircuit/templates/blocks.py +2 -2
  39. tensorcircuit/templates/hamiltonians.py +174 -0
  40. tensorcircuit/templates/lattice.py +1789 -0
  41. tensorcircuit/timeevol.py +896 -0
  42. tensorcircuit/translation.py +10 -3
  43. tensorcircuit/utils.py +7 -0
  44. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/METADATA +73 -23
  45. tensorcircuit_nightly-1.4.0.dev20251103.dist-info/RECORD +96 -0
  46. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/WHEEL +1 -1
  47. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/top_level.txt +0 -1
  48. tensorcircuit_nightly-1.0.2.dev20250108.dist-info/RECORD +0 -115
  49. tests/__init__.py +0 -0
  50. tests/conftest.py +0 -67
  51. tests/test_backends.py +0 -1031
  52. tests/test_calibrating.py +0 -149
  53. tests/test_channels.py +0 -365
  54. tests/test_circuit.py +0 -1699
  55. tests/test_cloud.py +0 -219
  56. tests/test_compiler.py +0 -147
  57. tests/test_dmcircuit.py +0 -555
  58. tests/test_ensemble.py +0 -72
  59. tests/test_fgs.py +0 -310
  60. tests/test_gates.py +0 -156
  61. tests/test_interfaces.py +0 -429
  62. tests/test_keras.py +0 -160
  63. tests/test_miscs.py +0 -277
  64. tests/test_mpscircuit.py +0 -341
  65. tests/test_noisemodel.py +0 -156
  66. tests/test_qaoa.py +0 -86
  67. tests/test_qem.py +0 -152
  68. tests/test_quantum.py +0 -526
  69. tests/test_quantum_attr.py +0 -42
  70. tests/test_results.py +0 -347
  71. tests/test_shadows.py +0 -160
  72. tests/test_simplify.py +0 -46
  73. tests/test_templates.py +0 -218
  74. tests/test_torchnn.py +0 -99
  75. tests/test_van.py +0 -102
  76. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info/licenses}/LICENSE +0 -0
@@ -0,0 +1,1789 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ The lattice module for defining and manipulating lattice geometries.
4
+ """
5
+ import logging
6
+ import abc
7
+ from typing import (
8
+ Any,
9
+ Dict,
10
+ Hashable,
11
+ Iterator,
12
+ List,
13
+ Optional,
14
+ Tuple,
15
+ Union,
16
+ TYPE_CHECKING,
17
+ cast,
18
+ Set,
19
+ )
20
+
21
+ import itertools
22
+ import math
23
+ import numpy as np
24
+ from scipy.spatial import KDTree
25
+
26
+ from .. import backend
27
+
28
+
29
+ # This block resolves a name resolution issue for the static type checker (mypy).
30
+ # GOAL:
31
+ # Keep `matplotlib` as an optional dependency, so it is only imported
32
+ # inside the `show()` method, not at the module level.
33
+ # PROBLEM:
34
+ # The type hint for the `ax` parameter in `show()`'s signature
35
+ # (`ax: Optional["matplotlib.axes.Axes"]`) needs to know what `matplotlib` is.
36
+ # Without this block, mypy would raise a "Name 'matplotlib' is not defined" error.
37
+ # SOLUTION:
38
+ # The `if TYPE_CHECKING:` block is ignored at runtime but processed by mypy.
39
+ # This makes the name `matplotlib` available to the type checker without
40
+ # creating a hard dependency for the user.
41
+ if TYPE_CHECKING:
42
+ import matplotlib.axes
43
+ from mpl_toolkits.mplot3d import Axes3D
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+ Tensor = Any
48
+ SiteIndex = int
49
+ SiteIdentifier = Hashable
50
+ Coordinates = Tensor
51
+
52
+ NeighborMap = Dict[SiteIndex, List[SiteIndex]]
53
+
54
+
55
+ class AbstractLattice(abc.ABC):
56
+ """Abstract base class for describing lattice systems.
57
+
58
+ This class defines the common interface for all lattice structures,
59
+ providing access to fundamental properties like site information
60
+ (count, coordinates, identifiers) and neighbor relationships.
61
+ Subclasses are responsible for implementing the specific logic for
62
+ generating the lattice points and calculating neighbor connections.
63
+
64
+ :param dimensionality: The spatial dimension of the lattice (e.g., 1, 2, 3).
65
+ :type dimensionality: int
66
+ """
67
+
68
+ def __init__(self, dimensionality: int):
69
+ """Initializes the base lattice class."""
70
+ self._dimensionality = dimensionality
71
+
72
+ # Core data structures for storing site information.
73
+ self._indices: List[SiteIndex] = [] # List of integer indices [0, 1, ..., N-1]
74
+ self._identifiers: List[SiteIdentifier] = (
75
+ []
76
+ ) # List of unique, hashable site identifiers
77
+ # Always initialize to an empty coordinate tensor with correct dimensionality
78
+ # so that type checkers know this is indexable and not Optional.
79
+ self._coordinates: Coordinates = backend.zeros((0, dimensionality))
80
+
81
+ # Mappings for efficient lookups.
82
+ self._ident_to_idx: Dict[SiteIdentifier, SiteIndex] = (
83
+ {}
84
+ ) # Maps identifiers to indices
85
+
86
+ # Cached properties, computed on demand.
87
+ self._neighbor_maps: Dict[int, NeighborMap] = (
88
+ {}
89
+ ) # Caches neighbor info for different k
90
+ self._distance_matrix: Optional[Coordinates] = (
91
+ None # Caches the full N x N distance matrix
92
+ )
93
+
94
+ @property
95
+ def num_sites(self) -> int:
96
+ """Returns the total number of sites (N) in the lattice."""
97
+ return len(self._indices)
98
+
99
+ @property
100
+ def dimensionality(self) -> int:
101
+ """Returns the spatial dimension of the lattice."""
102
+ return self._dimensionality
103
+
104
+ def __len__(self) -> int:
105
+ """Returns the total number of sites, enabling `len(lattice)`."""
106
+ return self.num_sites
107
+
108
+ # --- Public API for Accessing Lattice Information ---
109
+ @property
110
+ def distance_matrix(self) -> Coordinates:
111
+ """
112
+ Returns the full N x N distance matrix.
113
+ The matrix is computed on the first access and then cached for
114
+ subsequent calls. This computation can be expensive for large lattices.
115
+ """
116
+ if self._distance_matrix is None:
117
+ self._distance_matrix = self._compute_distance_matrix()
118
+ return self._distance_matrix
119
+
120
+ def _validate_index(self, index: SiteIndex) -> None:
121
+ """A private helper to check if a site index is within the valid range."""
122
+ if not (0 <= index < self.num_sites):
123
+ raise IndexError(
124
+ f"Site index {index} out of range (0-{self.num_sites - 1})"
125
+ )
126
+
127
+ def get_coordinates(self, index: SiteIndex) -> Coordinates:
128
+ """Gets the spatial coordinates of a site by its integer index.
129
+
130
+ :param index: The integer index of the site.
131
+ :type index: SiteIndex
132
+ :raises IndexError: If the site index is out of range.
133
+ :return: The spatial coordinates as a NumPy array.
134
+ :rtype: Coordinates
135
+ """
136
+ self._validate_index(index)
137
+ coords = self._coordinates[index]
138
+ return coords
139
+
140
+ def get_identifier(self, index: SiteIndex) -> SiteIdentifier:
141
+ """Gets the abstract identifier of a site by its integer index.
142
+
143
+ :param index: The integer index of the site.
144
+ :type index: SiteIndex
145
+ :raises IndexError: If the site index is out of range.
146
+ :return: The unique, hashable identifier of the site.
147
+ :rtype: SiteIdentifier
148
+ """
149
+ self._validate_index(index)
150
+ return self._identifiers[index]
151
+
152
+ def get_index(self, identifier: SiteIdentifier) -> SiteIndex:
153
+ """Gets the integer index of a site by its unique identifier.
154
+
155
+ :param identifier: The unique identifier of the site.
156
+ :type identifier: SiteIdentifier
157
+ :raises ValueError: If the identifier is not found in the lattice.
158
+ :return: The corresponding integer index of the site.
159
+ :rtype: SiteIndex
160
+ """
161
+ try:
162
+ index = self._ident_to_idx[identifier]
163
+ return index
164
+ except KeyError as e:
165
+ raise ValueError(
166
+ f"Identifier {identifier} not found in the lattice."
167
+ ) from e
168
+
169
+ def get_site_info(
170
+ self, index_or_identifier: Union[SiteIndex, SiteIdentifier]
171
+ ) -> Tuple[SiteIndex, SiteIdentifier, Coordinates]:
172
+ """Gets all information for a single site.
173
+
174
+ This method provides a convenient way to retrieve all relevant data for a
175
+ site (its index, identifier, and coordinates) by using either its
176
+ integer index or its unique identifier.
177
+
178
+ :param index_or_identifier: The integer
179
+ index or the unique identifier of the site to look up.
180
+ :type index_or_identifier: Union[SiteIndex, SiteIdentifier]
181
+ :raises IndexError: If the given index is out of bounds.
182
+ :raises ValueError: If the given identifier is not found in the lattice.
183
+ :return: A tuple containing:
184
+ - The site's integer index.
185
+ - The site's unique identifier.
186
+ - The site's coordinates as a NumPy array.
187
+ :rtype: Tuple[SiteIndex, SiteIdentifier, Coordinates]
188
+ """
189
+ if isinstance(index_or_identifier, int): # SiteIndex is an int
190
+ idx = index_or_identifier
191
+ self._validate_index(idx)
192
+ return idx, self._identifiers[idx], self._coordinates[idx]
193
+ else:
194
+ ident = index_or_identifier
195
+ idx = self.get_index(ident)
196
+ return idx, ident, self._coordinates[idx]
197
+
198
+ def sites(self) -> Iterator[Tuple[SiteIndex, SiteIdentifier, Coordinates]]:
199
+ """Returns an iterator over all sites in the lattice.
200
+
201
+ This provides a convenient way to loop through all sites, for example:
202
+ `for idx, ident, coords in my_lattice.sites(): ...`
203
+
204
+ :return: An iterator where each item is a tuple containing the site's
205
+ index, identifier, and coordinates.
206
+ :rtype: Iterator[Tuple[SiteIndex, SiteIdentifier, Coordinates]]
207
+ """
208
+ for i in range(self.num_sites):
209
+ yield i, self._identifiers[i], self._coordinates[i]
210
+
211
+ def get_neighbors(self, index: SiteIndex, k: int = 1) -> List[SiteIndex]:
212
+ """Gets the list of k-th nearest neighbor indices for a given site.
213
+
214
+ :param index: The integer index of the center site.
215
+ :type index: SiteIndex
216
+ :param k: The order of the neighbors, where k=1 corresponds
217
+ to nearest neighbors (NN), k=2 to next-nearest neighbors (NNN),
218
+ and so on. Defaults to 1.
219
+ :type k: int, optional
220
+ :return: A list of integer indices for the neighboring sites.
221
+ Returns an empty list if neighbors for the given `k` have not been
222
+ pre-calculated or if the site has no such neighbors.
223
+ :rtype: List[SiteIndex]
224
+ """
225
+ if k not in self._neighbor_maps:
226
+ logger.info(
227
+ f"Neighbors for k={k} not pre-computed. Building now up to max_k={k}."
228
+ )
229
+ self._build_neighbors(max_k=k)
230
+
231
+ if k not in self._neighbor_maps:
232
+ return []
233
+
234
+ return self._neighbor_maps[k].get(index, [])
235
+
236
+ def get_neighbor_pairs(
237
+ self, k: int = 1, unique: bool = True
238
+ ) -> List[Tuple[SiteIndex, SiteIndex]]:
239
+ """Gets all pairs of k-th nearest neighbors, representing bonds.
240
+
241
+ :param k: The order of the neighbors to consider.
242
+ Defaults to 1.
243
+ :type k: int, optional
244
+ :param unique: If True, returns only one representation
245
+ for each pair (i, j) such that i < j, avoiding duplicates
246
+ like (j, i). If False, returns all directed pairs.
247
+ Defaults to True.
248
+ :type unique: bool, optional
249
+ :return: A list of tuples, where each
250
+ tuple is a pair of neighbor indices.
251
+ :rtype: List[Tuple[SiteIndex, SiteIndex]]
252
+ """
253
+
254
+ if k not in self._neighbor_maps:
255
+ logger.info(
256
+ f"Neighbor pairs for k={k} not pre-computed. Building now up to max_k={k}."
257
+ )
258
+ self._build_neighbors(max_k=k)
259
+
260
+ if k not in self._neighbor_maps:
261
+ return []
262
+
263
+ pairs = []
264
+ for i, neighbors in self._neighbor_maps[k].items():
265
+ for j in neighbors:
266
+ if unique:
267
+ if i < j:
268
+ pairs.append((i, j))
269
+ else:
270
+ pairs.append((i, j))
271
+ return sorted(pairs)
272
+
273
+ def get_all_pairs(self) -> List[Tuple[SiteIndex, SiteIndex]]:
274
+ """
275
+ Returns a list of all unique pairs of site indices (i, j) where i < j.
276
+
277
+ This method provides all-to-all connectivity, useful for Hamiltonians
278
+ where every site interacts with every other site.
279
+
280
+ Note on Differentiability:
281
+ This method provides a static list of index pairs and is not differentiable
282
+ itself. However, it is designed to be used in combination with the fully
283
+ differentiable ``distance_matrix`` property. By using the pairs from this
284
+ method to index into the ``distance_matrix``, one can construct differentiable
285
+ objective functions based on all-pair interactions, effectively bypassing the
286
+ non-differentiable ``get_neighbor_pairs`` method for geometry optimization tasks.
287
+
288
+ :return: A list of tuples, where each tuple is a unique pair of site indices.
289
+ :rtype: List[Tuple[SiteIndex, SiteIndex]]
290
+ """
291
+ if self.num_sites < 2:
292
+ return []
293
+ # Use itertools.combinations to efficiently generate all unique pairs (i, j) with i < j.
294
+ return sorted(list(itertools.combinations(range(self.num_sites), 2)))
295
+
296
+ @abc.abstractmethod
297
+ def _build_lattice(self, *args: Any, **kwargs: Any) -> None:
298
+ """
299
+ Abstract method for subclasses to generate the lattice data.
300
+
301
+ A concrete implementation of this method in a subclass is responsible
302
+ for populating the following internal attributes:
303
+ - self._indices
304
+ - self._identifiers
305
+ - self._coordinates
306
+ - self._ident_to_idx
307
+ """
308
+ pass
309
+
310
+ @abc.abstractmethod
311
+ def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None:
312
+ """
313
+ Abstract method for subclasses to calculate neighbor relationships.
314
+
315
+ A concrete implementation of this method should calculate the neighbor
316
+ relationships up to `max_k` and populate the `self._neighbor_maps`
317
+ dictionary. The keys of the dictionary should be the neighbor order (k),
318
+ and the values should be a dictionary mapping site indices to their
319
+ list of k-th neighbors.
320
+ """
321
+ pass
322
+
323
+ def _compute_distance_matrix(self) -> Coordinates:
324
+ """
325
+ Default generic distance matrix computation (no periodic images).
326
+
327
+ Subclasses can override this when a specialized rule is required
328
+ (e.g., applying Minimum Image Convention for PBC in TILattice).
329
+ """
330
+ # Handle empty lattices and trivial 1-site lattices
331
+ if self.num_sites == 0:
332
+ return backend.zeros((0, 0))
333
+
334
+ # Vectorized pairwise Euclidean distances
335
+ all_coords = self._coordinates
336
+ displacements = backend.expand_dims(all_coords, 1) - backend.expand_dims(
337
+ all_coords, 0
338
+ )
339
+ dist_matrix_sq = backend.sum(displacements**2, axis=-1)
340
+ return backend.sqrt(dist_matrix_sq)
341
+
342
+ def show(
343
+ self,
344
+ show_indices: bool = False,
345
+ show_identifiers: bool = False,
346
+ show_bonds_k: Optional[int] = None,
347
+ ax: Optional["matplotlib.axes.Axes"] = None,
348
+ bond_kwargs: Optional[Dict[str, Any]] = None,
349
+ **kwargs: Any,
350
+ ) -> None:
351
+ """Visualizes the lattice structure using Matplotlib.
352
+
353
+ This method supports 1D, 2D, and 3D plotting. For 1D lattices, sites
354
+ are plotted along the x-axis.
355
+
356
+ :param show_indices: If True, displays the integer index
357
+ next to each site. Defaults to False.
358
+ :type show_indices: bool, optional
359
+ :param show_identifiers: If True, displays the unique
360
+ identifier next to each site. Defaults to False.
361
+ :type show_identifiers: bool, optional
362
+ :param show_bonds_k: Specifies which order of
363
+ neighbor bonds to draw (e.g., 1 for NN, 2 for NNN). If None,
364
+ no bonds are drawn. If the specified neighbors have not been
365
+ calculated, a warning is printed. Defaults to None.
366
+ :type show_bonds_k: Optional[int], optional
367
+ :param ax: An existing Matplotlib Axes object to plot on.
368
+ If None, a new Figure and Axes are created automatically. Defaults to None.
369
+ :type ax: Optional["matplotlib.axes.Axes"], optional
370
+ :param bond_kwargs: A dictionary of keyword arguments for customizing bond appearance,
371
+ passed directly to the Matplotlib plot function. Defaults to None.
372
+ :type bond_kwargs: Optional[Dict[str, Any]], optional
373
+
374
+ :param kwargs: Additional keyword arguments to be passed directly to the
375
+ `matplotlib.pyplot.scatter` function for customizing site appearance.
376
+ """
377
+ try:
378
+ import matplotlib.pyplot as plt
379
+ except ImportError:
380
+ logger.warning(
381
+ "Matplotlib is required for visualization. "
382
+ "Please install it using 'pip install matplotlib'."
383
+ )
384
+ return
385
+
386
+ # Flag to track if the Matplotlib figure was created by this method.
387
+ # This prevents calling plt.show() if the user provided their own Axes.
388
+ fig_created_internally = False
389
+
390
+ if self.num_sites == 0:
391
+ logger.info("Lattice is empty, nothing to show.")
392
+ return
393
+ if self.dimensionality not in [1, 2, 3]:
394
+ logger.warning(
395
+ f"show() is not implemented for {self.dimensionality}D lattices."
396
+ )
397
+ return
398
+
399
+ if ax is None:
400
+ # If no Axes object is provided, create a new figure and axes.
401
+ fig_created_internally = True
402
+ if self.dimensionality == 3:
403
+ fig = plt.figure(figsize=(8, 8))
404
+ ax = fig.add_subplot(111, projection="3d")
405
+ else:
406
+ fig, ax = plt.subplots(figsize=(8, 8))
407
+ else:
408
+ fig = ax.figure # type: ignore
409
+
410
+ coords = np.array(self._coordinates)
411
+ # Prepare arguments for the scatter plot, allowing user overrides.
412
+ scatter_args = {"s": 100, "zorder": 2}
413
+ scatter_args.update(kwargs)
414
+ if self.dimensionality == 1:
415
+ ax.scatter(coords[:, 0], np.zeros_like(coords[:, 0]), **scatter_args) # type: ignore
416
+ elif self.dimensionality == 2:
417
+ ax.scatter(coords[:, 0], coords[:, 1], **scatter_args) # type: ignore
418
+ elif self.dimensionality > 2: # Safely handle 3D and future higher dimensions
419
+ scatter_args["s"] = scatter_args.get("s", 100) // 2
420
+ ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2], **scatter_args) # type: ignore
421
+
422
+ if show_indices or show_identifiers:
423
+ for i in range(self.num_sites):
424
+ label = str(self._identifiers[i]) if show_identifiers else str(i)
425
+ # Calculate a small offset for placing text labels to avoid overlap with sites.
426
+ offset = (
427
+ 0.02 * np.max(np.ptp(coords, axis=0)) if coords.size > 0 else 0.1
428
+ )
429
+
430
+ if self.dimensionality == 1:
431
+ ax.text(coords[i, 0], offset, label, fontsize=9, ha="center")
432
+ elif self.dimensionality == 2:
433
+ ax.text(
434
+ coords[i, 0] + offset,
435
+ coords[i, 1] + offset,
436
+ label,
437
+ fontsize=9,
438
+ zorder=3,
439
+ )
440
+ elif self.dimensionality == 3:
441
+ ax_3d = cast("Axes3D", ax)
442
+ ax_3d.text(
443
+ coords[i, 0],
444
+ coords[i, 1],
445
+ coords[i, 2] + offset,
446
+ label,
447
+ fontsize=9,
448
+ zorder=3,
449
+ )
450
+
451
+ if show_bonds_k is not None:
452
+ if show_bonds_k not in self._neighbor_maps:
453
+ logger.warning(
454
+ f"Cannot draw bonds. k={show_bonds_k} neighbors have not been calculated."
455
+ )
456
+ else:
457
+ try:
458
+ bonds = self.get_neighbor_pairs(k=show_bonds_k, unique=True)
459
+ plot_bond_kwargs = {
460
+ "color": "k",
461
+ "linestyle": "-",
462
+ "alpha": 0.6,
463
+ "zorder": 1,
464
+ }
465
+ if bond_kwargs:
466
+ plot_bond_kwargs.update(bond_kwargs)
467
+
468
+ if self.dimensionality > 2:
469
+ ax_3d = cast("Axes3D", ax)
470
+ for i, j in bonds:
471
+ p1, p2 = self._coordinates[i], self._coordinates[j]
472
+ ax_3d.plot(
473
+ [p1[0], p2[0]],
474
+ [p1[1], p2[1]],
475
+ [p1[2], p2[2]],
476
+ **plot_bond_kwargs,
477
+ )
478
+ else:
479
+ for i, j in bonds:
480
+ p1, p2 = self._coordinates[i], self._coordinates[j]
481
+ if self.dimensionality == 1: # type: ignore
482
+
483
+ ax.plot([p1[0], p2[0]], [0, 0], **plot_bond_kwargs) # type: ignore
484
+ else:
485
+ ax.plot([p1[0], p2[0]], [p1[1], p2[1]], **plot_bond_kwargs) # type: ignore
486
+
487
+ except ValueError as e:
488
+ logger.info(f"Could not draw bonds: {e}")
489
+
490
+ ax.set_title(f"{self.__class__.__name__} ({self.num_sites} sites)")
491
+ if self.dimensionality == 2:
492
+ ax.set_aspect("equal", adjustable="box")
493
+ ax.set_xlabel("x")
494
+ if self.dimensionality > 1:
495
+ ax.set_ylabel("y")
496
+ if self.dimensionality > 2 and hasattr(ax, "set_zlabel"):
497
+ ax.set_zlabel("z")
498
+ ax.grid(True)
499
+
500
+ # Display the plot only if the figure was created within this function.
501
+ if fig_created_internally:
502
+ plt.show()
503
+
504
+ def _identify_distance_shells(
505
+ self,
506
+ all_distances_sq: Union[Coordinates, List[float]],
507
+ max_k: int,
508
+ tol: float = 1e-6,
509
+ ) -> List[float]:
510
+ """Identifies unique distance shells from a list of squared distances.
511
+
512
+ This helper function takes a flat list of squared distances, sorts them,
513
+ and identifies the first `max_k` unique distance shells based on a
514
+ numerical tolerance.
515
+
516
+ :param all_distances_sq: A list or array
517
+ of all squared distances between pairs of sites.
518
+ :type all_distances_sq: Union[np.ndarray, List[float]]
519
+ :param max_k: The maximum number of neighbor shells to identify.
520
+ :type max_k: int
521
+ :param tol: The numerical tolerance to consider two distances equal.
522
+ :type tol: float
523
+ :return: A sorted list of squared distances representing the shells.
524
+ :rtype: List[float]
525
+ """
526
+ # A small threshold to filter out zero distances (site to itself).
527
+ ZERO_THRESHOLD_SQ = 1e-12
528
+
529
+ all_distances_sq = backend.convert_to_tensor(all_distances_sq)
530
+ # Now, the .size call below is guaranteed to be safe.
531
+ if backend.sizen(all_distances_sq) == 0:
532
+ return []
533
+
534
+ # Filter out self-distances and sort the remaining squared distances.
535
+ sorted_dist = backend.sort(
536
+ all_distances_sq[all_distances_sq > ZERO_THRESHOLD_SQ]
537
+ )
538
+
539
+ if backend.sizen(sorted_dist) == 0:
540
+ return []
541
+
542
+ dist_shells = [sorted_dist[0]]
543
+
544
+ for d_sq in sorted_dist[1:]:
545
+ if len(dist_shells) >= max_k:
546
+ break
547
+ if backend.sqrt(d_sq) - backend.sqrt(dist_shells[-1]) > tol:
548
+ dist_shells.append(d_sq)
549
+
550
+ return dist_shells
551
+
552
+ def _build_neighbors_by_distance_matrix(
553
+ self, max_k: int = 2, tol: float = 1e-6
554
+ ) -> None:
555
+ """A generic, distance-based neighbor finding method.
556
+ This method calculates the full N x N distance matrix to find neighbor
557
+ shells. It is computationally expensive for large N (O(N^2)) and is
558
+ best suited for non-periodic or custom-defined lattices.
559
+ :param max_k: The maximum number of neighbor shells to
560
+ calculate. Defaults to 2.
561
+ :type max_k: int, optional
562
+ :param tol: The numerical tolerance for distance
563
+ comparisons. Defaults to 1e-6.
564
+ :type tol: float, optional
565
+ """
566
+ if self.num_sites < 2:
567
+ return
568
+
569
+ all_coords = self._coordinates
570
+ # Vectorized computation of the squared distance matrix:
571
+ # (N, 1, D) - (1, N, D) -> (N, N, D) -> (N, N)
572
+ displacements = backend.expand_dims(all_coords, 1) - backend.expand_dims(
573
+ all_coords, 0
574
+ )
575
+ dist_matrix_sq = backend.sum(displacements**2, axis=-1)
576
+
577
+ # Flatten the matrix to a list of all squared distances to identify shells.
578
+ all_distances_sq = backend.reshape(dist_matrix_sq, [-1])
579
+ dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
580
+
581
+ self._neighbor_maps = self._build_neighbor_map_from_distances(
582
+ dist_matrix_sq, dist_shells_sq, tol
583
+ )
584
+ self._distance_matrix = backend.sqrt(dist_matrix_sq)
585
+
586
+ def _build_neighbor_map_from_distances(
587
+ self,
588
+ dist_matrix_sq: Coordinates,
589
+ dist_shells_sq: List[float],
590
+ tol: float = 1e-6,
591
+ ) -> Dict[int, NeighborMap]:
592
+ """
593
+ Builds a neighbor map from a squared distance matrix and identified shells.
594
+ This is a generic helper function to reduce code duplication.
595
+ """
596
+ neighbor_maps: Dict[int, NeighborMap] = {
597
+ k: {} for k in range(1, len(dist_shells_sq) + 1)
598
+ }
599
+ for k_idx, target_d_sq in enumerate(dist_shells_sq):
600
+ k = k_idx + 1
601
+ current_k_map: Dict[int, List[int]] = {}
602
+ # For each shell, find all pairs of sites (i, j) with that distance.
603
+ is_close_matrix = backend.abs(dist_matrix_sq - target_d_sq) < tol
604
+ rows, cols = backend.where(is_close_matrix)
605
+
606
+ for i, j in zip(backend.numpy(rows), backend.numpy(cols)):
607
+ if i == j:
608
+ continue
609
+ if i not in current_k_map:
610
+ current_k_map[i] = []
611
+ current_k_map[i].append(j)
612
+
613
+ for i in current_k_map:
614
+ current_k_map[i].sort()
615
+
616
+ neighbor_maps[k] = current_k_map
617
+ return neighbor_maps
618
+
619
+
620
+ class TILattice(AbstractLattice):
621
+ """Describes a periodic lattice with translational invariance.
622
+
623
+ This class serves as a base for any lattice defined by a repeating unit
624
+ cell. The geometry is specified by lattice vectors, the coordinates of
625
+ basis sites within a unit cell, and the total size of the lattice in
626
+ terms of unit cells.
627
+
628
+ The site identifier for this class is a tuple in the format of
629
+ `(uc_coord_1, ..., uc_coord_d, basis_index)`, where `uc_coord` represents
630
+ the integer coordinate of the unit cell and `basis_index` is the index
631
+ of the site within that unit cell's basis.
632
+
633
+ :param dimensionality: The spatial dimension of the lattice.
634
+ :type dimensionality: int
635
+ :param lattice_vectors: The lattice vectors defining the unit
636
+ cell, given as row vectors. Shape: (dimensionality, dimensionality).
637
+ For example, in 2D: `np.array([[ax, ay], [bx, by]])`.
638
+ :type lattice_vectors: np.ndarray
639
+ :param basis_coords: The Cartesian coordinates of the basis sites
640
+ within the unit cell. Shape: (num_basis_sites, dimensionality).
641
+ For a simple Bravais lattice, this would be `np.array([[0, 0]])`.
642
+ :type basis_coords: np.ndarray
643
+ :param size: A tuple specifying the number of unit cells
644
+ to generate in each lattice vector direction (e.g., (Nx, Ny)).
645
+ :type size: Tuple[int, ...]
646
+ :param pbc: Specifies whether
647
+ periodic boundary conditions are applied. Can be a single boolean
648
+ for all dimensions or a tuple of booleans for each dimension
649
+ individually. Defaults to True.
650
+ :type pbc: Union[bool, Tuple[bool, ...]], optional
651
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
652
+ up to the given order `k` upon initialization. Defaults to None.
653
+ :type precompute_neighbors: Optional[int], optional
654
+
655
+ """
656
+
657
+ def __init__(
658
+ self,
659
+ dimensionality: int,
660
+ lattice_vectors: Coordinates,
661
+ basis_coords: Coordinates,
662
+ size: Tuple[int, ...],
663
+ pbc: Union[bool, Tuple[bool, ...]] = True,
664
+ precompute_neighbors: Optional[int] = None,
665
+ ):
666
+ """Initializes the Translationally Invariant Lattice."""
667
+ super().__init__(dimensionality)
668
+
669
+ self.lattice_vectors = backend.convert_to_tensor(lattice_vectors)
670
+ self.basis_coords = backend.convert_to_tensor(basis_coords)
671
+
672
+ if self.lattice_vectors.shape != (dimensionality, dimensionality):
673
+ raise ValueError(
674
+ f"Lattice vectors shape {self.lattice_vectors.shape} does not match "
675
+ f"expected ({dimensionality}, {dimensionality})"
676
+ )
677
+ if self.basis_coords.shape[1] != dimensionality:
678
+ raise ValueError(
679
+ f"Basis coordinates dimension {self.basis_coords.shape[1]} does not "
680
+ f"match lattice dimensionality {dimensionality}"
681
+ )
682
+ if len(size) != dimensionality:
683
+ raise ValueError(
684
+ f"Size tuple length {len(size)} does not match dimensionality {dimensionality}"
685
+ )
686
+
687
+ self.num_basis = self.basis_coords.shape[0]
688
+ self.size = size
689
+ if isinstance(pbc, bool):
690
+ self.pbc = tuple([pbc] * dimensionality)
691
+ else:
692
+ if len(pbc) != dimensionality:
693
+ raise ValueError(
694
+ f"PBC tuple length {len(pbc)} does not match dimensionality {dimensionality}"
695
+ )
696
+ self.pbc = tuple(pbc)
697
+
698
+ self._build_lattice()
699
+ if precompute_neighbors is not None and precompute_neighbors > 0:
700
+ logger.info(f"Pre-computing neighbors up to k={precompute_neighbors}...")
701
+ self._build_neighbors(max_k=precompute_neighbors)
702
+
703
+ def _build_lattice(self) -> None:
704
+ """
705
+ Generates all site information for the periodic lattice in a vectorized manner.
706
+ """
707
+ ranges = [backend.arange(s) for s in self.size]
708
+
709
+ # Generate a grid of all integer unit cell coordinates.
710
+ grid = backend.meshgrid(*ranges, indexing="ij")
711
+ all_cell_coords = backend.reshape(
712
+ backend.stack(grid, axis=-1), (-1, self.dimensionality)
713
+ )
714
+
715
+ all_cell_coords = backend.cast(all_cell_coords, self.lattice_vectors.dtype)
716
+
717
+ cell_vectors = backend.tensordot(
718
+ all_cell_coords, self.lattice_vectors, axes=[[1], [0]]
719
+ )
720
+
721
+ cell_vectors = backend.cast(cell_vectors, self.basis_coords.dtype)
722
+
723
+ # Combine cell vectors with basis coordinates to get all site positions
724
+ # via broadcasting: (num_cells, 1, D) + (1, num_basis, D) -> (num_cells, num_basis, D)
725
+ all_coords = backend.expand_dims(cell_vectors, 1) + backend.expand_dims(
726
+ self.basis_coords, 0
727
+ )
728
+
729
+ self._coordinates = backend.reshape(all_coords, (-1, self.dimensionality))
730
+
731
+ self._indices = []
732
+ self._identifiers = []
733
+ self._ident_to_idx = {}
734
+ current_index = 0
735
+
736
+ # Generate integer indices and tuple-based identifiers for all sites.
737
+ # e.g., identifier = (uc_x, uc_y, basis_idx)
738
+ size_ranges = [range(s) for s in self.size]
739
+ for cell_coord_tuple in itertools.product(*size_ranges):
740
+ for basis_index in range(self.num_basis):
741
+ identifier = cell_coord_tuple + (basis_index,)
742
+ self._indices.append(current_index)
743
+ self._identifiers.append(identifier)
744
+ self._ident_to_idx[identifier] = current_index
745
+ current_index += 1
746
+
747
+ def _get_distance_matrix_with_mic_vectorized(self) -> Coordinates:
748
+ """
749
+ Computes the full N x N distance matrix using a fully vectorized approach
750
+ that correctly applies the Minimum Image Convention (MIC) for periodic
751
+ boundary conditions.
752
+
753
+ This method uses full vectorization for optimal performance and compatibility
754
+ with JIT compilation frameworks like JAX. The implementation processes all
755
+ site pairs simultaneously rather than iterating row-by-row, which provides:
756
+
757
+ - Better performance through vectorized operations
758
+ - Full compatibility with automatic differentiation
759
+ - JIT compilation support (e.g., JAX, TensorFlow)
760
+ - Consistent tensor operations throughout
761
+
762
+ The trade-off is higher memory usage compared to iterative approaches,
763
+ as it computes all pairwise distances simultaneously. For very large
764
+ lattices (N > 10^4 sites), memory usage scales as O(N^2).
765
+
766
+ :return: Distance matrix with shape (N, N) where entry (i,j) is the
767
+ minimum distance between sites i and j under periodic boundary conditions.
768
+ :rtype: Coordinates
769
+ """
770
+ # Ensure dtype consistency across backends (especially torch) by explicitly
771
+ # casting size and lattice_vectors to the same floating dtype used internally.
772
+ # Strategy: prefer existing lattice_vectors dtype; if it's an unusual dtype,
773
+ # fall back to float32 to avoid mixed-precision issues in vectorized ops.
774
+ # Note: `self.lattice_vectors` is always created via `backend.convert_to_tensor`
775
+ # in __init__, so `backend.dtype(...)` is reliable here and doesn't need try/except.
776
+ target_dt = str(backend.dtype(self.lattice_vectors)) # type: ignore
777
+ if target_dt not in ("float32", "float64"):
778
+ # fallback for unusual dtypes
779
+ target_dt = "float32"
780
+
781
+ size_arr = backend.cast(backend.convert_to_tensor(self.size), target_dt)
782
+ lattice_vecs = backend.cast(
783
+ backend.convert_to_tensor(self.lattice_vectors), target_dt
784
+ )
785
+ system_vectors = lattice_vecs * backend.expand_dims(size_arr, axis=1)
786
+
787
+ pbc_mask = backend.convert_to_tensor(self.pbc)
788
+
789
+ # Generate all 3^d possible image shifts (-1, 0, 1) for all dimensions
790
+ shift_options = [
791
+ backend.convert_to_tensor([-1.0, 0.0, 1.0])
792
+ ] * self.dimensionality
793
+ shifts_grid = backend.meshgrid(*shift_options, indexing="ij")
794
+ all_shifts = backend.reshape(
795
+ backend.stack(shifts_grid, axis=-1), (-1, self.dimensionality)
796
+ )
797
+
798
+ # Only apply shifts to periodic dimensions
799
+ masked_shifts = all_shifts * backend.cast(pbc_mask, all_shifts.dtype)
800
+
801
+ # Calculate all translation vectors due to PBC
802
+ translations_arr = backend.tensordot(
803
+ masked_shifts, system_vectors, axes=[[1], [0]]
804
+ )
805
+
806
+ # Vectorized computation of all displacements between any two sites
807
+ # Shape: (N, 1, D) - (1, N, D) -> (N, N, D)
808
+ displacements = backend.expand_dims(self._coordinates, 1) - backend.expand_dims(
809
+ self._coordinates, 0
810
+ )
811
+
812
+ # Consider all periodic images for each displacement
813
+ # Shape: (N, N, 1, D) - (1, 1, num_translations, D) -> (N, N, num_translations, D)
814
+ image_displacements = backend.expand_dims(
815
+ displacements, 2
816
+ ) - backend.expand_dims(backend.expand_dims(translations_arr, 0), 0)
817
+
818
+ # Sum of squares for distances
819
+ image_d_sq = backend.sum(image_displacements**2, axis=3)
820
+
821
+ # Find the minimum distance among all images (Minimum Image Convention)
822
+ min_dist_sq = backend.min(image_d_sq, axis=2)
823
+
824
+ safe_dist_matrix_sq = backend.where(min_dist_sq > 0, min_dist_sq, 0.0)
825
+ return backend.sqrt(safe_dist_matrix_sq)
826
+
827
+ def _build_neighbors(self, max_k: int = 2, **kwargs: Any) -> None:
828
+ """Calculates neighbor relationships for the periodic lattice.
829
+
830
+ This method computes neighbor information by first calculating the full
831
+ distance matrix using the Minimum Image Convention (MIC) to correctly
832
+ handle periodic boundary conditions. It then identifies unique distance
833
+ shells (e.g., nearest, next-nearest) and populates the neighbor maps
834
+ accordingly. This approach is general and works for any periodic lattice
835
+ geometry defined by the TILattice class.
836
+
837
+ :param max_k: The maximum order of neighbors to compute (e.g., k=1 for
838
+ nearest neighbors, k=2 for next-nearest, etc.). Defaults to 2.
839
+ :type max_k: int, optional
840
+ :param kwargs: Additional keyword arguments. May include:
841
+ - ``tol`` (float): The numerical tolerance used to determine if two
842
+ distances are equal when identifying shells. Defaults to 1e-6.
843
+ """
844
+ tol = kwargs.get("tol", 1e-6)
845
+ dist_matrix = self._get_distance_matrix_with_mic_vectorized()
846
+ dist_matrix_sq = dist_matrix**2
847
+ self._distance_matrix = dist_matrix
848
+ all_distances_sq = backend.reshape(dist_matrix_sq, [-1])
849
+ dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
850
+ self._neighbor_maps = self._build_neighbor_map_from_distances(
851
+ dist_matrix_sq, dist_shells_sq, tol
852
+ )
853
+
854
+ def _compute_distance_matrix(self) -> Coordinates:
855
+ """Computes the distance matrix using the Minimum Image Convention."""
856
+ if self.num_sites == 0:
857
+ return backend.zeros((0, 0))
858
+ return self._get_distance_matrix_with_mic_vectorized()
859
+
860
+
861
+ class SquareLattice(TILattice):
862
+ """A 2D square lattice.
863
+
864
+ This is a concrete implementation of a translationally invariant lattice
865
+ representing a simple square grid. It is a Bravais lattice with a
866
+ single-site basis.
867
+
868
+ :param size: A tuple (Nx, Ny) specifying the number of
869
+ unit cells (sites) in the x and y directions.
870
+ :type size: Tuple[int, int]
871
+ :param lattice_constant: The distance between two adjacent
872
+ sites. Defaults to 1.0.
873
+ :type lattice_constant: float, optional
874
+ :param pbc: Specifies periodic boundary conditions. Can be a single boolean
875
+ for all dimensions or a tuple of booleans for each dimension
876
+ individually. Defaults to True.
877
+ :type pbc: Union[bool, Tuple[bool, bool]], optional
878
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
879
+ up to the given order `k` upon initialization. Defaults to None.
880
+ :type precompute_neighbors: Optional[int], optional
881
+ """
882
+
883
+ def __init__(
884
+ self,
885
+ size: Tuple[int, int],
886
+ lattice_constant: Union[float, Any] = 1.0,
887
+ pbc: Union[bool, Tuple[bool, bool]] = True,
888
+ precompute_neighbors: Optional[int] = None,
889
+ ):
890
+ """Initializes the SquareLattice."""
891
+ dimensionality = 2
892
+ # Define orthogonal lattice vectors for a square.
893
+ # Avoid mixing Python floats with backend Tensors (TF would error),
894
+ # so first convert inputs to tensors of a unified dtype, then stack.
895
+ lc = backend.convert_to_tensor(lattice_constant)
896
+ dt = backend.dtype(lc)
897
+ z = backend.cast(backend.convert_to_tensor(0.0), dt)
898
+ row1 = backend.stack([lc, z])
899
+ row2 = backend.stack([z, lc])
900
+ lattice_vectors = backend.stack([row1, row2])
901
+ # A square lattice is a Bravais lattice, so it has a single-site basis.
902
+ basis_coords = backend.stack([backend.stack([z, z])])
903
+
904
+ super().__init__(
905
+ dimensionality=dimensionality,
906
+ lattice_vectors=lattice_vectors,
907
+ basis_coords=basis_coords,
908
+ size=size,
909
+ pbc=pbc,
910
+ precompute_neighbors=precompute_neighbors,
911
+ )
912
+
913
+
914
+ class HoneycombLattice(TILattice):
915
+ """A 2D honeycomb lattice.
916
+
917
+ This is a classic example of a composite lattice. It consists of a
918
+ two-site basis (sublattices A and B) on an underlying triangular
919
+ Bravais lattice.
920
+
921
+ :param size: A tuple (Nx, Ny) specifying the number of unit
922
+ cells along the two lattice vector directions.
923
+ :type size: Tuple[int, int]
924
+ :param lattice_constant: The bond length, i.e., the distance
925
+ between two nearest neighbor sites. Defaults to 1.0.
926
+ :type lattice_constant: float, optional
927
+ :param pbc: Specifies periodic
928
+ boundary conditions. Defaults to True.
929
+ :type pbc: Union[bool, Tuple[bool, bool]], optional
930
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
931
+ up to the given order `k` upon initialization. Defaults to None.
932
+ :type precompute_neighbors: Optional[int], optional
933
+
934
+ """
935
+
936
+ def __init__(
937
+ self,
938
+ size: Tuple[int, int],
939
+ lattice_constant: Union[float, Any] = 1.0,
940
+ pbc: Union[bool, Tuple[bool, bool]] = True,
941
+ precompute_neighbors: Optional[int] = None,
942
+ ):
943
+ """Initializes the HoneycombLattice."""
944
+ dimensionality = 2
945
+ a = lattice_constant
946
+ a_t = backend.convert_to_tensor(a)
947
+ zero = a_t * 0.0
948
+
949
+ # Define the two primitive lattice vectors for the underlying triangular Bravais lattice.
950
+ rt3_over_2 = math.sqrt(3.0) / 2.0
951
+ lattice_vectors = backend.stack(
952
+ [
953
+ backend.stack([a_t * 1.5, a_t * rt3_over_2]),
954
+ backend.stack([a_t * 1.5, -a_t * rt3_over_2]),
955
+ ]
956
+ )
957
+ # Define the two basis sites (A and B) within the unit cell.
958
+ basis_coords = backend.stack(
959
+ [backend.stack([zero, zero]), backend.stack([a_t * 1.0, zero])]
960
+ )
961
+
962
+ super().__init__(
963
+ dimensionality=dimensionality,
964
+ lattice_vectors=lattice_vectors,
965
+ basis_coords=basis_coords,
966
+ size=size,
967
+ pbc=pbc,
968
+ precompute_neighbors=precompute_neighbors,
969
+ )
970
+
971
+
972
+ class TriangularLattice(TILattice):
973
+ """A 2D triangular lattice.
974
+
975
+ This is a Bravais lattice where each site has 6 nearest neighbors.
976
+
977
+ :param size: A tuple (Nx, Ny) specifying the number of
978
+ unit cells along the two lattice vector directions.
979
+ :type size: Tuple[int, int]
980
+ :param lattice_constant: The bond length, i.e., the
981
+ distance between two nearest neighbor sites. Defaults to 1.0.
982
+ :type lattice_constant: float, optional
983
+ :param pbc: Specifies periodic
984
+ boundary conditions. Defaults to True.
985
+ :type pbc: Union[bool, Tuple[bool, bool]], optional
986
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
987
+ up to the given order `k` upon initialization. Defaults to None.
988
+ :type precompute_neighbors: Optional[int], optional
989
+
990
+ """
991
+
992
+ def __init__(
993
+ self,
994
+ size: Tuple[int, int],
995
+ lattice_constant: Union[float, Any] = 1.0,
996
+ pbc: Union[bool, Tuple[bool, bool]] = True,
997
+ precompute_neighbors: Optional[int] = None,
998
+ ):
999
+ """Initializes the TriangularLattice."""
1000
+ dimensionality = 2
1001
+ a = lattice_constant
1002
+ a_t = backend.convert_to_tensor(a)
1003
+ zero = a_t * 0.0
1004
+
1005
+ # Define the primitive lattice vectors for a triangular lattice.
1006
+ lattice_vectors = backend.stack(
1007
+ [
1008
+ backend.stack([a_t * 1.0, zero]),
1009
+ backend.stack(
1010
+ [
1011
+ a_t * 0.5,
1012
+ a_t * backend.sqrt(backend.convert_to_tensor(3.0)) / 2.0,
1013
+ ]
1014
+ ),
1015
+ ]
1016
+ )
1017
+ # A triangular lattice is a Bravais lattice with a single-site basis.
1018
+ basis_coords = backend.stack([backend.stack([zero, zero])])
1019
+
1020
+ super().__init__(
1021
+ dimensionality=dimensionality,
1022
+ lattice_vectors=lattice_vectors,
1023
+ basis_coords=basis_coords,
1024
+ size=size,
1025
+ pbc=pbc,
1026
+ precompute_neighbors=precompute_neighbors,
1027
+ )
1028
+
1029
+
1030
+ class ChainLattice(TILattice):
1031
+ """A 1D chain (simple Bravais lattice).
1032
+
1033
+ :param size: A tuple `(N,)` specifying the number of sites in the chain.
1034
+ :type size: Tuple[int]
1035
+ :param lattice_constant: The distance between two adjacent sites. Defaults to 1.0.
1036
+ :type lattice_constant: float, optional
1037
+ :param pbc: Specifies if periodic boundary conditions are applied. Defaults to True.
1038
+ :type pbc: bool, optional
1039
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
1040
+ up to the given order `k` upon initialization. Defaults to None.
1041
+ :type precompute_neighbors: Optional[int], optional
1042
+ """
1043
+
1044
+ def __init__(
1045
+ self,
1046
+ size: Tuple[int],
1047
+ lattice_constant: Union[float, Any] = 1.0,
1048
+ pbc: bool = True,
1049
+ precompute_neighbors: Optional[int] = None,
1050
+ ):
1051
+ dimensionality = 1
1052
+ # The lattice vector is just the lattice constant along one dimension.
1053
+ lc = backend.convert_to_tensor(lattice_constant)
1054
+ lattice_vectors = backend.stack([backend.stack([lc])])
1055
+ # A simple chain is a Bravais lattice with a single-site basis.
1056
+ zero = lc * 0.0
1057
+ basis_coords = backend.stack([backend.stack([zero])])
1058
+
1059
+ super().__init__(
1060
+ dimensionality=dimensionality,
1061
+ lattice_vectors=lattice_vectors,
1062
+ basis_coords=basis_coords,
1063
+ size=size,
1064
+ pbc=pbc,
1065
+ precompute_neighbors=precompute_neighbors,
1066
+ )
1067
+
1068
+
1069
+ class DimerizedChainLattice(TILattice):
1070
+ """A 1D chain with an AB sublattice (dimerized chain).
1071
+
1072
+ The unit cell contains two sites, A and B. The bond length is uniform.
1073
+
1074
+ :param size: A tuple `(N,)` specifying the number of **unit cells**.
1075
+ The total number of sites in the chain will be `2 * N`, as each
1076
+ unit cell contains two sites.
1077
+ :type size: Tuple[int]
1078
+ :param lattice_constant: The distance between two adjacent sites (bond length). Defaults to 1.0.
1079
+ :type lattice_constant: float, optional
1080
+ :param pbc: Specifies if periodic boundary conditions are applied. Defaults to True.
1081
+ :type pbc: bool, optional
1082
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
1083
+ up to the given order `k` upon initialization. Defaults to None.
1084
+ :type precompute_neighbors: Optional[int], optional
1085
+ """
1086
+
1087
+ def __init__(
1088
+ self,
1089
+ size: Tuple[int],
1090
+ lattice_constant: Union[float, Any] = 1.0,
1091
+ pbc: bool = True,
1092
+ precompute_neighbors: Optional[int] = None,
1093
+ ):
1094
+ dimensionality = 1
1095
+ # The unit cell is twice the bond length, as it contains two sites.
1096
+ lc = backend.convert_to_tensor(lattice_constant)
1097
+ lattice_vectors = backend.stack([backend.stack([2 * lc])])
1098
+ # Two basis sites (A and B) separated by the bond length.
1099
+ zero = lc * 0.0
1100
+ basis_coords = backend.stack([backend.stack([zero]), backend.stack([lc])])
1101
+
1102
+ super().__init__(
1103
+ dimensionality=dimensionality,
1104
+ lattice_vectors=lattice_vectors,
1105
+ basis_coords=basis_coords,
1106
+ size=size,
1107
+ pbc=pbc,
1108
+ precompute_neighbors=precompute_neighbors,
1109
+ )
1110
+
1111
+
1112
+ class RectangularLattice(TILattice):
1113
+ """A 2D rectangular lattice.
1114
+
1115
+ This is a generalization of the SquareLattice where the lattice constants
1116
+ in the x and y directions can be different.
1117
+
1118
+ :param size: A tuple (Nx, Ny) specifying the number of sites in x and y.
1119
+ :type size: Tuple[int, int]
1120
+ :param lattice_constants: The distance between adjacent sites
1121
+ in the x and y directions, e.g., (ax, ay). Defaults to (1.0, 1.0).
1122
+ :type lattice_constants: Tuple[float, float], optional
1123
+ :param pbc: Specifies periodic boundary conditions. Defaults to True.
1124
+ :type pbc: Union[bool, Tuple[bool, bool]], optional
1125
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
1126
+ up to the given order `k` upon initialization. Defaults to None.
1127
+ :type precompute_neighbors: Optional[int], optional
1128
+ """
1129
+
1130
+ def __init__(
1131
+ self,
1132
+ size: Tuple[int, int],
1133
+ lattice_constants: Union[Tuple[float, float], Any] = (1.0, 1.0),
1134
+ pbc: Union[bool, Tuple[bool, bool]] = True,
1135
+ precompute_neighbors: Optional[int] = None,
1136
+ ):
1137
+ dimensionality = 2
1138
+ ax, ay = lattice_constants
1139
+ ax_t = backend.convert_to_tensor(ax)
1140
+ dt = backend.dtype(ax_t)
1141
+ ay_t = backend.cast(backend.convert_to_tensor(ay), dt)
1142
+ z = backend.cast(backend.convert_to_tensor(0.0), dt)
1143
+ # Orthogonal lattice vectors with potentially different lengths.
1144
+ row1 = backend.stack([ax_t, z])
1145
+ row2 = backend.stack([z, ay_t])
1146
+ lattice_vectors = backend.stack([row1, row2])
1147
+ # A rectangular lattice is a Bravais lattice with a single-site basis.
1148
+ basis_coords = backend.stack([backend.stack([z, z])])
1149
+
1150
+ super().__init__(
1151
+ dimensionality=dimensionality,
1152
+ lattice_vectors=lattice_vectors,
1153
+ basis_coords=basis_coords,
1154
+ size=size,
1155
+ pbc=pbc,
1156
+ precompute_neighbors=precompute_neighbors,
1157
+ )
1158
+
1159
+
1160
+ class CheckerboardLattice(TILattice):
1161
+ """A 2D checkerboard lattice (a square lattice with an AB sublattice).
1162
+
1163
+ The unit cell is a square rotated by 45 degrees, containing two sites.
1164
+
1165
+ :param size: A tuple (Nx, Ny) specifying the number of unit cells. Total sites will be 2*Nx*Ny.
1166
+ :type size: Tuple[int, int]
1167
+ :param lattice_constant: The bond length between nearest neighbors. Defaults to 1.0.
1168
+ :type lattice_constant: float, optional
1169
+ :param pbc: Specifies periodic boundary conditions. Defaults to True.
1170
+ :type pbc: Union[bool, Tuple[bool, bool]], optional
1171
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
1172
+ up to the given order `k` upon initialization. Defaults to None.
1173
+ :type precompute_neighbors: Optional[int], optional
1174
+ """
1175
+
1176
+ def __init__(
1177
+ self,
1178
+ size: Tuple[int, int],
1179
+ lattice_constant: Union[float, Any] = 1.0,
1180
+ pbc: Union[bool, Tuple[bool, bool]] = True,
1181
+ precompute_neighbors: Optional[int] = None,
1182
+ ):
1183
+ dimensionality = 2
1184
+ a = lattice_constant
1185
+ a_t = backend.convert_to_tensor(a)
1186
+ # The unit cell is a square rotated by 45 degrees.
1187
+ lattice_vectors = backend.stack(
1188
+ [
1189
+ backend.stack([a_t * 1.0, a_t * 1.0]),
1190
+ backend.stack([a_t * 1.0, a_t * -1.0]),
1191
+ ]
1192
+ )
1193
+ # Two basis sites (A and B) within the unit cell.
1194
+ zero = a_t * 0.0
1195
+ basis_coords = backend.stack(
1196
+ [backend.stack([zero, zero]), backend.stack([a_t * 1.0, zero])]
1197
+ )
1198
+
1199
+ super().__init__(
1200
+ dimensionality=dimensionality,
1201
+ lattice_vectors=lattice_vectors,
1202
+ basis_coords=basis_coords,
1203
+ size=size,
1204
+ pbc=pbc,
1205
+ precompute_neighbors=precompute_neighbors,
1206
+ )
1207
+
1208
+
1209
+ class KagomeLattice(TILattice):
1210
+ """A 2D Kagome lattice.
1211
+
1212
+ This is a lattice with a three-site basis on a triangular Bravais lattice.
1213
+
1214
+ :param size: A tuple (Nx, Ny) specifying the number of unit cells. Total sites will be 3*Nx*Ny.
1215
+ :type size: Tuple[int, int]
1216
+ :param lattice_constant: The bond length. Defaults to 1.0.
1217
+ :type lattice_constant: float, optional
1218
+ :param pbc: Specifies periodic boundary conditions. Defaults to True.
1219
+ :type pbc: Union[bool, Tuple[bool, bool]], optional
1220
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
1221
+ up to the given order `k` upon initialization. Defaults to None.
1222
+ :type precompute_neighbors: Optional[int], optional
1223
+ """
1224
+
1225
+ def __init__(
1226
+ self,
1227
+ size: Tuple[int, int],
1228
+ lattice_constant: Union[float, Any] = 1.0,
1229
+ pbc: Union[bool, Tuple[bool, bool]] = True,
1230
+ precompute_neighbors: Optional[int] = None,
1231
+ ):
1232
+ dimensionality = 2
1233
+ a = lattice_constant
1234
+ a_t = backend.convert_to_tensor(a)
1235
+ # The Kagome lattice is based on a triangular Bravais lattice.
1236
+ lattice_vectors = backend.stack(
1237
+ [
1238
+ backend.stack([a_t * 2.0, a_t * 0.0]),
1239
+ backend.stack([a_t * 1.0, a_t * backend.sqrt(3.0)]),
1240
+ ]
1241
+ )
1242
+ # It has a three-site basis, forming the corners of the triangles.
1243
+ zero = a_t * 0.0
1244
+ basis_coords = backend.stack(
1245
+ [
1246
+ backend.stack([zero, zero]),
1247
+ backend.stack([a_t * 1.0, zero]),
1248
+ backend.stack([a_t * 0.5, a_t * backend.sqrt(3.0) / 2.0]),
1249
+ ]
1250
+ )
1251
+
1252
+ super().__init__(
1253
+ dimensionality=dimensionality,
1254
+ lattice_vectors=lattice_vectors,
1255
+ basis_coords=basis_coords,
1256
+ size=size,
1257
+ pbc=pbc,
1258
+ precompute_neighbors=precompute_neighbors,
1259
+ )
1260
+
1261
+
1262
+ class LiebLattice(TILattice):
1263
+ """A 2D Lieb lattice.
1264
+
1265
+ This is a lattice with a three-site basis on a square Bravais lattice.
1266
+ It has sites at the corners and centers of the edges of a square.
1267
+
1268
+ :param size: A tuple (Nx, Ny) specifying the number of unit cells. Total sites will be 3*Nx*Ny.
1269
+ :type size: Tuple[int, int]
1270
+ :param lattice_constant: The bond length. Defaults to 1.0.
1271
+ :type lattice_constant: float, optional
1272
+ :param pbc: Specifies periodic boundary conditions. Defaults to True.
1273
+ :type pbc: Union[bool, Tuple[bool, bool]], optional
1274
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
1275
+ up to the given order `k` upon initialization. Defaults to None.
1276
+ :type precompute_neighbors: Optional[int], optional
1277
+ """
1278
+
1279
+ def __init__(
1280
+ self,
1281
+ size: Tuple[int, int],
1282
+ lattice_constant: Union[float, Any] = 1.0,
1283
+ pbc: Union[bool, Tuple[bool, bool]] = True,
1284
+ precompute_neighbors: Optional[int] = None,
1285
+ ):
1286
+ """Initializes the LiebLattice."""
1287
+ dimensionality = 2
1288
+ bond_length = lattice_constant
1289
+ bl_t = backend.convert_to_tensor(bond_length)
1290
+ unit_cell_side_t = 2 * bl_t
1291
+ # The Lieb lattice is based on a square Bravais lattice.
1292
+ z = bl_t * 0.0
1293
+ lattice_vectors = backend.stack(
1294
+ [backend.stack([unit_cell_side_t, z]), backend.stack([z, unit_cell_side_t])]
1295
+ )
1296
+ # It has a three-site basis: one corner and two edge-centers.
1297
+ basis_coords = backend.stack(
1298
+ [
1299
+ backend.stack([z, z]), # Corner site
1300
+ backend.stack([bl_t, z]), # x-edge center
1301
+ backend.stack([z, bl_t]), # y-edge center
1302
+ ]
1303
+ )
1304
+
1305
+ super().__init__(
1306
+ dimensionality=dimensionality,
1307
+ lattice_vectors=lattice_vectors,
1308
+ basis_coords=basis_coords,
1309
+ size=size,
1310
+ pbc=pbc,
1311
+ precompute_neighbors=precompute_neighbors,
1312
+ )
1313
+
1314
+
1315
+ class CubicLattice(TILattice):
1316
+ """A 3D cubic lattice.
1317
+
1318
+ This is a simple Bravais lattice, the 3D generalization of SquareLattice.
1319
+
1320
+ :param size: A tuple (Nx, Ny, Nz) specifying the number of sites.
1321
+ :type size: Tuple[int, int, int]
1322
+ :param lattice_constant: The distance between adjacent sites. Defaults to 1.0.
1323
+ :type lattice_constant: float, optional
1324
+ :param pbc: Specifies periodic boundary conditions. Defaults to True.
1325
+ :type pbc: Union[bool, Tuple[bool, bool, bool]], optional
1326
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
1327
+ up to the given order `k` upon initialization. Defaults to None.
1328
+ :type precompute_neighbors: Optional[int], optional
1329
+ """
1330
+
1331
+ def __init__(
1332
+ self,
1333
+ size: Tuple[int, int, int],
1334
+ lattice_constant: Union[float, Any] = 1.0,
1335
+ pbc: Union[bool, Tuple[bool, bool, bool]] = True,
1336
+ precompute_neighbors: Optional[int] = None,
1337
+ ):
1338
+ dimensionality = 3
1339
+ a = lattice_constant
1340
+ a_t = backend.convert_to_tensor(a)
1341
+ # Orthogonal lattice vectors of equal length in 3D.
1342
+ z = a_t * 0.0
1343
+ lattice_vectors = backend.stack(
1344
+ [
1345
+ backend.stack([a_t, z, z]),
1346
+ backend.stack([z, a_t, z]),
1347
+ backend.stack([z, z, a_t]),
1348
+ ]
1349
+ )
1350
+ # A simple cubic lattice is a Bravais lattice with a single-site basis.
1351
+ basis_coords = backend.stack([backend.stack([z, z, z])])
1352
+ super().__init__(
1353
+ dimensionality=dimensionality,
1354
+ lattice_vectors=lattice_vectors,
1355
+ basis_coords=basis_coords,
1356
+ size=size,
1357
+ pbc=pbc,
1358
+ precompute_neighbors=precompute_neighbors,
1359
+ )
1360
+
1361
+
1362
+ class CustomizeLattice(AbstractLattice):
1363
+ """A general lattice built from an explicit list of sites and coordinates.
1364
+
1365
+ This class is suitable for creating lattices with arbitrary geometries,
1366
+ such as finite clusters, disordered systems, or any custom structure
1367
+ that does not have translational symmetry. The lattice is defined simply
1368
+ by providing lists of identifiers and coordinates for each site.
1369
+
1370
+ :param dimensionality: The spatial dimension of the lattice.
1371
+ :type dimensionality: int
1372
+ :param identifiers: A list of unique, hashable
1373
+ identifiers for the sites. The length must match `coordinates`.
1374
+ :type identifiers: List[SiteIdentifier]
1375
+ :param coordinates: A list of site
1376
+ coordinates. Each coordinate should be a list of floats or a
1377
+ NumPy array.
1378
+ :type coordinates: List[Union[List[float], Coordinates]]
1379
+ :raises ValueError: If the lengths of `identifiers` and `coordinates` lists
1380
+ do not match, or if a coordinate's dimension is incorrect.
1381
+ :param precompute_neighbors: If specified, pre-computes neighbor relationships
1382
+ up to the given order `k` upon initialization. Defaults to None.
1383
+ :type precompute_neighbors: Optional[int], optional
1384
+
1385
+ """
1386
+
1387
+ def __init__(
1388
+ self,
1389
+ dimensionality: int,
1390
+ identifiers: List[SiteIdentifier],
1391
+ coordinates: Any,
1392
+ precompute_neighbors: Optional[int] = None,
1393
+ ):
1394
+ """Initializes the CustomizeLattice."""
1395
+ super().__init__(dimensionality)
1396
+
1397
+ self._coordinates = backend.convert_to_tensor(coordinates)
1398
+ if len(identifiers) == 0:
1399
+ self._coordinates = backend.reshape(
1400
+ self._coordinates, (0, self.dimensionality)
1401
+ )
1402
+
1403
+ if len(identifiers) != backend.shape_tuple(self._coordinates)[0]:
1404
+ raise ValueError(
1405
+ "The number of identifiers must match the number of coordinates. "
1406
+ f"Got {len(identifiers)} identifiers and "
1407
+ f"{backend.shape_tuple(self._coordinates)[0]} coordinates."
1408
+ )
1409
+
1410
+ self._identifiers = list(identifiers)
1411
+ self._indices = list(range(len(identifiers)))
1412
+ self._ident_to_idx = {ident: idx for idx, ident in enumerate(identifiers)}
1413
+
1414
+ if (
1415
+ self.num_sites > 0
1416
+ and backend.shape_tuple(self._coordinates)[1] != dimensionality
1417
+ ):
1418
+ raise ValueError(
1419
+ f"Coordinates tensor has dimension {backend.shape_tuple(self._coordinates)[1]}, "
1420
+ f"but expected dimensionality is {dimensionality}."
1421
+ )
1422
+
1423
+ logger.info(f"CustomizeLattice with {self.num_sites} sites created.")
1424
+
1425
+ if precompute_neighbors is not None and precompute_neighbors > 0:
1426
+ self._build_neighbors(max_k=precompute_neighbors)
1427
+
1428
+ def _build_lattice(self, *args: Any, **kwargs: Any) -> None:
1429
+ """For CustomizeLattice, lattice data is built during __init__."""
1430
+ pass
1431
+
1432
+ def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None:
1433
+ """
1434
+ Calculates neighbor relationships using either KDTree or distance matrix methods.
1435
+
1436
+ This method supports two modes:
1437
+ 1. KDTree mode (use_kdtree=True): Fast, O(N log N) performance for large lattices
1438
+ but breaks differentiability due to scipy dependency
1439
+ 2. Distance matrix mode (use_kdtree=False): Slower O(N^2) but fully differentiable
1440
+ and backend-agnostic
1441
+
1442
+ :param max_k: Maximum number of neighbor shells to compute
1443
+ :type max_k: int
1444
+ :param kwargs: Additional arguments including:
1445
+ - use_kdtree (bool): Whether to use KDTree optimization. Defaults to False.
1446
+ - tol (float): Distance tolerance for neighbor identification. Defaults to 1e-6.
1447
+ """
1448
+ tol = kwargs.get("tol", 1e-6)
1449
+ # Reviewer suggestion: prefer differentiable method by default
1450
+ use_kdtree = kwargs.get("use_kdtree", False)
1451
+
1452
+ if self.num_sites < 2:
1453
+ return
1454
+
1455
+ # Choose algorithm based on user preference
1456
+ if use_kdtree:
1457
+ logger.info(
1458
+ f"Using KDTree method for {self.num_sites} sites up to k={max_k}"
1459
+ )
1460
+ self._build_neighbors_kdtree(max_k, tol)
1461
+ else:
1462
+ logger.info(
1463
+ f"Using differentiable distance matrix method for {self.num_sites} sites up to k={max_k}"
1464
+ )
1465
+
1466
+ # Use the existing distance matrix method
1467
+ self._build_neighbors_by_distance_matrix(max_k, tol)
1468
+
1469
+ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None:
1470
+ """
1471
+ Build neighbors using KDTree for optimal performance.
1472
+
1473
+ This method provides O(N log N) performance for neighbor finding but breaks
1474
+ differentiability due to scipy dependency. Use this method when:
1475
+ - Performance is critical
1476
+ - Differentiability is not required
1477
+ - Large lattices (N > 1000)
1478
+
1479
+ Note: This method uses numpy arrays directly and may not be compatible
1480
+ with all backend types (JAX, TensorFlow, etc.).
1481
+ """
1482
+
1483
+ # For small lattices or cases with potential duplicate coordinates,
1484
+ # fall back to distance matrix method for robustness
1485
+ if self.num_sites < 200:
1486
+ logger.info(
1487
+ "Small lattice detected, falling back to distance matrix method for robustness"
1488
+ )
1489
+ self._build_neighbors_by_distance_matrix(max_k, tol)
1490
+ return
1491
+
1492
+ # Convert coordinates to numpy for KDTree
1493
+ coords_np = backend.numpy(self._coordinates)
1494
+
1495
+ # Build KDTree
1496
+ logger.info("Building KDTree...")
1497
+ tree = KDTree(coords_np)
1498
+ # Find all distances for shell identification - use comprehensive sampling
1499
+ logger.info("Identifying distance shells...")
1500
+ distances_for_shells: List[float] = []
1501
+
1502
+ # For robust shell identification, query all pairwise distances for smaller lattices
1503
+ # or use dense sampling for larger ones
1504
+ if self.num_sites <= 100:
1505
+ # For small lattices, compute all pairwise distances for accuracy
1506
+ for i in range(self.num_sites):
1507
+ query_k = min(self.num_sites - 1, max_k * 20)
1508
+ if query_k > 0:
1509
+ dists, _ = tree.query(
1510
+ coords_np[i], k=query_k + 1
1511
+ ) # +1 to exclude self
1512
+ if isinstance(dists, np.ndarray):
1513
+ distances_for_shells.extend(dists[1:]) # Skip distance to self
1514
+ else:
1515
+ distances_for_shells.append(dists) # Single distance
1516
+ else:
1517
+ # For larger lattices, use adaptive sampling but ensure we capture all shells
1518
+ sample_size = min(1000, self.num_sites // 2) # More conservative sampling
1519
+ for i in range(0, self.num_sites, max(1, self.num_sites // sample_size)):
1520
+ query_k = min(max_k * 20 + 50, self.num_sites - 1)
1521
+ if query_k > 0:
1522
+ dists, _ = tree.query(
1523
+ coords_np[i], k=query_k + 1
1524
+ ) # +1 to exclude self
1525
+ if isinstance(dists, np.ndarray):
1526
+ distances_for_shells.extend(dists[1:]) # Skip distance to self
1527
+ else:
1528
+ distances_for_shells.append(dists) # Single distance
1529
+
1530
+ # Filter out zero distances (duplicate coordinates) before shell identification
1531
+ ZERO_THRESHOLD = 1e-12
1532
+ distances_for_shells = [d for d in distances_for_shells if d > ZERO_THRESHOLD]
1533
+
1534
+ if not distances_for_shells:
1535
+ logger.warning("No valid distances found for shell identification")
1536
+ self._neighbor_maps = {}
1537
+ return
1538
+
1539
+ # Use the same shell identification logic as distance matrix method
1540
+ distances_for_shells_sq = [d * d for d in distances_for_shells]
1541
+ dist_shells_sq = self._identify_distance_shells(
1542
+ distances_for_shells_sq, max_k, tol
1543
+ )
1544
+ dist_shells = [np.sqrt(d_sq) for d_sq in dist_shells_sq]
1545
+
1546
+ logger.info(f"Found {len(dist_shells)} distance shells: {dist_shells[:5]}...")
1547
+
1548
+ # Initialize neighbor maps
1549
+ self._neighbor_maps = {k: {} for k in range(1, len(dist_shells) + 1)}
1550
+
1551
+ # Build neighbor lists for each site
1552
+ for i in range(self.num_sites):
1553
+ # Query enough neighbors to capture all shells
1554
+ query_k = min(max_k * 20 + 50, self.num_sites - 1)
1555
+ if query_k > 0:
1556
+ distances, indices = tree.query(
1557
+ coords_np[i], k=query_k + 1
1558
+ ) # +1 for self
1559
+
1560
+ # Skip the first entry (distance to self)
1561
+ # Handle both single value and array cases
1562
+ if isinstance(distances, np.ndarray) and len(distances) > 1:
1563
+ distances_slice = distances[1:]
1564
+ indices_slice = (
1565
+ indices[1:]
1566
+ if isinstance(indices, np.ndarray)
1567
+ else np.array([], dtype=int)
1568
+ )
1569
+ else:
1570
+ # Single value or empty case - no neighbors to process
1571
+ distances_slice = np.array([])
1572
+ indices_slice = np.array([], dtype=int)
1573
+
1574
+ # Filter out zero distances (duplicate coordinates)
1575
+ valid_pairs = [
1576
+ (d, idx)
1577
+ for d, idx in zip(distances_slice, indices_slice)
1578
+ if d > ZERO_THRESHOLD
1579
+ ]
1580
+
1581
+ # Assign neighbors to shells
1582
+ for shell_idx, shell_dist in enumerate(dist_shells):
1583
+ k = shell_idx + 1
1584
+ shell_neighbors = []
1585
+
1586
+ for dist, neighbor_idx in valid_pairs:
1587
+ if abs(dist - shell_dist) <= tol:
1588
+ shell_neighbors.append(int(neighbor_idx))
1589
+ elif dist > shell_dist + tol:
1590
+ break # Distances are sorted, no more matches
1591
+
1592
+ if shell_neighbors:
1593
+ self._neighbor_maps[k][i] = sorted(shell_neighbors)
1594
+
1595
+ # Set distance matrix to None - will compute on demand
1596
+ self._distance_matrix = None
1597
+
1598
+ def _reset_computations(self) -> None:
1599
+ """Resets all cached data that depends on the lattice structure."""
1600
+ self._neighbor_maps = {}
1601
+ self._distance_matrix = None
1602
+
1603
+ @classmethod
1604
+ def from_lattice(cls, lattice: "AbstractLattice") -> "CustomizeLattice":
1605
+ """Creates a CustomizeLattice instance from any existing lattice object.
1606
+
1607
+ This is useful for 'detaching' a procedurally generated lattice (like
1608
+ a SquareLattice) into a customizable one for further modifications,
1609
+ such as adding defects or extra sites.
1610
+
1611
+ :param lattice: An instance of any AbstractLattice subclass.
1612
+ :type lattice: AbstractLattice
1613
+ :return: A new CustomizeLattice instance with the same sites.
1614
+ :rtype: CustomizeLattice
1615
+ """
1616
+ all_sites_info = list(lattice.sites())
1617
+
1618
+ if not all_sites_info:
1619
+ return cls(
1620
+ dimensionality=lattice.dimensionality, identifiers=[], coordinates=[]
1621
+ )
1622
+
1623
+ # Unzip the list of tuples into separate lists of identifiers and coordinates
1624
+ _, identifiers, _ = zip(*all_sites_info)
1625
+
1626
+ # Detach-and-copy coordinates while remaining in tensor form to avoid
1627
+ # host roundtrips and device/dtype changes; this keeps CustomizeLattice
1628
+ # decoupled from the original graph but backend-friendly.
1629
+ # Some backends (e.g., NumPy) don't implement stop_gradient; fall back.
1630
+ try:
1631
+ coords_detached = backend.stop_gradient(lattice._coordinates)
1632
+ except NotImplementedError:
1633
+ coords_detached = lattice._coordinates
1634
+ coords_tensor = backend.copy(coords_detached)
1635
+
1636
+ return cls(
1637
+ dimensionality=lattice.dimensionality,
1638
+ identifiers=list(identifiers),
1639
+ coordinates=coords_tensor,
1640
+ )
1641
+
1642
+ def add_sites(
1643
+ self,
1644
+ identifiers: List[SiteIdentifier],
1645
+ coordinates: Any,
1646
+ ) -> None:
1647
+ """Adds new sites to the lattice.
1648
+
1649
+ This operation modifies the lattice in-place. After adding sites, any
1650
+ previously computed neighbor information is cleared and must be
1651
+ recalculated.
1652
+
1653
+ :param identifiers: A list of unique identifiers for the new sites.
1654
+ :type identifiers: List[SiteIdentifier]
1655
+ :param coordinates: The coordinates for the new sites. Can be a list of lists,
1656
+ a NumPy array, or a backend-compatible tensor (e.g., jax.numpy.ndarray).
1657
+ :type coordinates: Any
1658
+ """
1659
+ if not identifiers:
1660
+ return
1661
+
1662
+ new_coords_tensor = backend.convert_to_tensor(coordinates)
1663
+
1664
+ if len(identifiers) != backend.shape_tuple(new_coords_tensor)[0]:
1665
+ raise ValueError(
1666
+ "Identifiers and coordinates lists must have the same length."
1667
+ )
1668
+
1669
+ if backend.shape_tuple(new_coords_tensor)[1] != self.dimensionality:
1670
+ raise ValueError(
1671
+ f"New coordinate tensor has dimension {backend.shape_tuple(new_coords_tensor)[1]}, "
1672
+ f"but expected dimensionality is {self.dimensionality}."
1673
+ )
1674
+
1675
+ # Ensure that the new identifiers are unique and do not already exist.
1676
+ existing_ids = set(self._identifiers)
1677
+ new_ids = set(identifiers)
1678
+ if not new_ids.isdisjoint(existing_ids):
1679
+ raise ValueError(
1680
+ f"Duplicate identifiers found: {new_ids.intersection(existing_ids)}"
1681
+ )
1682
+
1683
+ self._coordinates = backend.concat(
1684
+ [self._coordinates, new_coords_tensor], axis=0
1685
+ )
1686
+ self._identifiers.extend(identifiers)
1687
+
1688
+ self._indices = list(range(len(self._identifiers)))
1689
+ self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)}
1690
+
1691
+ self._reset_computations()
1692
+ logger.info(
1693
+ f"{len(identifiers)} sites added. Lattice now has {self.num_sites} sites."
1694
+ )
1695
+
1696
+ def remove_sites(self, identifiers: List[SiteIdentifier]) -> None:
1697
+ """Removes specified sites from the lattice.
1698
+
1699
+ This operation modifies the lattice in-place. After removing sites,
1700
+ all site indices are re-calculated, and any previously computed
1701
+ neighbor information is cleared.
1702
+
1703
+ :param identifiers: A list of identifiers for the sites to be removed.
1704
+ :type identifiers: List[SiteIdentifier]
1705
+ """
1706
+ if not identifiers:
1707
+ return
1708
+
1709
+ ids_to_remove = set(identifiers)
1710
+ current_ids = set(self._identifiers)
1711
+ if not ids_to_remove.issubset(current_ids):
1712
+ raise ValueError(
1713
+ f"Non-existent identifiers provided for removal: {ids_to_remove - current_ids}"
1714
+ )
1715
+
1716
+ # Find the indices of the sites that we want to keep.
1717
+ indices_to_keep = [
1718
+ idx
1719
+ for idx, ident in enumerate(self._identifiers)
1720
+ if ident not in ids_to_remove
1721
+ ]
1722
+
1723
+ new_identifiers = [self._identifiers[i] for i in indices_to_keep]
1724
+
1725
+ self._coordinates = backend.gather1d(
1726
+ self._coordinates,
1727
+ backend.cast(backend.convert_to_tensor(indices_to_keep), "int32"),
1728
+ )
1729
+
1730
+ self._identifiers = new_identifiers
1731
+
1732
+ self._indices = list(range(len(self._identifiers)))
1733
+ self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)}
1734
+
1735
+ self._reset_computations()
1736
+ logger.info(
1737
+ f"{len(ids_to_remove)} sites removed. Lattice now has {self.num_sites} sites."
1738
+ )
1739
+
1740
+
1741
+ def get_compatible_layers(bonds: List[Tuple[int, int]]) -> List[List[Tuple[int, int]]]:
1742
+ """
1743
+ Partitions a list of pairs (bonds) into compatible layers for parallel
1744
+ gate application using a greedy edge-coloring algorithm.
1745
+
1746
+ This function takes a list of pairs, representing connections like
1747
+ nearest-neighbor (NN) or next-nearest-neighbor (NNN) bonds, and
1748
+ partitions them into the minimum number of sets ("layers") where no two
1749
+ pairs in a set share an index. This is a general utility for scheduling
1750
+ non-overlapping operations.
1751
+
1752
+ :Example:
1753
+
1754
+ >>> from tensorcircuit.templates.lattice import SquareLattice
1755
+ >>> sq_lattice = SquareLattice(size=(2, 2), pbc=False)
1756
+ >>> nn_bonds = sq_lattice.get_neighbor_pairs(k=1, unique=True)
1757
+
1758
+ >>> gate_layers = get_compatible_layers(nn_bonds)
1759
+ >>> print(gate_layers)
1760
+ [[[0, 1], [2, 3]], [[0, 2], [1, 3]]]
1761
+
1762
+ :param bonds: A list of tuples, where each tuple represents a bond (i, j)
1763
+ of site indices to be scheduled.
1764
+ :type bonds: List[Tuple[int, int]]
1765
+ :return: A list of layers. Each layer is a list of tuples, where each
1766
+ tuple represents a bond. All bonds within a layer are non-overlapping.
1767
+ :rtype: List[List[Tuple[int, int]]]
1768
+ """
1769
+ uncolored_edges: Set[Tuple[int, int]] = {(min(bond), max(bond)) for bond in bonds}
1770
+
1771
+ layers: List[List[Tuple[int, int]]] = []
1772
+
1773
+ while uncolored_edges:
1774
+ current_layer: List[Tuple[int, int]] = []
1775
+ qubits_in_this_layer: Set[int] = set()
1776
+
1777
+ edges_to_process = sorted(list(uncolored_edges))
1778
+
1779
+ for edge in edges_to_process:
1780
+ i, j = edge
1781
+ if i not in qubits_in_this_layer and j not in qubits_in_this_layer:
1782
+ current_layer.append(edge)
1783
+ qubits_in_this_layer.add(i)
1784
+ qubits_in_this_layer.add(j)
1785
+
1786
+ uncolored_edges -= set(current_layer)
1787
+ layers.append(sorted(current_layer))
1788
+
1789
+ return layers