tensorcircuit-nightly 1.0.2.dev20250108__py3-none-any.whl → 1.4.0.dev20251103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tensorcircuit-nightly might be problematic. Click here for more details.
- tensorcircuit/__init__.py +18 -2
- tensorcircuit/about.py +46 -0
- tensorcircuit/abstractcircuit.py +4 -0
- tensorcircuit/analogcircuit.py +413 -0
- tensorcircuit/applications/layers.py +1 -1
- tensorcircuit/applications/van.py +1 -1
- tensorcircuit/backends/abstract_backend.py +320 -7
- tensorcircuit/backends/cupy_backend.py +3 -1
- tensorcircuit/backends/jax_backend.py +102 -4
- tensorcircuit/backends/jax_ops.py +110 -1
- tensorcircuit/backends/numpy_backend.py +49 -3
- tensorcircuit/backends/pytorch_backend.py +92 -3
- tensorcircuit/backends/tensorflow_backend.py +102 -3
- tensorcircuit/basecircuit.py +157 -98
- tensorcircuit/circuit.py +115 -57
- tensorcircuit/cloud/local.py +1 -1
- tensorcircuit/cloud/quafu_provider.py +1 -1
- tensorcircuit/cloud/tencent.py +1 -1
- tensorcircuit/compiler/simple_compiler.py +2 -2
- tensorcircuit/cons.py +142 -21
- tensorcircuit/densitymatrix.py +43 -14
- tensorcircuit/experimental.py +387 -129
- tensorcircuit/fgs.py +282 -81
- tensorcircuit/gates.py +66 -22
- tensorcircuit/interfaces/__init__.py +1 -3
- tensorcircuit/interfaces/jax.py +189 -0
- tensorcircuit/keras.py +3 -3
- tensorcircuit/mpscircuit.py +154 -65
- tensorcircuit/quantum.py +868 -152
- tensorcircuit/quditcircuit.py +733 -0
- tensorcircuit/quditgates.py +618 -0
- tensorcircuit/results/counts.py +147 -20
- tensorcircuit/results/readout_mitigation.py +4 -1
- tensorcircuit/shadows.py +1 -1
- tensorcircuit/simplify.py +3 -1
- tensorcircuit/stabilizercircuit.py +479 -0
- tensorcircuit/templates/__init__.py +2 -0
- tensorcircuit/templates/blocks.py +2 -2
- tensorcircuit/templates/hamiltonians.py +174 -0
- tensorcircuit/templates/lattice.py +1789 -0
- tensorcircuit/timeevol.py +896 -0
- tensorcircuit/translation.py +10 -3
- tensorcircuit/utils.py +7 -0
- {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/METADATA +73 -23
- tensorcircuit_nightly-1.4.0.dev20251103.dist-info/RECORD +96 -0
- {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/WHEEL +1 -1
- {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/top_level.txt +0 -1
- tensorcircuit_nightly-1.0.2.dev20250108.dist-info/RECORD +0 -115
- tests/__init__.py +0 -0
- tests/conftest.py +0 -67
- tests/test_backends.py +0 -1031
- tests/test_calibrating.py +0 -149
- tests/test_channels.py +0 -365
- tests/test_circuit.py +0 -1699
- tests/test_cloud.py +0 -219
- tests/test_compiler.py +0 -147
- tests/test_dmcircuit.py +0 -555
- tests/test_ensemble.py +0 -72
- tests/test_fgs.py +0 -310
- tests/test_gates.py +0 -156
- tests/test_interfaces.py +0 -429
- tests/test_keras.py +0 -160
- tests/test_miscs.py +0 -277
- tests/test_mpscircuit.py +0 -341
- tests/test_noisemodel.py +0 -156
- tests/test_qaoa.py +0 -86
- tests/test_qem.py +0 -152
- tests/test_quantum.py +0 -526
- tests/test_quantum_attr.py +0 -42
- tests/test_results.py +0 -347
- tests/test_shadows.py +0 -160
- tests/test_simplify.py +0 -46
- tests/test_templates.py +0 -218
- tests/test_torchnn.py +0 -99
- tests/test_van.py +0 -102
- {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info/licenses}/LICENSE +0 -0
|
@@ -0,0 +1,1789 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
The lattice module for defining and manipulating lattice geometries.
|
|
4
|
+
"""
|
|
5
|
+
import logging
|
|
6
|
+
import abc
|
|
7
|
+
from typing import (
|
|
8
|
+
Any,
|
|
9
|
+
Dict,
|
|
10
|
+
Hashable,
|
|
11
|
+
Iterator,
|
|
12
|
+
List,
|
|
13
|
+
Optional,
|
|
14
|
+
Tuple,
|
|
15
|
+
Union,
|
|
16
|
+
TYPE_CHECKING,
|
|
17
|
+
cast,
|
|
18
|
+
Set,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
import itertools
|
|
22
|
+
import math
|
|
23
|
+
import numpy as np
|
|
24
|
+
from scipy.spatial import KDTree
|
|
25
|
+
|
|
26
|
+
from .. import backend
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# This block resolves a name resolution issue for the static type checker (mypy).
|
|
30
|
+
# GOAL:
|
|
31
|
+
# Keep `matplotlib` as an optional dependency, so it is only imported
|
|
32
|
+
# inside the `show()` method, not at the module level.
|
|
33
|
+
# PROBLEM:
|
|
34
|
+
# The type hint for the `ax` parameter in `show()`'s signature
|
|
35
|
+
# (`ax: Optional["matplotlib.axes.Axes"]`) needs to know what `matplotlib` is.
|
|
36
|
+
# Without this block, mypy would raise a "Name 'matplotlib' is not defined" error.
|
|
37
|
+
# SOLUTION:
|
|
38
|
+
# The `if TYPE_CHECKING:` block is ignored at runtime but processed by mypy.
|
|
39
|
+
# This makes the name `matplotlib` available to the type checker without
|
|
40
|
+
# creating a hard dependency for the user.
|
|
41
|
+
if TYPE_CHECKING:
|
|
42
|
+
import matplotlib.axes
|
|
43
|
+
from mpl_toolkits.mplot3d import Axes3D
|
|
44
|
+
|
|
45
|
+
logger = logging.getLogger(__name__)
|
|
46
|
+
|
|
47
|
+
Tensor = Any
|
|
48
|
+
SiteIndex = int
|
|
49
|
+
SiteIdentifier = Hashable
|
|
50
|
+
Coordinates = Tensor
|
|
51
|
+
|
|
52
|
+
NeighborMap = Dict[SiteIndex, List[SiteIndex]]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class AbstractLattice(abc.ABC):
|
|
56
|
+
"""Abstract base class for describing lattice systems.
|
|
57
|
+
|
|
58
|
+
This class defines the common interface for all lattice structures,
|
|
59
|
+
providing access to fundamental properties like site information
|
|
60
|
+
(count, coordinates, identifiers) and neighbor relationships.
|
|
61
|
+
Subclasses are responsible for implementing the specific logic for
|
|
62
|
+
generating the lattice points and calculating neighbor connections.
|
|
63
|
+
|
|
64
|
+
:param dimensionality: The spatial dimension of the lattice (e.g., 1, 2, 3).
|
|
65
|
+
:type dimensionality: int
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(self, dimensionality: int):
|
|
69
|
+
"""Initializes the base lattice class."""
|
|
70
|
+
self._dimensionality = dimensionality
|
|
71
|
+
|
|
72
|
+
# Core data structures for storing site information.
|
|
73
|
+
self._indices: List[SiteIndex] = [] # List of integer indices [0, 1, ..., N-1]
|
|
74
|
+
self._identifiers: List[SiteIdentifier] = (
|
|
75
|
+
[]
|
|
76
|
+
) # List of unique, hashable site identifiers
|
|
77
|
+
# Always initialize to an empty coordinate tensor with correct dimensionality
|
|
78
|
+
# so that type checkers know this is indexable and not Optional.
|
|
79
|
+
self._coordinates: Coordinates = backend.zeros((0, dimensionality))
|
|
80
|
+
|
|
81
|
+
# Mappings for efficient lookups.
|
|
82
|
+
self._ident_to_idx: Dict[SiteIdentifier, SiteIndex] = (
|
|
83
|
+
{}
|
|
84
|
+
) # Maps identifiers to indices
|
|
85
|
+
|
|
86
|
+
# Cached properties, computed on demand.
|
|
87
|
+
self._neighbor_maps: Dict[int, NeighborMap] = (
|
|
88
|
+
{}
|
|
89
|
+
) # Caches neighbor info for different k
|
|
90
|
+
self._distance_matrix: Optional[Coordinates] = (
|
|
91
|
+
None # Caches the full N x N distance matrix
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def num_sites(self) -> int:
|
|
96
|
+
"""Returns the total number of sites (N) in the lattice."""
|
|
97
|
+
return len(self._indices)
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def dimensionality(self) -> int:
|
|
101
|
+
"""Returns the spatial dimension of the lattice."""
|
|
102
|
+
return self._dimensionality
|
|
103
|
+
|
|
104
|
+
def __len__(self) -> int:
|
|
105
|
+
"""Returns the total number of sites, enabling `len(lattice)`."""
|
|
106
|
+
return self.num_sites
|
|
107
|
+
|
|
108
|
+
# --- Public API for Accessing Lattice Information ---
|
|
109
|
+
@property
|
|
110
|
+
def distance_matrix(self) -> Coordinates:
|
|
111
|
+
"""
|
|
112
|
+
Returns the full N x N distance matrix.
|
|
113
|
+
The matrix is computed on the first access and then cached for
|
|
114
|
+
subsequent calls. This computation can be expensive for large lattices.
|
|
115
|
+
"""
|
|
116
|
+
if self._distance_matrix is None:
|
|
117
|
+
self._distance_matrix = self._compute_distance_matrix()
|
|
118
|
+
return self._distance_matrix
|
|
119
|
+
|
|
120
|
+
def _validate_index(self, index: SiteIndex) -> None:
|
|
121
|
+
"""A private helper to check if a site index is within the valid range."""
|
|
122
|
+
if not (0 <= index < self.num_sites):
|
|
123
|
+
raise IndexError(
|
|
124
|
+
f"Site index {index} out of range (0-{self.num_sites - 1})"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def get_coordinates(self, index: SiteIndex) -> Coordinates:
|
|
128
|
+
"""Gets the spatial coordinates of a site by its integer index.
|
|
129
|
+
|
|
130
|
+
:param index: The integer index of the site.
|
|
131
|
+
:type index: SiteIndex
|
|
132
|
+
:raises IndexError: If the site index is out of range.
|
|
133
|
+
:return: The spatial coordinates as a NumPy array.
|
|
134
|
+
:rtype: Coordinates
|
|
135
|
+
"""
|
|
136
|
+
self._validate_index(index)
|
|
137
|
+
coords = self._coordinates[index]
|
|
138
|
+
return coords
|
|
139
|
+
|
|
140
|
+
def get_identifier(self, index: SiteIndex) -> SiteIdentifier:
|
|
141
|
+
"""Gets the abstract identifier of a site by its integer index.
|
|
142
|
+
|
|
143
|
+
:param index: The integer index of the site.
|
|
144
|
+
:type index: SiteIndex
|
|
145
|
+
:raises IndexError: If the site index is out of range.
|
|
146
|
+
:return: The unique, hashable identifier of the site.
|
|
147
|
+
:rtype: SiteIdentifier
|
|
148
|
+
"""
|
|
149
|
+
self._validate_index(index)
|
|
150
|
+
return self._identifiers[index]
|
|
151
|
+
|
|
152
|
+
def get_index(self, identifier: SiteIdentifier) -> SiteIndex:
|
|
153
|
+
"""Gets the integer index of a site by its unique identifier.
|
|
154
|
+
|
|
155
|
+
:param identifier: The unique identifier of the site.
|
|
156
|
+
:type identifier: SiteIdentifier
|
|
157
|
+
:raises ValueError: If the identifier is not found in the lattice.
|
|
158
|
+
:return: The corresponding integer index of the site.
|
|
159
|
+
:rtype: SiteIndex
|
|
160
|
+
"""
|
|
161
|
+
try:
|
|
162
|
+
index = self._ident_to_idx[identifier]
|
|
163
|
+
return index
|
|
164
|
+
except KeyError as e:
|
|
165
|
+
raise ValueError(
|
|
166
|
+
f"Identifier {identifier} not found in the lattice."
|
|
167
|
+
) from e
|
|
168
|
+
|
|
169
|
+
def get_site_info(
|
|
170
|
+
self, index_or_identifier: Union[SiteIndex, SiteIdentifier]
|
|
171
|
+
) -> Tuple[SiteIndex, SiteIdentifier, Coordinates]:
|
|
172
|
+
"""Gets all information for a single site.
|
|
173
|
+
|
|
174
|
+
This method provides a convenient way to retrieve all relevant data for a
|
|
175
|
+
site (its index, identifier, and coordinates) by using either its
|
|
176
|
+
integer index or its unique identifier.
|
|
177
|
+
|
|
178
|
+
:param index_or_identifier: The integer
|
|
179
|
+
index or the unique identifier of the site to look up.
|
|
180
|
+
:type index_or_identifier: Union[SiteIndex, SiteIdentifier]
|
|
181
|
+
:raises IndexError: If the given index is out of bounds.
|
|
182
|
+
:raises ValueError: If the given identifier is not found in the lattice.
|
|
183
|
+
:return: A tuple containing:
|
|
184
|
+
- The site's integer index.
|
|
185
|
+
- The site's unique identifier.
|
|
186
|
+
- The site's coordinates as a NumPy array.
|
|
187
|
+
:rtype: Tuple[SiteIndex, SiteIdentifier, Coordinates]
|
|
188
|
+
"""
|
|
189
|
+
if isinstance(index_or_identifier, int): # SiteIndex is an int
|
|
190
|
+
idx = index_or_identifier
|
|
191
|
+
self._validate_index(idx)
|
|
192
|
+
return idx, self._identifiers[idx], self._coordinates[idx]
|
|
193
|
+
else:
|
|
194
|
+
ident = index_or_identifier
|
|
195
|
+
idx = self.get_index(ident)
|
|
196
|
+
return idx, ident, self._coordinates[idx]
|
|
197
|
+
|
|
198
|
+
def sites(self) -> Iterator[Tuple[SiteIndex, SiteIdentifier, Coordinates]]:
|
|
199
|
+
"""Returns an iterator over all sites in the lattice.
|
|
200
|
+
|
|
201
|
+
This provides a convenient way to loop through all sites, for example:
|
|
202
|
+
`for idx, ident, coords in my_lattice.sites(): ...`
|
|
203
|
+
|
|
204
|
+
:return: An iterator where each item is a tuple containing the site's
|
|
205
|
+
index, identifier, and coordinates.
|
|
206
|
+
:rtype: Iterator[Tuple[SiteIndex, SiteIdentifier, Coordinates]]
|
|
207
|
+
"""
|
|
208
|
+
for i in range(self.num_sites):
|
|
209
|
+
yield i, self._identifiers[i], self._coordinates[i]
|
|
210
|
+
|
|
211
|
+
def get_neighbors(self, index: SiteIndex, k: int = 1) -> List[SiteIndex]:
|
|
212
|
+
"""Gets the list of k-th nearest neighbor indices for a given site.
|
|
213
|
+
|
|
214
|
+
:param index: The integer index of the center site.
|
|
215
|
+
:type index: SiteIndex
|
|
216
|
+
:param k: The order of the neighbors, where k=1 corresponds
|
|
217
|
+
to nearest neighbors (NN), k=2 to next-nearest neighbors (NNN),
|
|
218
|
+
and so on. Defaults to 1.
|
|
219
|
+
:type k: int, optional
|
|
220
|
+
:return: A list of integer indices for the neighboring sites.
|
|
221
|
+
Returns an empty list if neighbors for the given `k` have not been
|
|
222
|
+
pre-calculated or if the site has no such neighbors.
|
|
223
|
+
:rtype: List[SiteIndex]
|
|
224
|
+
"""
|
|
225
|
+
if k not in self._neighbor_maps:
|
|
226
|
+
logger.info(
|
|
227
|
+
f"Neighbors for k={k} not pre-computed. Building now up to max_k={k}."
|
|
228
|
+
)
|
|
229
|
+
self._build_neighbors(max_k=k)
|
|
230
|
+
|
|
231
|
+
if k not in self._neighbor_maps:
|
|
232
|
+
return []
|
|
233
|
+
|
|
234
|
+
return self._neighbor_maps[k].get(index, [])
|
|
235
|
+
|
|
236
|
+
def get_neighbor_pairs(
|
|
237
|
+
self, k: int = 1, unique: bool = True
|
|
238
|
+
) -> List[Tuple[SiteIndex, SiteIndex]]:
|
|
239
|
+
"""Gets all pairs of k-th nearest neighbors, representing bonds.
|
|
240
|
+
|
|
241
|
+
:param k: The order of the neighbors to consider.
|
|
242
|
+
Defaults to 1.
|
|
243
|
+
:type k: int, optional
|
|
244
|
+
:param unique: If True, returns only one representation
|
|
245
|
+
for each pair (i, j) such that i < j, avoiding duplicates
|
|
246
|
+
like (j, i). If False, returns all directed pairs.
|
|
247
|
+
Defaults to True.
|
|
248
|
+
:type unique: bool, optional
|
|
249
|
+
:return: A list of tuples, where each
|
|
250
|
+
tuple is a pair of neighbor indices.
|
|
251
|
+
:rtype: List[Tuple[SiteIndex, SiteIndex]]
|
|
252
|
+
"""
|
|
253
|
+
|
|
254
|
+
if k not in self._neighbor_maps:
|
|
255
|
+
logger.info(
|
|
256
|
+
f"Neighbor pairs for k={k} not pre-computed. Building now up to max_k={k}."
|
|
257
|
+
)
|
|
258
|
+
self._build_neighbors(max_k=k)
|
|
259
|
+
|
|
260
|
+
if k not in self._neighbor_maps:
|
|
261
|
+
return []
|
|
262
|
+
|
|
263
|
+
pairs = []
|
|
264
|
+
for i, neighbors in self._neighbor_maps[k].items():
|
|
265
|
+
for j in neighbors:
|
|
266
|
+
if unique:
|
|
267
|
+
if i < j:
|
|
268
|
+
pairs.append((i, j))
|
|
269
|
+
else:
|
|
270
|
+
pairs.append((i, j))
|
|
271
|
+
return sorted(pairs)
|
|
272
|
+
|
|
273
|
+
def get_all_pairs(self) -> List[Tuple[SiteIndex, SiteIndex]]:
|
|
274
|
+
"""
|
|
275
|
+
Returns a list of all unique pairs of site indices (i, j) where i < j.
|
|
276
|
+
|
|
277
|
+
This method provides all-to-all connectivity, useful for Hamiltonians
|
|
278
|
+
where every site interacts with every other site.
|
|
279
|
+
|
|
280
|
+
Note on Differentiability:
|
|
281
|
+
This method provides a static list of index pairs and is not differentiable
|
|
282
|
+
itself. However, it is designed to be used in combination with the fully
|
|
283
|
+
differentiable ``distance_matrix`` property. By using the pairs from this
|
|
284
|
+
method to index into the ``distance_matrix``, one can construct differentiable
|
|
285
|
+
objective functions based on all-pair interactions, effectively bypassing the
|
|
286
|
+
non-differentiable ``get_neighbor_pairs`` method for geometry optimization tasks.
|
|
287
|
+
|
|
288
|
+
:return: A list of tuples, where each tuple is a unique pair of site indices.
|
|
289
|
+
:rtype: List[Tuple[SiteIndex, SiteIndex]]
|
|
290
|
+
"""
|
|
291
|
+
if self.num_sites < 2:
|
|
292
|
+
return []
|
|
293
|
+
# Use itertools.combinations to efficiently generate all unique pairs (i, j) with i < j.
|
|
294
|
+
return sorted(list(itertools.combinations(range(self.num_sites), 2)))
|
|
295
|
+
|
|
296
|
+
@abc.abstractmethod
|
|
297
|
+
def _build_lattice(self, *args: Any, **kwargs: Any) -> None:
|
|
298
|
+
"""
|
|
299
|
+
Abstract method for subclasses to generate the lattice data.
|
|
300
|
+
|
|
301
|
+
A concrete implementation of this method in a subclass is responsible
|
|
302
|
+
for populating the following internal attributes:
|
|
303
|
+
- self._indices
|
|
304
|
+
- self._identifiers
|
|
305
|
+
- self._coordinates
|
|
306
|
+
- self._ident_to_idx
|
|
307
|
+
"""
|
|
308
|
+
pass
|
|
309
|
+
|
|
310
|
+
@abc.abstractmethod
|
|
311
|
+
def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None:
|
|
312
|
+
"""
|
|
313
|
+
Abstract method for subclasses to calculate neighbor relationships.
|
|
314
|
+
|
|
315
|
+
A concrete implementation of this method should calculate the neighbor
|
|
316
|
+
relationships up to `max_k` and populate the `self._neighbor_maps`
|
|
317
|
+
dictionary. The keys of the dictionary should be the neighbor order (k),
|
|
318
|
+
and the values should be a dictionary mapping site indices to their
|
|
319
|
+
list of k-th neighbors.
|
|
320
|
+
"""
|
|
321
|
+
pass
|
|
322
|
+
|
|
323
|
+
def _compute_distance_matrix(self) -> Coordinates:
|
|
324
|
+
"""
|
|
325
|
+
Default generic distance matrix computation (no periodic images).
|
|
326
|
+
|
|
327
|
+
Subclasses can override this when a specialized rule is required
|
|
328
|
+
(e.g., applying Minimum Image Convention for PBC in TILattice).
|
|
329
|
+
"""
|
|
330
|
+
# Handle empty lattices and trivial 1-site lattices
|
|
331
|
+
if self.num_sites == 0:
|
|
332
|
+
return backend.zeros((0, 0))
|
|
333
|
+
|
|
334
|
+
# Vectorized pairwise Euclidean distances
|
|
335
|
+
all_coords = self._coordinates
|
|
336
|
+
displacements = backend.expand_dims(all_coords, 1) - backend.expand_dims(
|
|
337
|
+
all_coords, 0
|
|
338
|
+
)
|
|
339
|
+
dist_matrix_sq = backend.sum(displacements**2, axis=-1)
|
|
340
|
+
return backend.sqrt(dist_matrix_sq)
|
|
341
|
+
|
|
342
|
+
def show(
|
|
343
|
+
self,
|
|
344
|
+
show_indices: bool = False,
|
|
345
|
+
show_identifiers: bool = False,
|
|
346
|
+
show_bonds_k: Optional[int] = None,
|
|
347
|
+
ax: Optional["matplotlib.axes.Axes"] = None,
|
|
348
|
+
bond_kwargs: Optional[Dict[str, Any]] = None,
|
|
349
|
+
**kwargs: Any,
|
|
350
|
+
) -> None:
|
|
351
|
+
"""Visualizes the lattice structure using Matplotlib.
|
|
352
|
+
|
|
353
|
+
This method supports 1D, 2D, and 3D plotting. For 1D lattices, sites
|
|
354
|
+
are plotted along the x-axis.
|
|
355
|
+
|
|
356
|
+
:param show_indices: If True, displays the integer index
|
|
357
|
+
next to each site. Defaults to False.
|
|
358
|
+
:type show_indices: bool, optional
|
|
359
|
+
:param show_identifiers: If True, displays the unique
|
|
360
|
+
identifier next to each site. Defaults to False.
|
|
361
|
+
:type show_identifiers: bool, optional
|
|
362
|
+
:param show_bonds_k: Specifies which order of
|
|
363
|
+
neighbor bonds to draw (e.g., 1 for NN, 2 for NNN). If None,
|
|
364
|
+
no bonds are drawn. If the specified neighbors have not been
|
|
365
|
+
calculated, a warning is printed. Defaults to None.
|
|
366
|
+
:type show_bonds_k: Optional[int], optional
|
|
367
|
+
:param ax: An existing Matplotlib Axes object to plot on.
|
|
368
|
+
If None, a new Figure and Axes are created automatically. Defaults to None.
|
|
369
|
+
:type ax: Optional["matplotlib.axes.Axes"], optional
|
|
370
|
+
:param bond_kwargs: A dictionary of keyword arguments for customizing bond appearance,
|
|
371
|
+
passed directly to the Matplotlib plot function. Defaults to None.
|
|
372
|
+
:type bond_kwargs: Optional[Dict[str, Any]], optional
|
|
373
|
+
|
|
374
|
+
:param kwargs: Additional keyword arguments to be passed directly to the
|
|
375
|
+
`matplotlib.pyplot.scatter` function for customizing site appearance.
|
|
376
|
+
"""
|
|
377
|
+
try:
|
|
378
|
+
import matplotlib.pyplot as plt
|
|
379
|
+
except ImportError:
|
|
380
|
+
logger.warning(
|
|
381
|
+
"Matplotlib is required for visualization. "
|
|
382
|
+
"Please install it using 'pip install matplotlib'."
|
|
383
|
+
)
|
|
384
|
+
return
|
|
385
|
+
|
|
386
|
+
# Flag to track if the Matplotlib figure was created by this method.
|
|
387
|
+
# This prevents calling plt.show() if the user provided their own Axes.
|
|
388
|
+
fig_created_internally = False
|
|
389
|
+
|
|
390
|
+
if self.num_sites == 0:
|
|
391
|
+
logger.info("Lattice is empty, nothing to show.")
|
|
392
|
+
return
|
|
393
|
+
if self.dimensionality not in [1, 2, 3]:
|
|
394
|
+
logger.warning(
|
|
395
|
+
f"show() is not implemented for {self.dimensionality}D lattices."
|
|
396
|
+
)
|
|
397
|
+
return
|
|
398
|
+
|
|
399
|
+
if ax is None:
|
|
400
|
+
# If no Axes object is provided, create a new figure and axes.
|
|
401
|
+
fig_created_internally = True
|
|
402
|
+
if self.dimensionality == 3:
|
|
403
|
+
fig = plt.figure(figsize=(8, 8))
|
|
404
|
+
ax = fig.add_subplot(111, projection="3d")
|
|
405
|
+
else:
|
|
406
|
+
fig, ax = plt.subplots(figsize=(8, 8))
|
|
407
|
+
else:
|
|
408
|
+
fig = ax.figure # type: ignore
|
|
409
|
+
|
|
410
|
+
coords = np.array(self._coordinates)
|
|
411
|
+
# Prepare arguments for the scatter plot, allowing user overrides.
|
|
412
|
+
scatter_args = {"s": 100, "zorder": 2}
|
|
413
|
+
scatter_args.update(kwargs)
|
|
414
|
+
if self.dimensionality == 1:
|
|
415
|
+
ax.scatter(coords[:, 0], np.zeros_like(coords[:, 0]), **scatter_args) # type: ignore
|
|
416
|
+
elif self.dimensionality == 2:
|
|
417
|
+
ax.scatter(coords[:, 0], coords[:, 1], **scatter_args) # type: ignore
|
|
418
|
+
elif self.dimensionality > 2: # Safely handle 3D and future higher dimensions
|
|
419
|
+
scatter_args["s"] = scatter_args.get("s", 100) // 2
|
|
420
|
+
ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2], **scatter_args) # type: ignore
|
|
421
|
+
|
|
422
|
+
if show_indices or show_identifiers:
|
|
423
|
+
for i in range(self.num_sites):
|
|
424
|
+
label = str(self._identifiers[i]) if show_identifiers else str(i)
|
|
425
|
+
# Calculate a small offset for placing text labels to avoid overlap with sites.
|
|
426
|
+
offset = (
|
|
427
|
+
0.02 * np.max(np.ptp(coords, axis=0)) if coords.size > 0 else 0.1
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
if self.dimensionality == 1:
|
|
431
|
+
ax.text(coords[i, 0], offset, label, fontsize=9, ha="center")
|
|
432
|
+
elif self.dimensionality == 2:
|
|
433
|
+
ax.text(
|
|
434
|
+
coords[i, 0] + offset,
|
|
435
|
+
coords[i, 1] + offset,
|
|
436
|
+
label,
|
|
437
|
+
fontsize=9,
|
|
438
|
+
zorder=3,
|
|
439
|
+
)
|
|
440
|
+
elif self.dimensionality == 3:
|
|
441
|
+
ax_3d = cast("Axes3D", ax)
|
|
442
|
+
ax_3d.text(
|
|
443
|
+
coords[i, 0],
|
|
444
|
+
coords[i, 1],
|
|
445
|
+
coords[i, 2] + offset,
|
|
446
|
+
label,
|
|
447
|
+
fontsize=9,
|
|
448
|
+
zorder=3,
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
if show_bonds_k is not None:
|
|
452
|
+
if show_bonds_k not in self._neighbor_maps:
|
|
453
|
+
logger.warning(
|
|
454
|
+
f"Cannot draw bonds. k={show_bonds_k} neighbors have not been calculated."
|
|
455
|
+
)
|
|
456
|
+
else:
|
|
457
|
+
try:
|
|
458
|
+
bonds = self.get_neighbor_pairs(k=show_bonds_k, unique=True)
|
|
459
|
+
plot_bond_kwargs = {
|
|
460
|
+
"color": "k",
|
|
461
|
+
"linestyle": "-",
|
|
462
|
+
"alpha": 0.6,
|
|
463
|
+
"zorder": 1,
|
|
464
|
+
}
|
|
465
|
+
if bond_kwargs:
|
|
466
|
+
plot_bond_kwargs.update(bond_kwargs)
|
|
467
|
+
|
|
468
|
+
if self.dimensionality > 2:
|
|
469
|
+
ax_3d = cast("Axes3D", ax)
|
|
470
|
+
for i, j in bonds:
|
|
471
|
+
p1, p2 = self._coordinates[i], self._coordinates[j]
|
|
472
|
+
ax_3d.plot(
|
|
473
|
+
[p1[0], p2[0]],
|
|
474
|
+
[p1[1], p2[1]],
|
|
475
|
+
[p1[2], p2[2]],
|
|
476
|
+
**plot_bond_kwargs,
|
|
477
|
+
)
|
|
478
|
+
else:
|
|
479
|
+
for i, j in bonds:
|
|
480
|
+
p1, p2 = self._coordinates[i], self._coordinates[j]
|
|
481
|
+
if self.dimensionality == 1: # type: ignore
|
|
482
|
+
|
|
483
|
+
ax.plot([p1[0], p2[0]], [0, 0], **plot_bond_kwargs) # type: ignore
|
|
484
|
+
else:
|
|
485
|
+
ax.plot([p1[0], p2[0]], [p1[1], p2[1]], **plot_bond_kwargs) # type: ignore
|
|
486
|
+
|
|
487
|
+
except ValueError as e:
|
|
488
|
+
logger.info(f"Could not draw bonds: {e}")
|
|
489
|
+
|
|
490
|
+
ax.set_title(f"{self.__class__.__name__} ({self.num_sites} sites)")
|
|
491
|
+
if self.dimensionality == 2:
|
|
492
|
+
ax.set_aspect("equal", adjustable="box")
|
|
493
|
+
ax.set_xlabel("x")
|
|
494
|
+
if self.dimensionality > 1:
|
|
495
|
+
ax.set_ylabel("y")
|
|
496
|
+
if self.dimensionality > 2 and hasattr(ax, "set_zlabel"):
|
|
497
|
+
ax.set_zlabel("z")
|
|
498
|
+
ax.grid(True)
|
|
499
|
+
|
|
500
|
+
# Display the plot only if the figure was created within this function.
|
|
501
|
+
if fig_created_internally:
|
|
502
|
+
plt.show()
|
|
503
|
+
|
|
504
|
+
def _identify_distance_shells(
|
|
505
|
+
self,
|
|
506
|
+
all_distances_sq: Union[Coordinates, List[float]],
|
|
507
|
+
max_k: int,
|
|
508
|
+
tol: float = 1e-6,
|
|
509
|
+
) -> List[float]:
|
|
510
|
+
"""Identifies unique distance shells from a list of squared distances.
|
|
511
|
+
|
|
512
|
+
This helper function takes a flat list of squared distances, sorts them,
|
|
513
|
+
and identifies the first `max_k` unique distance shells based on a
|
|
514
|
+
numerical tolerance.
|
|
515
|
+
|
|
516
|
+
:param all_distances_sq: A list or array
|
|
517
|
+
of all squared distances between pairs of sites.
|
|
518
|
+
:type all_distances_sq: Union[np.ndarray, List[float]]
|
|
519
|
+
:param max_k: The maximum number of neighbor shells to identify.
|
|
520
|
+
:type max_k: int
|
|
521
|
+
:param tol: The numerical tolerance to consider two distances equal.
|
|
522
|
+
:type tol: float
|
|
523
|
+
:return: A sorted list of squared distances representing the shells.
|
|
524
|
+
:rtype: List[float]
|
|
525
|
+
"""
|
|
526
|
+
# A small threshold to filter out zero distances (site to itself).
|
|
527
|
+
ZERO_THRESHOLD_SQ = 1e-12
|
|
528
|
+
|
|
529
|
+
all_distances_sq = backend.convert_to_tensor(all_distances_sq)
|
|
530
|
+
# Now, the .size call below is guaranteed to be safe.
|
|
531
|
+
if backend.sizen(all_distances_sq) == 0:
|
|
532
|
+
return []
|
|
533
|
+
|
|
534
|
+
# Filter out self-distances and sort the remaining squared distances.
|
|
535
|
+
sorted_dist = backend.sort(
|
|
536
|
+
all_distances_sq[all_distances_sq > ZERO_THRESHOLD_SQ]
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
if backend.sizen(sorted_dist) == 0:
|
|
540
|
+
return []
|
|
541
|
+
|
|
542
|
+
dist_shells = [sorted_dist[0]]
|
|
543
|
+
|
|
544
|
+
for d_sq in sorted_dist[1:]:
|
|
545
|
+
if len(dist_shells) >= max_k:
|
|
546
|
+
break
|
|
547
|
+
if backend.sqrt(d_sq) - backend.sqrt(dist_shells[-1]) > tol:
|
|
548
|
+
dist_shells.append(d_sq)
|
|
549
|
+
|
|
550
|
+
return dist_shells
|
|
551
|
+
|
|
552
|
+
def _build_neighbors_by_distance_matrix(
|
|
553
|
+
self, max_k: int = 2, tol: float = 1e-6
|
|
554
|
+
) -> None:
|
|
555
|
+
"""A generic, distance-based neighbor finding method.
|
|
556
|
+
This method calculates the full N x N distance matrix to find neighbor
|
|
557
|
+
shells. It is computationally expensive for large N (O(N^2)) and is
|
|
558
|
+
best suited for non-periodic or custom-defined lattices.
|
|
559
|
+
:param max_k: The maximum number of neighbor shells to
|
|
560
|
+
calculate. Defaults to 2.
|
|
561
|
+
:type max_k: int, optional
|
|
562
|
+
:param tol: The numerical tolerance for distance
|
|
563
|
+
comparisons. Defaults to 1e-6.
|
|
564
|
+
:type tol: float, optional
|
|
565
|
+
"""
|
|
566
|
+
if self.num_sites < 2:
|
|
567
|
+
return
|
|
568
|
+
|
|
569
|
+
all_coords = self._coordinates
|
|
570
|
+
# Vectorized computation of the squared distance matrix:
|
|
571
|
+
# (N, 1, D) - (1, N, D) -> (N, N, D) -> (N, N)
|
|
572
|
+
displacements = backend.expand_dims(all_coords, 1) - backend.expand_dims(
|
|
573
|
+
all_coords, 0
|
|
574
|
+
)
|
|
575
|
+
dist_matrix_sq = backend.sum(displacements**2, axis=-1)
|
|
576
|
+
|
|
577
|
+
# Flatten the matrix to a list of all squared distances to identify shells.
|
|
578
|
+
all_distances_sq = backend.reshape(dist_matrix_sq, [-1])
|
|
579
|
+
dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
|
|
580
|
+
|
|
581
|
+
self._neighbor_maps = self._build_neighbor_map_from_distances(
|
|
582
|
+
dist_matrix_sq, dist_shells_sq, tol
|
|
583
|
+
)
|
|
584
|
+
self._distance_matrix = backend.sqrt(dist_matrix_sq)
|
|
585
|
+
|
|
586
|
+
def _build_neighbor_map_from_distances(
|
|
587
|
+
self,
|
|
588
|
+
dist_matrix_sq: Coordinates,
|
|
589
|
+
dist_shells_sq: List[float],
|
|
590
|
+
tol: float = 1e-6,
|
|
591
|
+
) -> Dict[int, NeighborMap]:
|
|
592
|
+
"""
|
|
593
|
+
Builds a neighbor map from a squared distance matrix and identified shells.
|
|
594
|
+
This is a generic helper function to reduce code duplication.
|
|
595
|
+
"""
|
|
596
|
+
neighbor_maps: Dict[int, NeighborMap] = {
|
|
597
|
+
k: {} for k in range(1, len(dist_shells_sq) + 1)
|
|
598
|
+
}
|
|
599
|
+
for k_idx, target_d_sq in enumerate(dist_shells_sq):
|
|
600
|
+
k = k_idx + 1
|
|
601
|
+
current_k_map: Dict[int, List[int]] = {}
|
|
602
|
+
# For each shell, find all pairs of sites (i, j) with that distance.
|
|
603
|
+
is_close_matrix = backend.abs(dist_matrix_sq - target_d_sq) < tol
|
|
604
|
+
rows, cols = backend.where(is_close_matrix)
|
|
605
|
+
|
|
606
|
+
for i, j in zip(backend.numpy(rows), backend.numpy(cols)):
|
|
607
|
+
if i == j:
|
|
608
|
+
continue
|
|
609
|
+
if i not in current_k_map:
|
|
610
|
+
current_k_map[i] = []
|
|
611
|
+
current_k_map[i].append(j)
|
|
612
|
+
|
|
613
|
+
for i in current_k_map:
|
|
614
|
+
current_k_map[i].sort()
|
|
615
|
+
|
|
616
|
+
neighbor_maps[k] = current_k_map
|
|
617
|
+
return neighbor_maps
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
class TILattice(AbstractLattice):
|
|
621
|
+
"""Describes a periodic lattice with translational invariance.
|
|
622
|
+
|
|
623
|
+
This class serves as a base for any lattice defined by a repeating unit
|
|
624
|
+
cell. The geometry is specified by lattice vectors, the coordinates of
|
|
625
|
+
basis sites within a unit cell, and the total size of the lattice in
|
|
626
|
+
terms of unit cells.
|
|
627
|
+
|
|
628
|
+
The site identifier for this class is a tuple in the format of
|
|
629
|
+
`(uc_coord_1, ..., uc_coord_d, basis_index)`, where `uc_coord` represents
|
|
630
|
+
the integer coordinate of the unit cell and `basis_index` is the index
|
|
631
|
+
of the site within that unit cell's basis.
|
|
632
|
+
|
|
633
|
+
:param dimensionality: The spatial dimension of the lattice.
|
|
634
|
+
:type dimensionality: int
|
|
635
|
+
:param lattice_vectors: The lattice vectors defining the unit
|
|
636
|
+
cell, given as row vectors. Shape: (dimensionality, dimensionality).
|
|
637
|
+
For example, in 2D: `np.array([[ax, ay], [bx, by]])`.
|
|
638
|
+
:type lattice_vectors: np.ndarray
|
|
639
|
+
:param basis_coords: The Cartesian coordinates of the basis sites
|
|
640
|
+
within the unit cell. Shape: (num_basis_sites, dimensionality).
|
|
641
|
+
For a simple Bravais lattice, this would be `np.array([[0, 0]])`.
|
|
642
|
+
:type basis_coords: np.ndarray
|
|
643
|
+
:param size: A tuple specifying the number of unit cells
|
|
644
|
+
to generate in each lattice vector direction (e.g., (Nx, Ny)).
|
|
645
|
+
:type size: Tuple[int, ...]
|
|
646
|
+
:param pbc: Specifies whether
|
|
647
|
+
periodic boundary conditions are applied. Can be a single boolean
|
|
648
|
+
for all dimensions or a tuple of booleans for each dimension
|
|
649
|
+
individually. Defaults to True.
|
|
650
|
+
:type pbc: Union[bool, Tuple[bool, ...]], optional
|
|
651
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
652
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
653
|
+
:type precompute_neighbors: Optional[int], optional
|
|
654
|
+
|
|
655
|
+
"""
|
|
656
|
+
|
|
657
|
+
def __init__(
|
|
658
|
+
self,
|
|
659
|
+
dimensionality: int,
|
|
660
|
+
lattice_vectors: Coordinates,
|
|
661
|
+
basis_coords: Coordinates,
|
|
662
|
+
size: Tuple[int, ...],
|
|
663
|
+
pbc: Union[bool, Tuple[bool, ...]] = True,
|
|
664
|
+
precompute_neighbors: Optional[int] = None,
|
|
665
|
+
):
|
|
666
|
+
"""Initializes the Translationally Invariant Lattice."""
|
|
667
|
+
super().__init__(dimensionality)
|
|
668
|
+
|
|
669
|
+
self.lattice_vectors = backend.convert_to_tensor(lattice_vectors)
|
|
670
|
+
self.basis_coords = backend.convert_to_tensor(basis_coords)
|
|
671
|
+
|
|
672
|
+
if self.lattice_vectors.shape != (dimensionality, dimensionality):
|
|
673
|
+
raise ValueError(
|
|
674
|
+
f"Lattice vectors shape {self.lattice_vectors.shape} does not match "
|
|
675
|
+
f"expected ({dimensionality}, {dimensionality})"
|
|
676
|
+
)
|
|
677
|
+
if self.basis_coords.shape[1] != dimensionality:
|
|
678
|
+
raise ValueError(
|
|
679
|
+
f"Basis coordinates dimension {self.basis_coords.shape[1]} does not "
|
|
680
|
+
f"match lattice dimensionality {dimensionality}"
|
|
681
|
+
)
|
|
682
|
+
if len(size) != dimensionality:
|
|
683
|
+
raise ValueError(
|
|
684
|
+
f"Size tuple length {len(size)} does not match dimensionality {dimensionality}"
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
self.num_basis = self.basis_coords.shape[0]
|
|
688
|
+
self.size = size
|
|
689
|
+
if isinstance(pbc, bool):
|
|
690
|
+
self.pbc = tuple([pbc] * dimensionality)
|
|
691
|
+
else:
|
|
692
|
+
if len(pbc) != dimensionality:
|
|
693
|
+
raise ValueError(
|
|
694
|
+
f"PBC tuple length {len(pbc)} does not match dimensionality {dimensionality}"
|
|
695
|
+
)
|
|
696
|
+
self.pbc = tuple(pbc)
|
|
697
|
+
|
|
698
|
+
self._build_lattice()
|
|
699
|
+
if precompute_neighbors is not None and precompute_neighbors > 0:
|
|
700
|
+
logger.info(f"Pre-computing neighbors up to k={precompute_neighbors}...")
|
|
701
|
+
self._build_neighbors(max_k=precompute_neighbors)
|
|
702
|
+
|
|
703
|
+
def _build_lattice(self) -> None:
|
|
704
|
+
"""
|
|
705
|
+
Generates all site information for the periodic lattice in a vectorized manner.
|
|
706
|
+
"""
|
|
707
|
+
ranges = [backend.arange(s) for s in self.size]
|
|
708
|
+
|
|
709
|
+
# Generate a grid of all integer unit cell coordinates.
|
|
710
|
+
grid = backend.meshgrid(*ranges, indexing="ij")
|
|
711
|
+
all_cell_coords = backend.reshape(
|
|
712
|
+
backend.stack(grid, axis=-1), (-1, self.dimensionality)
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
all_cell_coords = backend.cast(all_cell_coords, self.lattice_vectors.dtype)
|
|
716
|
+
|
|
717
|
+
cell_vectors = backend.tensordot(
|
|
718
|
+
all_cell_coords, self.lattice_vectors, axes=[[1], [0]]
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
cell_vectors = backend.cast(cell_vectors, self.basis_coords.dtype)
|
|
722
|
+
|
|
723
|
+
# Combine cell vectors with basis coordinates to get all site positions
|
|
724
|
+
# via broadcasting: (num_cells, 1, D) + (1, num_basis, D) -> (num_cells, num_basis, D)
|
|
725
|
+
all_coords = backend.expand_dims(cell_vectors, 1) + backend.expand_dims(
|
|
726
|
+
self.basis_coords, 0
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
self._coordinates = backend.reshape(all_coords, (-1, self.dimensionality))
|
|
730
|
+
|
|
731
|
+
self._indices = []
|
|
732
|
+
self._identifiers = []
|
|
733
|
+
self._ident_to_idx = {}
|
|
734
|
+
current_index = 0
|
|
735
|
+
|
|
736
|
+
# Generate integer indices and tuple-based identifiers for all sites.
|
|
737
|
+
# e.g., identifier = (uc_x, uc_y, basis_idx)
|
|
738
|
+
size_ranges = [range(s) for s in self.size]
|
|
739
|
+
for cell_coord_tuple in itertools.product(*size_ranges):
|
|
740
|
+
for basis_index in range(self.num_basis):
|
|
741
|
+
identifier = cell_coord_tuple + (basis_index,)
|
|
742
|
+
self._indices.append(current_index)
|
|
743
|
+
self._identifiers.append(identifier)
|
|
744
|
+
self._ident_to_idx[identifier] = current_index
|
|
745
|
+
current_index += 1
|
|
746
|
+
|
|
747
|
+
def _get_distance_matrix_with_mic_vectorized(self) -> Coordinates:
|
|
748
|
+
"""
|
|
749
|
+
Computes the full N x N distance matrix using a fully vectorized approach
|
|
750
|
+
that correctly applies the Minimum Image Convention (MIC) for periodic
|
|
751
|
+
boundary conditions.
|
|
752
|
+
|
|
753
|
+
This method uses full vectorization for optimal performance and compatibility
|
|
754
|
+
with JIT compilation frameworks like JAX. The implementation processes all
|
|
755
|
+
site pairs simultaneously rather than iterating row-by-row, which provides:
|
|
756
|
+
|
|
757
|
+
- Better performance through vectorized operations
|
|
758
|
+
- Full compatibility with automatic differentiation
|
|
759
|
+
- JIT compilation support (e.g., JAX, TensorFlow)
|
|
760
|
+
- Consistent tensor operations throughout
|
|
761
|
+
|
|
762
|
+
The trade-off is higher memory usage compared to iterative approaches,
|
|
763
|
+
as it computes all pairwise distances simultaneously. For very large
|
|
764
|
+
lattices (N > 10^4 sites), memory usage scales as O(N^2).
|
|
765
|
+
|
|
766
|
+
:return: Distance matrix with shape (N, N) where entry (i,j) is the
|
|
767
|
+
minimum distance between sites i and j under periodic boundary conditions.
|
|
768
|
+
:rtype: Coordinates
|
|
769
|
+
"""
|
|
770
|
+
# Ensure dtype consistency across backends (especially torch) by explicitly
|
|
771
|
+
# casting size and lattice_vectors to the same floating dtype used internally.
|
|
772
|
+
# Strategy: prefer existing lattice_vectors dtype; if it's an unusual dtype,
|
|
773
|
+
# fall back to float32 to avoid mixed-precision issues in vectorized ops.
|
|
774
|
+
# Note: `self.lattice_vectors` is always created via `backend.convert_to_tensor`
|
|
775
|
+
# in __init__, so `backend.dtype(...)` is reliable here and doesn't need try/except.
|
|
776
|
+
target_dt = str(backend.dtype(self.lattice_vectors)) # type: ignore
|
|
777
|
+
if target_dt not in ("float32", "float64"):
|
|
778
|
+
# fallback for unusual dtypes
|
|
779
|
+
target_dt = "float32"
|
|
780
|
+
|
|
781
|
+
size_arr = backend.cast(backend.convert_to_tensor(self.size), target_dt)
|
|
782
|
+
lattice_vecs = backend.cast(
|
|
783
|
+
backend.convert_to_tensor(self.lattice_vectors), target_dt
|
|
784
|
+
)
|
|
785
|
+
system_vectors = lattice_vecs * backend.expand_dims(size_arr, axis=1)
|
|
786
|
+
|
|
787
|
+
pbc_mask = backend.convert_to_tensor(self.pbc)
|
|
788
|
+
|
|
789
|
+
# Generate all 3^d possible image shifts (-1, 0, 1) for all dimensions
|
|
790
|
+
shift_options = [
|
|
791
|
+
backend.convert_to_tensor([-1.0, 0.0, 1.0])
|
|
792
|
+
] * self.dimensionality
|
|
793
|
+
shifts_grid = backend.meshgrid(*shift_options, indexing="ij")
|
|
794
|
+
all_shifts = backend.reshape(
|
|
795
|
+
backend.stack(shifts_grid, axis=-1), (-1, self.dimensionality)
|
|
796
|
+
)
|
|
797
|
+
|
|
798
|
+
# Only apply shifts to periodic dimensions
|
|
799
|
+
masked_shifts = all_shifts * backend.cast(pbc_mask, all_shifts.dtype)
|
|
800
|
+
|
|
801
|
+
# Calculate all translation vectors due to PBC
|
|
802
|
+
translations_arr = backend.tensordot(
|
|
803
|
+
masked_shifts, system_vectors, axes=[[1], [0]]
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
# Vectorized computation of all displacements between any two sites
|
|
807
|
+
# Shape: (N, 1, D) - (1, N, D) -> (N, N, D)
|
|
808
|
+
displacements = backend.expand_dims(self._coordinates, 1) - backend.expand_dims(
|
|
809
|
+
self._coordinates, 0
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
# Consider all periodic images for each displacement
|
|
813
|
+
# Shape: (N, N, 1, D) - (1, 1, num_translations, D) -> (N, N, num_translations, D)
|
|
814
|
+
image_displacements = backend.expand_dims(
|
|
815
|
+
displacements, 2
|
|
816
|
+
) - backend.expand_dims(backend.expand_dims(translations_arr, 0), 0)
|
|
817
|
+
|
|
818
|
+
# Sum of squares for distances
|
|
819
|
+
image_d_sq = backend.sum(image_displacements**2, axis=3)
|
|
820
|
+
|
|
821
|
+
# Find the minimum distance among all images (Minimum Image Convention)
|
|
822
|
+
min_dist_sq = backend.min(image_d_sq, axis=2)
|
|
823
|
+
|
|
824
|
+
safe_dist_matrix_sq = backend.where(min_dist_sq > 0, min_dist_sq, 0.0)
|
|
825
|
+
return backend.sqrt(safe_dist_matrix_sq)
|
|
826
|
+
|
|
827
|
+
def _build_neighbors(self, max_k: int = 2, **kwargs: Any) -> None:
|
|
828
|
+
"""Calculates neighbor relationships for the periodic lattice.
|
|
829
|
+
|
|
830
|
+
This method computes neighbor information by first calculating the full
|
|
831
|
+
distance matrix using the Minimum Image Convention (MIC) to correctly
|
|
832
|
+
handle periodic boundary conditions. It then identifies unique distance
|
|
833
|
+
shells (e.g., nearest, next-nearest) and populates the neighbor maps
|
|
834
|
+
accordingly. This approach is general and works for any periodic lattice
|
|
835
|
+
geometry defined by the TILattice class.
|
|
836
|
+
|
|
837
|
+
:param max_k: The maximum order of neighbors to compute (e.g., k=1 for
|
|
838
|
+
nearest neighbors, k=2 for next-nearest, etc.). Defaults to 2.
|
|
839
|
+
:type max_k: int, optional
|
|
840
|
+
:param kwargs: Additional keyword arguments. May include:
|
|
841
|
+
- ``tol`` (float): The numerical tolerance used to determine if two
|
|
842
|
+
distances are equal when identifying shells. Defaults to 1e-6.
|
|
843
|
+
"""
|
|
844
|
+
tol = kwargs.get("tol", 1e-6)
|
|
845
|
+
dist_matrix = self._get_distance_matrix_with_mic_vectorized()
|
|
846
|
+
dist_matrix_sq = dist_matrix**2
|
|
847
|
+
self._distance_matrix = dist_matrix
|
|
848
|
+
all_distances_sq = backend.reshape(dist_matrix_sq, [-1])
|
|
849
|
+
dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
|
|
850
|
+
self._neighbor_maps = self._build_neighbor_map_from_distances(
|
|
851
|
+
dist_matrix_sq, dist_shells_sq, tol
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
def _compute_distance_matrix(self) -> Coordinates:
|
|
855
|
+
"""Computes the distance matrix using the Minimum Image Convention."""
|
|
856
|
+
if self.num_sites == 0:
|
|
857
|
+
return backend.zeros((0, 0))
|
|
858
|
+
return self._get_distance_matrix_with_mic_vectorized()
|
|
859
|
+
|
|
860
|
+
|
|
861
|
+
class SquareLattice(TILattice):
|
|
862
|
+
"""A 2D square lattice.
|
|
863
|
+
|
|
864
|
+
This is a concrete implementation of a translationally invariant lattice
|
|
865
|
+
representing a simple square grid. It is a Bravais lattice with a
|
|
866
|
+
single-site basis.
|
|
867
|
+
|
|
868
|
+
:param size: A tuple (Nx, Ny) specifying the number of
|
|
869
|
+
unit cells (sites) in the x and y directions.
|
|
870
|
+
:type size: Tuple[int, int]
|
|
871
|
+
:param lattice_constant: The distance between two adjacent
|
|
872
|
+
sites. Defaults to 1.0.
|
|
873
|
+
:type lattice_constant: float, optional
|
|
874
|
+
:param pbc: Specifies periodic boundary conditions. Can be a single boolean
|
|
875
|
+
for all dimensions or a tuple of booleans for each dimension
|
|
876
|
+
individually. Defaults to True.
|
|
877
|
+
:type pbc: Union[bool, Tuple[bool, bool]], optional
|
|
878
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
879
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
880
|
+
:type precompute_neighbors: Optional[int], optional
|
|
881
|
+
"""
|
|
882
|
+
|
|
883
|
+
def __init__(
|
|
884
|
+
self,
|
|
885
|
+
size: Tuple[int, int],
|
|
886
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
887
|
+
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
888
|
+
precompute_neighbors: Optional[int] = None,
|
|
889
|
+
):
|
|
890
|
+
"""Initializes the SquareLattice."""
|
|
891
|
+
dimensionality = 2
|
|
892
|
+
# Define orthogonal lattice vectors for a square.
|
|
893
|
+
# Avoid mixing Python floats with backend Tensors (TF would error),
|
|
894
|
+
# so first convert inputs to tensors of a unified dtype, then stack.
|
|
895
|
+
lc = backend.convert_to_tensor(lattice_constant)
|
|
896
|
+
dt = backend.dtype(lc)
|
|
897
|
+
z = backend.cast(backend.convert_to_tensor(0.0), dt)
|
|
898
|
+
row1 = backend.stack([lc, z])
|
|
899
|
+
row2 = backend.stack([z, lc])
|
|
900
|
+
lattice_vectors = backend.stack([row1, row2])
|
|
901
|
+
# A square lattice is a Bravais lattice, so it has a single-site basis.
|
|
902
|
+
basis_coords = backend.stack([backend.stack([z, z])])
|
|
903
|
+
|
|
904
|
+
super().__init__(
|
|
905
|
+
dimensionality=dimensionality,
|
|
906
|
+
lattice_vectors=lattice_vectors,
|
|
907
|
+
basis_coords=basis_coords,
|
|
908
|
+
size=size,
|
|
909
|
+
pbc=pbc,
|
|
910
|
+
precompute_neighbors=precompute_neighbors,
|
|
911
|
+
)
|
|
912
|
+
|
|
913
|
+
|
|
914
|
+
class HoneycombLattice(TILattice):
|
|
915
|
+
"""A 2D honeycomb lattice.
|
|
916
|
+
|
|
917
|
+
This is a classic example of a composite lattice. It consists of a
|
|
918
|
+
two-site basis (sublattices A and B) on an underlying triangular
|
|
919
|
+
Bravais lattice.
|
|
920
|
+
|
|
921
|
+
:param size: A tuple (Nx, Ny) specifying the number of unit
|
|
922
|
+
cells along the two lattice vector directions.
|
|
923
|
+
:type size: Tuple[int, int]
|
|
924
|
+
:param lattice_constant: The bond length, i.e., the distance
|
|
925
|
+
between two nearest neighbor sites. Defaults to 1.0.
|
|
926
|
+
:type lattice_constant: float, optional
|
|
927
|
+
:param pbc: Specifies periodic
|
|
928
|
+
boundary conditions. Defaults to True.
|
|
929
|
+
:type pbc: Union[bool, Tuple[bool, bool]], optional
|
|
930
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
931
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
932
|
+
:type precompute_neighbors: Optional[int], optional
|
|
933
|
+
|
|
934
|
+
"""
|
|
935
|
+
|
|
936
|
+
def __init__(
|
|
937
|
+
self,
|
|
938
|
+
size: Tuple[int, int],
|
|
939
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
940
|
+
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
941
|
+
precompute_neighbors: Optional[int] = None,
|
|
942
|
+
):
|
|
943
|
+
"""Initializes the HoneycombLattice."""
|
|
944
|
+
dimensionality = 2
|
|
945
|
+
a = lattice_constant
|
|
946
|
+
a_t = backend.convert_to_tensor(a)
|
|
947
|
+
zero = a_t * 0.0
|
|
948
|
+
|
|
949
|
+
# Define the two primitive lattice vectors for the underlying triangular Bravais lattice.
|
|
950
|
+
rt3_over_2 = math.sqrt(3.0) / 2.0
|
|
951
|
+
lattice_vectors = backend.stack(
|
|
952
|
+
[
|
|
953
|
+
backend.stack([a_t * 1.5, a_t * rt3_over_2]),
|
|
954
|
+
backend.stack([a_t * 1.5, -a_t * rt3_over_2]),
|
|
955
|
+
]
|
|
956
|
+
)
|
|
957
|
+
# Define the two basis sites (A and B) within the unit cell.
|
|
958
|
+
basis_coords = backend.stack(
|
|
959
|
+
[backend.stack([zero, zero]), backend.stack([a_t * 1.0, zero])]
|
|
960
|
+
)
|
|
961
|
+
|
|
962
|
+
super().__init__(
|
|
963
|
+
dimensionality=dimensionality,
|
|
964
|
+
lattice_vectors=lattice_vectors,
|
|
965
|
+
basis_coords=basis_coords,
|
|
966
|
+
size=size,
|
|
967
|
+
pbc=pbc,
|
|
968
|
+
precompute_neighbors=precompute_neighbors,
|
|
969
|
+
)
|
|
970
|
+
|
|
971
|
+
|
|
972
|
+
class TriangularLattice(TILattice):
|
|
973
|
+
"""A 2D triangular lattice.
|
|
974
|
+
|
|
975
|
+
This is a Bravais lattice where each site has 6 nearest neighbors.
|
|
976
|
+
|
|
977
|
+
:param size: A tuple (Nx, Ny) specifying the number of
|
|
978
|
+
unit cells along the two lattice vector directions.
|
|
979
|
+
:type size: Tuple[int, int]
|
|
980
|
+
:param lattice_constant: The bond length, i.e., the
|
|
981
|
+
distance between two nearest neighbor sites. Defaults to 1.0.
|
|
982
|
+
:type lattice_constant: float, optional
|
|
983
|
+
:param pbc: Specifies periodic
|
|
984
|
+
boundary conditions. Defaults to True.
|
|
985
|
+
:type pbc: Union[bool, Tuple[bool, bool]], optional
|
|
986
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
987
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
988
|
+
:type precompute_neighbors: Optional[int], optional
|
|
989
|
+
|
|
990
|
+
"""
|
|
991
|
+
|
|
992
|
+
def __init__(
|
|
993
|
+
self,
|
|
994
|
+
size: Tuple[int, int],
|
|
995
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
996
|
+
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
997
|
+
precompute_neighbors: Optional[int] = None,
|
|
998
|
+
):
|
|
999
|
+
"""Initializes the TriangularLattice."""
|
|
1000
|
+
dimensionality = 2
|
|
1001
|
+
a = lattice_constant
|
|
1002
|
+
a_t = backend.convert_to_tensor(a)
|
|
1003
|
+
zero = a_t * 0.0
|
|
1004
|
+
|
|
1005
|
+
# Define the primitive lattice vectors for a triangular lattice.
|
|
1006
|
+
lattice_vectors = backend.stack(
|
|
1007
|
+
[
|
|
1008
|
+
backend.stack([a_t * 1.0, zero]),
|
|
1009
|
+
backend.stack(
|
|
1010
|
+
[
|
|
1011
|
+
a_t * 0.5,
|
|
1012
|
+
a_t * backend.sqrt(backend.convert_to_tensor(3.0)) / 2.0,
|
|
1013
|
+
]
|
|
1014
|
+
),
|
|
1015
|
+
]
|
|
1016
|
+
)
|
|
1017
|
+
# A triangular lattice is a Bravais lattice with a single-site basis.
|
|
1018
|
+
basis_coords = backend.stack([backend.stack([zero, zero])])
|
|
1019
|
+
|
|
1020
|
+
super().__init__(
|
|
1021
|
+
dimensionality=dimensionality,
|
|
1022
|
+
lattice_vectors=lattice_vectors,
|
|
1023
|
+
basis_coords=basis_coords,
|
|
1024
|
+
size=size,
|
|
1025
|
+
pbc=pbc,
|
|
1026
|
+
precompute_neighbors=precompute_neighbors,
|
|
1027
|
+
)
|
|
1028
|
+
|
|
1029
|
+
|
|
1030
|
+
class ChainLattice(TILattice):
|
|
1031
|
+
"""A 1D chain (simple Bravais lattice).
|
|
1032
|
+
|
|
1033
|
+
:param size: A tuple `(N,)` specifying the number of sites in the chain.
|
|
1034
|
+
:type size: Tuple[int]
|
|
1035
|
+
:param lattice_constant: The distance between two adjacent sites. Defaults to 1.0.
|
|
1036
|
+
:type lattice_constant: float, optional
|
|
1037
|
+
:param pbc: Specifies if periodic boundary conditions are applied. Defaults to True.
|
|
1038
|
+
:type pbc: bool, optional
|
|
1039
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
1040
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
1041
|
+
:type precompute_neighbors: Optional[int], optional
|
|
1042
|
+
"""
|
|
1043
|
+
|
|
1044
|
+
def __init__(
|
|
1045
|
+
self,
|
|
1046
|
+
size: Tuple[int],
|
|
1047
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
1048
|
+
pbc: bool = True,
|
|
1049
|
+
precompute_neighbors: Optional[int] = None,
|
|
1050
|
+
):
|
|
1051
|
+
dimensionality = 1
|
|
1052
|
+
# The lattice vector is just the lattice constant along one dimension.
|
|
1053
|
+
lc = backend.convert_to_tensor(lattice_constant)
|
|
1054
|
+
lattice_vectors = backend.stack([backend.stack([lc])])
|
|
1055
|
+
# A simple chain is a Bravais lattice with a single-site basis.
|
|
1056
|
+
zero = lc * 0.0
|
|
1057
|
+
basis_coords = backend.stack([backend.stack([zero])])
|
|
1058
|
+
|
|
1059
|
+
super().__init__(
|
|
1060
|
+
dimensionality=dimensionality,
|
|
1061
|
+
lattice_vectors=lattice_vectors,
|
|
1062
|
+
basis_coords=basis_coords,
|
|
1063
|
+
size=size,
|
|
1064
|
+
pbc=pbc,
|
|
1065
|
+
precompute_neighbors=precompute_neighbors,
|
|
1066
|
+
)
|
|
1067
|
+
|
|
1068
|
+
|
|
1069
|
+
class DimerizedChainLattice(TILattice):
|
|
1070
|
+
"""A 1D chain with an AB sublattice (dimerized chain).
|
|
1071
|
+
|
|
1072
|
+
The unit cell contains two sites, A and B. The bond length is uniform.
|
|
1073
|
+
|
|
1074
|
+
:param size: A tuple `(N,)` specifying the number of **unit cells**.
|
|
1075
|
+
The total number of sites in the chain will be `2 * N`, as each
|
|
1076
|
+
unit cell contains two sites.
|
|
1077
|
+
:type size: Tuple[int]
|
|
1078
|
+
:param lattice_constant: The distance between two adjacent sites (bond length). Defaults to 1.0.
|
|
1079
|
+
:type lattice_constant: float, optional
|
|
1080
|
+
:param pbc: Specifies if periodic boundary conditions are applied. Defaults to True.
|
|
1081
|
+
:type pbc: bool, optional
|
|
1082
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
1083
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
1084
|
+
:type precompute_neighbors: Optional[int], optional
|
|
1085
|
+
"""
|
|
1086
|
+
|
|
1087
|
+
def __init__(
|
|
1088
|
+
self,
|
|
1089
|
+
size: Tuple[int],
|
|
1090
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
1091
|
+
pbc: bool = True,
|
|
1092
|
+
precompute_neighbors: Optional[int] = None,
|
|
1093
|
+
):
|
|
1094
|
+
dimensionality = 1
|
|
1095
|
+
# The unit cell is twice the bond length, as it contains two sites.
|
|
1096
|
+
lc = backend.convert_to_tensor(lattice_constant)
|
|
1097
|
+
lattice_vectors = backend.stack([backend.stack([2 * lc])])
|
|
1098
|
+
# Two basis sites (A and B) separated by the bond length.
|
|
1099
|
+
zero = lc * 0.0
|
|
1100
|
+
basis_coords = backend.stack([backend.stack([zero]), backend.stack([lc])])
|
|
1101
|
+
|
|
1102
|
+
super().__init__(
|
|
1103
|
+
dimensionality=dimensionality,
|
|
1104
|
+
lattice_vectors=lattice_vectors,
|
|
1105
|
+
basis_coords=basis_coords,
|
|
1106
|
+
size=size,
|
|
1107
|
+
pbc=pbc,
|
|
1108
|
+
precompute_neighbors=precompute_neighbors,
|
|
1109
|
+
)
|
|
1110
|
+
|
|
1111
|
+
|
|
1112
|
+
class RectangularLattice(TILattice):
|
|
1113
|
+
"""A 2D rectangular lattice.
|
|
1114
|
+
|
|
1115
|
+
This is a generalization of the SquareLattice where the lattice constants
|
|
1116
|
+
in the x and y directions can be different.
|
|
1117
|
+
|
|
1118
|
+
:param size: A tuple (Nx, Ny) specifying the number of sites in x and y.
|
|
1119
|
+
:type size: Tuple[int, int]
|
|
1120
|
+
:param lattice_constants: The distance between adjacent sites
|
|
1121
|
+
in the x and y directions, e.g., (ax, ay). Defaults to (1.0, 1.0).
|
|
1122
|
+
:type lattice_constants: Tuple[float, float], optional
|
|
1123
|
+
:param pbc: Specifies periodic boundary conditions. Defaults to True.
|
|
1124
|
+
:type pbc: Union[bool, Tuple[bool, bool]], optional
|
|
1125
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
1126
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
1127
|
+
:type precompute_neighbors: Optional[int], optional
|
|
1128
|
+
"""
|
|
1129
|
+
|
|
1130
|
+
def __init__(
|
|
1131
|
+
self,
|
|
1132
|
+
size: Tuple[int, int],
|
|
1133
|
+
lattice_constants: Union[Tuple[float, float], Any] = (1.0, 1.0),
|
|
1134
|
+
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
1135
|
+
precompute_neighbors: Optional[int] = None,
|
|
1136
|
+
):
|
|
1137
|
+
dimensionality = 2
|
|
1138
|
+
ax, ay = lattice_constants
|
|
1139
|
+
ax_t = backend.convert_to_tensor(ax)
|
|
1140
|
+
dt = backend.dtype(ax_t)
|
|
1141
|
+
ay_t = backend.cast(backend.convert_to_tensor(ay), dt)
|
|
1142
|
+
z = backend.cast(backend.convert_to_tensor(0.0), dt)
|
|
1143
|
+
# Orthogonal lattice vectors with potentially different lengths.
|
|
1144
|
+
row1 = backend.stack([ax_t, z])
|
|
1145
|
+
row2 = backend.stack([z, ay_t])
|
|
1146
|
+
lattice_vectors = backend.stack([row1, row2])
|
|
1147
|
+
# A rectangular lattice is a Bravais lattice with a single-site basis.
|
|
1148
|
+
basis_coords = backend.stack([backend.stack([z, z])])
|
|
1149
|
+
|
|
1150
|
+
super().__init__(
|
|
1151
|
+
dimensionality=dimensionality,
|
|
1152
|
+
lattice_vectors=lattice_vectors,
|
|
1153
|
+
basis_coords=basis_coords,
|
|
1154
|
+
size=size,
|
|
1155
|
+
pbc=pbc,
|
|
1156
|
+
precompute_neighbors=precompute_neighbors,
|
|
1157
|
+
)
|
|
1158
|
+
|
|
1159
|
+
|
|
1160
|
+
class CheckerboardLattice(TILattice):
|
|
1161
|
+
"""A 2D checkerboard lattice (a square lattice with an AB sublattice).
|
|
1162
|
+
|
|
1163
|
+
The unit cell is a square rotated by 45 degrees, containing two sites.
|
|
1164
|
+
|
|
1165
|
+
:param size: A tuple (Nx, Ny) specifying the number of unit cells. Total sites will be 2*Nx*Ny.
|
|
1166
|
+
:type size: Tuple[int, int]
|
|
1167
|
+
:param lattice_constant: The bond length between nearest neighbors. Defaults to 1.0.
|
|
1168
|
+
:type lattice_constant: float, optional
|
|
1169
|
+
:param pbc: Specifies periodic boundary conditions. Defaults to True.
|
|
1170
|
+
:type pbc: Union[bool, Tuple[bool, bool]], optional
|
|
1171
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
1172
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
1173
|
+
:type precompute_neighbors: Optional[int], optional
|
|
1174
|
+
"""
|
|
1175
|
+
|
|
1176
|
+
def __init__(
|
|
1177
|
+
self,
|
|
1178
|
+
size: Tuple[int, int],
|
|
1179
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
1180
|
+
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
1181
|
+
precompute_neighbors: Optional[int] = None,
|
|
1182
|
+
):
|
|
1183
|
+
dimensionality = 2
|
|
1184
|
+
a = lattice_constant
|
|
1185
|
+
a_t = backend.convert_to_tensor(a)
|
|
1186
|
+
# The unit cell is a square rotated by 45 degrees.
|
|
1187
|
+
lattice_vectors = backend.stack(
|
|
1188
|
+
[
|
|
1189
|
+
backend.stack([a_t * 1.0, a_t * 1.0]),
|
|
1190
|
+
backend.stack([a_t * 1.0, a_t * -1.0]),
|
|
1191
|
+
]
|
|
1192
|
+
)
|
|
1193
|
+
# Two basis sites (A and B) within the unit cell.
|
|
1194
|
+
zero = a_t * 0.0
|
|
1195
|
+
basis_coords = backend.stack(
|
|
1196
|
+
[backend.stack([zero, zero]), backend.stack([a_t * 1.0, zero])]
|
|
1197
|
+
)
|
|
1198
|
+
|
|
1199
|
+
super().__init__(
|
|
1200
|
+
dimensionality=dimensionality,
|
|
1201
|
+
lattice_vectors=lattice_vectors,
|
|
1202
|
+
basis_coords=basis_coords,
|
|
1203
|
+
size=size,
|
|
1204
|
+
pbc=pbc,
|
|
1205
|
+
precompute_neighbors=precompute_neighbors,
|
|
1206
|
+
)
|
|
1207
|
+
|
|
1208
|
+
|
|
1209
|
+
class KagomeLattice(TILattice):
|
|
1210
|
+
"""A 2D Kagome lattice.
|
|
1211
|
+
|
|
1212
|
+
This is a lattice with a three-site basis on a triangular Bravais lattice.
|
|
1213
|
+
|
|
1214
|
+
:param size: A tuple (Nx, Ny) specifying the number of unit cells. Total sites will be 3*Nx*Ny.
|
|
1215
|
+
:type size: Tuple[int, int]
|
|
1216
|
+
:param lattice_constant: The bond length. Defaults to 1.0.
|
|
1217
|
+
:type lattice_constant: float, optional
|
|
1218
|
+
:param pbc: Specifies periodic boundary conditions. Defaults to True.
|
|
1219
|
+
:type pbc: Union[bool, Tuple[bool, bool]], optional
|
|
1220
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
1221
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
1222
|
+
:type precompute_neighbors: Optional[int], optional
|
|
1223
|
+
"""
|
|
1224
|
+
|
|
1225
|
+
def __init__(
|
|
1226
|
+
self,
|
|
1227
|
+
size: Tuple[int, int],
|
|
1228
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
1229
|
+
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
1230
|
+
precompute_neighbors: Optional[int] = None,
|
|
1231
|
+
):
|
|
1232
|
+
dimensionality = 2
|
|
1233
|
+
a = lattice_constant
|
|
1234
|
+
a_t = backend.convert_to_tensor(a)
|
|
1235
|
+
# The Kagome lattice is based on a triangular Bravais lattice.
|
|
1236
|
+
lattice_vectors = backend.stack(
|
|
1237
|
+
[
|
|
1238
|
+
backend.stack([a_t * 2.0, a_t * 0.0]),
|
|
1239
|
+
backend.stack([a_t * 1.0, a_t * backend.sqrt(3.0)]),
|
|
1240
|
+
]
|
|
1241
|
+
)
|
|
1242
|
+
# It has a three-site basis, forming the corners of the triangles.
|
|
1243
|
+
zero = a_t * 0.0
|
|
1244
|
+
basis_coords = backend.stack(
|
|
1245
|
+
[
|
|
1246
|
+
backend.stack([zero, zero]),
|
|
1247
|
+
backend.stack([a_t * 1.0, zero]),
|
|
1248
|
+
backend.stack([a_t * 0.5, a_t * backend.sqrt(3.0) / 2.0]),
|
|
1249
|
+
]
|
|
1250
|
+
)
|
|
1251
|
+
|
|
1252
|
+
super().__init__(
|
|
1253
|
+
dimensionality=dimensionality,
|
|
1254
|
+
lattice_vectors=lattice_vectors,
|
|
1255
|
+
basis_coords=basis_coords,
|
|
1256
|
+
size=size,
|
|
1257
|
+
pbc=pbc,
|
|
1258
|
+
precompute_neighbors=precompute_neighbors,
|
|
1259
|
+
)
|
|
1260
|
+
|
|
1261
|
+
|
|
1262
|
+
class LiebLattice(TILattice):
|
|
1263
|
+
"""A 2D Lieb lattice.
|
|
1264
|
+
|
|
1265
|
+
This is a lattice with a three-site basis on a square Bravais lattice.
|
|
1266
|
+
It has sites at the corners and centers of the edges of a square.
|
|
1267
|
+
|
|
1268
|
+
:param size: A tuple (Nx, Ny) specifying the number of unit cells. Total sites will be 3*Nx*Ny.
|
|
1269
|
+
:type size: Tuple[int, int]
|
|
1270
|
+
:param lattice_constant: The bond length. Defaults to 1.0.
|
|
1271
|
+
:type lattice_constant: float, optional
|
|
1272
|
+
:param pbc: Specifies periodic boundary conditions. Defaults to True.
|
|
1273
|
+
:type pbc: Union[bool, Tuple[bool, bool]], optional
|
|
1274
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
1275
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
1276
|
+
:type precompute_neighbors: Optional[int], optional
|
|
1277
|
+
"""
|
|
1278
|
+
|
|
1279
|
+
def __init__(
|
|
1280
|
+
self,
|
|
1281
|
+
size: Tuple[int, int],
|
|
1282
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
1283
|
+
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
1284
|
+
precompute_neighbors: Optional[int] = None,
|
|
1285
|
+
):
|
|
1286
|
+
"""Initializes the LiebLattice."""
|
|
1287
|
+
dimensionality = 2
|
|
1288
|
+
bond_length = lattice_constant
|
|
1289
|
+
bl_t = backend.convert_to_tensor(bond_length)
|
|
1290
|
+
unit_cell_side_t = 2 * bl_t
|
|
1291
|
+
# The Lieb lattice is based on a square Bravais lattice.
|
|
1292
|
+
z = bl_t * 0.0
|
|
1293
|
+
lattice_vectors = backend.stack(
|
|
1294
|
+
[backend.stack([unit_cell_side_t, z]), backend.stack([z, unit_cell_side_t])]
|
|
1295
|
+
)
|
|
1296
|
+
# It has a three-site basis: one corner and two edge-centers.
|
|
1297
|
+
basis_coords = backend.stack(
|
|
1298
|
+
[
|
|
1299
|
+
backend.stack([z, z]), # Corner site
|
|
1300
|
+
backend.stack([bl_t, z]), # x-edge center
|
|
1301
|
+
backend.stack([z, bl_t]), # y-edge center
|
|
1302
|
+
]
|
|
1303
|
+
)
|
|
1304
|
+
|
|
1305
|
+
super().__init__(
|
|
1306
|
+
dimensionality=dimensionality,
|
|
1307
|
+
lattice_vectors=lattice_vectors,
|
|
1308
|
+
basis_coords=basis_coords,
|
|
1309
|
+
size=size,
|
|
1310
|
+
pbc=pbc,
|
|
1311
|
+
precompute_neighbors=precompute_neighbors,
|
|
1312
|
+
)
|
|
1313
|
+
|
|
1314
|
+
|
|
1315
|
+
class CubicLattice(TILattice):
|
|
1316
|
+
"""A 3D cubic lattice.
|
|
1317
|
+
|
|
1318
|
+
This is a simple Bravais lattice, the 3D generalization of SquareLattice.
|
|
1319
|
+
|
|
1320
|
+
:param size: A tuple (Nx, Ny, Nz) specifying the number of sites.
|
|
1321
|
+
:type size: Tuple[int, int, int]
|
|
1322
|
+
:param lattice_constant: The distance between adjacent sites. Defaults to 1.0.
|
|
1323
|
+
:type lattice_constant: float, optional
|
|
1324
|
+
:param pbc: Specifies periodic boundary conditions. Defaults to True.
|
|
1325
|
+
:type pbc: Union[bool, Tuple[bool, bool, bool]], optional
|
|
1326
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
1327
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
1328
|
+
:type precompute_neighbors: Optional[int], optional
|
|
1329
|
+
"""
|
|
1330
|
+
|
|
1331
|
+
def __init__(
|
|
1332
|
+
self,
|
|
1333
|
+
size: Tuple[int, int, int],
|
|
1334
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
1335
|
+
pbc: Union[bool, Tuple[bool, bool, bool]] = True,
|
|
1336
|
+
precompute_neighbors: Optional[int] = None,
|
|
1337
|
+
):
|
|
1338
|
+
dimensionality = 3
|
|
1339
|
+
a = lattice_constant
|
|
1340
|
+
a_t = backend.convert_to_tensor(a)
|
|
1341
|
+
# Orthogonal lattice vectors of equal length in 3D.
|
|
1342
|
+
z = a_t * 0.0
|
|
1343
|
+
lattice_vectors = backend.stack(
|
|
1344
|
+
[
|
|
1345
|
+
backend.stack([a_t, z, z]),
|
|
1346
|
+
backend.stack([z, a_t, z]),
|
|
1347
|
+
backend.stack([z, z, a_t]),
|
|
1348
|
+
]
|
|
1349
|
+
)
|
|
1350
|
+
# A simple cubic lattice is a Bravais lattice with a single-site basis.
|
|
1351
|
+
basis_coords = backend.stack([backend.stack([z, z, z])])
|
|
1352
|
+
super().__init__(
|
|
1353
|
+
dimensionality=dimensionality,
|
|
1354
|
+
lattice_vectors=lattice_vectors,
|
|
1355
|
+
basis_coords=basis_coords,
|
|
1356
|
+
size=size,
|
|
1357
|
+
pbc=pbc,
|
|
1358
|
+
precompute_neighbors=precompute_neighbors,
|
|
1359
|
+
)
|
|
1360
|
+
|
|
1361
|
+
|
|
1362
|
+
class CustomizeLattice(AbstractLattice):
|
|
1363
|
+
"""A general lattice built from an explicit list of sites and coordinates.
|
|
1364
|
+
|
|
1365
|
+
This class is suitable for creating lattices with arbitrary geometries,
|
|
1366
|
+
such as finite clusters, disordered systems, or any custom structure
|
|
1367
|
+
that does not have translational symmetry. The lattice is defined simply
|
|
1368
|
+
by providing lists of identifiers and coordinates for each site.
|
|
1369
|
+
|
|
1370
|
+
:param dimensionality: The spatial dimension of the lattice.
|
|
1371
|
+
:type dimensionality: int
|
|
1372
|
+
:param identifiers: A list of unique, hashable
|
|
1373
|
+
identifiers for the sites. The length must match `coordinates`.
|
|
1374
|
+
:type identifiers: List[SiteIdentifier]
|
|
1375
|
+
:param coordinates: A list of site
|
|
1376
|
+
coordinates. Each coordinate should be a list of floats or a
|
|
1377
|
+
NumPy array.
|
|
1378
|
+
:type coordinates: List[Union[List[float], Coordinates]]
|
|
1379
|
+
:raises ValueError: If the lengths of `identifiers` and `coordinates` lists
|
|
1380
|
+
do not match, or if a coordinate's dimension is incorrect.
|
|
1381
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
1382
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
1383
|
+
:type precompute_neighbors: Optional[int], optional
|
|
1384
|
+
|
|
1385
|
+
"""
|
|
1386
|
+
|
|
1387
|
+
def __init__(
|
|
1388
|
+
self,
|
|
1389
|
+
dimensionality: int,
|
|
1390
|
+
identifiers: List[SiteIdentifier],
|
|
1391
|
+
coordinates: Any,
|
|
1392
|
+
precompute_neighbors: Optional[int] = None,
|
|
1393
|
+
):
|
|
1394
|
+
"""Initializes the CustomizeLattice."""
|
|
1395
|
+
super().__init__(dimensionality)
|
|
1396
|
+
|
|
1397
|
+
self._coordinates = backend.convert_to_tensor(coordinates)
|
|
1398
|
+
if len(identifiers) == 0:
|
|
1399
|
+
self._coordinates = backend.reshape(
|
|
1400
|
+
self._coordinates, (0, self.dimensionality)
|
|
1401
|
+
)
|
|
1402
|
+
|
|
1403
|
+
if len(identifiers) != backend.shape_tuple(self._coordinates)[0]:
|
|
1404
|
+
raise ValueError(
|
|
1405
|
+
"The number of identifiers must match the number of coordinates. "
|
|
1406
|
+
f"Got {len(identifiers)} identifiers and "
|
|
1407
|
+
f"{backend.shape_tuple(self._coordinates)[0]} coordinates."
|
|
1408
|
+
)
|
|
1409
|
+
|
|
1410
|
+
self._identifiers = list(identifiers)
|
|
1411
|
+
self._indices = list(range(len(identifiers)))
|
|
1412
|
+
self._ident_to_idx = {ident: idx for idx, ident in enumerate(identifiers)}
|
|
1413
|
+
|
|
1414
|
+
if (
|
|
1415
|
+
self.num_sites > 0
|
|
1416
|
+
and backend.shape_tuple(self._coordinates)[1] != dimensionality
|
|
1417
|
+
):
|
|
1418
|
+
raise ValueError(
|
|
1419
|
+
f"Coordinates tensor has dimension {backend.shape_tuple(self._coordinates)[1]}, "
|
|
1420
|
+
f"but expected dimensionality is {dimensionality}."
|
|
1421
|
+
)
|
|
1422
|
+
|
|
1423
|
+
logger.info(f"CustomizeLattice with {self.num_sites} sites created.")
|
|
1424
|
+
|
|
1425
|
+
if precompute_neighbors is not None and precompute_neighbors > 0:
|
|
1426
|
+
self._build_neighbors(max_k=precompute_neighbors)
|
|
1427
|
+
|
|
1428
|
+
def _build_lattice(self, *args: Any, **kwargs: Any) -> None:
|
|
1429
|
+
"""For CustomizeLattice, lattice data is built during __init__."""
|
|
1430
|
+
pass
|
|
1431
|
+
|
|
1432
|
+
def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None:
|
|
1433
|
+
"""
|
|
1434
|
+
Calculates neighbor relationships using either KDTree or distance matrix methods.
|
|
1435
|
+
|
|
1436
|
+
This method supports two modes:
|
|
1437
|
+
1. KDTree mode (use_kdtree=True): Fast, O(N log N) performance for large lattices
|
|
1438
|
+
but breaks differentiability due to scipy dependency
|
|
1439
|
+
2. Distance matrix mode (use_kdtree=False): Slower O(N^2) but fully differentiable
|
|
1440
|
+
and backend-agnostic
|
|
1441
|
+
|
|
1442
|
+
:param max_k: Maximum number of neighbor shells to compute
|
|
1443
|
+
:type max_k: int
|
|
1444
|
+
:param kwargs: Additional arguments including:
|
|
1445
|
+
- use_kdtree (bool): Whether to use KDTree optimization. Defaults to False.
|
|
1446
|
+
- tol (float): Distance tolerance for neighbor identification. Defaults to 1e-6.
|
|
1447
|
+
"""
|
|
1448
|
+
tol = kwargs.get("tol", 1e-6)
|
|
1449
|
+
# Reviewer suggestion: prefer differentiable method by default
|
|
1450
|
+
use_kdtree = kwargs.get("use_kdtree", False)
|
|
1451
|
+
|
|
1452
|
+
if self.num_sites < 2:
|
|
1453
|
+
return
|
|
1454
|
+
|
|
1455
|
+
# Choose algorithm based on user preference
|
|
1456
|
+
if use_kdtree:
|
|
1457
|
+
logger.info(
|
|
1458
|
+
f"Using KDTree method for {self.num_sites} sites up to k={max_k}"
|
|
1459
|
+
)
|
|
1460
|
+
self._build_neighbors_kdtree(max_k, tol)
|
|
1461
|
+
else:
|
|
1462
|
+
logger.info(
|
|
1463
|
+
f"Using differentiable distance matrix method for {self.num_sites} sites up to k={max_k}"
|
|
1464
|
+
)
|
|
1465
|
+
|
|
1466
|
+
# Use the existing distance matrix method
|
|
1467
|
+
self._build_neighbors_by_distance_matrix(max_k, tol)
|
|
1468
|
+
|
|
1469
|
+
def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None:
|
|
1470
|
+
"""
|
|
1471
|
+
Build neighbors using KDTree for optimal performance.
|
|
1472
|
+
|
|
1473
|
+
This method provides O(N log N) performance for neighbor finding but breaks
|
|
1474
|
+
differentiability due to scipy dependency. Use this method when:
|
|
1475
|
+
- Performance is critical
|
|
1476
|
+
- Differentiability is not required
|
|
1477
|
+
- Large lattices (N > 1000)
|
|
1478
|
+
|
|
1479
|
+
Note: This method uses numpy arrays directly and may not be compatible
|
|
1480
|
+
with all backend types (JAX, TensorFlow, etc.).
|
|
1481
|
+
"""
|
|
1482
|
+
|
|
1483
|
+
# For small lattices or cases with potential duplicate coordinates,
|
|
1484
|
+
# fall back to distance matrix method for robustness
|
|
1485
|
+
if self.num_sites < 200:
|
|
1486
|
+
logger.info(
|
|
1487
|
+
"Small lattice detected, falling back to distance matrix method for robustness"
|
|
1488
|
+
)
|
|
1489
|
+
self._build_neighbors_by_distance_matrix(max_k, tol)
|
|
1490
|
+
return
|
|
1491
|
+
|
|
1492
|
+
# Convert coordinates to numpy for KDTree
|
|
1493
|
+
coords_np = backend.numpy(self._coordinates)
|
|
1494
|
+
|
|
1495
|
+
# Build KDTree
|
|
1496
|
+
logger.info("Building KDTree...")
|
|
1497
|
+
tree = KDTree(coords_np)
|
|
1498
|
+
# Find all distances for shell identification - use comprehensive sampling
|
|
1499
|
+
logger.info("Identifying distance shells...")
|
|
1500
|
+
distances_for_shells: List[float] = []
|
|
1501
|
+
|
|
1502
|
+
# For robust shell identification, query all pairwise distances for smaller lattices
|
|
1503
|
+
# or use dense sampling for larger ones
|
|
1504
|
+
if self.num_sites <= 100:
|
|
1505
|
+
# For small lattices, compute all pairwise distances for accuracy
|
|
1506
|
+
for i in range(self.num_sites):
|
|
1507
|
+
query_k = min(self.num_sites - 1, max_k * 20)
|
|
1508
|
+
if query_k > 0:
|
|
1509
|
+
dists, _ = tree.query(
|
|
1510
|
+
coords_np[i], k=query_k + 1
|
|
1511
|
+
) # +1 to exclude self
|
|
1512
|
+
if isinstance(dists, np.ndarray):
|
|
1513
|
+
distances_for_shells.extend(dists[1:]) # Skip distance to self
|
|
1514
|
+
else:
|
|
1515
|
+
distances_for_shells.append(dists) # Single distance
|
|
1516
|
+
else:
|
|
1517
|
+
# For larger lattices, use adaptive sampling but ensure we capture all shells
|
|
1518
|
+
sample_size = min(1000, self.num_sites // 2) # More conservative sampling
|
|
1519
|
+
for i in range(0, self.num_sites, max(1, self.num_sites // sample_size)):
|
|
1520
|
+
query_k = min(max_k * 20 + 50, self.num_sites - 1)
|
|
1521
|
+
if query_k > 0:
|
|
1522
|
+
dists, _ = tree.query(
|
|
1523
|
+
coords_np[i], k=query_k + 1
|
|
1524
|
+
) # +1 to exclude self
|
|
1525
|
+
if isinstance(dists, np.ndarray):
|
|
1526
|
+
distances_for_shells.extend(dists[1:]) # Skip distance to self
|
|
1527
|
+
else:
|
|
1528
|
+
distances_for_shells.append(dists) # Single distance
|
|
1529
|
+
|
|
1530
|
+
# Filter out zero distances (duplicate coordinates) before shell identification
|
|
1531
|
+
ZERO_THRESHOLD = 1e-12
|
|
1532
|
+
distances_for_shells = [d for d in distances_for_shells if d > ZERO_THRESHOLD]
|
|
1533
|
+
|
|
1534
|
+
if not distances_for_shells:
|
|
1535
|
+
logger.warning("No valid distances found for shell identification")
|
|
1536
|
+
self._neighbor_maps = {}
|
|
1537
|
+
return
|
|
1538
|
+
|
|
1539
|
+
# Use the same shell identification logic as distance matrix method
|
|
1540
|
+
distances_for_shells_sq = [d * d for d in distances_for_shells]
|
|
1541
|
+
dist_shells_sq = self._identify_distance_shells(
|
|
1542
|
+
distances_for_shells_sq, max_k, tol
|
|
1543
|
+
)
|
|
1544
|
+
dist_shells = [np.sqrt(d_sq) for d_sq in dist_shells_sq]
|
|
1545
|
+
|
|
1546
|
+
logger.info(f"Found {len(dist_shells)} distance shells: {dist_shells[:5]}...")
|
|
1547
|
+
|
|
1548
|
+
# Initialize neighbor maps
|
|
1549
|
+
self._neighbor_maps = {k: {} for k in range(1, len(dist_shells) + 1)}
|
|
1550
|
+
|
|
1551
|
+
# Build neighbor lists for each site
|
|
1552
|
+
for i in range(self.num_sites):
|
|
1553
|
+
# Query enough neighbors to capture all shells
|
|
1554
|
+
query_k = min(max_k * 20 + 50, self.num_sites - 1)
|
|
1555
|
+
if query_k > 0:
|
|
1556
|
+
distances, indices = tree.query(
|
|
1557
|
+
coords_np[i], k=query_k + 1
|
|
1558
|
+
) # +1 for self
|
|
1559
|
+
|
|
1560
|
+
# Skip the first entry (distance to self)
|
|
1561
|
+
# Handle both single value and array cases
|
|
1562
|
+
if isinstance(distances, np.ndarray) and len(distances) > 1:
|
|
1563
|
+
distances_slice = distances[1:]
|
|
1564
|
+
indices_slice = (
|
|
1565
|
+
indices[1:]
|
|
1566
|
+
if isinstance(indices, np.ndarray)
|
|
1567
|
+
else np.array([], dtype=int)
|
|
1568
|
+
)
|
|
1569
|
+
else:
|
|
1570
|
+
# Single value or empty case - no neighbors to process
|
|
1571
|
+
distances_slice = np.array([])
|
|
1572
|
+
indices_slice = np.array([], dtype=int)
|
|
1573
|
+
|
|
1574
|
+
# Filter out zero distances (duplicate coordinates)
|
|
1575
|
+
valid_pairs = [
|
|
1576
|
+
(d, idx)
|
|
1577
|
+
for d, idx in zip(distances_slice, indices_slice)
|
|
1578
|
+
if d > ZERO_THRESHOLD
|
|
1579
|
+
]
|
|
1580
|
+
|
|
1581
|
+
# Assign neighbors to shells
|
|
1582
|
+
for shell_idx, shell_dist in enumerate(dist_shells):
|
|
1583
|
+
k = shell_idx + 1
|
|
1584
|
+
shell_neighbors = []
|
|
1585
|
+
|
|
1586
|
+
for dist, neighbor_idx in valid_pairs:
|
|
1587
|
+
if abs(dist - shell_dist) <= tol:
|
|
1588
|
+
shell_neighbors.append(int(neighbor_idx))
|
|
1589
|
+
elif dist > shell_dist + tol:
|
|
1590
|
+
break # Distances are sorted, no more matches
|
|
1591
|
+
|
|
1592
|
+
if shell_neighbors:
|
|
1593
|
+
self._neighbor_maps[k][i] = sorted(shell_neighbors)
|
|
1594
|
+
|
|
1595
|
+
# Set distance matrix to None - will compute on demand
|
|
1596
|
+
self._distance_matrix = None
|
|
1597
|
+
|
|
1598
|
+
def _reset_computations(self) -> None:
|
|
1599
|
+
"""Resets all cached data that depends on the lattice structure."""
|
|
1600
|
+
self._neighbor_maps = {}
|
|
1601
|
+
self._distance_matrix = None
|
|
1602
|
+
|
|
1603
|
+
@classmethod
|
|
1604
|
+
def from_lattice(cls, lattice: "AbstractLattice") -> "CustomizeLattice":
|
|
1605
|
+
"""Creates a CustomizeLattice instance from any existing lattice object.
|
|
1606
|
+
|
|
1607
|
+
This is useful for 'detaching' a procedurally generated lattice (like
|
|
1608
|
+
a SquareLattice) into a customizable one for further modifications,
|
|
1609
|
+
such as adding defects or extra sites.
|
|
1610
|
+
|
|
1611
|
+
:param lattice: An instance of any AbstractLattice subclass.
|
|
1612
|
+
:type lattice: AbstractLattice
|
|
1613
|
+
:return: A new CustomizeLattice instance with the same sites.
|
|
1614
|
+
:rtype: CustomizeLattice
|
|
1615
|
+
"""
|
|
1616
|
+
all_sites_info = list(lattice.sites())
|
|
1617
|
+
|
|
1618
|
+
if not all_sites_info:
|
|
1619
|
+
return cls(
|
|
1620
|
+
dimensionality=lattice.dimensionality, identifiers=[], coordinates=[]
|
|
1621
|
+
)
|
|
1622
|
+
|
|
1623
|
+
# Unzip the list of tuples into separate lists of identifiers and coordinates
|
|
1624
|
+
_, identifiers, _ = zip(*all_sites_info)
|
|
1625
|
+
|
|
1626
|
+
# Detach-and-copy coordinates while remaining in tensor form to avoid
|
|
1627
|
+
# host roundtrips and device/dtype changes; this keeps CustomizeLattice
|
|
1628
|
+
# decoupled from the original graph but backend-friendly.
|
|
1629
|
+
# Some backends (e.g., NumPy) don't implement stop_gradient; fall back.
|
|
1630
|
+
try:
|
|
1631
|
+
coords_detached = backend.stop_gradient(lattice._coordinates)
|
|
1632
|
+
except NotImplementedError:
|
|
1633
|
+
coords_detached = lattice._coordinates
|
|
1634
|
+
coords_tensor = backend.copy(coords_detached)
|
|
1635
|
+
|
|
1636
|
+
return cls(
|
|
1637
|
+
dimensionality=lattice.dimensionality,
|
|
1638
|
+
identifiers=list(identifiers),
|
|
1639
|
+
coordinates=coords_tensor,
|
|
1640
|
+
)
|
|
1641
|
+
|
|
1642
|
+
def add_sites(
|
|
1643
|
+
self,
|
|
1644
|
+
identifiers: List[SiteIdentifier],
|
|
1645
|
+
coordinates: Any,
|
|
1646
|
+
) -> None:
|
|
1647
|
+
"""Adds new sites to the lattice.
|
|
1648
|
+
|
|
1649
|
+
This operation modifies the lattice in-place. After adding sites, any
|
|
1650
|
+
previously computed neighbor information is cleared and must be
|
|
1651
|
+
recalculated.
|
|
1652
|
+
|
|
1653
|
+
:param identifiers: A list of unique identifiers for the new sites.
|
|
1654
|
+
:type identifiers: List[SiteIdentifier]
|
|
1655
|
+
:param coordinates: The coordinates for the new sites. Can be a list of lists,
|
|
1656
|
+
a NumPy array, or a backend-compatible tensor (e.g., jax.numpy.ndarray).
|
|
1657
|
+
:type coordinates: Any
|
|
1658
|
+
"""
|
|
1659
|
+
if not identifiers:
|
|
1660
|
+
return
|
|
1661
|
+
|
|
1662
|
+
new_coords_tensor = backend.convert_to_tensor(coordinates)
|
|
1663
|
+
|
|
1664
|
+
if len(identifiers) != backend.shape_tuple(new_coords_tensor)[0]:
|
|
1665
|
+
raise ValueError(
|
|
1666
|
+
"Identifiers and coordinates lists must have the same length."
|
|
1667
|
+
)
|
|
1668
|
+
|
|
1669
|
+
if backend.shape_tuple(new_coords_tensor)[1] != self.dimensionality:
|
|
1670
|
+
raise ValueError(
|
|
1671
|
+
f"New coordinate tensor has dimension {backend.shape_tuple(new_coords_tensor)[1]}, "
|
|
1672
|
+
f"but expected dimensionality is {self.dimensionality}."
|
|
1673
|
+
)
|
|
1674
|
+
|
|
1675
|
+
# Ensure that the new identifiers are unique and do not already exist.
|
|
1676
|
+
existing_ids = set(self._identifiers)
|
|
1677
|
+
new_ids = set(identifiers)
|
|
1678
|
+
if not new_ids.isdisjoint(existing_ids):
|
|
1679
|
+
raise ValueError(
|
|
1680
|
+
f"Duplicate identifiers found: {new_ids.intersection(existing_ids)}"
|
|
1681
|
+
)
|
|
1682
|
+
|
|
1683
|
+
self._coordinates = backend.concat(
|
|
1684
|
+
[self._coordinates, new_coords_tensor], axis=0
|
|
1685
|
+
)
|
|
1686
|
+
self._identifiers.extend(identifiers)
|
|
1687
|
+
|
|
1688
|
+
self._indices = list(range(len(self._identifiers)))
|
|
1689
|
+
self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)}
|
|
1690
|
+
|
|
1691
|
+
self._reset_computations()
|
|
1692
|
+
logger.info(
|
|
1693
|
+
f"{len(identifiers)} sites added. Lattice now has {self.num_sites} sites."
|
|
1694
|
+
)
|
|
1695
|
+
|
|
1696
|
+
def remove_sites(self, identifiers: List[SiteIdentifier]) -> None:
|
|
1697
|
+
"""Removes specified sites from the lattice.
|
|
1698
|
+
|
|
1699
|
+
This operation modifies the lattice in-place. After removing sites,
|
|
1700
|
+
all site indices are re-calculated, and any previously computed
|
|
1701
|
+
neighbor information is cleared.
|
|
1702
|
+
|
|
1703
|
+
:param identifiers: A list of identifiers for the sites to be removed.
|
|
1704
|
+
:type identifiers: List[SiteIdentifier]
|
|
1705
|
+
"""
|
|
1706
|
+
if not identifiers:
|
|
1707
|
+
return
|
|
1708
|
+
|
|
1709
|
+
ids_to_remove = set(identifiers)
|
|
1710
|
+
current_ids = set(self._identifiers)
|
|
1711
|
+
if not ids_to_remove.issubset(current_ids):
|
|
1712
|
+
raise ValueError(
|
|
1713
|
+
f"Non-existent identifiers provided for removal: {ids_to_remove - current_ids}"
|
|
1714
|
+
)
|
|
1715
|
+
|
|
1716
|
+
# Find the indices of the sites that we want to keep.
|
|
1717
|
+
indices_to_keep = [
|
|
1718
|
+
idx
|
|
1719
|
+
for idx, ident in enumerate(self._identifiers)
|
|
1720
|
+
if ident not in ids_to_remove
|
|
1721
|
+
]
|
|
1722
|
+
|
|
1723
|
+
new_identifiers = [self._identifiers[i] for i in indices_to_keep]
|
|
1724
|
+
|
|
1725
|
+
self._coordinates = backend.gather1d(
|
|
1726
|
+
self._coordinates,
|
|
1727
|
+
backend.cast(backend.convert_to_tensor(indices_to_keep), "int32"),
|
|
1728
|
+
)
|
|
1729
|
+
|
|
1730
|
+
self._identifiers = new_identifiers
|
|
1731
|
+
|
|
1732
|
+
self._indices = list(range(len(self._identifiers)))
|
|
1733
|
+
self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)}
|
|
1734
|
+
|
|
1735
|
+
self._reset_computations()
|
|
1736
|
+
logger.info(
|
|
1737
|
+
f"{len(ids_to_remove)} sites removed. Lattice now has {self.num_sites} sites."
|
|
1738
|
+
)
|
|
1739
|
+
|
|
1740
|
+
|
|
1741
|
+
def get_compatible_layers(bonds: List[Tuple[int, int]]) -> List[List[Tuple[int, int]]]:
|
|
1742
|
+
"""
|
|
1743
|
+
Partitions a list of pairs (bonds) into compatible layers for parallel
|
|
1744
|
+
gate application using a greedy edge-coloring algorithm.
|
|
1745
|
+
|
|
1746
|
+
This function takes a list of pairs, representing connections like
|
|
1747
|
+
nearest-neighbor (NN) or next-nearest-neighbor (NNN) bonds, and
|
|
1748
|
+
partitions them into the minimum number of sets ("layers") where no two
|
|
1749
|
+
pairs in a set share an index. This is a general utility for scheduling
|
|
1750
|
+
non-overlapping operations.
|
|
1751
|
+
|
|
1752
|
+
:Example:
|
|
1753
|
+
|
|
1754
|
+
>>> from tensorcircuit.templates.lattice import SquareLattice
|
|
1755
|
+
>>> sq_lattice = SquareLattice(size=(2, 2), pbc=False)
|
|
1756
|
+
>>> nn_bonds = sq_lattice.get_neighbor_pairs(k=1, unique=True)
|
|
1757
|
+
|
|
1758
|
+
>>> gate_layers = get_compatible_layers(nn_bonds)
|
|
1759
|
+
>>> print(gate_layers)
|
|
1760
|
+
[[[0, 1], [2, 3]], [[0, 2], [1, 3]]]
|
|
1761
|
+
|
|
1762
|
+
:param bonds: A list of tuples, where each tuple represents a bond (i, j)
|
|
1763
|
+
of site indices to be scheduled.
|
|
1764
|
+
:type bonds: List[Tuple[int, int]]
|
|
1765
|
+
:return: A list of layers. Each layer is a list of tuples, where each
|
|
1766
|
+
tuple represents a bond. All bonds within a layer are non-overlapping.
|
|
1767
|
+
:rtype: List[List[Tuple[int, int]]]
|
|
1768
|
+
"""
|
|
1769
|
+
uncolored_edges: Set[Tuple[int, int]] = {(min(bond), max(bond)) for bond in bonds}
|
|
1770
|
+
|
|
1771
|
+
layers: List[List[Tuple[int, int]]] = []
|
|
1772
|
+
|
|
1773
|
+
while uncolored_edges:
|
|
1774
|
+
current_layer: List[Tuple[int, int]] = []
|
|
1775
|
+
qubits_in_this_layer: Set[int] = set()
|
|
1776
|
+
|
|
1777
|
+
edges_to_process = sorted(list(uncolored_edges))
|
|
1778
|
+
|
|
1779
|
+
for edge in edges_to_process:
|
|
1780
|
+
i, j = edge
|
|
1781
|
+
if i not in qubits_in_this_layer and j not in qubits_in_this_layer:
|
|
1782
|
+
current_layer.append(edge)
|
|
1783
|
+
qubits_in_this_layer.add(i)
|
|
1784
|
+
qubits_in_this_layer.add(j)
|
|
1785
|
+
|
|
1786
|
+
uncolored_edges -= set(current_layer)
|
|
1787
|
+
layers.append(sorted(current_layer))
|
|
1788
|
+
|
|
1789
|
+
return layers
|