tensorcircuit-nightly 1.3.0.dev20250815__py3-none-any.whl → 1.3.0.dev20250817__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tensorcircuit-nightly might be problematic. Click here for more details.
- tensorcircuit/__init__.py +1 -1
- tensorcircuit/backends/abstract_backend.py +113 -0
- tensorcircuit/backends/cupy_backend.py +3 -1
- tensorcircuit/backends/jax_backend.py +34 -2
- tensorcircuit/backends/numpy_backend.py +28 -1
- tensorcircuit/backends/pytorch_backend.py +36 -0
- tensorcircuit/backends/tensorflow_backend.py +38 -0
- tensorcircuit/templates/hamiltonians.py +16 -3
- tensorcircuit/templates/lattice.py +624 -335
- {tensorcircuit_nightly-1.3.0.dev20250815.dist-info → tensorcircuit_nightly-1.3.0.dev20250817.dist-info}/METADATA +1 -1
- {tensorcircuit_nightly-1.3.0.dev20250815.dist-info → tensorcircuit_nightly-1.3.0.dev20250817.dist-info}/RECORD +14 -14
- {tensorcircuit_nightly-1.3.0.dev20250815.dist-info → tensorcircuit_nightly-1.3.0.dev20250817.dist-info}/WHEEL +0 -0
- {tensorcircuit_nightly-1.3.0.dev20250815.dist-info → tensorcircuit_nightly-1.3.0.dev20250817.dist-info}/licenses/LICENSE +0 -0
- {tensorcircuit_nightly-1.3.0.dev20250815.dist-info → tensorcircuit_nightly-1.3.0.dev20250817.dist-info}/top_level.txt +0 -0
|
@@ -18,11 +18,12 @@ from typing import (
|
|
|
18
18
|
Set,
|
|
19
19
|
)
|
|
20
20
|
|
|
21
|
-
|
|
21
|
+
import itertools
|
|
22
|
+
import math
|
|
22
23
|
import numpy as np
|
|
23
|
-
|
|
24
24
|
from scipy.spatial import KDTree
|
|
25
|
-
|
|
25
|
+
|
|
26
|
+
from .. import backend
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
# This block resolves a name resolution issue for the static type checker (mypy).
|
|
@@ -41,9 +42,13 @@ if TYPE_CHECKING:
|
|
|
41
42
|
import matplotlib.axes
|
|
42
43
|
from mpl_toolkits.mplot3d import Axes3D
|
|
43
44
|
|
|
45
|
+
logger = logging.getLogger(__name__)
|
|
46
|
+
|
|
47
|
+
Tensor = Any
|
|
44
48
|
SiteIndex = int
|
|
45
49
|
SiteIdentifier = Hashable
|
|
46
|
-
Coordinates =
|
|
50
|
+
Coordinates = Tensor
|
|
51
|
+
|
|
47
52
|
NeighborMap = Dict[SiteIndex, List[SiteIndex]]
|
|
48
53
|
|
|
49
54
|
|
|
@@ -64,13 +69,27 @@ class AbstractLattice(abc.ABC):
|
|
|
64
69
|
"""Initializes the base lattice class."""
|
|
65
70
|
self._dimensionality = dimensionality
|
|
66
71
|
|
|
67
|
-
#
|
|
68
|
-
self._indices: List[SiteIndex] = []
|
|
69
|
-
self._identifiers: List[SiteIdentifier] =
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
72
|
+
# Core data structures for storing site information.
|
|
73
|
+
self._indices: List[SiteIndex] = [] # List of integer indices [0, 1, ..., N-1]
|
|
74
|
+
self._identifiers: List[SiteIdentifier] = (
|
|
75
|
+
[]
|
|
76
|
+
) # List of unique, hashable site identifiers
|
|
77
|
+
# Always initialize to an empty coordinate tensor with correct dimensionality
|
|
78
|
+
# so that type checkers know this is indexable and not Optional.
|
|
79
|
+
self._coordinates: Coordinates = backend.zeros((0, dimensionality))
|
|
80
|
+
|
|
81
|
+
# Mappings for efficient lookups.
|
|
82
|
+
self._ident_to_idx: Dict[SiteIdentifier, SiteIndex] = (
|
|
83
|
+
{}
|
|
84
|
+
) # Maps identifiers to indices
|
|
85
|
+
|
|
86
|
+
# Cached properties, computed on demand.
|
|
87
|
+
self._neighbor_maps: Dict[int, NeighborMap] = (
|
|
88
|
+
{}
|
|
89
|
+
) # Caches neighbor info for different k
|
|
90
|
+
self._distance_matrix: Optional[Coordinates] = (
|
|
91
|
+
None # Caches the full N x N distance matrix
|
|
92
|
+
)
|
|
74
93
|
|
|
75
94
|
@property
|
|
76
95
|
def num_sites(self) -> int:
|
|
@@ -95,7 +114,6 @@ class AbstractLattice(abc.ABC):
|
|
|
95
114
|
subsequent calls. This computation can be expensive for large lattices.
|
|
96
115
|
"""
|
|
97
116
|
if self._distance_matrix is None:
|
|
98
|
-
logger.info("Distance matrix not cached. Computing now...")
|
|
99
117
|
self._distance_matrix = self._compute_distance_matrix()
|
|
100
118
|
return self._distance_matrix
|
|
101
119
|
|
|
@@ -116,7 +134,8 @@ class AbstractLattice(abc.ABC):
|
|
|
116
134
|
:rtype: Coordinates
|
|
117
135
|
"""
|
|
118
136
|
self._validate_index(index)
|
|
119
|
-
|
|
137
|
+
coords = self._coordinates[index]
|
|
138
|
+
return coords
|
|
120
139
|
|
|
121
140
|
def get_identifier(self, index: SiteIndex) -> SiteIdentifier:
|
|
122
141
|
"""Gets the abstract identifier of a site by its integer index.
|
|
@@ -140,7 +159,8 @@ class AbstractLattice(abc.ABC):
|
|
|
140
159
|
:rtype: SiteIndex
|
|
141
160
|
"""
|
|
142
161
|
try:
|
|
143
|
-
|
|
162
|
+
index = self._ident_to_idx[identifier]
|
|
163
|
+
return index
|
|
144
164
|
except KeyError as e:
|
|
145
165
|
raise ValueError(
|
|
146
166
|
f"Identifier {identifier} not found in the lattice."
|
|
@@ -170,7 +190,7 @@ class AbstractLattice(abc.ABC):
|
|
|
170
190
|
idx = index_or_identifier
|
|
171
191
|
self._validate_index(idx)
|
|
172
192
|
return idx, self._identifiers[idx], self._coordinates[idx]
|
|
173
|
-
else:
|
|
193
|
+
else:
|
|
174
194
|
ident = index_or_identifier
|
|
175
195
|
idx = self.get_index(ident)
|
|
176
196
|
return idx, ident, self._coordinates[idx]
|
|
@@ -237,7 +257,6 @@ class AbstractLattice(abc.ABC):
|
|
|
237
257
|
)
|
|
238
258
|
self._build_neighbors(max_k=k)
|
|
239
259
|
|
|
240
|
-
# After attempting to build, check again. If still not found, return empty.
|
|
241
260
|
if k not in self._neighbor_maps:
|
|
242
261
|
return []
|
|
243
262
|
|
|
@@ -251,8 +270,28 @@ class AbstractLattice(abc.ABC):
|
|
|
251
270
|
pairs.append((i, j))
|
|
252
271
|
return sorted(pairs)
|
|
253
272
|
|
|
254
|
-
|
|
255
|
-
|
|
273
|
+
def get_all_pairs(self) -> List[Tuple[SiteIndex, SiteIndex]]:
|
|
274
|
+
"""
|
|
275
|
+
Returns a list of all unique pairs of site indices (i, j) where i < j.
|
|
276
|
+
|
|
277
|
+
This method provides all-to-all connectivity, useful for Hamiltonians
|
|
278
|
+
where every site interacts with every other site.
|
|
279
|
+
|
|
280
|
+
Note on Differentiability:
|
|
281
|
+
This method provides a static list of index pairs and is not differentiable
|
|
282
|
+
itself. However, it is designed to be used in combination with the fully
|
|
283
|
+
differentiable ``distance_matrix`` property. By using the pairs from this
|
|
284
|
+
method to index into the ``distance_matrix``, one can construct differentiable
|
|
285
|
+
objective functions based on all-pair interactions, effectively bypassing the
|
|
286
|
+
non-differentiable ``get_neighbor_pairs`` method for geometry optimization tasks.
|
|
287
|
+
|
|
288
|
+
:return: A list of tuples, where each tuple is a unique pair of site indices.
|
|
289
|
+
:rtype: List[Tuple[SiteIndex, SiteIndex]]
|
|
290
|
+
"""
|
|
291
|
+
if self.num_sites < 2:
|
|
292
|
+
return []
|
|
293
|
+
# Use itertools.combinations to efficiently generate all unique pairs (i, j) with i < j.
|
|
294
|
+
return sorted(list(itertools.combinations(range(self.num_sites), 2)))
|
|
256
295
|
|
|
257
296
|
@abc.abstractmethod
|
|
258
297
|
def _build_lattice(self, *args: Any, **kwargs: Any) -> None:
|
|
@@ -281,14 +320,24 @@ class AbstractLattice(abc.ABC):
|
|
|
281
320
|
"""
|
|
282
321
|
pass
|
|
283
322
|
|
|
284
|
-
@abc.abstractmethod
|
|
285
323
|
def _compute_distance_matrix(self) -> Coordinates:
|
|
286
324
|
"""
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
325
|
+
Default generic distance matrix computation (no periodic images).
|
|
326
|
+
|
|
327
|
+
Subclasses can override this when a specialized rule is required
|
|
328
|
+
(e.g., applying Minimum Image Convention for PBC in TILattice).
|
|
290
329
|
"""
|
|
291
|
-
|
|
330
|
+
# Handle empty lattices and trivial 1-site lattices
|
|
331
|
+
if self.num_sites == 0:
|
|
332
|
+
return backend.zeros((0, 0))
|
|
333
|
+
|
|
334
|
+
# Vectorized pairwise Euclidean distances
|
|
335
|
+
all_coords = self._coordinates
|
|
336
|
+
displacements = backend.expand_dims(all_coords, 1) - backend.expand_dims(
|
|
337
|
+
all_coords, 0
|
|
338
|
+
)
|
|
339
|
+
dist_matrix_sq = backend.sum(displacements**2, axis=-1)
|
|
340
|
+
return backend.sqrt(dist_matrix_sq)
|
|
292
341
|
|
|
293
342
|
def show(
|
|
294
343
|
self,
|
|
@@ -328,13 +377,14 @@ class AbstractLattice(abc.ABC):
|
|
|
328
377
|
try:
|
|
329
378
|
import matplotlib.pyplot as plt
|
|
330
379
|
except ImportError:
|
|
331
|
-
logger.
|
|
380
|
+
logger.warning(
|
|
332
381
|
"Matplotlib is required for visualization. "
|
|
333
382
|
"Please install it using 'pip install matplotlib'."
|
|
334
383
|
)
|
|
335
384
|
return
|
|
336
385
|
|
|
337
|
-
#
|
|
386
|
+
# Flag to track if the Matplotlib figure was created by this method.
|
|
387
|
+
# This prevents calling plt.show() if the user provided their own Axes.
|
|
338
388
|
fig_created_internally = False
|
|
339
389
|
|
|
340
390
|
if self.num_sites == 0:
|
|
@@ -347,7 +397,7 @@ class AbstractLattice(abc.ABC):
|
|
|
347
397
|
return
|
|
348
398
|
|
|
349
399
|
if ax is None:
|
|
350
|
-
#
|
|
400
|
+
# If no Axes object is provided, create a new figure and axes.
|
|
351
401
|
fig_created_internally = True
|
|
352
402
|
if self.dimensionality == 3:
|
|
353
403
|
fig = plt.figure(figsize=(8, 8))
|
|
@@ -358,6 +408,7 @@ class AbstractLattice(abc.ABC):
|
|
|
358
408
|
fig = ax.figure # type: ignore
|
|
359
409
|
|
|
360
410
|
coords = np.array(self._coordinates)
|
|
411
|
+
# Prepare arguments for the scatter plot, allowing user overrides.
|
|
361
412
|
scatter_args = {"s": 100, "zorder": 2}
|
|
362
413
|
scatter_args.update(kwargs)
|
|
363
414
|
if self.dimensionality == 1:
|
|
@@ -371,12 +422,11 @@ class AbstractLattice(abc.ABC):
|
|
|
371
422
|
if show_indices or show_identifiers:
|
|
372
423
|
for i in range(self.num_sites):
|
|
373
424
|
label = str(self._identifiers[i]) if show_identifiers else str(i)
|
|
425
|
+
# Calculate a small offset for placing text labels to avoid overlap with sites.
|
|
374
426
|
offset = (
|
|
375
427
|
0.02 * np.max(np.ptp(coords, axis=0)) if coords.size > 0 else 0.1
|
|
376
428
|
)
|
|
377
429
|
|
|
378
|
-
# Robust Logic: Decide plotting strategy based on known dimensionality.
|
|
379
|
-
|
|
380
430
|
if self.dimensionality == 1:
|
|
381
431
|
ax.text(coords[i, 0], offset, label, fontsize=9, ha="center")
|
|
382
432
|
elif self.dimensionality == 2:
|
|
@@ -398,8 +448,6 @@ class AbstractLattice(abc.ABC):
|
|
|
398
448
|
zorder=3,
|
|
399
449
|
)
|
|
400
450
|
|
|
401
|
-
# Note: No 'else' needed as we already check dimensionality at the start.
|
|
402
|
-
|
|
403
451
|
if show_bonds_k is not None:
|
|
404
452
|
if show_bonds_k not in self._neighbor_maps:
|
|
405
453
|
logger.warning(
|
|
@@ -433,7 +481,7 @@ class AbstractLattice(abc.ABC):
|
|
|
433
481
|
if self.dimensionality == 1: # type: ignore
|
|
434
482
|
|
|
435
483
|
ax.plot([p1[0], p2[0]], [0, 0], **plot_bond_kwargs) # type: ignore
|
|
436
|
-
else:
|
|
484
|
+
else:
|
|
437
485
|
ax.plot([p1[0], p2[0]], [p1[1], p2[1]], **plot_bond_kwargs) # type: ignore
|
|
438
486
|
|
|
439
487
|
except ValueError as e:
|
|
@@ -449,7 +497,7 @@ class AbstractLattice(abc.ABC):
|
|
|
449
497
|
ax.set_zlabel("z")
|
|
450
498
|
ax.grid(True)
|
|
451
499
|
|
|
452
|
-
#
|
|
500
|
+
# Display the plot only if the figure was created within this function.
|
|
453
501
|
if fig_created_internally:
|
|
454
502
|
plt.show()
|
|
455
503
|
|
|
@@ -475,26 +523,28 @@ class AbstractLattice(abc.ABC):
|
|
|
475
523
|
:return: A sorted list of squared distances representing the shells.
|
|
476
524
|
:rtype: List[float]
|
|
477
525
|
"""
|
|
526
|
+
# A small threshold to filter out zero distances (site to itself).
|
|
478
527
|
ZERO_THRESHOLD_SQ = 1e-12
|
|
479
528
|
|
|
480
|
-
all_distances_sq =
|
|
529
|
+
all_distances_sq = backend.convert_to_tensor(all_distances_sq)
|
|
481
530
|
# Now, the .size call below is guaranteed to be safe.
|
|
482
|
-
if all_distances_sq
|
|
531
|
+
if backend.sizen(all_distances_sq) == 0:
|
|
483
532
|
return []
|
|
484
533
|
|
|
485
|
-
|
|
534
|
+
# Filter out self-distances and sort the remaining squared distances.
|
|
535
|
+
sorted_dist = backend.sort(
|
|
536
|
+
all_distances_sq[all_distances_sq > ZERO_THRESHOLD_SQ]
|
|
537
|
+
)
|
|
486
538
|
|
|
487
|
-
if sorted_dist
|
|
539
|
+
if backend.sizen(sorted_dist) == 0:
|
|
488
540
|
return []
|
|
489
541
|
|
|
490
|
-
# Identify shells using the user-provided tolerance.
|
|
491
542
|
dist_shells = [sorted_dist[0]]
|
|
492
543
|
|
|
493
544
|
for d_sq in sorted_dist[1:]:
|
|
494
545
|
if len(dist_shells) >= max_k:
|
|
495
546
|
break
|
|
496
|
-
|
|
497
|
-
if d_sq > dist_shells[-1] + tol**2:
|
|
547
|
+
if backend.sqrt(d_sq) - backend.sqrt(dist_shells[-1]) > tol:
|
|
498
548
|
dist_shells.append(d_sq)
|
|
499
549
|
|
|
500
550
|
return dist_shells
|
|
@@ -503,11 +553,9 @@ class AbstractLattice(abc.ABC):
|
|
|
503
553
|
self, max_k: int = 2, tol: float = 1e-6
|
|
504
554
|
) -> None:
|
|
505
555
|
"""A generic, distance-based neighbor finding method.
|
|
506
|
-
|
|
507
556
|
This method calculates the full N x N distance matrix to find neighbor
|
|
508
557
|
shells. It is computationally expensive for large N (O(N^2)) and is
|
|
509
558
|
best suited for non-periodic or custom-defined lattices.
|
|
510
|
-
|
|
511
559
|
:param max_k: The maximum number of neighbor shells to
|
|
512
560
|
calculate. Defaults to 2.
|
|
513
561
|
:type max_k: int, optional
|
|
@@ -518,26 +566,55 @@ class AbstractLattice(abc.ABC):
|
|
|
518
566
|
if self.num_sites < 2:
|
|
519
567
|
return
|
|
520
568
|
|
|
521
|
-
all_coords =
|
|
522
|
-
|
|
523
|
-
|
|
569
|
+
all_coords = self._coordinates
|
|
570
|
+
# Vectorized computation of the squared distance matrix:
|
|
571
|
+
# (N, 1, D) - (1, N, D) -> (N, N, D) -> (N, N)
|
|
572
|
+
displacements = backend.expand_dims(all_coords, 1) - backend.expand_dims(
|
|
573
|
+
all_coords, 0
|
|
524
574
|
)
|
|
575
|
+
dist_matrix_sq = backend.sum(displacements**2, axis=-1)
|
|
525
576
|
|
|
526
|
-
|
|
577
|
+
# Flatten the matrix to a list of all squared distances to identify shells.
|
|
578
|
+
all_distances_sq = backend.reshape(dist_matrix_sq, [-1])
|
|
527
579
|
dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
|
|
528
580
|
|
|
529
|
-
self._neighbor_maps =
|
|
581
|
+
self._neighbor_maps = self._build_neighbor_map_from_distances(
|
|
582
|
+
dist_matrix_sq, dist_shells_sq, tol
|
|
583
|
+
)
|
|
584
|
+
self._distance_matrix = backend.sqrt(dist_matrix_sq)
|
|
585
|
+
|
|
586
|
+
def _build_neighbor_map_from_distances(
|
|
587
|
+
self,
|
|
588
|
+
dist_matrix_sq: Coordinates,
|
|
589
|
+
dist_shells_sq: List[float],
|
|
590
|
+
tol: float = 1e-6,
|
|
591
|
+
) -> Dict[int, NeighborMap]:
|
|
592
|
+
"""
|
|
593
|
+
Builds a neighbor map from a squared distance matrix and identified shells.
|
|
594
|
+
This is a generic helper function to reduce code duplication.
|
|
595
|
+
"""
|
|
596
|
+
neighbor_maps: Dict[int, NeighborMap] = {
|
|
597
|
+
k: {} for k in range(1, len(dist_shells_sq) + 1)
|
|
598
|
+
}
|
|
530
599
|
for k_idx, target_d_sq in enumerate(dist_shells_sq):
|
|
531
600
|
k = k_idx + 1
|
|
532
601
|
current_k_map: Dict[int, List[int]] = {}
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
602
|
+
# For each shell, find all pairs of sites (i, j) with that distance.
|
|
603
|
+
is_close_matrix = backend.abs(dist_matrix_sq - target_d_sq) < tol
|
|
604
|
+
rows, cols = backend.where(is_close_matrix)
|
|
605
|
+
|
|
606
|
+
for i, j in zip(backend.numpy(rows), backend.numpy(cols)):
|
|
607
|
+
if i == j:
|
|
608
|
+
continue
|
|
609
|
+
if i not in current_k_map:
|
|
610
|
+
current_k_map[i] = []
|
|
611
|
+
current_k_map[i].append(j)
|
|
612
|
+
|
|
613
|
+
for i in current_k_map:
|
|
614
|
+
current_k_map[i].sort()
|
|
615
|
+
|
|
616
|
+
neighbor_maps[k] = current_k_map
|
|
617
|
+
return neighbor_maps
|
|
541
618
|
|
|
542
619
|
|
|
543
620
|
class TILattice(AbstractLattice):
|
|
@@ -588,150 +665,197 @@ class TILattice(AbstractLattice):
|
|
|
588
665
|
):
|
|
589
666
|
"""Initializes the Translationally Invariant Lattice."""
|
|
590
667
|
super().__init__(dimensionality)
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
self.
|
|
601
|
-
|
|
602
|
-
|
|
668
|
+
|
|
669
|
+
self.lattice_vectors = backend.convert_to_tensor(lattice_vectors)
|
|
670
|
+
self.basis_coords = backend.convert_to_tensor(basis_coords)
|
|
671
|
+
|
|
672
|
+
if self.lattice_vectors.shape != (dimensionality, dimensionality):
|
|
673
|
+
raise ValueError(
|
|
674
|
+
f"Lattice vectors shape {self.lattice_vectors.shape} does not match "
|
|
675
|
+
f"expected ({dimensionality}, {dimensionality})"
|
|
676
|
+
)
|
|
677
|
+
if self.basis_coords.shape[1] != dimensionality:
|
|
678
|
+
raise ValueError(
|
|
679
|
+
f"Basis coordinates dimension {self.basis_coords.shape[1]} does not "
|
|
680
|
+
f"match lattice dimensionality {dimensionality}"
|
|
681
|
+
)
|
|
682
|
+
if len(size) != dimensionality:
|
|
683
|
+
raise ValueError(
|
|
684
|
+
f"Size tuple length {len(size)} does not match dimensionality {dimensionality}"
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
self.num_basis = self.basis_coords.shape[0]
|
|
603
688
|
self.size = size
|
|
604
689
|
if isinstance(pbc, bool):
|
|
605
690
|
self.pbc = tuple([pbc] * dimensionality)
|
|
606
691
|
else:
|
|
607
|
-
|
|
692
|
+
if len(pbc) != dimensionality:
|
|
693
|
+
raise ValueError(
|
|
694
|
+
f"PBC tuple length {len(pbc)} does not match dimensionality {dimensionality}"
|
|
695
|
+
)
|
|
608
696
|
self.pbc = tuple(pbc)
|
|
609
697
|
|
|
610
|
-
# Build the lattice sites and their neighbor relationships
|
|
611
698
|
self._build_lattice()
|
|
612
699
|
if precompute_neighbors is not None and precompute_neighbors > 0:
|
|
613
700
|
logger.info(f"Pre-computing neighbors up to k={precompute_neighbors}...")
|
|
614
701
|
self._build_neighbors(max_k=precompute_neighbors)
|
|
615
702
|
|
|
616
703
|
def _build_lattice(self) -> None:
|
|
617
|
-
"""Generates all site information for the periodic lattice.
|
|
618
|
-
|
|
619
|
-
This method iterates through each unit cell defined by `self.size`,
|
|
620
|
-
and for each unit cell, it iterates through all basis sites. It then
|
|
621
|
-
calculates the real-space coordinates and creates a unique identifier
|
|
622
|
-
for each site, populating the internal lattice data structures.
|
|
623
704
|
"""
|
|
624
|
-
|
|
705
|
+
Generates all site information for the periodic lattice in a vectorized manner.
|
|
706
|
+
"""
|
|
707
|
+
ranges = [backend.arange(s) for s in self.size]
|
|
625
708
|
|
|
626
|
-
#
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
709
|
+
# Generate a grid of all integer unit cell coordinates.
|
|
710
|
+
grid = backend.meshgrid(*ranges, indexing="ij")
|
|
711
|
+
all_cell_coords = backend.reshape(
|
|
712
|
+
backend.stack(grid, axis=-1), (-1, self.dimensionality)
|
|
713
|
+
)
|
|
631
714
|
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
715
|
+
all_cell_coords = backend.cast(all_cell_coords, self.lattice_vectors.dtype)
|
|
716
|
+
|
|
717
|
+
cell_vectors = backend.tensordot(
|
|
718
|
+
all_cell_coords, self.lattice_vectors, axes=[[1], [0]]
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
cell_vectors = backend.cast(cell_vectors, self.basis_coords.dtype)
|
|
635
722
|
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
723
|
+
# Combine cell vectors with basis coordinates to get all site positions
|
|
724
|
+
# via broadcasting: (num_cells, 1, D) + (1, num_basis, D) -> (num_cells, num_basis, D)
|
|
725
|
+
all_coords = backend.expand_dims(cell_vectors, 1) + backend.expand_dims(
|
|
726
|
+
self.basis_coords, 0
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
self._coordinates = backend.reshape(all_coords, (-1, self.dimensionality))
|
|
730
|
+
|
|
731
|
+
self._indices = []
|
|
732
|
+
self._identifiers = []
|
|
733
|
+
self._ident_to_idx = {}
|
|
734
|
+
current_index = 0
|
|
640
735
|
|
|
641
|
-
|
|
736
|
+
# Generate integer indices and tuple-based identifiers for all sites.
|
|
737
|
+
# e.g., identifier = (uc_x, uc_y, basis_idx)
|
|
738
|
+
size_ranges = [range(s) for s in self.size]
|
|
739
|
+
for cell_coord_tuple in itertools.product(*size_ranges):
|
|
740
|
+
for basis_index in range(self.num_basis):
|
|
741
|
+
identifier = cell_coord_tuple + (basis_index,)
|
|
642
742
|
self._indices.append(current_index)
|
|
643
743
|
self._identifiers.append(identifier)
|
|
644
|
-
self._coordinates.append(coord)
|
|
645
744
|
self._ident_to_idx[identifier] = current_index
|
|
646
745
|
current_index += 1
|
|
647
746
|
|
|
648
|
-
def
|
|
747
|
+
def _get_distance_matrix_with_mic_vectorized(self) -> Coordinates:
|
|
649
748
|
"""
|
|
650
|
-
Computes the full N x N distance matrix
|
|
651
|
-
Minimum Image Convention (MIC) for
|
|
749
|
+
Computes the full N x N distance matrix using a fully vectorized approach
|
|
750
|
+
that correctly applies the Minimum Image Convention (MIC) for periodic
|
|
751
|
+
boundary conditions.
|
|
752
|
+
|
|
753
|
+
This method uses full vectorization for optimal performance and compatibility
|
|
754
|
+
with JIT compilation frameworks like JAX. The implementation processes all
|
|
755
|
+
site pairs simultaneously rather than iterating row-by-row, which provides:
|
|
756
|
+
|
|
757
|
+
- Better performance through vectorized operations
|
|
758
|
+
- Full compatibility with automatic differentiation
|
|
759
|
+
- JIT compilation support (e.g., JAX, TensorFlow)
|
|
760
|
+
- Consistent tensor operations throughout
|
|
761
|
+
|
|
762
|
+
The trade-off is higher memory usage compared to iterative approaches,
|
|
763
|
+
as it computes all pairwise distances simultaneously. For very large
|
|
764
|
+
lattices (N > 10^4 sites), memory usage scales as O(N^2).
|
|
765
|
+
|
|
766
|
+
:return: Distance matrix with shape (N, N) where entry (i,j) is the
|
|
767
|
+
minimum distance between sites i and j under periodic boundary conditions.
|
|
768
|
+
:rtype: Coordinates
|
|
652
769
|
"""
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
#
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
if
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
pbc_translations = all_shifts @ pbc_system_vectors
|
|
671
|
-
translations.extend(pbc_translations)
|
|
672
|
-
|
|
673
|
-
translations_arr = np.array(translations, dtype=float)
|
|
674
|
-
|
|
675
|
-
# Calculate the distance matrix applying MIC
|
|
676
|
-
dist_matrix_sq = np.full((self.num_sites, self.num_sites), np.inf, dtype=float)
|
|
677
|
-
for i in range(self.num_sites):
|
|
678
|
-
displacements = all_coords - all_coords[i]
|
|
679
|
-
image_displacements = (
|
|
680
|
-
displacements[:, np.newaxis, :] - translations_arr[np.newaxis, :, :]
|
|
681
|
-
)
|
|
682
|
-
image_d_sq = np.sum(image_displacements**2, axis=2)
|
|
683
|
-
dist_matrix_sq[i, :] = np.min(image_d_sq, axis=1)
|
|
770
|
+
# Ensure dtype consistency across backends (especially torch) by explicitly
|
|
771
|
+
# casting size and lattice_vectors to the same floating dtype used internally.
|
|
772
|
+
# Strategy: prefer existing lattice_vectors dtype; if it's an unusual dtype,
|
|
773
|
+
# fall back to float32 to avoid mixed-precision issues in vectorized ops.
|
|
774
|
+
# Note: `self.lattice_vectors` is always created via `backend.convert_to_tensor`
|
|
775
|
+
# in __init__, so `backend.dtype(...)` is reliable here and doesn't need try/except.
|
|
776
|
+
target_dt = str(backend.dtype(self.lattice_vectors)) # type: ignore
|
|
777
|
+
if target_dt not in ("float32", "float64"):
|
|
778
|
+
# fallback for unusual dtypes
|
|
779
|
+
target_dt = "float32"
|
|
780
|
+
|
|
781
|
+
size_arr = backend.cast(backend.convert_to_tensor(self.size), target_dt)
|
|
782
|
+
lattice_vecs = backend.cast(
|
|
783
|
+
backend.convert_to_tensor(self.lattice_vectors), target_dt
|
|
784
|
+
)
|
|
785
|
+
system_vectors = lattice_vecs * backend.expand_dims(size_arr, axis=1)
|
|
684
786
|
|
|
685
|
-
|
|
787
|
+
pbc_mask = backend.convert_to_tensor(self.pbc)
|
|
788
|
+
|
|
789
|
+
# Generate all 3^d possible image shifts (-1, 0, 1) for all dimensions
|
|
790
|
+
shift_options = [
|
|
791
|
+
backend.convert_to_tensor([-1.0, 0.0, 1.0])
|
|
792
|
+
] * self.dimensionality
|
|
793
|
+
shifts_grid = backend.meshgrid(*shift_options, indexing="ij")
|
|
794
|
+
all_shifts = backend.reshape(
|
|
795
|
+
backend.stack(shifts_grid, axis=-1), (-1, self.dimensionality)
|
|
796
|
+
)
|
|
797
|
+
|
|
798
|
+
# Only apply shifts to periodic dimensions
|
|
799
|
+
masked_shifts = all_shifts * backend.cast(pbc_mask, all_shifts.dtype)
|
|
800
|
+
|
|
801
|
+
# Calculate all translation vectors due to PBC
|
|
802
|
+
translations_arr = backend.tensordot(
|
|
803
|
+
masked_shifts, system_vectors, axes=[[1], [0]]
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
# Vectorized computation of all displacements between any two sites
|
|
807
|
+
# Shape: (N, 1, D) - (1, N, D) -> (N, N, D)
|
|
808
|
+
displacements = backend.expand_dims(self._coordinates, 1) - backend.expand_dims(
|
|
809
|
+
self._coordinates, 0
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
# Consider all periodic images for each displacement
|
|
813
|
+
# Shape: (N, N, 1, D) - (1, 1, num_translations, D) -> (N, N, num_translations, D)
|
|
814
|
+
image_displacements = backend.expand_dims(
|
|
815
|
+
displacements, 2
|
|
816
|
+
) - backend.expand_dims(backend.expand_dims(translations_arr, 0), 0)
|
|
817
|
+
|
|
818
|
+
# Sum of squares for distances
|
|
819
|
+
image_d_sq = backend.sum(image_displacements**2, axis=3)
|
|
820
|
+
|
|
821
|
+
# Find the minimum distance among all images (Minimum Image Convention)
|
|
822
|
+
min_dist_sq = backend.min(image_d_sq, axis=2)
|
|
823
|
+
|
|
824
|
+
safe_dist_matrix_sq = backend.where(min_dist_sq > 0, min_dist_sq, 0.0)
|
|
825
|
+
return backend.sqrt(safe_dist_matrix_sq)
|
|
686
826
|
|
|
687
827
|
def _build_neighbors(self, max_k: int = 2, **kwargs: Any) -> None:
|
|
688
828
|
"""Calculates neighbor relationships for the periodic lattice.
|
|
689
829
|
|
|
690
|
-
This method
|
|
691
|
-
distance matrix
|
|
692
|
-
periodic
|
|
693
|
-
(
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
the specified `max_k` and populates the neighbor maps. The computed
|
|
697
|
-
distance matrix is then cached for future use.
|
|
830
|
+
This method computes neighbor information by first calculating the full
|
|
831
|
+
distance matrix using the Minimum Image Convention (MIC) to correctly
|
|
832
|
+
handle periodic boundary conditions. It then identifies unique distance
|
|
833
|
+
shells (e.g., nearest, next-nearest) and populates the neighbor maps
|
|
834
|
+
accordingly. This approach is general and works for any periodic lattice
|
|
835
|
+
geometry defined by the TILattice class.
|
|
698
836
|
|
|
699
|
-
:param max_k: The maximum
|
|
700
|
-
|
|
837
|
+
:param max_k: The maximum order of neighbors to compute (e.g., k=1 for
|
|
838
|
+
nearest neighbors, k=2 for next-nearest, etc.). Defaults to 2.
|
|
701
839
|
:type max_k: int, optional
|
|
702
|
-
:param
|
|
703
|
-
|
|
704
|
-
|
|
840
|
+
:param kwargs: Additional keyword arguments. May include:
|
|
841
|
+
- ``tol`` (float): The numerical tolerance used to determine if two
|
|
842
|
+
distances are equal when identifying shells. Defaults to 1e-6.
|
|
705
843
|
"""
|
|
706
844
|
tol = kwargs.get("tol", 1e-6)
|
|
707
|
-
dist_matrix = self.
|
|
845
|
+
dist_matrix = self._get_distance_matrix_with_mic_vectorized()
|
|
708
846
|
dist_matrix_sq = dist_matrix**2
|
|
709
847
|
self._distance_matrix = dist_matrix
|
|
710
|
-
all_distances_sq =
|
|
848
|
+
all_distances_sq = backend.reshape(dist_matrix_sq, [-1])
|
|
711
849
|
dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
k = k_idx + 1
|
|
716
|
-
current_k_map: Dict[int, List[int]] = {}
|
|
717
|
-
match_indices = np.where(
|
|
718
|
-
np.isclose(dist_matrix_sq, target_d_sq, rtol=0, atol=tol**2)
|
|
719
|
-
)
|
|
720
|
-
for i, j in zip(*match_indices):
|
|
721
|
-
if i == j:
|
|
722
|
-
continue
|
|
723
|
-
if i not in current_k_map:
|
|
724
|
-
current_k_map[i] = []
|
|
725
|
-
current_k_map[i].append(j)
|
|
726
|
-
|
|
727
|
-
for i in current_k_map:
|
|
728
|
-
current_k_map[i].sort()
|
|
729
|
-
|
|
730
|
-
self._neighbor_maps[k] = current_k_map
|
|
850
|
+
self._neighbor_maps = self._build_neighbor_map_from_distances(
|
|
851
|
+
dist_matrix_sq, dist_shells_sq, tol
|
|
852
|
+
)
|
|
731
853
|
|
|
732
854
|
def _compute_distance_matrix(self) -> Coordinates:
|
|
733
855
|
"""Computes the distance matrix using the Minimum Image Convention."""
|
|
734
|
-
|
|
856
|
+
if self.num_sites == 0:
|
|
857
|
+
return backend.zeros((0, 0))
|
|
858
|
+
return self._get_distance_matrix_with_mic_vectorized()
|
|
735
859
|
|
|
736
860
|
|
|
737
861
|
class SquareLattice(TILattice):
|
|
@@ -759,20 +883,24 @@ class SquareLattice(TILattice):
|
|
|
759
883
|
def __init__(
|
|
760
884
|
self,
|
|
761
885
|
size: Tuple[int, int],
|
|
762
|
-
lattice_constant: float = 1.0,
|
|
886
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
763
887
|
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
764
888
|
precompute_neighbors: Optional[int] = None,
|
|
765
889
|
):
|
|
766
890
|
"""Initializes the SquareLattice."""
|
|
767
891
|
dimensionality = 2
|
|
892
|
+
# Define orthogonal lattice vectors for a square.
|
|
893
|
+
# Avoid mixing Python floats with backend Tensors (TF would error),
|
|
894
|
+
# so first convert inputs to tensors of a unified dtype, then stack.
|
|
895
|
+
lc = backend.convert_to_tensor(lattice_constant)
|
|
896
|
+
dt = backend.dtype(lc)
|
|
897
|
+
z = backend.cast(backend.convert_to_tensor(0.0), dt)
|
|
898
|
+
row1 = backend.stack([lc, z])
|
|
899
|
+
row2 = backend.stack([z, lc])
|
|
900
|
+
lattice_vectors = backend.stack([row1, row2])
|
|
901
|
+
# A square lattice is a Bravais lattice, so it has a single-site basis.
|
|
902
|
+
basis_coords = backend.stack([backend.stack([z, z])])
|
|
768
903
|
|
|
769
|
-
# Define lattice vectors for a square lattice
|
|
770
|
-
lattice_vectors = np.array([[lattice_constant, 0.0], [0.0, lattice_constant]])
|
|
771
|
-
|
|
772
|
-
# A square lattice has a single site in its basis
|
|
773
|
-
basis_coords = np.array([[0.0, 0.0]])
|
|
774
|
-
|
|
775
|
-
# Call the parent TILattice constructor with these parameters
|
|
776
904
|
super().__init__(
|
|
777
905
|
dimensionality=dimensionality,
|
|
778
906
|
lattice_vectors=lattice_vectors,
|
|
@@ -808,19 +936,28 @@ class HoneycombLattice(TILattice):
|
|
|
808
936
|
def __init__(
|
|
809
937
|
self,
|
|
810
938
|
size: Tuple[int, int],
|
|
811
|
-
lattice_constant: float = 1.0,
|
|
939
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
812
940
|
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
813
941
|
precompute_neighbors: Optional[int] = None,
|
|
814
942
|
):
|
|
815
943
|
"""Initializes the HoneycombLattice."""
|
|
816
944
|
dimensionality = 2
|
|
817
945
|
a = lattice_constant
|
|
946
|
+
a_t = backend.convert_to_tensor(a)
|
|
947
|
+
zero = a_t * 0.0
|
|
818
948
|
|
|
819
|
-
# Define the primitive lattice vectors for the underlying triangular lattice
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
949
|
+
# Define the two primitive lattice vectors for the underlying triangular Bravais lattice.
|
|
950
|
+
rt3_over_2 = math.sqrt(3.0) / 2.0
|
|
951
|
+
lattice_vectors = backend.stack(
|
|
952
|
+
[
|
|
953
|
+
backend.stack([a_t * 1.5, a_t * rt3_over_2]),
|
|
954
|
+
backend.stack([a_t * 1.5, -a_t * rt3_over_2]),
|
|
955
|
+
]
|
|
956
|
+
)
|
|
957
|
+
# Define the two basis sites (A and B) within the unit cell.
|
|
958
|
+
basis_coords = backend.stack(
|
|
959
|
+
[backend.stack([zero, zero]), backend.stack([a_t * 1.0, zero])]
|
|
960
|
+
)
|
|
824
961
|
|
|
825
962
|
super().__init__(
|
|
826
963
|
dimensionality=dimensionality,
|
|
@@ -855,19 +992,30 @@ class TriangularLattice(TILattice):
|
|
|
855
992
|
def __init__(
|
|
856
993
|
self,
|
|
857
994
|
size: Tuple[int, int],
|
|
858
|
-
lattice_constant: float = 1.0,
|
|
995
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
859
996
|
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
860
997
|
precompute_neighbors: Optional[int] = None,
|
|
861
998
|
):
|
|
862
999
|
"""Initializes the TriangularLattice."""
|
|
863
1000
|
dimensionality = 2
|
|
864
1001
|
a = lattice_constant
|
|
1002
|
+
a_t = backend.convert_to_tensor(a)
|
|
1003
|
+
zero = a_t * 0.0
|
|
865
1004
|
|
|
866
|
-
# Define the primitive lattice vectors for a triangular lattice
|
|
867
|
-
lattice_vectors =
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
1005
|
+
# Define the primitive lattice vectors for a triangular lattice.
|
|
1006
|
+
lattice_vectors = backend.stack(
|
|
1007
|
+
[
|
|
1008
|
+
backend.stack([a_t * 1.0, zero]),
|
|
1009
|
+
backend.stack(
|
|
1010
|
+
[
|
|
1011
|
+
a_t * 0.5,
|
|
1012
|
+
a_t * backend.sqrt(backend.convert_to_tensor(3.0)) / 2.0,
|
|
1013
|
+
]
|
|
1014
|
+
),
|
|
1015
|
+
]
|
|
1016
|
+
)
|
|
1017
|
+
# A triangular lattice is a Bravais lattice with a single-site basis.
|
|
1018
|
+
basis_coords = backend.stack([backend.stack([zero, zero])])
|
|
871
1019
|
|
|
872
1020
|
super().__init__(
|
|
873
1021
|
dimensionality=dimensionality,
|
|
@@ -896,13 +1044,18 @@ class ChainLattice(TILattice):
|
|
|
896
1044
|
def __init__(
|
|
897
1045
|
self,
|
|
898
1046
|
size: Tuple[int],
|
|
899
|
-
lattice_constant: float = 1.0,
|
|
1047
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
900
1048
|
pbc: bool = True,
|
|
901
1049
|
precompute_neighbors: Optional[int] = None,
|
|
902
1050
|
):
|
|
903
1051
|
dimensionality = 1
|
|
904
|
-
|
|
905
|
-
|
|
1052
|
+
# The lattice vector is just the lattice constant along one dimension.
|
|
1053
|
+
lc = backend.convert_to_tensor(lattice_constant)
|
|
1054
|
+
lattice_vectors = backend.stack([backend.stack([lc])])
|
|
1055
|
+
# A simple chain is a Bravais lattice with a single-site basis.
|
|
1056
|
+
zero = lc * 0.0
|
|
1057
|
+
basis_coords = backend.stack([backend.stack([zero])])
|
|
1058
|
+
|
|
906
1059
|
super().__init__(
|
|
907
1060
|
dimensionality=dimensionality,
|
|
908
1061
|
lattice_vectors=lattice_vectors,
|
|
@@ -934,15 +1087,17 @@ class DimerizedChainLattice(TILattice):
|
|
|
934
1087
|
def __init__(
|
|
935
1088
|
self,
|
|
936
1089
|
size: Tuple[int],
|
|
937
|
-
lattice_constant: float = 1.0,
|
|
1090
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
938
1091
|
pbc: bool = True,
|
|
939
1092
|
precompute_neighbors: Optional[int] = None,
|
|
940
1093
|
):
|
|
941
1094
|
dimensionality = 1
|
|
942
|
-
# The unit cell
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
1095
|
+
# The unit cell is twice the bond length, as it contains two sites.
|
|
1096
|
+
lc = backend.convert_to_tensor(lattice_constant)
|
|
1097
|
+
lattice_vectors = backend.stack([backend.stack([2 * lc])])
|
|
1098
|
+
# Two basis sites (A and B) separated by the bond length.
|
|
1099
|
+
zero = lc * 0.0
|
|
1100
|
+
basis_coords = backend.stack([backend.stack([zero]), backend.stack([lc])])
|
|
946
1101
|
|
|
947
1102
|
super().__init__(
|
|
948
1103
|
dimensionality=dimensionality,
|
|
@@ -975,14 +1130,22 @@ class RectangularLattice(TILattice):
|
|
|
975
1130
|
def __init__(
|
|
976
1131
|
self,
|
|
977
1132
|
size: Tuple[int, int],
|
|
978
|
-
lattice_constants: Tuple[float, float] = (1.0, 1.0),
|
|
1133
|
+
lattice_constants: Union[Tuple[float, float], Any] = (1.0, 1.0),
|
|
979
1134
|
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
980
1135
|
precompute_neighbors: Optional[int] = None,
|
|
981
1136
|
):
|
|
982
1137
|
dimensionality = 2
|
|
983
1138
|
ax, ay = lattice_constants
|
|
984
|
-
|
|
985
|
-
|
|
1139
|
+
ax_t = backend.convert_to_tensor(ax)
|
|
1140
|
+
dt = backend.dtype(ax_t)
|
|
1141
|
+
ay_t = backend.cast(backend.convert_to_tensor(ay), dt)
|
|
1142
|
+
z = backend.cast(backend.convert_to_tensor(0.0), dt)
|
|
1143
|
+
# Orthogonal lattice vectors with potentially different lengths.
|
|
1144
|
+
row1 = backend.stack([ax_t, z])
|
|
1145
|
+
row2 = backend.stack([z, ay_t])
|
|
1146
|
+
lattice_vectors = backend.stack([row1, row2])
|
|
1147
|
+
# A rectangular lattice is a Bravais lattice with a single-site basis.
|
|
1148
|
+
basis_coords = backend.stack([backend.stack([z, z])])
|
|
986
1149
|
|
|
987
1150
|
super().__init__(
|
|
988
1151
|
dimensionality=dimensionality,
|
|
@@ -1013,16 +1176,26 @@ class CheckerboardLattice(TILattice):
|
|
|
1013
1176
|
def __init__(
|
|
1014
1177
|
self,
|
|
1015
1178
|
size: Tuple[int, int],
|
|
1016
|
-
lattice_constant: float = 1.0,
|
|
1179
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
1017
1180
|
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
1018
1181
|
precompute_neighbors: Optional[int] = None,
|
|
1019
1182
|
):
|
|
1020
1183
|
dimensionality = 2
|
|
1021
1184
|
a = lattice_constant
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1185
|
+
a_t = backend.convert_to_tensor(a)
|
|
1186
|
+
# The unit cell is a square rotated by 45 degrees.
|
|
1187
|
+
lattice_vectors = backend.stack(
|
|
1188
|
+
[
|
|
1189
|
+
backend.stack([a_t * 1.0, a_t * 1.0]),
|
|
1190
|
+
backend.stack([a_t * 1.0, a_t * -1.0]),
|
|
1191
|
+
]
|
|
1192
|
+
)
|
|
1193
|
+
# Two basis sites (A and B) within the unit cell.
|
|
1194
|
+
zero = a_t * 0.0
|
|
1195
|
+
basis_coords = backend.stack(
|
|
1196
|
+
[backend.stack([zero, zero]), backend.stack([a_t * 1.0, zero])]
|
|
1197
|
+
)
|
|
1198
|
+
|
|
1026
1199
|
super().__init__(
|
|
1027
1200
|
dimensionality=dimensionality,
|
|
1028
1201
|
lattice_vectors=lattice_vectors,
|
|
@@ -1052,16 +1225,30 @@ class KagomeLattice(TILattice):
|
|
|
1052
1225
|
def __init__(
|
|
1053
1226
|
self,
|
|
1054
1227
|
size: Tuple[int, int],
|
|
1055
|
-
lattice_constant: float = 1.0,
|
|
1228
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
1056
1229
|
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
1057
1230
|
precompute_neighbors: Optional[int] = None,
|
|
1058
1231
|
):
|
|
1059
1232
|
dimensionality = 2
|
|
1060
1233
|
a = lattice_constant
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1234
|
+
a_t = backend.convert_to_tensor(a)
|
|
1235
|
+
# The Kagome lattice is based on a triangular Bravais lattice.
|
|
1236
|
+
lattice_vectors = backend.stack(
|
|
1237
|
+
[
|
|
1238
|
+
backend.stack([a_t * 2.0, a_t * 0.0]),
|
|
1239
|
+
backend.stack([a_t * 1.0, a_t * backend.sqrt(3.0)]),
|
|
1240
|
+
]
|
|
1241
|
+
)
|
|
1242
|
+
# It has a three-site basis, forming the corners of the triangles.
|
|
1243
|
+
zero = a_t * 0.0
|
|
1244
|
+
basis_coords = backend.stack(
|
|
1245
|
+
[
|
|
1246
|
+
backend.stack([zero, zero]),
|
|
1247
|
+
backend.stack([a_t * 1.0, zero]),
|
|
1248
|
+
backend.stack([a_t * 0.5, a_t * backend.sqrt(3.0) / 2.0]),
|
|
1249
|
+
]
|
|
1250
|
+
)
|
|
1251
|
+
|
|
1065
1252
|
super().__init__(
|
|
1066
1253
|
dimensionality=dimensionality,
|
|
1067
1254
|
lattice_vectors=lattice_vectors,
|
|
@@ -1092,29 +1279,26 @@ class LiebLattice(TILattice):
|
|
|
1092
1279
|
def __init__(
|
|
1093
1280
|
self,
|
|
1094
1281
|
size: Tuple[int, int],
|
|
1095
|
-
lattice_constant: float = 1.0,
|
|
1282
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
1096
1283
|
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
1097
1284
|
precompute_neighbors: Optional[int] = None,
|
|
1098
1285
|
):
|
|
1099
1286
|
"""Initializes the LiebLattice."""
|
|
1100
1287
|
dimensionality = 2
|
|
1101
|
-
# Use a more descriptive name for clarity. In a Lieb lattice,
|
|
1102
|
-
# the lattice_constant is the bond length between nearest neighbors.
|
|
1103
1288
|
bond_length = lattice_constant
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
#
|
|
1107
|
-
|
|
1108
|
-
lattice_vectors =
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
#
|
|
1112
|
-
|
|
1113
|
-
basis_coords = np.array(
|
|
1289
|
+
bl_t = backend.convert_to_tensor(bond_length)
|
|
1290
|
+
unit_cell_side_t = 2 * bl_t
|
|
1291
|
+
# The Lieb lattice is based on a square Bravais lattice.
|
|
1292
|
+
z = bl_t * 0.0
|
|
1293
|
+
lattice_vectors = backend.stack(
|
|
1294
|
+
[backend.stack([unit_cell_side_t, z]), backend.stack([z, unit_cell_side_t])]
|
|
1295
|
+
)
|
|
1296
|
+
# It has a three-site basis: one corner and two edge-centers.
|
|
1297
|
+
basis_coords = backend.stack(
|
|
1114
1298
|
[
|
|
1115
|
-
[
|
|
1116
|
-
[
|
|
1117
|
-
[
|
|
1299
|
+
backend.stack([z, z]), # Corner site
|
|
1300
|
+
backend.stack([bl_t, z]), # x-edge center
|
|
1301
|
+
backend.stack([z, bl_t]), # y-edge center
|
|
1118
1302
|
]
|
|
1119
1303
|
)
|
|
1120
1304
|
|
|
@@ -1147,14 +1331,24 @@ class CubicLattice(TILattice):
|
|
|
1147
1331
|
def __init__(
|
|
1148
1332
|
self,
|
|
1149
1333
|
size: Tuple[int, int, int],
|
|
1150
|
-
lattice_constant: float = 1.0,
|
|
1334
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
1151
1335
|
pbc: Union[bool, Tuple[bool, bool, bool]] = True,
|
|
1152
1336
|
precompute_neighbors: Optional[int] = None,
|
|
1153
1337
|
):
|
|
1154
1338
|
dimensionality = 3
|
|
1155
1339
|
a = lattice_constant
|
|
1156
|
-
|
|
1157
|
-
|
|
1340
|
+
a_t = backend.convert_to_tensor(a)
|
|
1341
|
+
# Orthogonal lattice vectors of equal length in 3D.
|
|
1342
|
+
z = a_t * 0.0
|
|
1343
|
+
lattice_vectors = backend.stack(
|
|
1344
|
+
[
|
|
1345
|
+
backend.stack([a_t, z, z]),
|
|
1346
|
+
backend.stack([z, a_t, z]),
|
|
1347
|
+
backend.stack([z, z, a_t]),
|
|
1348
|
+
]
|
|
1349
|
+
)
|
|
1350
|
+
# A simple cubic lattice is a Bravais lattice with a single-site basis.
|
|
1351
|
+
basis_coords = backend.stack([backend.stack([z, z, z])])
|
|
1158
1352
|
super().__init__(
|
|
1159
1353
|
dimensionality=dimensionality,
|
|
1160
1354
|
lattice_vectors=lattice_vectors,
|
|
@@ -1194,29 +1388,37 @@ class CustomizeLattice(AbstractLattice):
|
|
|
1194
1388
|
self,
|
|
1195
1389
|
dimensionality: int,
|
|
1196
1390
|
identifiers: List[SiteIdentifier],
|
|
1197
|
-
coordinates:
|
|
1391
|
+
coordinates: Any,
|
|
1198
1392
|
precompute_neighbors: Optional[int] = None,
|
|
1199
1393
|
):
|
|
1200
1394
|
"""Initializes the CustomizeLattice."""
|
|
1201
1395
|
super().__init__(dimensionality)
|
|
1202
|
-
|
|
1396
|
+
|
|
1397
|
+
self._coordinates = backend.convert_to_tensor(coordinates)
|
|
1398
|
+
if len(identifiers) == 0:
|
|
1399
|
+
self._coordinates = backend.reshape(
|
|
1400
|
+
self._coordinates, (0, self.dimensionality)
|
|
1401
|
+
)
|
|
1402
|
+
|
|
1403
|
+
if len(identifiers) != backend.shape_tuple(self._coordinates)[0]:
|
|
1203
1404
|
raise ValueError(
|
|
1204
|
-
"
|
|
1405
|
+
"The number of identifiers must match the number of coordinates. "
|
|
1406
|
+
f"Got {len(identifiers)} identifiers and "
|
|
1407
|
+
f"{backend.shape_tuple(self._coordinates)[0]} coordinates."
|
|
1205
1408
|
)
|
|
1206
1409
|
|
|
1207
|
-
# The _build_lattice logic is simple enough to be in __init__
|
|
1208
1410
|
self._identifiers = list(identifiers)
|
|
1209
|
-
self._coordinates = [np.array(c) for c in coordinates]
|
|
1210
1411
|
self._indices = list(range(len(identifiers)))
|
|
1211
1412
|
self._ident_to_idx = {ident: idx for idx, ident in enumerate(identifiers)}
|
|
1212
1413
|
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
|
|
1414
|
+
if (
|
|
1415
|
+
self.num_sites > 0
|
|
1416
|
+
and backend.shape_tuple(self._coordinates)[1] != dimensionality
|
|
1417
|
+
):
|
|
1418
|
+
raise ValueError(
|
|
1419
|
+
f"Coordinates tensor has dimension {backend.shape_tuple(self._coordinates)[1]}, "
|
|
1420
|
+
f"but expected dimensionality is {dimensionality}."
|
|
1421
|
+
)
|
|
1220
1422
|
|
|
1221
1423
|
logger.info(f"CustomizeLattice with {self.num_sites} sites created.")
|
|
1222
1424
|
|
|
@@ -1228,95 +1430,170 @@ class CustomizeLattice(AbstractLattice):
|
|
|
1228
1430
|
pass
|
|
1229
1431
|
|
|
1230
1432
|
def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None:
|
|
1231
|
-
"""
|
|
1232
|
-
|
|
1233
|
-
This method uses a memory-efficient approach to identify neighbors without
|
|
1234
|
-
initially computing the full N x N distance matrix. It leverages
|
|
1235
|
-
`scipy.spatial.distance.pdist` to find unique distance shells and then
|
|
1236
|
-
a `scipy.spatial.KDTree` for fast radius queries. This approach is
|
|
1237
|
-
significantly more memory-efficient during the neighbor identification phase.
|
|
1433
|
+
"""
|
|
1434
|
+
Calculates neighbor relationships using either KDTree or distance matrix methods.
|
|
1238
1435
|
|
|
1239
|
-
|
|
1240
|
-
|
|
1436
|
+
This method supports two modes:
|
|
1437
|
+
1. KDTree mode (use_kdtree=True): Fast, O(N log N) performance for large lattices
|
|
1438
|
+
but breaks differentiability due to scipy dependency
|
|
1439
|
+
2. Distance matrix mode (use_kdtree=False): Slower O(N²) but fully differentiable
|
|
1440
|
+
and backend-agnostic
|
|
1241
1441
|
|
|
1242
|
-
:param max_k:
|
|
1243
|
-
|
|
1244
|
-
:
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
:type tol: float, optional
|
|
1442
|
+
:param max_k: Maximum number of neighbor shells to compute
|
|
1443
|
+
:type max_k: int
|
|
1444
|
+
:param kwargs: Additional arguments including:
|
|
1445
|
+
- use_kdtree (bool): Whether to use KDTree optimization. Defaults to False.
|
|
1446
|
+
- tol (float): Distance tolerance for neighbor identification. Defaults to 1e-6.
|
|
1248
1447
|
"""
|
|
1249
1448
|
tol = kwargs.get("tol", 1e-6)
|
|
1250
|
-
|
|
1449
|
+
# Reviewer suggestion: prefer differentiable method by default
|
|
1450
|
+
use_kdtree = kwargs.get("use_kdtree", False)
|
|
1451
|
+
|
|
1251
1452
|
if self.num_sites < 2:
|
|
1252
1453
|
return
|
|
1253
1454
|
|
|
1254
|
-
|
|
1455
|
+
# Choose algorithm based on user preference
|
|
1456
|
+
if use_kdtree:
|
|
1457
|
+
logger.info(
|
|
1458
|
+
f"Using KDTree method for {self.num_sites} sites up to k={max_k}"
|
|
1459
|
+
)
|
|
1460
|
+
self._build_neighbors_kdtree(max_k, tol)
|
|
1461
|
+
else:
|
|
1462
|
+
logger.info(
|
|
1463
|
+
f"Using differentiable distance matrix method for {self.num_sites} sites up to k={max_k}"
|
|
1464
|
+
)
|
|
1255
1465
|
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
all_distances_sq = pdist(all_coords, metric="sqeuclidean")
|
|
1259
|
-
dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
|
|
1466
|
+
# Use the existing distance matrix method
|
|
1467
|
+
self._build_neighbors_by_distance_matrix(max_k, tol)
|
|
1260
1468
|
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1469
|
+
def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None:
|
|
1470
|
+
"""
|
|
1471
|
+
Build neighbors using KDTree for optimal performance.
|
|
1264
1472
|
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1473
|
+
This method provides O(N log N) performance for neighbor finding but breaks
|
|
1474
|
+
differentiability due to scipy dependency. Use this method when:
|
|
1475
|
+
- Performance is critical
|
|
1476
|
+
- Differentiability is not required
|
|
1477
|
+
- Large lattices (N > 1000)
|
|
1268
1478
|
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1479
|
+
Note: This method uses numpy arrays directly and may not be compatible
|
|
1480
|
+
with all backend types (JAX, TensorFlow, etc.).
|
|
1481
|
+
"""
|
|
1482
|
+
|
|
1483
|
+
# For small lattices or cases with potential duplicate coordinates,
|
|
1484
|
+
# fall back to distance matrix method for robustness
|
|
1485
|
+
if self.num_sites < 200:
|
|
1486
|
+
logger.info(
|
|
1487
|
+
"Small lattice detected, falling back to distance matrix method for robustness"
|
|
1277
1488
|
)
|
|
1489
|
+
self._build_neighbors_by_distance_matrix(max_k, tol)
|
|
1490
|
+
return
|
|
1278
1491
|
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
current_k_map: Dict[int, List[int]] = {}
|
|
1282
|
-
for i in range(self.num_sites):
|
|
1492
|
+
# Convert coordinates to numpy for KDTree
|
|
1493
|
+
coords_np = backend.numpy(self._coordinates)
|
|
1283
1494
|
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1495
|
+
# Build KDTree
|
|
1496
|
+
logger.info("Building KDTree...")
|
|
1497
|
+
tree = KDTree(coords_np)
|
|
1498
|
+
# Find all distances for shell identification - use comprehensive sampling
|
|
1499
|
+
logger.info("Identifying distance shells...")
|
|
1500
|
+
distances_for_shells: List[float] = []
|
|
1289
1501
|
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1502
|
+
# For robust shell identification, query all pairwise distances for smaller lattices
|
|
1503
|
+
# or use dense sampling for larger ones
|
|
1504
|
+
if self.num_sites <= 100:
|
|
1505
|
+
# For small lattices, compute all pairwise distances for accuracy
|
|
1506
|
+
for i in range(self.num_sites):
|
|
1507
|
+
query_k = min(self.num_sites - 1, max_k * 20)
|
|
1508
|
+
if query_k > 0:
|
|
1509
|
+
dists, _ = tree.query(
|
|
1510
|
+
coords_np[i], k=query_k + 1
|
|
1511
|
+
) # +1 to exclude self
|
|
1512
|
+
if isinstance(dists, np.ndarray):
|
|
1513
|
+
distances_for_shells.extend(dists[1:]) # Skip distance to self
|
|
1514
|
+
else:
|
|
1515
|
+
distances_for_shells.append(dists) # Single distance
|
|
1516
|
+
else:
|
|
1517
|
+
# For larger lattices, use adaptive sampling but ensure we capture all shells
|
|
1518
|
+
sample_size = min(1000, self.num_sites // 2) # More conservative sampling
|
|
1519
|
+
for i in range(0, self.num_sites, max(1, self.num_sites // sample_size)):
|
|
1520
|
+
query_k = min(max_k * 20 + 50, self.num_sites - 1)
|
|
1521
|
+
if query_k > 0:
|
|
1522
|
+
dists, _ = tree.query(
|
|
1523
|
+
coords_np[i], k=query_k + 1
|
|
1524
|
+
) # +1 to exclude self
|
|
1525
|
+
if isinstance(dists, np.ndarray):
|
|
1526
|
+
distances_for_shells.extend(dists[1:]) # Skip distance to self
|
|
1527
|
+
else:
|
|
1528
|
+
distances_for_shells.append(dists) # Single distance
|
|
1293
1529
|
|
|
1294
|
-
|
|
1295
|
-
|
|
1530
|
+
# Filter out zero distances (duplicate coordinates) before shell identification
|
|
1531
|
+
ZERO_THRESHOLD = 1e-12
|
|
1532
|
+
distances_for_shells = [d for d in distances_for_shells if d > ZERO_THRESHOLD]
|
|
1296
1533
|
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
self._distance_matrix = np.sqrt(squareform(all_distances_sq))
|
|
1534
|
+
if not distances_for_shells:
|
|
1535
|
+
logger.warning("No valid distances found for shell identification")
|
|
1536
|
+
self._neighbor_maps = {}
|
|
1537
|
+
return
|
|
1302
1538
|
|
|
1303
|
-
|
|
1539
|
+
# Use the same shell identification logic as distance matrix method
|
|
1540
|
+
distances_for_shells_sq = [d * d for d in distances_for_shells]
|
|
1541
|
+
dist_shells_sq = self._identify_distance_shells(
|
|
1542
|
+
distances_for_shells_sq, max_k, tol
|
|
1543
|
+
)
|
|
1544
|
+
dist_shells = [np.sqrt(d_sq) for d_sq in dist_shells_sq]
|
|
1304
1545
|
|
|
1305
|
-
|
|
1306
|
-
"""Computes the distance matrix from the stored coordinates.
|
|
1546
|
+
logger.info(f"Found {len(dist_shells)} distance shells: {dist_shells[:5]}...")
|
|
1307
1547
|
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
full square matrix.
|
|
1311
|
-
"""
|
|
1312
|
-
if self.num_sites < 2:
|
|
1313
|
-
return cast(Coordinates, np.empty((self.num_sites, self.num_sites)))
|
|
1548
|
+
# Initialize neighbor maps
|
|
1549
|
+
self._neighbor_maps = {k: {} for k in range(1, len(dist_shells) + 1)}
|
|
1314
1550
|
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
|
|
1319
|
-
|
|
1551
|
+
# Build neighbor lists for each site
|
|
1552
|
+
for i in range(self.num_sites):
|
|
1553
|
+
# Query enough neighbors to capture all shells
|
|
1554
|
+
query_k = min(max_k * 20 + 50, self.num_sites - 1)
|
|
1555
|
+
if query_k > 0:
|
|
1556
|
+
distances, indices = tree.query(
|
|
1557
|
+
coords_np[i], k=query_k + 1
|
|
1558
|
+
) # +1 for self
|
|
1559
|
+
|
|
1560
|
+
# Skip the first entry (distance to self)
|
|
1561
|
+
# Handle both single value and array cases
|
|
1562
|
+
if isinstance(distances, np.ndarray) and len(distances) > 1:
|
|
1563
|
+
distances_slice = distances[1:]
|
|
1564
|
+
indices_slice = (
|
|
1565
|
+
indices[1:]
|
|
1566
|
+
if isinstance(indices, np.ndarray)
|
|
1567
|
+
else np.array([], dtype=int)
|
|
1568
|
+
)
|
|
1569
|
+
else:
|
|
1570
|
+
# Single value or empty case - no neighbors to process
|
|
1571
|
+
distances_slice = np.array([])
|
|
1572
|
+
indices_slice = np.array([], dtype=int)
|
|
1573
|
+
|
|
1574
|
+
# Filter out zero distances (duplicate coordinates)
|
|
1575
|
+
valid_pairs = [
|
|
1576
|
+
(d, idx)
|
|
1577
|
+
for d, idx in zip(distances_slice, indices_slice)
|
|
1578
|
+
if d > ZERO_THRESHOLD
|
|
1579
|
+
]
|
|
1580
|
+
|
|
1581
|
+
# Assign neighbors to shells
|
|
1582
|
+
for shell_idx, shell_dist in enumerate(dist_shells):
|
|
1583
|
+
k = shell_idx + 1
|
|
1584
|
+
shell_neighbors = []
|
|
1585
|
+
|
|
1586
|
+
for dist, neighbor_idx in valid_pairs:
|
|
1587
|
+
if abs(dist - shell_dist) <= tol:
|
|
1588
|
+
shell_neighbors.append(int(neighbor_idx))
|
|
1589
|
+
elif dist > shell_dist + tol:
|
|
1590
|
+
break # Distances are sorted, no more matches
|
|
1591
|
+
|
|
1592
|
+
if shell_neighbors:
|
|
1593
|
+
self._neighbor_maps[k][i] = sorted(shell_neighbors)
|
|
1594
|
+
|
|
1595
|
+
# Set distance matrix to None - will compute on demand
|
|
1596
|
+
self._distance_matrix = None
|
|
1320
1597
|
|
|
1321
1598
|
def _reset_computations(self) -> None:
|
|
1322
1599
|
"""Resets all cached data that depends on the lattice structure."""
|
|
@@ -1344,18 +1621,28 @@ class CustomizeLattice(AbstractLattice):
|
|
|
1344
1621
|
)
|
|
1345
1622
|
|
|
1346
1623
|
# Unzip the list of tuples into separate lists of identifiers and coordinates
|
|
1347
|
-
_, identifiers,
|
|
1624
|
+
_, identifiers, _ = zip(*all_sites_info)
|
|
1625
|
+
|
|
1626
|
+
# Detach-and-copy coordinates while remaining in tensor form to avoid
|
|
1627
|
+
# host roundtrips and device/dtype changes; this keeps CustomizeLattice
|
|
1628
|
+
# decoupled from the original graph but backend-friendly.
|
|
1629
|
+
# Some backends (e.g., NumPy) don't implement stop_gradient; fall back.
|
|
1630
|
+
try:
|
|
1631
|
+
coords_detached = backend.stop_gradient(lattice._coordinates)
|
|
1632
|
+
except NotImplementedError:
|
|
1633
|
+
coords_detached = lattice._coordinates
|
|
1634
|
+
coords_tensor = backend.copy(coords_detached)
|
|
1348
1635
|
|
|
1349
1636
|
return cls(
|
|
1350
1637
|
dimensionality=lattice.dimensionality,
|
|
1351
1638
|
identifiers=list(identifiers),
|
|
1352
|
-
coordinates=
|
|
1639
|
+
coordinates=coords_tensor,
|
|
1353
1640
|
)
|
|
1354
1641
|
|
|
1355
1642
|
def add_sites(
|
|
1356
1643
|
self,
|
|
1357
1644
|
identifiers: List[SiteIdentifier],
|
|
1358
|
-
coordinates:
|
|
1645
|
+
coordinates: Any,
|
|
1359
1646
|
) -> None:
|
|
1360
1647
|
"""Adds new sites to the lattice.
|
|
1361
1648
|
|
|
@@ -1363,21 +1650,29 @@ class CustomizeLattice(AbstractLattice):
|
|
|
1363
1650
|
previously computed neighbor information is cleared and must be
|
|
1364
1651
|
recalculated.
|
|
1365
1652
|
|
|
1366
|
-
:param identifiers: A list of unique
|
|
1653
|
+
:param identifiers: A list of unique identifiers for the new sites.
|
|
1367
1654
|
:type identifiers: List[SiteIdentifier]
|
|
1368
|
-
:param coordinates:
|
|
1369
|
-
|
|
1370
|
-
:
|
|
1371
|
-
identifier already exists in the lattice.
|
|
1655
|
+
:param coordinates: The coordinates for the new sites. Can be a list of lists,
|
|
1656
|
+
a NumPy array, or a backend-compatible tensor (e.g., jax.numpy.ndarray).
|
|
1657
|
+
:type coordinates: Any
|
|
1372
1658
|
"""
|
|
1373
|
-
if
|
|
1659
|
+
if not identifiers:
|
|
1660
|
+
return
|
|
1661
|
+
|
|
1662
|
+
new_coords_tensor = backend.convert_to_tensor(coordinates)
|
|
1663
|
+
|
|
1664
|
+
if len(identifiers) != backend.shape_tuple(new_coords_tensor)[0]:
|
|
1374
1665
|
raise ValueError(
|
|
1375
1666
|
"Identifiers and coordinates lists must have the same length."
|
|
1376
1667
|
)
|
|
1377
|
-
if not identifiers:
|
|
1378
|
-
return # Nothing to add
|
|
1379
1668
|
|
|
1380
|
-
|
|
1669
|
+
if backend.shape_tuple(new_coords_tensor)[1] != self.dimensionality:
|
|
1670
|
+
raise ValueError(
|
|
1671
|
+
f"New coordinate tensor has dimension {backend.shape_tuple(new_coords_tensor)[1]}, "
|
|
1672
|
+
f"but expected dimensionality is {self.dimensionality}."
|
|
1673
|
+
)
|
|
1674
|
+
|
|
1675
|
+
# Ensure that the new identifiers are unique and do not already exist.
|
|
1381
1676
|
existing_ids = set(self._identifiers)
|
|
1382
1677
|
new_ids = set(identifiers)
|
|
1383
1678
|
if not new_ids.isdisjoint(existing_ids):
|
|
@@ -1385,21 +1680,14 @@ class CustomizeLattice(AbstractLattice):
|
|
|
1385
1680
|
f"Duplicate identifiers found: {new_ids.intersection(existing_ids)}"
|
|
1386
1681
|
)
|
|
1387
1682
|
|
|
1388
|
-
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
f"New coordinate at index {i} has shape {coord_arr.shape}, "
|
|
1393
|
-
f"expected ({self.dimensionality},)"
|
|
1394
|
-
)
|
|
1395
|
-
self._coordinates.append(coord_arr)
|
|
1396
|
-
self._identifiers.append(identifiers[i])
|
|
1683
|
+
self._coordinates = backend.concat(
|
|
1684
|
+
[self._coordinates, new_coords_tensor], axis=0
|
|
1685
|
+
)
|
|
1686
|
+
self._identifiers.extend(identifiers)
|
|
1397
1687
|
|
|
1398
|
-
# Rebuild index mappings from scratch
|
|
1399
1688
|
self._indices = list(range(len(self._identifiers)))
|
|
1400
1689
|
self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)}
|
|
1401
1690
|
|
|
1402
|
-
# Invalidate any previously computed neighbors or distance matrices
|
|
1403
1691
|
self._reset_computations()
|
|
1404
1692
|
logger.info(
|
|
1405
1693
|
f"{len(identifiers)} sites added. Lattice now has {self.num_sites} sites."
|
|
@@ -1414,10 +1702,9 @@ class CustomizeLattice(AbstractLattice):
|
|
|
1414
1702
|
|
|
1415
1703
|
:param identifiers: A list of identifiers for the sites to be removed.
|
|
1416
1704
|
:type identifiers: List[SiteIdentifier]
|
|
1417
|
-
:raises ValueError: If any of the specified identifiers do not exist.
|
|
1418
1705
|
"""
|
|
1419
1706
|
if not identifiers:
|
|
1420
|
-
return
|
|
1707
|
+
return
|
|
1421
1708
|
|
|
1422
1709
|
ids_to_remove = set(identifiers)
|
|
1423
1710
|
current_ids = set(self._identifiers)
|
|
@@ -1426,23 +1713,25 @@ class CustomizeLattice(AbstractLattice):
|
|
|
1426
1713
|
f"Non-existent identifiers provided for removal: {ids_to_remove - current_ids}"
|
|
1427
1714
|
)
|
|
1428
1715
|
|
|
1429
|
-
#
|
|
1430
|
-
|
|
1431
|
-
|
|
1432
|
-
|
|
1433
|
-
if ident not in ids_to_remove
|
|
1434
|
-
|
|
1435
|
-
|
|
1716
|
+
# Find the indices of the sites that we want to keep.
|
|
1717
|
+
indices_to_keep = [
|
|
1718
|
+
idx
|
|
1719
|
+
for idx, ident in enumerate(self._identifiers)
|
|
1720
|
+
if ident not in ids_to_remove
|
|
1721
|
+
]
|
|
1722
|
+
|
|
1723
|
+
new_identifiers = [self._identifiers[i] for i in indices_to_keep]
|
|
1724
|
+
|
|
1725
|
+
self._coordinates = backend.gather1d(
|
|
1726
|
+
self._coordinates,
|
|
1727
|
+
backend.cast(backend.convert_to_tensor(indices_to_keep), "int32"),
|
|
1728
|
+
)
|
|
1436
1729
|
|
|
1437
|
-
# Replace old data with the new, filtered data
|
|
1438
1730
|
self._identifiers = new_identifiers
|
|
1439
|
-
self._coordinates = new_coordinates
|
|
1440
1731
|
|
|
1441
|
-
# Rebuild index mappings
|
|
1442
1732
|
self._indices = list(range(len(self._identifiers)))
|
|
1443
1733
|
self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)}
|
|
1444
1734
|
|
|
1445
|
-
# Invalidate caches
|
|
1446
1735
|
self._reset_computations()
|
|
1447
1736
|
logger.info(
|
|
1448
1737
|
f"{len(ids_to_remove)} sites removed. Lattice now has {self.num_sites} sites."
|