tensorcircuit-nightly 1.2.1.dev20250721__py3-none-any.whl → 1.2.1.dev20250722__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tensorcircuit-nightly might be problematic. Click here for more details.
- tensorcircuit/__init__.py +1 -1
- tensorcircuit/templates/__init__.py +1 -0
- tensorcircuit/templates/lattice.py +1448 -0
- {tensorcircuit_nightly-1.2.1.dev20250721.dist-info → tensorcircuit_nightly-1.2.1.dev20250722.dist-info}/METADATA +4 -1
- {tensorcircuit_nightly-1.2.1.dev20250721.dist-info → tensorcircuit_nightly-1.2.1.dev20250722.dist-info}/RECORD +9 -7
- tests/test_lattice.py +1666 -0
- {tensorcircuit_nightly-1.2.1.dev20250721.dist-info → tensorcircuit_nightly-1.2.1.dev20250722.dist-info}/WHEEL +0 -0
- {tensorcircuit_nightly-1.2.1.dev20250721.dist-info → tensorcircuit_nightly-1.2.1.dev20250722.dist-info}/licenses/LICENSE +0 -0
- {tensorcircuit_nightly-1.2.1.dev20250721.dist-info → tensorcircuit_nightly-1.2.1.dev20250722.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1448 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
The lattice module for defining and manipulating lattice geometries.
|
|
4
|
+
"""
|
|
5
|
+
import logging
|
|
6
|
+
import abc
|
|
7
|
+
from typing import (
|
|
8
|
+
Any,
|
|
9
|
+
Dict,
|
|
10
|
+
Hashable,
|
|
11
|
+
Iterator,
|
|
12
|
+
List,
|
|
13
|
+
Optional,
|
|
14
|
+
Tuple,
|
|
15
|
+
Union,
|
|
16
|
+
TYPE_CHECKING,
|
|
17
|
+
cast,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
import numpy as np
|
|
22
|
+
|
|
23
|
+
from scipy.spatial import KDTree
|
|
24
|
+
from scipy.spatial.distance import pdist, squareform
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# This block resolves a name resolution issue for the static type checker (mypy).
|
|
28
|
+
# GOAL:
|
|
29
|
+
# Keep `matplotlib` as an optional dependency, so it is only imported
|
|
30
|
+
# inside the `show()` method, not at the module level.
|
|
31
|
+
# PROBLEM:
|
|
32
|
+
# The type hint for the `ax` parameter in `show()`'s signature
|
|
33
|
+
# (`ax: Optional["matplotlib.axes.Axes"]`) needs to know what `matplotlib` is.
|
|
34
|
+
# Without this block, mypy would raise a "Name 'matplotlib' is not defined" error.
|
|
35
|
+
# SOLUTION:
|
|
36
|
+
# The `if TYPE_CHECKING:` block is ignored at runtime but processed by mypy.
|
|
37
|
+
# This makes the name `matplotlib` available to the type checker without
|
|
38
|
+
# creating a hard dependency for the user.
|
|
39
|
+
if TYPE_CHECKING:
|
|
40
|
+
import matplotlib.axes
|
|
41
|
+
from mpl_toolkits.mplot3d import Axes3D
|
|
42
|
+
|
|
43
|
+
SiteIndex = int
|
|
44
|
+
SiteIdentifier = Hashable
|
|
45
|
+
Coordinates = np.ndarray[Any, Any]
|
|
46
|
+
NeighborMap = Dict[SiteIndex, List[SiteIndex]]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class AbstractLattice(abc.ABC):
|
|
50
|
+
"""Abstract base class for describing lattice systems.
|
|
51
|
+
|
|
52
|
+
This class defines the common interface for all lattice structures,
|
|
53
|
+
providing access to fundamental properties like site information
|
|
54
|
+
(count, coordinates, identifiers) and neighbor relationships.
|
|
55
|
+
Subclasses are responsible for implementing the specific logic for
|
|
56
|
+
generating the lattice points and calculating neighbor connections.
|
|
57
|
+
|
|
58
|
+
:param dimensionality: The spatial dimension of the lattice (e.g., 1, 2, 3).
|
|
59
|
+
:type dimensionality: int
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(self, dimensionality: int):
|
|
63
|
+
"""Initializes the base lattice class."""
|
|
64
|
+
self._dimensionality = dimensionality
|
|
65
|
+
|
|
66
|
+
# --- Internal Data Structures (to be populated by subclasses) ---
|
|
67
|
+
self._indices: List[SiteIndex] = []
|
|
68
|
+
self._identifiers: List[SiteIdentifier] = []
|
|
69
|
+
self._coordinates: List[Coordinates] = []
|
|
70
|
+
self._ident_to_idx: Dict[SiteIdentifier, SiteIndex] = {}
|
|
71
|
+
self._neighbor_maps: Dict[int, NeighborMap] = {}
|
|
72
|
+
self._distance_matrix: Optional[Coordinates] = None
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def num_sites(self) -> int:
|
|
76
|
+
"""Returns the total number of sites (N) in the lattice."""
|
|
77
|
+
return len(self._indices)
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def dimensionality(self) -> int:
|
|
81
|
+
"""Returns the spatial dimension of the lattice."""
|
|
82
|
+
return self._dimensionality
|
|
83
|
+
|
|
84
|
+
def __len__(self) -> int:
|
|
85
|
+
"""Returns the total number of sites, enabling `len(lattice)`."""
|
|
86
|
+
return self.num_sites
|
|
87
|
+
|
|
88
|
+
# --- Public API for Accessing Lattice Information ---
|
|
89
|
+
@property
|
|
90
|
+
def distance_matrix(self) -> Coordinates:
|
|
91
|
+
"""
|
|
92
|
+
Returns the full N x N distance matrix.
|
|
93
|
+
The matrix is computed on the first access and then cached for
|
|
94
|
+
subsequent calls. This computation can be expensive for large lattices.
|
|
95
|
+
"""
|
|
96
|
+
if self._distance_matrix is None:
|
|
97
|
+
logger.info("Distance matrix not cached. Computing now...")
|
|
98
|
+
self._distance_matrix = self._compute_distance_matrix()
|
|
99
|
+
return self._distance_matrix
|
|
100
|
+
|
|
101
|
+
def _validate_index(self, index: SiteIndex) -> None:
|
|
102
|
+
"""A private helper to check if a site index is within the valid range."""
|
|
103
|
+
if not (0 <= index < self.num_sites):
|
|
104
|
+
raise IndexError(
|
|
105
|
+
f"Site index {index} out of range (0-{self.num_sites - 1})"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def get_coordinates(self, index: SiteIndex) -> Coordinates:
|
|
109
|
+
"""Gets the spatial coordinates of a site by its integer index.
|
|
110
|
+
|
|
111
|
+
:param index: The integer index of the site.
|
|
112
|
+
:type index: SiteIndex
|
|
113
|
+
:raises IndexError: If the site index is out of range.
|
|
114
|
+
:return: The spatial coordinates as a NumPy array.
|
|
115
|
+
:rtype: Coordinates
|
|
116
|
+
"""
|
|
117
|
+
self._validate_index(index)
|
|
118
|
+
return self._coordinates[index]
|
|
119
|
+
|
|
120
|
+
def get_identifier(self, index: SiteIndex) -> SiteIdentifier:
|
|
121
|
+
"""Gets the abstract identifier of a site by its integer index.
|
|
122
|
+
|
|
123
|
+
:param index: The integer index of the site.
|
|
124
|
+
:type index: SiteIndex
|
|
125
|
+
:raises IndexError: If the site index is out of range.
|
|
126
|
+
:return: The unique, hashable identifier of the site.
|
|
127
|
+
:rtype: SiteIdentifier
|
|
128
|
+
"""
|
|
129
|
+
self._validate_index(index)
|
|
130
|
+
return self._identifiers[index]
|
|
131
|
+
|
|
132
|
+
def get_index(self, identifier: SiteIdentifier) -> SiteIndex:
|
|
133
|
+
"""Gets the integer index of a site by its unique identifier.
|
|
134
|
+
|
|
135
|
+
:param identifier: The unique identifier of the site.
|
|
136
|
+
:type identifier: SiteIdentifier
|
|
137
|
+
:raises ValueError: If the identifier is not found in the lattice.
|
|
138
|
+
:return: The corresponding integer index of the site.
|
|
139
|
+
:rtype: SiteIndex
|
|
140
|
+
"""
|
|
141
|
+
try:
|
|
142
|
+
return self._ident_to_idx[identifier]
|
|
143
|
+
except KeyError as e:
|
|
144
|
+
raise ValueError(
|
|
145
|
+
f"Identifier {identifier} not found in the lattice."
|
|
146
|
+
) from e
|
|
147
|
+
|
|
148
|
+
def get_site_info(
|
|
149
|
+
self, index_or_identifier: Union[SiteIndex, SiteIdentifier]
|
|
150
|
+
) -> Tuple[SiteIndex, SiteIdentifier, Coordinates]:
|
|
151
|
+
"""Gets all information for a single site.
|
|
152
|
+
|
|
153
|
+
This method provides a convenient way to retrieve all relevant data for a
|
|
154
|
+
site (its index, identifier, and coordinates) by using either its
|
|
155
|
+
integer index or its unique identifier.
|
|
156
|
+
|
|
157
|
+
:param index_or_identifier: The integer
|
|
158
|
+
index or the unique identifier of the site to look up.
|
|
159
|
+
:type index_or_identifier: Union[SiteIndex, SiteIdentifier]
|
|
160
|
+
:raises IndexError: If the given index is out of bounds.
|
|
161
|
+
:raises ValueError: If the given identifier is not found in the lattice.
|
|
162
|
+
:return: A tuple containing:
|
|
163
|
+
- The site's integer index.
|
|
164
|
+
- The site's unique identifier.
|
|
165
|
+
- The site's coordinates as a NumPy array.
|
|
166
|
+
:rtype: Tuple[SiteIndex, SiteIdentifier, Coordinates]
|
|
167
|
+
"""
|
|
168
|
+
if isinstance(index_or_identifier, int): # SiteIndex is an int
|
|
169
|
+
idx = index_or_identifier
|
|
170
|
+
self._validate_index(idx)
|
|
171
|
+
return idx, self._identifiers[idx], self._coordinates[idx]
|
|
172
|
+
else: # Identifier
|
|
173
|
+
ident = index_or_identifier
|
|
174
|
+
idx = self.get_index(ident)
|
|
175
|
+
return idx, ident, self._coordinates[idx]
|
|
176
|
+
|
|
177
|
+
def sites(self) -> Iterator[Tuple[SiteIndex, SiteIdentifier, Coordinates]]:
|
|
178
|
+
"""Returns an iterator over all sites in the lattice.
|
|
179
|
+
|
|
180
|
+
This provides a convenient way to loop through all sites, for example:
|
|
181
|
+
`for idx, ident, coords in my_lattice.sites(): ...`
|
|
182
|
+
|
|
183
|
+
:return: An iterator where each item is a tuple containing the site's
|
|
184
|
+
index, identifier, and coordinates.
|
|
185
|
+
:rtype: Iterator[Tuple[SiteIndex, SiteIdentifier, Coordinates]]
|
|
186
|
+
"""
|
|
187
|
+
for i in range(self.num_sites):
|
|
188
|
+
yield i, self._identifiers[i], self._coordinates[i]
|
|
189
|
+
|
|
190
|
+
def get_neighbors(self, index: SiteIndex, k: int = 1) -> List[SiteIndex]:
|
|
191
|
+
"""Gets the list of k-th nearest neighbor indices for a given site.
|
|
192
|
+
|
|
193
|
+
:param index: The integer index of the center site.
|
|
194
|
+
:type index: SiteIndex
|
|
195
|
+
:param k: The order of the neighbors, where k=1 corresponds
|
|
196
|
+
to nearest neighbors (NN), k=2 to next-nearest neighbors (NNN),
|
|
197
|
+
and so on. Defaults to 1.
|
|
198
|
+
:type k: int, optional
|
|
199
|
+
:return: A list of integer indices for the neighboring sites.
|
|
200
|
+
Returns an empty list if neighbors for the given `k` have not been
|
|
201
|
+
pre-calculated or if the site has no such neighbors.
|
|
202
|
+
:rtype: List[SiteIndex]
|
|
203
|
+
"""
|
|
204
|
+
if k not in self._neighbor_maps:
|
|
205
|
+
logger.info(
|
|
206
|
+
f"Neighbors for k={k} not pre-computed. Building now up to max_k={k}."
|
|
207
|
+
)
|
|
208
|
+
self._build_neighbors(max_k=k)
|
|
209
|
+
|
|
210
|
+
if k not in self._neighbor_maps:
|
|
211
|
+
return []
|
|
212
|
+
|
|
213
|
+
return self._neighbor_maps[k].get(index, [])
|
|
214
|
+
|
|
215
|
+
def get_neighbor_pairs(
|
|
216
|
+
self, k: int = 1, unique: bool = True
|
|
217
|
+
) -> List[Tuple[SiteIndex, SiteIndex]]:
|
|
218
|
+
"""Gets all pairs of k-th nearest neighbors, representing bonds.
|
|
219
|
+
|
|
220
|
+
:param k: The order of the neighbors to consider.
|
|
221
|
+
Defaults to 1.
|
|
222
|
+
:type k: int, optional
|
|
223
|
+
:param unique: If True, returns only one representation
|
|
224
|
+
for each pair (i, j) such that i < j, avoiding duplicates
|
|
225
|
+
like (j, i). If False, returns all directed pairs.
|
|
226
|
+
Defaults to True.
|
|
227
|
+
:type unique: bool, optional
|
|
228
|
+
:return: A list of tuples, where each
|
|
229
|
+
tuple is a pair of neighbor indices.
|
|
230
|
+
:rtype: List[Tuple[SiteIndex, SiteIndex]]
|
|
231
|
+
"""
|
|
232
|
+
|
|
233
|
+
if k not in self._neighbor_maps:
|
|
234
|
+
logger.info(
|
|
235
|
+
f"Neighbor pairs for k={k} not pre-computed. Building now up to max_k={k}."
|
|
236
|
+
)
|
|
237
|
+
self._build_neighbors(max_k=k)
|
|
238
|
+
|
|
239
|
+
# After attempting to build, check again. If still not found, return empty.
|
|
240
|
+
if k not in self._neighbor_maps:
|
|
241
|
+
return []
|
|
242
|
+
|
|
243
|
+
pairs = []
|
|
244
|
+
for i, neighbors in self._neighbor_maps[k].items():
|
|
245
|
+
for j in neighbors:
|
|
246
|
+
if unique:
|
|
247
|
+
if i < j:
|
|
248
|
+
pairs.append((i, j))
|
|
249
|
+
else:
|
|
250
|
+
pairs.append((i, j))
|
|
251
|
+
return sorted(pairs)
|
|
252
|
+
|
|
253
|
+
# Sorting provides a deterministic output order
|
|
254
|
+
# --- Abstract Methods for Subclass Implementation ---
|
|
255
|
+
|
|
256
|
+
@abc.abstractmethod
|
|
257
|
+
def _build_lattice(self, *args: Any, **kwargs: Any) -> None:
|
|
258
|
+
"""
|
|
259
|
+
Abstract method for subclasses to generate the lattice data.
|
|
260
|
+
|
|
261
|
+
A concrete implementation of this method in a subclass is responsible
|
|
262
|
+
for populating the following internal attributes:
|
|
263
|
+
- self._indices
|
|
264
|
+
- self._identifiers
|
|
265
|
+
- self._coordinates
|
|
266
|
+
- self._ident_to_idx
|
|
267
|
+
"""
|
|
268
|
+
pass
|
|
269
|
+
|
|
270
|
+
@abc.abstractmethod
|
|
271
|
+
def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None:
|
|
272
|
+
"""
|
|
273
|
+
Abstract method for subclasses to calculate neighbor relationships.
|
|
274
|
+
|
|
275
|
+
A concrete implementation of this method should calculate the neighbor
|
|
276
|
+
relationships up to `max_k` and populate the `self._neighbor_maps`
|
|
277
|
+
dictionary. The keys of the dictionary should be the neighbor order (k),
|
|
278
|
+
and the values should be a dictionary mapping site indices to their
|
|
279
|
+
list of k-th neighbors.
|
|
280
|
+
"""
|
|
281
|
+
pass
|
|
282
|
+
|
|
283
|
+
@abc.abstractmethod
|
|
284
|
+
def _compute_distance_matrix(self) -> Coordinates:
|
|
285
|
+
"""
|
|
286
|
+
Abstract method for subclasses to implement the actual matrix calculation.
|
|
287
|
+
This method is called by the `distance_matrix` property when the matrix
|
|
288
|
+
needs to be computed for the first time.
|
|
289
|
+
"""
|
|
290
|
+
pass
|
|
291
|
+
|
|
292
|
+
def show(
|
|
293
|
+
self,
|
|
294
|
+
show_indices: bool = False,
|
|
295
|
+
show_identifiers: bool = False,
|
|
296
|
+
show_bonds_k: Optional[int] = None,
|
|
297
|
+
ax: Optional["matplotlib.axes.Axes"] = None,
|
|
298
|
+
bond_kwargs: Optional[Dict[str, Any]] = None,
|
|
299
|
+
**kwargs: Any,
|
|
300
|
+
) -> None:
|
|
301
|
+
"""Visualizes the lattice structure using Matplotlib.
|
|
302
|
+
|
|
303
|
+
This method supports 1D, 2D, and 3D plotting. For 1D lattices, sites
|
|
304
|
+
are plotted along the x-axis.
|
|
305
|
+
|
|
306
|
+
:param show_indices: If True, displays the integer index
|
|
307
|
+
next to each site. Defaults to False.
|
|
308
|
+
:type show_indices: bool, optional
|
|
309
|
+
:param show_identifiers: If True, displays the unique
|
|
310
|
+
identifier next to each site. Defaults to False.
|
|
311
|
+
:type show_identifiers: bool, optional
|
|
312
|
+
:param show_bonds_k: Specifies which order of
|
|
313
|
+
neighbor bonds to draw (e.g., 1 for NN, 2 for NNN). If None,
|
|
314
|
+
no bonds are drawn. If the specified neighbors have not been
|
|
315
|
+
calculated, a warning is printed. Defaults to None.
|
|
316
|
+
:type show_bonds_k: Optional[int], optional
|
|
317
|
+
:param ax: An existing Matplotlib Axes object to plot on.
|
|
318
|
+
If None, a new Figure and Axes are created automatically. Defaults to None.
|
|
319
|
+
:type ax: Optional["matplotlib.axes.Axes"], optional
|
|
320
|
+
:param bond_kwargs: A dictionary of keyword arguments for customizing bond appearance,
|
|
321
|
+
passed directly to the Matplotlib plot function. Defaults to None.
|
|
322
|
+
:type bond_kwargs: Optional[Dict[str, Any]], optional
|
|
323
|
+
|
|
324
|
+
:param kwargs: Additional keyword arguments to be passed directly to the
|
|
325
|
+
`matplotlib.pyplot.scatter` function for customizing site appearance.
|
|
326
|
+
"""
|
|
327
|
+
try:
|
|
328
|
+
import matplotlib.pyplot as plt
|
|
329
|
+
except ImportError:
|
|
330
|
+
logger.error(
|
|
331
|
+
"Matplotlib is required for visualization. "
|
|
332
|
+
"Please install it using 'pip install matplotlib'."
|
|
333
|
+
)
|
|
334
|
+
return
|
|
335
|
+
|
|
336
|
+
# creat "fig_created_internally" as flag
|
|
337
|
+
fig_created_internally = False
|
|
338
|
+
|
|
339
|
+
if self.num_sites == 0:
|
|
340
|
+
logger.info("Lattice is empty, nothing to show.")
|
|
341
|
+
return
|
|
342
|
+
if self.dimensionality not in [1, 2, 3]:
|
|
343
|
+
logger.warning(
|
|
344
|
+
f"show() is not implemented for {self.dimensionality}D lattices."
|
|
345
|
+
)
|
|
346
|
+
return
|
|
347
|
+
|
|
348
|
+
if ax is None:
|
|
349
|
+
# when ax is none, make fig_created_internally true
|
|
350
|
+
fig_created_internally = True
|
|
351
|
+
if self.dimensionality == 3:
|
|
352
|
+
fig = plt.figure(figsize=(8, 8))
|
|
353
|
+
ax = fig.add_subplot(111, projection="3d")
|
|
354
|
+
else:
|
|
355
|
+
fig, ax = plt.subplots(figsize=(8, 8))
|
|
356
|
+
else:
|
|
357
|
+
fig = ax.figure # type: ignore
|
|
358
|
+
|
|
359
|
+
coords = np.array(self._coordinates)
|
|
360
|
+
scatter_args = {"s": 100, "zorder": 2}
|
|
361
|
+
scatter_args.update(kwargs)
|
|
362
|
+
if self.dimensionality == 1:
|
|
363
|
+
ax.scatter(coords[:, 0], np.zeros_like(coords[:, 0]), **scatter_args) # type: ignore
|
|
364
|
+
elif self.dimensionality == 2:
|
|
365
|
+
ax.scatter(coords[:, 0], coords[:, 1], **scatter_args) # type: ignore
|
|
366
|
+
elif self.dimensionality > 2: # Safely handle 3D and future higher dimensions
|
|
367
|
+
scatter_args["s"] = scatter_args.get("s", 100) // 2
|
|
368
|
+
ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2], **scatter_args) # type: ignore
|
|
369
|
+
|
|
370
|
+
if show_indices or show_identifiers:
|
|
371
|
+
for i in range(self.num_sites):
|
|
372
|
+
label = str(self._identifiers[i]) if show_identifiers else str(i)
|
|
373
|
+
offset = (
|
|
374
|
+
0.02 * np.max(np.ptp(coords, axis=0)) if coords.size > 0 else 0.1
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
# Robust Logic: Decide plotting strategy based on known dimensionality.
|
|
378
|
+
|
|
379
|
+
if self.dimensionality == 1:
|
|
380
|
+
ax.text(coords[i, 0], offset, label, fontsize=9, ha="center")
|
|
381
|
+
elif self.dimensionality == 2:
|
|
382
|
+
ax.text(
|
|
383
|
+
coords[i, 0] + offset,
|
|
384
|
+
coords[i, 1] + offset,
|
|
385
|
+
label,
|
|
386
|
+
fontsize=9,
|
|
387
|
+
zorder=3,
|
|
388
|
+
)
|
|
389
|
+
elif self.dimensionality == 3:
|
|
390
|
+
ax_3d = cast("Axes3D", ax)
|
|
391
|
+
ax_3d.text(
|
|
392
|
+
coords[i, 0],
|
|
393
|
+
coords[i, 1],
|
|
394
|
+
coords[i, 2] + offset,
|
|
395
|
+
label,
|
|
396
|
+
fontsize=9,
|
|
397
|
+
zorder=3,
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
# Note: No 'else' needed as we already check dimensionality at the start.
|
|
401
|
+
|
|
402
|
+
if show_bonds_k is not None:
|
|
403
|
+
if show_bonds_k not in self._neighbor_maps:
|
|
404
|
+
logger.warning(
|
|
405
|
+
f"Cannot draw bonds. k={show_bonds_k} neighbors have not been calculated."
|
|
406
|
+
)
|
|
407
|
+
else:
|
|
408
|
+
try:
|
|
409
|
+
bonds = self.get_neighbor_pairs(k=show_bonds_k, unique=True)
|
|
410
|
+
plot_bond_kwargs = {
|
|
411
|
+
"color": "k",
|
|
412
|
+
"linestyle": "-",
|
|
413
|
+
"alpha": 0.6,
|
|
414
|
+
"zorder": 1,
|
|
415
|
+
}
|
|
416
|
+
if bond_kwargs:
|
|
417
|
+
plot_bond_kwargs.update(bond_kwargs)
|
|
418
|
+
|
|
419
|
+
if self.dimensionality > 2:
|
|
420
|
+
ax_3d = cast("Axes3D", ax)
|
|
421
|
+
for i, j in bonds:
|
|
422
|
+
p1, p2 = self._coordinates[i], self._coordinates[j]
|
|
423
|
+
ax_3d.plot(
|
|
424
|
+
[p1[0], p2[0]],
|
|
425
|
+
[p1[1], p2[1]],
|
|
426
|
+
[p1[2], p2[2]],
|
|
427
|
+
**plot_bond_kwargs,
|
|
428
|
+
)
|
|
429
|
+
else:
|
|
430
|
+
for i, j in bonds:
|
|
431
|
+
p1, p2 = self._coordinates[i], self._coordinates[j]
|
|
432
|
+
if self.dimensionality == 1: # type: ignore
|
|
433
|
+
|
|
434
|
+
ax.plot([p1[0], p2[0]], [0, 0], **plot_bond_kwargs) # type: ignore
|
|
435
|
+
else: # dimensionality == 2
|
|
436
|
+
ax.plot([p1[0], p2[0]], [p1[1], p2[1]], **plot_bond_kwargs) # type: ignore
|
|
437
|
+
|
|
438
|
+
except ValueError as e:
|
|
439
|
+
logger.info(f"Could not draw bonds: {e}")
|
|
440
|
+
|
|
441
|
+
ax.set_title(f"{self.__class__.__name__} ({self.num_sites} sites)")
|
|
442
|
+
if self.dimensionality == 2:
|
|
443
|
+
ax.set_aspect("equal", adjustable="box")
|
|
444
|
+
ax.set_xlabel("x")
|
|
445
|
+
if self.dimensionality > 1:
|
|
446
|
+
ax.set_ylabel("y")
|
|
447
|
+
if self.dimensionality > 2 and hasattr(ax, "set_zlabel"):
|
|
448
|
+
ax.set_zlabel("z")
|
|
449
|
+
ax.grid(True)
|
|
450
|
+
|
|
451
|
+
# 3. whether plt.show()
|
|
452
|
+
if fig_created_internally:
|
|
453
|
+
plt.show()
|
|
454
|
+
|
|
455
|
+
def _identify_distance_shells(
|
|
456
|
+
self,
|
|
457
|
+
all_distances_sq: Union[Coordinates, List[float]],
|
|
458
|
+
max_k: int,
|
|
459
|
+
tol: float = 1e-6,
|
|
460
|
+
) -> List[float]:
|
|
461
|
+
"""Identifies unique distance shells from a list of squared distances.
|
|
462
|
+
|
|
463
|
+
This helper function takes a flat list of squared distances, sorts them,
|
|
464
|
+
and identifies the first `max_k` unique distance shells based on a
|
|
465
|
+
numerical tolerance.
|
|
466
|
+
|
|
467
|
+
:param all_distances_sq: A list or array
|
|
468
|
+
of all squared distances between pairs of sites.
|
|
469
|
+
:type all_distances_sq: Union[np.ndarray, List[float]]
|
|
470
|
+
:param max_k: The maximum number of neighbor shells to identify.
|
|
471
|
+
:type max_k: int
|
|
472
|
+
:param tol: The numerical tolerance to consider two distances equal.
|
|
473
|
+
:type tol: float
|
|
474
|
+
:return: A sorted list of squared distances representing the shells.
|
|
475
|
+
:rtype: List[float]
|
|
476
|
+
"""
|
|
477
|
+
ZERO_THRESHOLD_SQ = 1e-12
|
|
478
|
+
|
|
479
|
+
all_distances_sq = np.asarray(all_distances_sq)
|
|
480
|
+
# Now, the .size call below is guaranteed to be safe.
|
|
481
|
+
if all_distances_sq.size == 0:
|
|
482
|
+
return []
|
|
483
|
+
|
|
484
|
+
sorted_dist = np.sort(all_distances_sq[all_distances_sq > ZERO_THRESHOLD_SQ])
|
|
485
|
+
|
|
486
|
+
if sorted_dist.size == 0:
|
|
487
|
+
return []
|
|
488
|
+
|
|
489
|
+
# Identify shells using the user-provided tolerance.
|
|
490
|
+
dist_shells = [sorted_dist[0]]
|
|
491
|
+
|
|
492
|
+
for d_sq in sorted_dist[1:]:
|
|
493
|
+
if len(dist_shells) >= max_k:
|
|
494
|
+
break
|
|
495
|
+
# If the current distance is notably larger than the last shell's distance
|
|
496
|
+
if d_sq > dist_shells[-1] + tol**2:
|
|
497
|
+
dist_shells.append(d_sq)
|
|
498
|
+
|
|
499
|
+
return dist_shells
|
|
500
|
+
|
|
501
|
+
def _build_neighbors_by_distance_matrix(
|
|
502
|
+
self, max_k: int = 2, tol: float = 1e-6
|
|
503
|
+
) -> None:
|
|
504
|
+
"""A generic, distance-based neighbor finding method.
|
|
505
|
+
|
|
506
|
+
This method calculates the full N x N distance matrix to find neighbor
|
|
507
|
+
shells. It is computationally expensive for large N (O(N^2)) and is
|
|
508
|
+
best suited for non-periodic or custom-defined lattices.
|
|
509
|
+
|
|
510
|
+
:param max_k: The maximum number of neighbor shells to
|
|
511
|
+
calculate. Defaults to 2.
|
|
512
|
+
:type max_k: int, optional
|
|
513
|
+
:param tol: The numerical tolerance for distance
|
|
514
|
+
comparisons. Defaults to 1e-6.
|
|
515
|
+
:type tol: float, optional
|
|
516
|
+
"""
|
|
517
|
+
if self.num_sites < 2:
|
|
518
|
+
return
|
|
519
|
+
|
|
520
|
+
all_coords = np.array(self._coordinates)
|
|
521
|
+
dist_matrix_sq = np.sum(
|
|
522
|
+
(all_coords[:, np.newaxis, :] - all_coords[np.newaxis, :, :]) ** 2, axis=-1
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
all_distances_sq = dist_matrix_sq.flatten()
|
|
526
|
+
dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
|
|
527
|
+
|
|
528
|
+
self._neighbor_maps = {k: {} for k in range(1, len(dist_shells_sq) + 1)}
|
|
529
|
+
for k_idx, target_d_sq in enumerate(dist_shells_sq):
|
|
530
|
+
k = k_idx + 1
|
|
531
|
+
current_k_map: Dict[int, List[int]] = {}
|
|
532
|
+
for i in range(self.num_sites):
|
|
533
|
+
neighbor_indices = np.where(
|
|
534
|
+
np.isclose(dist_matrix_sq[i], target_d_sq, rtol=0, atol=tol**2)
|
|
535
|
+
)[0]
|
|
536
|
+
if len(neighbor_indices) > 0:
|
|
537
|
+
current_k_map[i] = sorted(neighbor_indices.tolist())
|
|
538
|
+
self._neighbor_maps[k] = current_k_map
|
|
539
|
+
self._distance_matrix = np.sqrt(dist_matrix_sq)
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
class TILattice(AbstractLattice):
|
|
543
|
+
"""Describes a periodic lattice with translational invariance.
|
|
544
|
+
|
|
545
|
+
This class serves as a base for any lattice defined by a repeating unit
|
|
546
|
+
cell. The geometry is specified by lattice vectors, the coordinates of
|
|
547
|
+
basis sites within a unit cell, and the total size of the lattice in
|
|
548
|
+
terms of unit cells.
|
|
549
|
+
|
|
550
|
+
The site identifier for this class is a tuple in the format of
|
|
551
|
+
`(uc_coord_1, ..., uc_coord_d, basis_index)`, where `uc_coord` represents
|
|
552
|
+
the integer coordinate of the unit cell and `basis_index` is the index
|
|
553
|
+
of the site within that unit cell's basis.
|
|
554
|
+
|
|
555
|
+
:param dimensionality: The spatial dimension of the lattice.
|
|
556
|
+
:type dimensionality: int
|
|
557
|
+
:param lattice_vectors: The lattice vectors defining the unit
|
|
558
|
+
cell, given as row vectors. Shape: (dimensionality, dimensionality).
|
|
559
|
+
For example, in 2D: `np.array([[ax, ay], [bx, by]])`.
|
|
560
|
+
:type lattice_vectors: np.ndarray
|
|
561
|
+
:param basis_coords: The Cartesian coordinates of the basis sites
|
|
562
|
+
within the unit cell. Shape: (num_basis_sites, dimensionality).
|
|
563
|
+
For a simple Bravais lattice, this would be `np.array([[0, 0]])`.
|
|
564
|
+
:type basis_coords: np.ndarray
|
|
565
|
+
:param size: A tuple specifying the number of unit cells
|
|
566
|
+
to generate in each lattice vector direction (e.g., (Nx, Ny)).
|
|
567
|
+
:type size: Tuple[int, ...]
|
|
568
|
+
:param pbc: Specifies whether
|
|
569
|
+
periodic boundary conditions are applied. Can be a single boolean
|
|
570
|
+
for all dimensions or a tuple of booleans for each dimension
|
|
571
|
+
individually. Defaults to True.
|
|
572
|
+
:type pbc: Union[bool, Tuple[bool, ...]], optional
|
|
573
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
574
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
575
|
+
:type precompute_neighbors: Optional[int], optional
|
|
576
|
+
|
|
577
|
+
"""
|
|
578
|
+
|
|
579
|
+
def __init__(
|
|
580
|
+
self,
|
|
581
|
+
dimensionality: int,
|
|
582
|
+
lattice_vectors: Coordinates,
|
|
583
|
+
basis_coords: Coordinates,
|
|
584
|
+
size: Tuple[int, ...],
|
|
585
|
+
pbc: Union[bool, Tuple[bool, ...]] = True,
|
|
586
|
+
precompute_neighbors: Optional[int] = None,
|
|
587
|
+
):
|
|
588
|
+
"""Initializes the Translationally Invariant Lattice."""
|
|
589
|
+
super().__init__(dimensionality)
|
|
590
|
+
assert lattice_vectors.shape == (
|
|
591
|
+
dimensionality,
|
|
592
|
+
dimensionality,
|
|
593
|
+
), "Lattice vectors shape mismatch"
|
|
594
|
+
assert (
|
|
595
|
+
basis_coords.shape[1] == dimensionality
|
|
596
|
+
), "Basis coordinates dimension mismatch"
|
|
597
|
+
assert len(size) == dimensionality, "Size tuple length mismatch"
|
|
598
|
+
|
|
599
|
+
self.lattice_vectors = lattice_vectors
|
|
600
|
+
self.basis_coords = basis_coords
|
|
601
|
+
self.num_basis = basis_coords.shape[0]
|
|
602
|
+
self.size = size
|
|
603
|
+
if isinstance(pbc, bool):
|
|
604
|
+
self.pbc = tuple([pbc] * dimensionality)
|
|
605
|
+
else:
|
|
606
|
+
assert len(pbc) == dimensionality, "PBC tuple length mismatch"
|
|
607
|
+
self.pbc = tuple(pbc)
|
|
608
|
+
|
|
609
|
+
# Build the lattice sites and their neighbor relationships
|
|
610
|
+
self._build_lattice()
|
|
611
|
+
if precompute_neighbors is not None and precompute_neighbors > 0:
|
|
612
|
+
logger.info(f"Pre-computing neighbors up to k={precompute_neighbors}...")
|
|
613
|
+
self._build_neighbors(max_k=precompute_neighbors)
|
|
614
|
+
|
|
615
|
+
def _build_lattice(self) -> None:
|
|
616
|
+
"""Generates all site information for the periodic lattice.
|
|
617
|
+
|
|
618
|
+
This method iterates through each unit cell defined by `self.size`,
|
|
619
|
+
and for each unit cell, it iterates through all basis sites. It then
|
|
620
|
+
calculates the real-space coordinates and creates a unique identifier
|
|
621
|
+
for each site, populating the internal lattice data structures.
|
|
622
|
+
"""
|
|
623
|
+
current_index = 0
|
|
624
|
+
|
|
625
|
+
# Iterate over all unit cell coordinates elegantly using np.ndindex
|
|
626
|
+
for cell_coord in np.ndindex(self.size):
|
|
627
|
+
cell_coord_arr = np.array(cell_coord)
|
|
628
|
+
# R = n1*a1 + n2*a2 + ...
|
|
629
|
+
cell_vector = np.dot(cell_coord_arr, self.lattice_vectors)
|
|
630
|
+
|
|
631
|
+
# Iterate over the basis sites within the unit cell
|
|
632
|
+
for basis_index in range(self.num_basis):
|
|
633
|
+
basis_vec = self.basis_coords[basis_index]
|
|
634
|
+
|
|
635
|
+
# Calculate the real-space coordinate
|
|
636
|
+
coord = cell_vector + basis_vec
|
|
637
|
+
# Create a structured identifier
|
|
638
|
+
identifier = cell_coord + (basis_index,)
|
|
639
|
+
|
|
640
|
+
# Store site information
|
|
641
|
+
self._indices.append(current_index)
|
|
642
|
+
self._identifiers.append(identifier)
|
|
643
|
+
self._coordinates.append(coord)
|
|
644
|
+
self._ident_to_idx[identifier] = current_index
|
|
645
|
+
current_index += 1
|
|
646
|
+
|
|
647
|
+
def _get_distance_matrix_with_mic(self) -> Coordinates:
|
|
648
|
+
"""
|
|
649
|
+
Computes the full N x N distance matrix, correctly applying the
|
|
650
|
+
Minimum Image Convention (MIC) for all periodic dimensions.
|
|
651
|
+
"""
|
|
652
|
+
all_coords = np.array(self._coordinates)
|
|
653
|
+
size_arr = np.array(self.size)
|
|
654
|
+
system_vectors = self.lattice_vectors * size_arr[:, np.newaxis]
|
|
655
|
+
|
|
656
|
+
# Generate translation vectors ONLY for periodic dimensions
|
|
657
|
+
pbc_dims = [d for d in range(self.dimensionality) if self.pbc[d]]
|
|
658
|
+
translations = [np.zeros(self.dimensionality)]
|
|
659
|
+
if pbc_dims:
|
|
660
|
+
num_pbc_dims = len(pbc_dims)
|
|
661
|
+
pbc_system_vectors = system_vectors[pbc_dims, :]
|
|
662
|
+
|
|
663
|
+
# Create all 3^k - 1 non-zero shifts for k periodic dimensions
|
|
664
|
+
shift_options = [np.array([-1, 0, 1])] * num_pbc_dims
|
|
665
|
+
shifts_grid = np.meshgrid(*shift_options, indexing="ij")
|
|
666
|
+
all_shifts = np.stack(shifts_grid, axis=-1).reshape(-1, num_pbc_dims)
|
|
667
|
+
all_shifts = all_shifts[np.any(all_shifts != 0, axis=1)]
|
|
668
|
+
|
|
669
|
+
pbc_translations = all_shifts @ pbc_system_vectors
|
|
670
|
+
translations.extend(pbc_translations)
|
|
671
|
+
|
|
672
|
+
translations_arr = np.array(translations, dtype=float)
|
|
673
|
+
|
|
674
|
+
# Calculate the distance matrix applying MIC
|
|
675
|
+
dist_matrix_sq = np.full((self.num_sites, self.num_sites), np.inf, dtype=float)
|
|
676
|
+
for i in range(self.num_sites):
|
|
677
|
+
displacements = all_coords - all_coords[i]
|
|
678
|
+
image_displacements = (
|
|
679
|
+
displacements[:, np.newaxis, :] - translations_arr[np.newaxis, :, :]
|
|
680
|
+
)
|
|
681
|
+
image_d_sq = np.sum(image_displacements**2, axis=2)
|
|
682
|
+
dist_matrix_sq[i, :] = np.min(image_d_sq, axis=1)
|
|
683
|
+
|
|
684
|
+
return cast(Coordinates, np.sqrt(dist_matrix_sq))
|
|
685
|
+
|
|
686
|
+
def _build_neighbors(self, max_k: int = 2, **kwargs: Any) -> None:
|
|
687
|
+
"""Calculates neighbor relationships for the periodic lattice.
|
|
688
|
+
|
|
689
|
+
This method calculates neighbor relationships by computing the full N x N
|
|
690
|
+
distance matrix. It robustly handles all boundary conditions (fully
|
|
691
|
+
periodic, open, or mixed) by applying the Minimum Image Convention
|
|
692
|
+
(MIC) only to the periodic dimensions.
|
|
693
|
+
|
|
694
|
+
From this distance matrix, it identifies unique neighbor shells up to
|
|
695
|
+
the specified `max_k` and populates the neighbor maps. The computed
|
|
696
|
+
distance matrix is then cached for future use.
|
|
697
|
+
|
|
698
|
+
:param max_k: The maximum number of neighbor shells to
|
|
699
|
+
calculate. Defaults to 2.
|
|
700
|
+
:type max_k: int, optional
|
|
701
|
+
:param tol: The numerical tolerance for distance
|
|
702
|
+
comparisons. Defaults to 1e-6.
|
|
703
|
+
:type tol: float, optional
|
|
704
|
+
"""
|
|
705
|
+
tol = kwargs.get("tol", 1e-6)
|
|
706
|
+
dist_matrix = self._get_distance_matrix_with_mic()
|
|
707
|
+
dist_matrix_sq = dist_matrix**2
|
|
708
|
+
self._distance_matrix = dist_matrix
|
|
709
|
+
all_distances_sq = dist_matrix_sq.flatten()
|
|
710
|
+
dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
|
|
711
|
+
|
|
712
|
+
self._neighbor_maps = {k: {} for k in range(1, len(dist_shells_sq) + 1)}
|
|
713
|
+
for k_idx, target_d_sq in enumerate(dist_shells_sq):
|
|
714
|
+
k = k_idx + 1
|
|
715
|
+
current_k_map: Dict[int, List[int]] = {}
|
|
716
|
+
match_indices = np.where(
|
|
717
|
+
np.isclose(dist_matrix_sq, target_d_sq, rtol=0, atol=tol**2)
|
|
718
|
+
)
|
|
719
|
+
for i, j in zip(*match_indices):
|
|
720
|
+
if i == j:
|
|
721
|
+
continue
|
|
722
|
+
if i not in current_k_map:
|
|
723
|
+
current_k_map[i] = []
|
|
724
|
+
current_k_map[i].append(j)
|
|
725
|
+
|
|
726
|
+
for i in current_k_map:
|
|
727
|
+
current_k_map[i].sort()
|
|
728
|
+
|
|
729
|
+
self._neighbor_maps[k] = current_k_map
|
|
730
|
+
|
|
731
|
+
def _compute_distance_matrix(self) -> Coordinates:
|
|
732
|
+
"""Computes the distance matrix using the Minimum Image Convention."""
|
|
733
|
+
return self._get_distance_matrix_with_mic()
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
class SquareLattice(TILattice):
|
|
737
|
+
"""A 2D square lattice.
|
|
738
|
+
|
|
739
|
+
This is a concrete implementation of a translationally invariant lattice
|
|
740
|
+
representing a simple square grid. It is a Bravais lattice with a
|
|
741
|
+
single-site basis.
|
|
742
|
+
|
|
743
|
+
:param size: A tuple (Nx, Ny) specifying the number of
|
|
744
|
+
unit cells (sites) in the x and y directions.
|
|
745
|
+
:type size: Tuple[int, int]
|
|
746
|
+
:param lattice_constant: The distance between two adjacent
|
|
747
|
+
sites. Defaults to 1.0.
|
|
748
|
+
:type lattice_constant: float, optional
|
|
749
|
+
:param pbc: Specifies periodic boundary conditions. Can be a single boolean
|
|
750
|
+
for all dimensions or a tuple of booleans for each dimension
|
|
751
|
+
individually. Defaults to True.
|
|
752
|
+
:type pbc: Union[bool, Tuple[bool, bool]], optional
|
|
753
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
754
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
755
|
+
:type precompute_neighbors: Optional[int], optional
|
|
756
|
+
"""
|
|
757
|
+
|
|
758
|
+
def __init__(
|
|
759
|
+
self,
|
|
760
|
+
size: Tuple[int, int],
|
|
761
|
+
lattice_constant: float = 1.0,
|
|
762
|
+
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
763
|
+
precompute_neighbors: Optional[int] = None,
|
|
764
|
+
):
|
|
765
|
+
"""Initializes the SquareLattice."""
|
|
766
|
+
dimensionality = 2
|
|
767
|
+
|
|
768
|
+
# Define lattice vectors for a square lattice
|
|
769
|
+
lattice_vectors = np.array([[lattice_constant, 0.0], [0.0, lattice_constant]])
|
|
770
|
+
|
|
771
|
+
# A square lattice has a single site in its basis
|
|
772
|
+
basis_coords = np.array([[0.0, 0.0]])
|
|
773
|
+
|
|
774
|
+
# Call the parent TILattice constructor with these parameters
|
|
775
|
+
super().__init__(
|
|
776
|
+
dimensionality=dimensionality,
|
|
777
|
+
lattice_vectors=lattice_vectors,
|
|
778
|
+
basis_coords=basis_coords,
|
|
779
|
+
size=size,
|
|
780
|
+
pbc=pbc,
|
|
781
|
+
precompute_neighbors=precompute_neighbors,
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
class HoneycombLattice(TILattice):
|
|
786
|
+
"""A 2D honeycomb lattice.
|
|
787
|
+
|
|
788
|
+
This is a classic example of a composite lattice. It consists of a
|
|
789
|
+
two-site basis (sublattices A and B) on an underlying triangular
|
|
790
|
+
Bravais lattice.
|
|
791
|
+
|
|
792
|
+
:param size: A tuple (Nx, Ny) specifying the number of unit
|
|
793
|
+
cells along the two lattice vector directions.
|
|
794
|
+
:type size: Tuple[int, int]
|
|
795
|
+
:param lattice_constant: The bond length, i.e., the distance
|
|
796
|
+
between two nearest neighbor sites. Defaults to 1.0.
|
|
797
|
+
:type lattice_constant: float, optional
|
|
798
|
+
:param pbc: Specifies periodic
|
|
799
|
+
boundary conditions. Defaults to True.
|
|
800
|
+
:type pbc: Union[bool, Tuple[bool, bool]], optional
|
|
801
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
802
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
803
|
+
:type precompute_neighbors: Optional[int], optional
|
|
804
|
+
|
|
805
|
+
"""
|
|
806
|
+
|
|
807
|
+
def __init__(
|
|
808
|
+
self,
|
|
809
|
+
size: Tuple[int, int],
|
|
810
|
+
lattice_constant: float = 1.0,
|
|
811
|
+
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
812
|
+
precompute_neighbors: Optional[int] = None,
|
|
813
|
+
):
|
|
814
|
+
"""Initializes the HoneycombLattice."""
|
|
815
|
+
dimensionality = 2
|
|
816
|
+
a = lattice_constant
|
|
817
|
+
|
|
818
|
+
# Define the primitive lattice vectors for the underlying triangular lattice
|
|
819
|
+
lattice_vectors = a * np.array([[1.5, np.sqrt(3) / 2], [1.5, -np.sqrt(3) / 2]])
|
|
820
|
+
|
|
821
|
+
# Define the coordinates of the two basis sites (A and B)
|
|
822
|
+
basis_coords = a * np.array([[0.0, 0.0], [1.0, 0.0]]) # Site A # Site B
|
|
823
|
+
|
|
824
|
+
super().__init__(
|
|
825
|
+
dimensionality=dimensionality,
|
|
826
|
+
lattice_vectors=lattice_vectors,
|
|
827
|
+
basis_coords=basis_coords,
|
|
828
|
+
size=size,
|
|
829
|
+
pbc=pbc,
|
|
830
|
+
precompute_neighbors=precompute_neighbors,
|
|
831
|
+
)
|
|
832
|
+
|
|
833
|
+
|
|
834
|
+
class TriangularLattice(TILattice):
|
|
835
|
+
"""A 2D triangular lattice.
|
|
836
|
+
|
|
837
|
+
This is a Bravais lattice where each site has 6 nearest neighbors.
|
|
838
|
+
|
|
839
|
+
:param size: A tuple (Nx, Ny) specifying the number of
|
|
840
|
+
unit cells along the two lattice vector directions.
|
|
841
|
+
:type size: Tuple[int, int]
|
|
842
|
+
:param lattice_constant: The bond length, i.e., the
|
|
843
|
+
distance between two nearest neighbor sites. Defaults to 1.0.
|
|
844
|
+
:type lattice_constant: float, optional
|
|
845
|
+
:param pbc: Specifies periodic
|
|
846
|
+
boundary conditions. Defaults to True.
|
|
847
|
+
:type pbc: Union[bool, Tuple[bool, bool]], optional
|
|
848
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
849
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
850
|
+
:type precompute_neighbors: Optional[int], optional
|
|
851
|
+
|
|
852
|
+
"""
|
|
853
|
+
|
|
854
|
+
def __init__(
|
|
855
|
+
self,
|
|
856
|
+
size: Tuple[int, int],
|
|
857
|
+
lattice_constant: float = 1.0,
|
|
858
|
+
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
859
|
+
precompute_neighbors: Optional[int] = None,
|
|
860
|
+
):
|
|
861
|
+
"""Initializes the TriangularLattice."""
|
|
862
|
+
dimensionality = 2
|
|
863
|
+
a = lattice_constant
|
|
864
|
+
|
|
865
|
+
# Define the primitive lattice vectors for a triangular lattice
|
|
866
|
+
lattice_vectors = a * np.array([[1.0, 0.0], [0.5, np.sqrt(3) / 2]])
|
|
867
|
+
|
|
868
|
+
# A triangular lattice is a Bravais lattice, with a single site in its basis
|
|
869
|
+
basis_coords = np.array([[0.0, 0.0]])
|
|
870
|
+
|
|
871
|
+
super().__init__(
|
|
872
|
+
dimensionality=dimensionality,
|
|
873
|
+
lattice_vectors=lattice_vectors,
|
|
874
|
+
basis_coords=basis_coords,
|
|
875
|
+
size=size,
|
|
876
|
+
pbc=pbc,
|
|
877
|
+
precompute_neighbors=precompute_neighbors,
|
|
878
|
+
)
|
|
879
|
+
|
|
880
|
+
|
|
881
|
+
class ChainLattice(TILattice):
|
|
882
|
+
"""A 1D chain (simple Bravais lattice).
|
|
883
|
+
|
|
884
|
+
:param size: A tuple `(N,)` specifying the number of sites in the chain.
|
|
885
|
+
:type size: Tuple[int]
|
|
886
|
+
:param lattice_constant: The distance between two adjacent sites. Defaults to 1.0.
|
|
887
|
+
:type lattice_constant: float, optional
|
|
888
|
+
:param pbc: Specifies if periodic boundary conditions are applied. Defaults to True.
|
|
889
|
+
:type pbc: bool, optional
|
|
890
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
891
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
892
|
+
:type precompute_neighbors: Optional[int], optional
|
|
893
|
+
"""
|
|
894
|
+
|
|
895
|
+
def __init__(
|
|
896
|
+
self,
|
|
897
|
+
size: Tuple[int],
|
|
898
|
+
lattice_constant: float = 1.0,
|
|
899
|
+
pbc: bool = True,
|
|
900
|
+
precompute_neighbors: Optional[int] = None,
|
|
901
|
+
):
|
|
902
|
+
dimensionality = 1
|
|
903
|
+
lattice_vectors = np.array([[lattice_constant]])
|
|
904
|
+
basis_coords = np.array([[0.0]])
|
|
905
|
+
super().__init__(
|
|
906
|
+
dimensionality=dimensionality,
|
|
907
|
+
lattice_vectors=lattice_vectors,
|
|
908
|
+
basis_coords=basis_coords,
|
|
909
|
+
size=size,
|
|
910
|
+
pbc=pbc,
|
|
911
|
+
precompute_neighbors=precompute_neighbors,
|
|
912
|
+
)
|
|
913
|
+
|
|
914
|
+
|
|
915
|
+
class DimerizedChainLattice(TILattice):
|
|
916
|
+
"""A 1D chain with an AB sublattice (dimerized chain).
|
|
917
|
+
|
|
918
|
+
The unit cell contains two sites, A and B. The bond length is uniform.
|
|
919
|
+
|
|
920
|
+
:param size: A tuple `(N,)` specifying the number of **unit cells**.
|
|
921
|
+
The total number of sites in the chain will be `2 * N`, as each
|
|
922
|
+
unit cell contains two sites.
|
|
923
|
+
:type size: Tuple[int]
|
|
924
|
+
:param lattice_constant: The distance between two adjacent sites (bond length). Defaults to 1.0.
|
|
925
|
+
:type lattice_constant: float, optional
|
|
926
|
+
:param pbc: Specifies if periodic boundary conditions are applied. Defaults to True.
|
|
927
|
+
:type pbc: bool, optional
|
|
928
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
929
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
930
|
+
:type precompute_neighbors: Optional[int], optional
|
|
931
|
+
"""
|
|
932
|
+
|
|
933
|
+
def __init__(
|
|
934
|
+
self,
|
|
935
|
+
size: Tuple[int],
|
|
936
|
+
lattice_constant: float = 1.0,
|
|
937
|
+
pbc: bool = True,
|
|
938
|
+
precompute_neighbors: Optional[int] = None,
|
|
939
|
+
):
|
|
940
|
+
dimensionality = 1
|
|
941
|
+
# The unit cell vector connects two A sites, spanning length 2*a
|
|
942
|
+
lattice_vectors = np.array([[2 * lattice_constant]])
|
|
943
|
+
# Basis has site A at origin, site B at distance 'a'
|
|
944
|
+
basis_coords = np.array([[0.0], [lattice_constant]])
|
|
945
|
+
|
|
946
|
+
super().__init__(
|
|
947
|
+
dimensionality=dimensionality,
|
|
948
|
+
lattice_vectors=lattice_vectors,
|
|
949
|
+
basis_coords=basis_coords,
|
|
950
|
+
size=size,
|
|
951
|
+
pbc=pbc,
|
|
952
|
+
precompute_neighbors=precompute_neighbors,
|
|
953
|
+
)
|
|
954
|
+
|
|
955
|
+
|
|
956
|
+
class RectangularLattice(TILattice):
|
|
957
|
+
"""A 2D rectangular lattice.
|
|
958
|
+
|
|
959
|
+
This is a generalization of the SquareLattice where the lattice constants
|
|
960
|
+
in the x and y directions can be different.
|
|
961
|
+
|
|
962
|
+
:param size: A tuple (Nx, Ny) specifying the number of sites in x and y.
|
|
963
|
+
:type size: Tuple[int, int]
|
|
964
|
+
:param lattice_constants: The distance between adjacent sites
|
|
965
|
+
in the x and y directions, e.g., (ax, ay). Defaults to (1.0, 1.0).
|
|
966
|
+
:type lattice_constants: Tuple[float, float], optional
|
|
967
|
+
:param pbc: Specifies periodic boundary conditions. Defaults to True.
|
|
968
|
+
:type pbc: Union[bool, Tuple[bool, bool]], optional
|
|
969
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
970
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
971
|
+
:type precompute_neighbors: Optional[int], optional
|
|
972
|
+
"""
|
|
973
|
+
|
|
974
|
+
def __init__(
|
|
975
|
+
self,
|
|
976
|
+
size: Tuple[int, int],
|
|
977
|
+
lattice_constants: Tuple[float, float] = (1.0, 1.0),
|
|
978
|
+
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
979
|
+
precompute_neighbors: Optional[int] = None,
|
|
980
|
+
):
|
|
981
|
+
dimensionality = 2
|
|
982
|
+
ax, ay = lattice_constants
|
|
983
|
+
lattice_vectors = np.array([[ax, 0.0], [0.0, ay]])
|
|
984
|
+
basis_coords = np.array([[0.0, 0.0]])
|
|
985
|
+
|
|
986
|
+
super().__init__(
|
|
987
|
+
dimensionality=dimensionality,
|
|
988
|
+
lattice_vectors=lattice_vectors,
|
|
989
|
+
basis_coords=basis_coords,
|
|
990
|
+
size=size,
|
|
991
|
+
pbc=pbc,
|
|
992
|
+
precompute_neighbors=precompute_neighbors,
|
|
993
|
+
)
|
|
994
|
+
|
|
995
|
+
|
|
996
|
+
class CheckerboardLattice(TILattice):
|
|
997
|
+
"""A 2D checkerboard lattice (a square lattice with an AB sublattice).
|
|
998
|
+
|
|
999
|
+
The unit cell is a square rotated by 45 degrees, containing two sites.
|
|
1000
|
+
|
|
1001
|
+
:param size: A tuple (Nx, Ny) specifying the number of unit cells. Total sites will be 2*Nx*Ny.
|
|
1002
|
+
:type size: Tuple[int, int]
|
|
1003
|
+
:param lattice_constant: The bond length between nearest neighbors. Defaults to 1.0.
|
|
1004
|
+
:type lattice_constant: float, optional
|
|
1005
|
+
:param pbc: Specifies periodic boundary conditions. Defaults to True.
|
|
1006
|
+
:type pbc: Union[bool, Tuple[bool, bool]], optional
|
|
1007
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
1008
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
1009
|
+
:type precompute_neighbors: Optional[int], optional
|
|
1010
|
+
"""
|
|
1011
|
+
|
|
1012
|
+
def __init__(
|
|
1013
|
+
self,
|
|
1014
|
+
size: Tuple[int, int],
|
|
1015
|
+
lattice_constant: float = 1.0,
|
|
1016
|
+
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
1017
|
+
precompute_neighbors: Optional[int] = None,
|
|
1018
|
+
):
|
|
1019
|
+
dimensionality = 2
|
|
1020
|
+
a = lattice_constant
|
|
1021
|
+
# Primitive vectors for a square lattice rotated by 45 degrees.
|
|
1022
|
+
lattice_vectors = a * np.array([[1.0, 1.0], [1.0, -1.0]])
|
|
1023
|
+
# Two-site basis
|
|
1024
|
+
basis_coords = a * np.array([[0.0, 0.0], [1.0, 0.0]])
|
|
1025
|
+
super().__init__(
|
|
1026
|
+
dimensionality=dimensionality,
|
|
1027
|
+
lattice_vectors=lattice_vectors,
|
|
1028
|
+
basis_coords=basis_coords,
|
|
1029
|
+
size=size,
|
|
1030
|
+
pbc=pbc,
|
|
1031
|
+
precompute_neighbors=precompute_neighbors,
|
|
1032
|
+
)
|
|
1033
|
+
|
|
1034
|
+
|
|
1035
|
+
class KagomeLattice(TILattice):
|
|
1036
|
+
"""A 2D Kagome lattice.
|
|
1037
|
+
|
|
1038
|
+
This is a lattice with a three-site basis on a triangular Bravais lattice.
|
|
1039
|
+
|
|
1040
|
+
:param size: A tuple (Nx, Ny) specifying the number of unit cells. Total sites will be 3*Nx*Ny.
|
|
1041
|
+
:type size: Tuple[int, int]
|
|
1042
|
+
:param lattice_constant: The bond length. Defaults to 1.0.
|
|
1043
|
+
:type lattice_constant: float, optional
|
|
1044
|
+
:param pbc: Specifies periodic boundary conditions. Defaults to True.
|
|
1045
|
+
:type pbc: Union[bool, Tuple[bool, bool]], optional
|
|
1046
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
1047
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
1048
|
+
:type precompute_neighbors: Optional[int], optional
|
|
1049
|
+
"""
|
|
1050
|
+
|
|
1051
|
+
def __init__(
|
|
1052
|
+
self,
|
|
1053
|
+
size: Tuple[int, int],
|
|
1054
|
+
lattice_constant: float = 1.0,
|
|
1055
|
+
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
1056
|
+
precompute_neighbors: Optional[int] = None,
|
|
1057
|
+
):
|
|
1058
|
+
dimensionality = 2
|
|
1059
|
+
a = lattice_constant
|
|
1060
|
+
# Using a rectangular unit cell definition for simplicity
|
|
1061
|
+
lattice_vectors = a * np.array([[2.0, 0.0], [1.0, np.sqrt(3)]])
|
|
1062
|
+
# Three-site basis
|
|
1063
|
+
basis_coords = a * np.array([[0.0, 0.0], [1.0, 0.0], [0.5, np.sqrt(3) / 2.0]])
|
|
1064
|
+
super().__init__(
|
|
1065
|
+
dimensionality=dimensionality,
|
|
1066
|
+
lattice_vectors=lattice_vectors,
|
|
1067
|
+
basis_coords=basis_coords,
|
|
1068
|
+
size=size,
|
|
1069
|
+
pbc=pbc,
|
|
1070
|
+
precompute_neighbors=precompute_neighbors,
|
|
1071
|
+
)
|
|
1072
|
+
|
|
1073
|
+
|
|
1074
|
+
class LiebLattice(TILattice):
|
|
1075
|
+
"""A 2D Lieb lattice.
|
|
1076
|
+
|
|
1077
|
+
This is a lattice with a three-site basis on a square Bravais lattice.
|
|
1078
|
+
It has sites at the corners and centers of the edges of a square.
|
|
1079
|
+
|
|
1080
|
+
:param size: A tuple (Nx, Ny) specifying the number of unit cells. Total sites will be 3*Nx*Ny.
|
|
1081
|
+
:type size: Tuple[int, int]
|
|
1082
|
+
:param lattice_constant: The bond length. Defaults to 1.0.
|
|
1083
|
+
:type lattice_constant: float, optional
|
|
1084
|
+
:param pbc: Specifies periodic boundary conditions. Defaults to True.
|
|
1085
|
+
:type pbc: Union[bool, Tuple[bool, bool]], optional
|
|
1086
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
1087
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
1088
|
+
:type precompute_neighbors: Optional[int], optional
|
|
1089
|
+
"""
|
|
1090
|
+
|
|
1091
|
+
def __init__(
|
|
1092
|
+
self,
|
|
1093
|
+
size: Tuple[int, int],
|
|
1094
|
+
lattice_constant: float = 1.0,
|
|
1095
|
+
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
1096
|
+
precompute_neighbors: Optional[int] = None,
|
|
1097
|
+
):
|
|
1098
|
+
"""Initializes the LiebLattice."""
|
|
1099
|
+
dimensionality = 2
|
|
1100
|
+
# Use a more descriptive name for clarity. In a Lieb lattice,
|
|
1101
|
+
# the lattice_constant is the bond length between nearest neighbors.
|
|
1102
|
+
bond_length = lattice_constant
|
|
1103
|
+
|
|
1104
|
+
# The unit cell of a Lieb lattice is a square with side length
|
|
1105
|
+
# equal to twice the bond length.
|
|
1106
|
+
unit_cell_side = 2 * bond_length
|
|
1107
|
+
lattice_vectors = np.array([[unit_cell_side, 0.0], [0.0, unit_cell_side]])
|
|
1108
|
+
|
|
1109
|
+
# The three-site basis consists of a corner site, a site on the
|
|
1110
|
+
# center of the horizontal edge, and a site on the center of the vertical edge.
|
|
1111
|
+
# Their coordinates are defined directly in terms of the physical bond length.
|
|
1112
|
+
basis_coords = np.array(
|
|
1113
|
+
[
|
|
1114
|
+
[0.0, 0.0], # Corner site
|
|
1115
|
+
[bond_length, 0.0], # Horizontal edge center
|
|
1116
|
+
[0.0, bond_length], # Vertical edge center
|
|
1117
|
+
]
|
|
1118
|
+
)
|
|
1119
|
+
|
|
1120
|
+
super().__init__(
|
|
1121
|
+
dimensionality=dimensionality,
|
|
1122
|
+
lattice_vectors=lattice_vectors,
|
|
1123
|
+
basis_coords=basis_coords,
|
|
1124
|
+
size=size,
|
|
1125
|
+
pbc=pbc,
|
|
1126
|
+
precompute_neighbors=precompute_neighbors,
|
|
1127
|
+
)
|
|
1128
|
+
|
|
1129
|
+
|
|
1130
|
+
class CubicLattice(TILattice):
|
|
1131
|
+
"""A 3D cubic lattice.
|
|
1132
|
+
|
|
1133
|
+
This is a simple Bravais lattice, the 3D generalization of SquareLattice.
|
|
1134
|
+
|
|
1135
|
+
:param size: A tuple (Nx, Ny, Nz) specifying the number of sites.
|
|
1136
|
+
:type size: Tuple[int, int, int]
|
|
1137
|
+
:param lattice_constant: The distance between adjacent sites. Defaults to 1.0.
|
|
1138
|
+
:type lattice_constant: float, optional
|
|
1139
|
+
:param pbc: Specifies periodic boundary conditions. Defaults to True.
|
|
1140
|
+
:type pbc: Union[bool, Tuple[bool, bool, bool]], optional
|
|
1141
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
1142
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
1143
|
+
:type precompute_neighbors: Optional[int], optional
|
|
1144
|
+
"""
|
|
1145
|
+
|
|
1146
|
+
def __init__(
|
|
1147
|
+
self,
|
|
1148
|
+
size: Tuple[int, int, int],
|
|
1149
|
+
lattice_constant: float = 1.0,
|
|
1150
|
+
pbc: Union[bool, Tuple[bool, bool, bool]] = True,
|
|
1151
|
+
precompute_neighbors: Optional[int] = None,
|
|
1152
|
+
):
|
|
1153
|
+
dimensionality = 3
|
|
1154
|
+
a = lattice_constant
|
|
1155
|
+
lattice_vectors = np.array([[a, 0, 0], [0, a, 0], [0, 0, a]])
|
|
1156
|
+
basis_coords = np.array([[0.0, 0.0, 0.0]])
|
|
1157
|
+
super().__init__(
|
|
1158
|
+
dimensionality=dimensionality,
|
|
1159
|
+
lattice_vectors=lattice_vectors,
|
|
1160
|
+
basis_coords=basis_coords,
|
|
1161
|
+
size=size,
|
|
1162
|
+
pbc=pbc,
|
|
1163
|
+
precompute_neighbors=precompute_neighbors,
|
|
1164
|
+
)
|
|
1165
|
+
|
|
1166
|
+
|
|
1167
|
+
class CustomizeLattice(AbstractLattice):
|
|
1168
|
+
"""A general lattice built from an explicit list of sites and coordinates.
|
|
1169
|
+
|
|
1170
|
+
This class is suitable for creating lattices with arbitrary geometries,
|
|
1171
|
+
such as finite clusters, disordered systems, or any custom structure
|
|
1172
|
+
that does not have translational symmetry. The lattice is defined simply
|
|
1173
|
+
by providing lists of identifiers and coordinates for each site.
|
|
1174
|
+
|
|
1175
|
+
:param dimensionality: The spatial dimension of the lattice.
|
|
1176
|
+
:type dimensionality: int
|
|
1177
|
+
:param identifiers: A list of unique, hashable
|
|
1178
|
+
identifiers for the sites. The length must match `coordinates`.
|
|
1179
|
+
:type identifiers: List[SiteIdentifier]
|
|
1180
|
+
:param coordinates: A list of site
|
|
1181
|
+
coordinates. Each coordinate should be a list of floats or a
|
|
1182
|
+
NumPy array.
|
|
1183
|
+
:type coordinates: List[Union[List[float], Coordinates]]
|
|
1184
|
+
:raises ValueError: If the lengths of `identifiers` and `coordinates` lists
|
|
1185
|
+
do not match, or if a coordinate's dimension is incorrect.
|
|
1186
|
+
:param precompute_neighbors: If specified, pre-computes neighbor relationships
|
|
1187
|
+
up to the given order `k` upon initialization. Defaults to None.
|
|
1188
|
+
:type precompute_neighbors: Optional[int], optional
|
|
1189
|
+
|
|
1190
|
+
"""
|
|
1191
|
+
|
|
1192
|
+
def __init__(
|
|
1193
|
+
self,
|
|
1194
|
+
dimensionality: int,
|
|
1195
|
+
identifiers: List[SiteIdentifier],
|
|
1196
|
+
coordinates: List[Union[List[float], Coordinates]],
|
|
1197
|
+
precompute_neighbors: Optional[int] = None,
|
|
1198
|
+
):
|
|
1199
|
+
"""Initializes the CustomizeLattice."""
|
|
1200
|
+
super().__init__(dimensionality)
|
|
1201
|
+
if len(identifiers) != len(coordinates):
|
|
1202
|
+
raise ValueError(
|
|
1203
|
+
"Identifiers and coordinates lists must have the same length."
|
|
1204
|
+
)
|
|
1205
|
+
|
|
1206
|
+
# The _build_lattice logic is simple enough to be in __init__
|
|
1207
|
+
self._identifiers = list(identifiers)
|
|
1208
|
+
self._coordinates = [np.array(c) for c in coordinates]
|
|
1209
|
+
self._indices = list(range(len(identifiers)))
|
|
1210
|
+
self._ident_to_idx = {ident: idx for idx, ident in enumerate(identifiers)}
|
|
1211
|
+
|
|
1212
|
+
# Validate coordinate dimensions
|
|
1213
|
+
for i, coord in enumerate(self._coordinates):
|
|
1214
|
+
if coord.shape != (dimensionality,):
|
|
1215
|
+
raise ValueError(
|
|
1216
|
+
f"Coordinate at index {i} has shape {coord.shape}, "
|
|
1217
|
+
f"expected ({dimensionality},)"
|
|
1218
|
+
)
|
|
1219
|
+
|
|
1220
|
+
logger.info(f"CustomizeLattice with {self.num_sites} sites created.")
|
|
1221
|
+
|
|
1222
|
+
if precompute_neighbors is not None and precompute_neighbors > 0:
|
|
1223
|
+
self._build_neighbors(max_k=precompute_neighbors)
|
|
1224
|
+
|
|
1225
|
+
def _build_lattice(self, *args: Any, **kwargs: Any) -> None:
|
|
1226
|
+
"""For CustomizeLattice, lattice data is built during __init__."""
|
|
1227
|
+
pass
|
|
1228
|
+
|
|
1229
|
+
def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None:
|
|
1230
|
+
"""Calculates neighbors using a KDTree for efficiency.
|
|
1231
|
+
|
|
1232
|
+
This method uses a memory-efficient approach to identify neighbors without
|
|
1233
|
+
initially computing the full N x N distance matrix. It leverages
|
|
1234
|
+
`scipy.spatial.distance.pdist` to find unique distance shells and then
|
|
1235
|
+
a `scipy.spatial.KDTree` for fast radius queries. This approach is
|
|
1236
|
+
significantly more memory-efficient during the neighbor identification phase.
|
|
1237
|
+
|
|
1238
|
+
After the neighbors are identified, the full distance matrix is computed
|
|
1239
|
+
from the pairwise distances and cached for potential future use.
|
|
1240
|
+
|
|
1241
|
+
:param max_k: The maximum number of neighbor shells to
|
|
1242
|
+
calculate. Defaults to 1.
|
|
1243
|
+
:type max_k: int, optional
|
|
1244
|
+
:param tol: The numerical tolerance for distance
|
|
1245
|
+
comparisons. Defaults to 1e-6.
|
|
1246
|
+
:type tol: float, optional
|
|
1247
|
+
"""
|
|
1248
|
+
tol = kwargs.get("tol", 1e-6)
|
|
1249
|
+
logger.info(f"Building neighbors for CustomizeLattice up to k={max_k}...")
|
|
1250
|
+
if self.num_sites < 2:
|
|
1251
|
+
return
|
|
1252
|
+
|
|
1253
|
+
all_coords = np.array(self._coordinates)
|
|
1254
|
+
|
|
1255
|
+
# 1. Use pdist for memory-efficient calculation of pairwise distances
|
|
1256
|
+
# to robustly identify the distance shells.
|
|
1257
|
+
all_distances_sq = pdist(all_coords, metric="sqeuclidean")
|
|
1258
|
+
dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
|
|
1259
|
+
|
|
1260
|
+
if not dist_shells_sq:
|
|
1261
|
+
logger.info("No distinct neighbor shells found.")
|
|
1262
|
+
return
|
|
1263
|
+
|
|
1264
|
+
# 2. Build the KDTree for efficient querying.
|
|
1265
|
+
tree = KDTree(all_coords)
|
|
1266
|
+
self._neighbor_maps = {k: {} for k in range(1, len(dist_shells_sq) + 1)}
|
|
1267
|
+
|
|
1268
|
+
# 3. Find neighbors by isolating shells using inclusion-exclusion.
|
|
1269
|
+
# `found_indices` will store all neighbors within a given radius.
|
|
1270
|
+
found_indices: List[set[int]] = []
|
|
1271
|
+
for k_idx, target_d_sq in enumerate(dist_shells_sq):
|
|
1272
|
+
radius = np.sqrt(target_d_sq) + tol
|
|
1273
|
+
# Query for all points within the new, larger radius.
|
|
1274
|
+
current_shell_indices = tree.query_ball_point(
|
|
1275
|
+
all_coords, r=radius, return_sorted=True
|
|
1276
|
+
)
|
|
1277
|
+
|
|
1278
|
+
# Now, isolate the neighbors for the current shell k
|
|
1279
|
+
k = k_idx + 1
|
|
1280
|
+
current_k_map: Dict[int, List[int]] = {}
|
|
1281
|
+
for i in range(self.num_sites):
|
|
1282
|
+
|
|
1283
|
+
if k_idx == 0:
|
|
1284
|
+
co_located_indices = tree.query_ball_point(all_coords[i], r=1e-12)
|
|
1285
|
+
prev_found = set(co_located_indices)
|
|
1286
|
+
else:
|
|
1287
|
+
prev_found = found_indices[i]
|
|
1288
|
+
|
|
1289
|
+
# The new neighbors are those in the current radius shell,
|
|
1290
|
+
# excluding those already found in smaller shells.
|
|
1291
|
+
new_neighbors = set(current_shell_indices[i]) - prev_found
|
|
1292
|
+
|
|
1293
|
+
if new_neighbors:
|
|
1294
|
+
current_k_map[i] = sorted(list(new_neighbors))
|
|
1295
|
+
|
|
1296
|
+
self._neighbor_maps[k] = current_k_map
|
|
1297
|
+
found_indices = [
|
|
1298
|
+
set(l) for l in current_shell_indices
|
|
1299
|
+
] # Update for next iteration
|
|
1300
|
+
self._distance_matrix = np.sqrt(squareform(all_distances_sq))
|
|
1301
|
+
|
|
1302
|
+
logger.info("Neighbor building complete using KDTree.")
|
|
1303
|
+
|
|
1304
|
+
def _compute_distance_matrix(self) -> Coordinates:
|
|
1305
|
+
"""Computes the distance matrix from the stored coordinates.
|
|
1306
|
+
|
|
1307
|
+
This implementation uses scipy.pdist for a memory-efficient
|
|
1308
|
+
calculation of pairwise distances, which is then converted to a
|
|
1309
|
+
full square matrix.
|
|
1310
|
+
"""
|
|
1311
|
+
if self.num_sites < 2:
|
|
1312
|
+
return cast(Coordinates, np.empty((self.num_sites, self.num_sites)))
|
|
1313
|
+
|
|
1314
|
+
all_coords = np.array(self._coordinates)
|
|
1315
|
+
# Use pdist for memory-efficiency, then build the full matrix.
|
|
1316
|
+
all_distances_sq = pdist(all_coords, metric="sqeuclidean")
|
|
1317
|
+
dist_matrix_sq = squareform(all_distances_sq)
|
|
1318
|
+
return cast(Coordinates, np.sqrt(dist_matrix_sq))
|
|
1319
|
+
|
|
1320
|
+
def _reset_computations(self) -> None:
|
|
1321
|
+
"""Resets all cached data that depends on the lattice structure."""
|
|
1322
|
+
self._neighbor_maps = {}
|
|
1323
|
+
self._distance_matrix = None
|
|
1324
|
+
|
|
1325
|
+
@classmethod
|
|
1326
|
+
def from_lattice(cls, lattice: "AbstractLattice") -> "CustomizeLattice":
|
|
1327
|
+
"""Creates a CustomizeLattice instance from any existing lattice object.
|
|
1328
|
+
|
|
1329
|
+
This is useful for 'detaching' a procedurally generated lattice (like
|
|
1330
|
+
a SquareLattice) into a customizable one for further modifications,
|
|
1331
|
+
such as adding defects or extra sites.
|
|
1332
|
+
|
|
1333
|
+
:param lattice: An instance of any AbstractLattice subclass.
|
|
1334
|
+
:type lattice: AbstractLattice
|
|
1335
|
+
:return: A new CustomizeLattice instance with the same sites.
|
|
1336
|
+
:rtype: CustomizeLattice
|
|
1337
|
+
"""
|
|
1338
|
+
all_sites_info = list(lattice.sites())
|
|
1339
|
+
|
|
1340
|
+
if not all_sites_info:
|
|
1341
|
+
return cls(
|
|
1342
|
+
dimensionality=lattice.dimensionality, identifiers=[], coordinates=[]
|
|
1343
|
+
)
|
|
1344
|
+
|
|
1345
|
+
# Unzip the list of tuples into separate lists of identifiers and coordinates
|
|
1346
|
+
_, identifiers, coordinates = zip(*all_sites_info)
|
|
1347
|
+
|
|
1348
|
+
return cls(
|
|
1349
|
+
dimensionality=lattice.dimensionality,
|
|
1350
|
+
identifiers=list(identifiers),
|
|
1351
|
+
coordinates=list(coordinates),
|
|
1352
|
+
)
|
|
1353
|
+
|
|
1354
|
+
def add_sites(
|
|
1355
|
+
self,
|
|
1356
|
+
identifiers: List[SiteIdentifier],
|
|
1357
|
+
coordinates: List[Union[List[float], Coordinates]],
|
|
1358
|
+
) -> None:
|
|
1359
|
+
"""Adds new sites to the lattice.
|
|
1360
|
+
|
|
1361
|
+
This operation modifies the lattice in-place. After adding sites, any
|
|
1362
|
+
previously computed neighbor information is cleared and must be
|
|
1363
|
+
recalculated.
|
|
1364
|
+
|
|
1365
|
+
:param identifiers: A list of unique, hashable identifiers for the new sites.
|
|
1366
|
+
:type identifiers: List[SiteIdentifier]
|
|
1367
|
+
:param coordinates: A list of coordinates for the new sites.
|
|
1368
|
+
:type coordinates: List[Union[List[float], np.ndarray]]
|
|
1369
|
+
:raises ValueError: If input lists have mismatched lengths, or if any new
|
|
1370
|
+
identifier already exists in the lattice.
|
|
1371
|
+
"""
|
|
1372
|
+
if len(identifiers) != len(coordinates):
|
|
1373
|
+
raise ValueError(
|
|
1374
|
+
"Identifiers and coordinates lists must have the same length."
|
|
1375
|
+
)
|
|
1376
|
+
if not identifiers:
|
|
1377
|
+
return # Nothing to add
|
|
1378
|
+
|
|
1379
|
+
# Check for duplicate identifiers before making any changes
|
|
1380
|
+
existing_ids = set(self._identifiers)
|
|
1381
|
+
new_ids = set(identifiers)
|
|
1382
|
+
if not new_ids.isdisjoint(existing_ids):
|
|
1383
|
+
raise ValueError(
|
|
1384
|
+
f"Duplicate identifiers found: {new_ids.intersection(existing_ids)}"
|
|
1385
|
+
)
|
|
1386
|
+
|
|
1387
|
+
for i, coord in enumerate(coordinates):
|
|
1388
|
+
coord_arr = np.asarray(coord)
|
|
1389
|
+
if coord_arr.shape != (self.dimensionality,):
|
|
1390
|
+
raise ValueError(
|
|
1391
|
+
f"New coordinate at index {i} has shape {coord_arr.shape}, "
|
|
1392
|
+
f"expected ({self.dimensionality},)"
|
|
1393
|
+
)
|
|
1394
|
+
self._coordinates.append(coord_arr)
|
|
1395
|
+
self._identifiers.append(identifiers[i])
|
|
1396
|
+
|
|
1397
|
+
# Rebuild index mappings from scratch
|
|
1398
|
+
self._indices = list(range(len(self._identifiers)))
|
|
1399
|
+
self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)}
|
|
1400
|
+
|
|
1401
|
+
# Invalidate any previously computed neighbors or distance matrices
|
|
1402
|
+
self._reset_computations()
|
|
1403
|
+
logger.info(
|
|
1404
|
+
f"{len(identifiers)} sites added. Lattice now has {self.num_sites} sites."
|
|
1405
|
+
)
|
|
1406
|
+
|
|
1407
|
+
def remove_sites(self, identifiers: List[SiteIdentifier]) -> None:
|
|
1408
|
+
"""Removes specified sites from the lattice.
|
|
1409
|
+
|
|
1410
|
+
This operation modifies the lattice in-place. After removing sites,
|
|
1411
|
+
all site indices are re-calculated, and any previously computed
|
|
1412
|
+
neighbor information is cleared.
|
|
1413
|
+
|
|
1414
|
+
:param identifiers: A list of identifiers for the sites to be removed.
|
|
1415
|
+
:type identifiers: List[SiteIdentifier]
|
|
1416
|
+
:raises ValueError: If any of the specified identifiers do not exist.
|
|
1417
|
+
"""
|
|
1418
|
+
if not identifiers:
|
|
1419
|
+
return # Nothing to remove
|
|
1420
|
+
|
|
1421
|
+
ids_to_remove = set(identifiers)
|
|
1422
|
+
current_ids = set(self._identifiers)
|
|
1423
|
+
if not ids_to_remove.issubset(current_ids):
|
|
1424
|
+
raise ValueError(
|
|
1425
|
+
f"Non-existent identifiers provided for removal: {ids_to_remove - current_ids}"
|
|
1426
|
+
)
|
|
1427
|
+
|
|
1428
|
+
# Create new lists containing only the sites to keep
|
|
1429
|
+
new_identifiers: List[SiteIdentifier] = []
|
|
1430
|
+
new_coordinates: List[Coordinates] = []
|
|
1431
|
+
for ident, coord in zip(self._identifiers, self._coordinates):
|
|
1432
|
+
if ident not in ids_to_remove:
|
|
1433
|
+
new_identifiers.append(ident)
|
|
1434
|
+
new_coordinates.append(coord)
|
|
1435
|
+
|
|
1436
|
+
# Replace old data with the new, filtered data
|
|
1437
|
+
self._identifiers = new_identifiers
|
|
1438
|
+
self._coordinates = new_coordinates
|
|
1439
|
+
|
|
1440
|
+
# Rebuild index mappings
|
|
1441
|
+
self._indices = list(range(len(self._identifiers)))
|
|
1442
|
+
self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)}
|
|
1443
|
+
|
|
1444
|
+
# Invalidate caches
|
|
1445
|
+
self._reset_computations()
|
|
1446
|
+
logger.info(
|
|
1447
|
+
f"{len(ids_to_remove)} sites removed. Lattice now has {self.num_sites} sites."
|
|
1448
|
+
)
|