tensorcircuit-nightly 1.3.0.dev20250728__py3-none-any.whl → 1.4.0.dev20251103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tensorcircuit-nightly might be problematic. Click here for more details.
- tensorcircuit/__init__.py +5 -1
- tensorcircuit/abstractcircuit.py +4 -0
- tensorcircuit/analogcircuit.py +413 -0
- tensorcircuit/applications/layers.py +1 -1
- tensorcircuit/applications/van.py +1 -1
- tensorcircuit/backends/abstract_backend.py +312 -5
- tensorcircuit/backends/cupy_backend.py +3 -1
- tensorcircuit/backends/jax_backend.py +92 -3
- tensorcircuit/backends/jax_ops.py +108 -0
- tensorcircuit/backends/numpy_backend.py +49 -3
- tensorcircuit/backends/pytorch_backend.py +92 -3
- tensorcircuit/backends/tensorflow_backend.py +102 -3
- tensorcircuit/basecircuit.py +123 -82
- tensorcircuit/circuit.py +67 -57
- tensorcircuit/cloud/local.py +1 -1
- tensorcircuit/cloud/quafu_provider.py +1 -1
- tensorcircuit/cloud/tencent.py +1 -1
- tensorcircuit/compiler/simple_compiler.py +2 -2
- tensorcircuit/cons.py +1 -0
- tensorcircuit/densitymatrix.py +16 -11
- tensorcircuit/experimental.py +7 -152
- tensorcircuit/fgs.py +5 -6
- tensorcircuit/gates.py +66 -22
- tensorcircuit/keras.py +3 -3
- tensorcircuit/mpscircuit.py +109 -61
- tensorcircuit/quantum.py +697 -133
- tensorcircuit/quditcircuit.py +733 -0
- tensorcircuit/quditgates.py +618 -0
- tensorcircuit/results/counts.py +45 -31
- tensorcircuit/shadows.py +1 -1
- tensorcircuit/simplify.py +3 -1
- tensorcircuit/stabilizercircuit.py +4 -2
- tensorcircuit/templates/blocks.py +2 -2
- tensorcircuit/templates/hamiltonians.py +29 -8
- tensorcircuit/templates/lattice.py +676 -335
- tensorcircuit/timeevol.py +896 -0
- {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/METADATA +50 -25
- tensorcircuit_nightly-1.4.0.dev20251103.dist-info/RECORD +96 -0
- {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/top_level.txt +0 -1
- tensorcircuit_nightly-1.3.0.dev20250728.dist-info/RECORD +0 -122
- tests/__init__.py +0 -0
- tests/conftest.py +0 -67
- tests/test_backends.py +0 -1035
- tests/test_calibrating.py +0 -149
- tests/test_channels.py +0 -409
- tests/test_circuit.py +0 -1713
- tests/test_cloud.py +0 -219
- tests/test_compiler.py +0 -147
- tests/test_dmcircuit.py +0 -555
- tests/test_ensemble.py +0 -72
- tests/test_fgs.py +0 -318
- tests/test_gates.py +0 -156
- tests/test_hamiltonians.py +0 -159
- tests/test_interfaces.py +0 -557
- tests/test_keras.py +0 -160
- tests/test_lattice.py +0 -1666
- tests/test_miscs.py +0 -334
- tests/test_mpscircuit.py +0 -341
- tests/test_noisemodel.py +0 -156
- tests/test_qaoa.py +0 -86
- tests/test_qem.py +0 -152
- tests/test_quantum.py +0 -549
- tests/test_quantum_attr.py +0 -42
- tests/test_results.py +0 -379
- tests/test_shadows.py +0 -160
- tests/test_simplify.py +0 -46
- tests/test_stabilizer.py +0 -226
- tests/test_templates.py +0 -218
- tests/test_torchnn.py +0 -99
- tests/test_van.py +0 -102
- {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/WHEEL +0 -0
- {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/licenses/LICENSE +0 -0
|
@@ -15,13 +15,15 @@ from typing import (
|
|
|
15
15
|
Union,
|
|
16
16
|
TYPE_CHECKING,
|
|
17
17
|
cast,
|
|
18
|
+
Set,
|
|
18
19
|
)
|
|
19
20
|
|
|
20
|
-
|
|
21
|
+
import itertools
|
|
22
|
+
import math
|
|
21
23
|
import numpy as np
|
|
22
|
-
|
|
23
24
|
from scipy.spatial import KDTree
|
|
24
|
-
|
|
25
|
+
|
|
26
|
+
from .. import backend
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
# This block resolves a name resolution issue for the static type checker (mypy).
|
|
@@ -40,9 +42,13 @@ if TYPE_CHECKING:
|
|
|
40
42
|
import matplotlib.axes
|
|
41
43
|
from mpl_toolkits.mplot3d import Axes3D
|
|
42
44
|
|
|
45
|
+
logger = logging.getLogger(__name__)
|
|
46
|
+
|
|
47
|
+
Tensor = Any
|
|
43
48
|
SiteIndex = int
|
|
44
49
|
SiteIdentifier = Hashable
|
|
45
|
-
Coordinates =
|
|
50
|
+
Coordinates = Tensor
|
|
51
|
+
|
|
46
52
|
NeighborMap = Dict[SiteIndex, List[SiteIndex]]
|
|
47
53
|
|
|
48
54
|
|
|
@@ -63,13 +69,27 @@ class AbstractLattice(abc.ABC):
|
|
|
63
69
|
"""Initializes the base lattice class."""
|
|
64
70
|
self._dimensionality = dimensionality
|
|
65
71
|
|
|
66
|
-
#
|
|
67
|
-
self._indices: List[SiteIndex] = []
|
|
68
|
-
self._identifiers: List[SiteIdentifier] =
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
72
|
+
# Core data structures for storing site information.
|
|
73
|
+
self._indices: List[SiteIndex] = [] # List of integer indices [0, 1, ..., N-1]
|
|
74
|
+
self._identifiers: List[SiteIdentifier] = (
|
|
75
|
+
[]
|
|
76
|
+
) # List of unique, hashable site identifiers
|
|
77
|
+
# Always initialize to an empty coordinate tensor with correct dimensionality
|
|
78
|
+
# so that type checkers know this is indexable and not Optional.
|
|
79
|
+
self._coordinates: Coordinates = backend.zeros((0, dimensionality))
|
|
80
|
+
|
|
81
|
+
# Mappings for efficient lookups.
|
|
82
|
+
self._ident_to_idx: Dict[SiteIdentifier, SiteIndex] = (
|
|
83
|
+
{}
|
|
84
|
+
) # Maps identifiers to indices
|
|
85
|
+
|
|
86
|
+
# Cached properties, computed on demand.
|
|
87
|
+
self._neighbor_maps: Dict[int, NeighborMap] = (
|
|
88
|
+
{}
|
|
89
|
+
) # Caches neighbor info for different k
|
|
90
|
+
self._distance_matrix: Optional[Coordinates] = (
|
|
91
|
+
None # Caches the full N x N distance matrix
|
|
92
|
+
)
|
|
73
93
|
|
|
74
94
|
@property
|
|
75
95
|
def num_sites(self) -> int:
|
|
@@ -94,7 +114,6 @@ class AbstractLattice(abc.ABC):
|
|
|
94
114
|
subsequent calls. This computation can be expensive for large lattices.
|
|
95
115
|
"""
|
|
96
116
|
if self._distance_matrix is None:
|
|
97
|
-
logger.info("Distance matrix not cached. Computing now...")
|
|
98
117
|
self._distance_matrix = self._compute_distance_matrix()
|
|
99
118
|
return self._distance_matrix
|
|
100
119
|
|
|
@@ -115,7 +134,8 @@ class AbstractLattice(abc.ABC):
|
|
|
115
134
|
:rtype: Coordinates
|
|
116
135
|
"""
|
|
117
136
|
self._validate_index(index)
|
|
118
|
-
|
|
137
|
+
coords = self._coordinates[index]
|
|
138
|
+
return coords
|
|
119
139
|
|
|
120
140
|
def get_identifier(self, index: SiteIndex) -> SiteIdentifier:
|
|
121
141
|
"""Gets the abstract identifier of a site by its integer index.
|
|
@@ -139,7 +159,8 @@ class AbstractLattice(abc.ABC):
|
|
|
139
159
|
:rtype: SiteIndex
|
|
140
160
|
"""
|
|
141
161
|
try:
|
|
142
|
-
|
|
162
|
+
index = self._ident_to_idx[identifier]
|
|
163
|
+
return index
|
|
143
164
|
except KeyError as e:
|
|
144
165
|
raise ValueError(
|
|
145
166
|
f"Identifier {identifier} not found in the lattice."
|
|
@@ -169,7 +190,7 @@ class AbstractLattice(abc.ABC):
|
|
|
169
190
|
idx = index_or_identifier
|
|
170
191
|
self._validate_index(idx)
|
|
171
192
|
return idx, self._identifiers[idx], self._coordinates[idx]
|
|
172
|
-
else:
|
|
193
|
+
else:
|
|
173
194
|
ident = index_or_identifier
|
|
174
195
|
idx = self.get_index(ident)
|
|
175
196
|
return idx, ident, self._coordinates[idx]
|
|
@@ -236,7 +257,6 @@ class AbstractLattice(abc.ABC):
|
|
|
236
257
|
)
|
|
237
258
|
self._build_neighbors(max_k=k)
|
|
238
259
|
|
|
239
|
-
# After attempting to build, check again. If still not found, return empty.
|
|
240
260
|
if k not in self._neighbor_maps:
|
|
241
261
|
return []
|
|
242
262
|
|
|
@@ -250,8 +270,28 @@ class AbstractLattice(abc.ABC):
|
|
|
250
270
|
pairs.append((i, j))
|
|
251
271
|
return sorted(pairs)
|
|
252
272
|
|
|
253
|
-
|
|
254
|
-
|
|
273
|
+
def get_all_pairs(self) -> List[Tuple[SiteIndex, SiteIndex]]:
|
|
274
|
+
"""
|
|
275
|
+
Returns a list of all unique pairs of site indices (i, j) where i < j.
|
|
276
|
+
|
|
277
|
+
This method provides all-to-all connectivity, useful for Hamiltonians
|
|
278
|
+
where every site interacts with every other site.
|
|
279
|
+
|
|
280
|
+
Note on Differentiability:
|
|
281
|
+
This method provides a static list of index pairs and is not differentiable
|
|
282
|
+
itself. However, it is designed to be used in combination with the fully
|
|
283
|
+
differentiable ``distance_matrix`` property. By using the pairs from this
|
|
284
|
+
method to index into the ``distance_matrix``, one can construct differentiable
|
|
285
|
+
objective functions based on all-pair interactions, effectively bypassing the
|
|
286
|
+
non-differentiable ``get_neighbor_pairs`` method for geometry optimization tasks.
|
|
287
|
+
|
|
288
|
+
:return: A list of tuples, where each tuple is a unique pair of site indices.
|
|
289
|
+
:rtype: List[Tuple[SiteIndex, SiteIndex]]
|
|
290
|
+
"""
|
|
291
|
+
if self.num_sites < 2:
|
|
292
|
+
return []
|
|
293
|
+
# Use itertools.combinations to efficiently generate all unique pairs (i, j) with i < j.
|
|
294
|
+
return sorted(list(itertools.combinations(range(self.num_sites), 2)))
|
|
255
295
|
|
|
256
296
|
@abc.abstractmethod
|
|
257
297
|
def _build_lattice(self, *args: Any, **kwargs: Any) -> None:
|
|
@@ -280,14 +320,24 @@ class AbstractLattice(abc.ABC):
|
|
|
280
320
|
"""
|
|
281
321
|
pass
|
|
282
322
|
|
|
283
|
-
@abc.abstractmethod
|
|
284
323
|
def _compute_distance_matrix(self) -> Coordinates:
|
|
285
324
|
"""
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
325
|
+
Default generic distance matrix computation (no periodic images).
|
|
326
|
+
|
|
327
|
+
Subclasses can override this when a specialized rule is required
|
|
328
|
+
(e.g., applying Minimum Image Convention for PBC in TILattice).
|
|
289
329
|
"""
|
|
290
|
-
|
|
330
|
+
# Handle empty lattices and trivial 1-site lattices
|
|
331
|
+
if self.num_sites == 0:
|
|
332
|
+
return backend.zeros((0, 0))
|
|
333
|
+
|
|
334
|
+
# Vectorized pairwise Euclidean distances
|
|
335
|
+
all_coords = self._coordinates
|
|
336
|
+
displacements = backend.expand_dims(all_coords, 1) - backend.expand_dims(
|
|
337
|
+
all_coords, 0
|
|
338
|
+
)
|
|
339
|
+
dist_matrix_sq = backend.sum(displacements**2, axis=-1)
|
|
340
|
+
return backend.sqrt(dist_matrix_sq)
|
|
291
341
|
|
|
292
342
|
def show(
|
|
293
343
|
self,
|
|
@@ -327,13 +377,14 @@ class AbstractLattice(abc.ABC):
|
|
|
327
377
|
try:
|
|
328
378
|
import matplotlib.pyplot as plt
|
|
329
379
|
except ImportError:
|
|
330
|
-
logger.
|
|
380
|
+
logger.warning(
|
|
331
381
|
"Matplotlib is required for visualization. "
|
|
332
382
|
"Please install it using 'pip install matplotlib'."
|
|
333
383
|
)
|
|
334
384
|
return
|
|
335
385
|
|
|
336
|
-
#
|
|
386
|
+
# Flag to track if the Matplotlib figure was created by this method.
|
|
387
|
+
# This prevents calling plt.show() if the user provided their own Axes.
|
|
337
388
|
fig_created_internally = False
|
|
338
389
|
|
|
339
390
|
if self.num_sites == 0:
|
|
@@ -346,7 +397,7 @@ class AbstractLattice(abc.ABC):
|
|
|
346
397
|
return
|
|
347
398
|
|
|
348
399
|
if ax is None:
|
|
349
|
-
#
|
|
400
|
+
# If no Axes object is provided, create a new figure and axes.
|
|
350
401
|
fig_created_internally = True
|
|
351
402
|
if self.dimensionality == 3:
|
|
352
403
|
fig = plt.figure(figsize=(8, 8))
|
|
@@ -357,6 +408,7 @@ class AbstractLattice(abc.ABC):
|
|
|
357
408
|
fig = ax.figure # type: ignore
|
|
358
409
|
|
|
359
410
|
coords = np.array(self._coordinates)
|
|
411
|
+
# Prepare arguments for the scatter plot, allowing user overrides.
|
|
360
412
|
scatter_args = {"s": 100, "zorder": 2}
|
|
361
413
|
scatter_args.update(kwargs)
|
|
362
414
|
if self.dimensionality == 1:
|
|
@@ -370,12 +422,11 @@ class AbstractLattice(abc.ABC):
|
|
|
370
422
|
if show_indices or show_identifiers:
|
|
371
423
|
for i in range(self.num_sites):
|
|
372
424
|
label = str(self._identifiers[i]) if show_identifiers else str(i)
|
|
425
|
+
# Calculate a small offset for placing text labels to avoid overlap with sites.
|
|
373
426
|
offset = (
|
|
374
427
|
0.02 * np.max(np.ptp(coords, axis=0)) if coords.size > 0 else 0.1
|
|
375
428
|
)
|
|
376
429
|
|
|
377
|
-
# Robust Logic: Decide plotting strategy based on known dimensionality.
|
|
378
|
-
|
|
379
430
|
if self.dimensionality == 1:
|
|
380
431
|
ax.text(coords[i, 0], offset, label, fontsize=9, ha="center")
|
|
381
432
|
elif self.dimensionality == 2:
|
|
@@ -397,8 +448,6 @@ class AbstractLattice(abc.ABC):
|
|
|
397
448
|
zorder=3,
|
|
398
449
|
)
|
|
399
450
|
|
|
400
|
-
# Note: No 'else' needed as we already check dimensionality at the start.
|
|
401
|
-
|
|
402
451
|
if show_bonds_k is not None:
|
|
403
452
|
if show_bonds_k not in self._neighbor_maps:
|
|
404
453
|
logger.warning(
|
|
@@ -432,7 +481,7 @@ class AbstractLattice(abc.ABC):
|
|
|
432
481
|
if self.dimensionality == 1: # type: ignore
|
|
433
482
|
|
|
434
483
|
ax.plot([p1[0], p2[0]], [0, 0], **plot_bond_kwargs) # type: ignore
|
|
435
|
-
else:
|
|
484
|
+
else:
|
|
436
485
|
ax.plot([p1[0], p2[0]], [p1[1], p2[1]], **plot_bond_kwargs) # type: ignore
|
|
437
486
|
|
|
438
487
|
except ValueError as e:
|
|
@@ -448,7 +497,7 @@ class AbstractLattice(abc.ABC):
|
|
|
448
497
|
ax.set_zlabel("z")
|
|
449
498
|
ax.grid(True)
|
|
450
499
|
|
|
451
|
-
#
|
|
500
|
+
# Display the plot only if the figure was created within this function.
|
|
452
501
|
if fig_created_internally:
|
|
453
502
|
plt.show()
|
|
454
503
|
|
|
@@ -474,26 +523,28 @@ class AbstractLattice(abc.ABC):
|
|
|
474
523
|
:return: A sorted list of squared distances representing the shells.
|
|
475
524
|
:rtype: List[float]
|
|
476
525
|
"""
|
|
526
|
+
# A small threshold to filter out zero distances (site to itself).
|
|
477
527
|
ZERO_THRESHOLD_SQ = 1e-12
|
|
478
528
|
|
|
479
|
-
all_distances_sq =
|
|
529
|
+
all_distances_sq = backend.convert_to_tensor(all_distances_sq)
|
|
480
530
|
# Now, the .size call below is guaranteed to be safe.
|
|
481
|
-
if all_distances_sq
|
|
531
|
+
if backend.sizen(all_distances_sq) == 0:
|
|
482
532
|
return []
|
|
483
533
|
|
|
484
|
-
|
|
534
|
+
# Filter out self-distances and sort the remaining squared distances.
|
|
535
|
+
sorted_dist = backend.sort(
|
|
536
|
+
all_distances_sq[all_distances_sq > ZERO_THRESHOLD_SQ]
|
|
537
|
+
)
|
|
485
538
|
|
|
486
|
-
if sorted_dist
|
|
539
|
+
if backend.sizen(sorted_dist) == 0:
|
|
487
540
|
return []
|
|
488
541
|
|
|
489
|
-
# Identify shells using the user-provided tolerance.
|
|
490
542
|
dist_shells = [sorted_dist[0]]
|
|
491
543
|
|
|
492
544
|
for d_sq in sorted_dist[1:]:
|
|
493
545
|
if len(dist_shells) >= max_k:
|
|
494
546
|
break
|
|
495
|
-
|
|
496
|
-
if d_sq > dist_shells[-1] + tol**2:
|
|
547
|
+
if backend.sqrt(d_sq) - backend.sqrt(dist_shells[-1]) > tol:
|
|
497
548
|
dist_shells.append(d_sq)
|
|
498
549
|
|
|
499
550
|
return dist_shells
|
|
@@ -502,11 +553,9 @@ class AbstractLattice(abc.ABC):
|
|
|
502
553
|
self, max_k: int = 2, tol: float = 1e-6
|
|
503
554
|
) -> None:
|
|
504
555
|
"""A generic, distance-based neighbor finding method.
|
|
505
|
-
|
|
506
556
|
This method calculates the full N x N distance matrix to find neighbor
|
|
507
557
|
shells. It is computationally expensive for large N (O(N^2)) and is
|
|
508
558
|
best suited for non-periodic or custom-defined lattices.
|
|
509
|
-
|
|
510
559
|
:param max_k: The maximum number of neighbor shells to
|
|
511
560
|
calculate. Defaults to 2.
|
|
512
561
|
:type max_k: int, optional
|
|
@@ -517,26 +566,55 @@ class AbstractLattice(abc.ABC):
|
|
|
517
566
|
if self.num_sites < 2:
|
|
518
567
|
return
|
|
519
568
|
|
|
520
|
-
all_coords =
|
|
521
|
-
|
|
522
|
-
|
|
569
|
+
all_coords = self._coordinates
|
|
570
|
+
# Vectorized computation of the squared distance matrix:
|
|
571
|
+
# (N, 1, D) - (1, N, D) -> (N, N, D) -> (N, N)
|
|
572
|
+
displacements = backend.expand_dims(all_coords, 1) - backend.expand_dims(
|
|
573
|
+
all_coords, 0
|
|
523
574
|
)
|
|
575
|
+
dist_matrix_sq = backend.sum(displacements**2, axis=-1)
|
|
524
576
|
|
|
525
|
-
|
|
577
|
+
# Flatten the matrix to a list of all squared distances to identify shells.
|
|
578
|
+
all_distances_sq = backend.reshape(dist_matrix_sq, [-1])
|
|
526
579
|
dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
|
|
527
580
|
|
|
528
|
-
self._neighbor_maps =
|
|
581
|
+
self._neighbor_maps = self._build_neighbor_map_from_distances(
|
|
582
|
+
dist_matrix_sq, dist_shells_sq, tol
|
|
583
|
+
)
|
|
584
|
+
self._distance_matrix = backend.sqrt(dist_matrix_sq)
|
|
585
|
+
|
|
586
|
+
def _build_neighbor_map_from_distances(
|
|
587
|
+
self,
|
|
588
|
+
dist_matrix_sq: Coordinates,
|
|
589
|
+
dist_shells_sq: List[float],
|
|
590
|
+
tol: float = 1e-6,
|
|
591
|
+
) -> Dict[int, NeighborMap]:
|
|
592
|
+
"""
|
|
593
|
+
Builds a neighbor map from a squared distance matrix and identified shells.
|
|
594
|
+
This is a generic helper function to reduce code duplication.
|
|
595
|
+
"""
|
|
596
|
+
neighbor_maps: Dict[int, NeighborMap] = {
|
|
597
|
+
k: {} for k in range(1, len(dist_shells_sq) + 1)
|
|
598
|
+
}
|
|
529
599
|
for k_idx, target_d_sq in enumerate(dist_shells_sq):
|
|
530
600
|
k = k_idx + 1
|
|
531
601
|
current_k_map: Dict[int, List[int]] = {}
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
602
|
+
# For each shell, find all pairs of sites (i, j) with that distance.
|
|
603
|
+
is_close_matrix = backend.abs(dist_matrix_sq - target_d_sq) < tol
|
|
604
|
+
rows, cols = backend.where(is_close_matrix)
|
|
605
|
+
|
|
606
|
+
for i, j in zip(backend.numpy(rows), backend.numpy(cols)):
|
|
607
|
+
if i == j:
|
|
608
|
+
continue
|
|
609
|
+
if i not in current_k_map:
|
|
610
|
+
current_k_map[i] = []
|
|
611
|
+
current_k_map[i].append(j)
|
|
612
|
+
|
|
613
|
+
for i in current_k_map:
|
|
614
|
+
current_k_map[i].sort()
|
|
615
|
+
|
|
616
|
+
neighbor_maps[k] = current_k_map
|
|
617
|
+
return neighbor_maps
|
|
540
618
|
|
|
541
619
|
|
|
542
620
|
class TILattice(AbstractLattice):
|
|
@@ -587,150 +665,197 @@ class TILattice(AbstractLattice):
|
|
|
587
665
|
):
|
|
588
666
|
"""Initializes the Translationally Invariant Lattice."""
|
|
589
667
|
super().__init__(dimensionality)
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
self.
|
|
600
|
-
|
|
601
|
-
|
|
668
|
+
|
|
669
|
+
self.lattice_vectors = backend.convert_to_tensor(lattice_vectors)
|
|
670
|
+
self.basis_coords = backend.convert_to_tensor(basis_coords)
|
|
671
|
+
|
|
672
|
+
if self.lattice_vectors.shape != (dimensionality, dimensionality):
|
|
673
|
+
raise ValueError(
|
|
674
|
+
f"Lattice vectors shape {self.lattice_vectors.shape} does not match "
|
|
675
|
+
f"expected ({dimensionality}, {dimensionality})"
|
|
676
|
+
)
|
|
677
|
+
if self.basis_coords.shape[1] != dimensionality:
|
|
678
|
+
raise ValueError(
|
|
679
|
+
f"Basis coordinates dimension {self.basis_coords.shape[1]} does not "
|
|
680
|
+
f"match lattice dimensionality {dimensionality}"
|
|
681
|
+
)
|
|
682
|
+
if len(size) != dimensionality:
|
|
683
|
+
raise ValueError(
|
|
684
|
+
f"Size tuple length {len(size)} does not match dimensionality {dimensionality}"
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
self.num_basis = self.basis_coords.shape[0]
|
|
602
688
|
self.size = size
|
|
603
689
|
if isinstance(pbc, bool):
|
|
604
690
|
self.pbc = tuple([pbc] * dimensionality)
|
|
605
691
|
else:
|
|
606
|
-
|
|
692
|
+
if len(pbc) != dimensionality:
|
|
693
|
+
raise ValueError(
|
|
694
|
+
f"PBC tuple length {len(pbc)} does not match dimensionality {dimensionality}"
|
|
695
|
+
)
|
|
607
696
|
self.pbc = tuple(pbc)
|
|
608
697
|
|
|
609
|
-
# Build the lattice sites and their neighbor relationships
|
|
610
698
|
self._build_lattice()
|
|
611
699
|
if precompute_neighbors is not None and precompute_neighbors > 0:
|
|
612
700
|
logger.info(f"Pre-computing neighbors up to k={precompute_neighbors}...")
|
|
613
701
|
self._build_neighbors(max_k=precompute_neighbors)
|
|
614
702
|
|
|
615
703
|
def _build_lattice(self) -> None:
|
|
616
|
-
"""Generates all site information for the periodic lattice.
|
|
617
|
-
|
|
618
|
-
This method iterates through each unit cell defined by `self.size`,
|
|
619
|
-
and for each unit cell, it iterates through all basis sites. It then
|
|
620
|
-
calculates the real-space coordinates and creates a unique identifier
|
|
621
|
-
for each site, populating the internal lattice data structures.
|
|
622
704
|
"""
|
|
623
|
-
|
|
705
|
+
Generates all site information for the periodic lattice in a vectorized manner.
|
|
706
|
+
"""
|
|
707
|
+
ranges = [backend.arange(s) for s in self.size]
|
|
624
708
|
|
|
625
|
-
#
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
709
|
+
# Generate a grid of all integer unit cell coordinates.
|
|
710
|
+
grid = backend.meshgrid(*ranges, indexing="ij")
|
|
711
|
+
all_cell_coords = backend.reshape(
|
|
712
|
+
backend.stack(grid, axis=-1), (-1, self.dimensionality)
|
|
713
|
+
)
|
|
630
714
|
|
|
631
|
-
|
|
632
|
-
for basis_index in range(self.num_basis):
|
|
633
|
-
basis_vec = self.basis_coords[basis_index]
|
|
715
|
+
all_cell_coords = backend.cast(all_cell_coords, self.lattice_vectors.dtype)
|
|
634
716
|
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
717
|
+
cell_vectors = backend.tensordot(
|
|
718
|
+
all_cell_coords, self.lattice_vectors, axes=[[1], [0]]
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
cell_vectors = backend.cast(cell_vectors, self.basis_coords.dtype)
|
|
722
|
+
|
|
723
|
+
# Combine cell vectors with basis coordinates to get all site positions
|
|
724
|
+
# via broadcasting: (num_cells, 1, D) + (1, num_basis, D) -> (num_cells, num_basis, D)
|
|
725
|
+
all_coords = backend.expand_dims(cell_vectors, 1) + backend.expand_dims(
|
|
726
|
+
self.basis_coords, 0
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
self._coordinates = backend.reshape(all_coords, (-1, self.dimensionality))
|
|
730
|
+
|
|
731
|
+
self._indices = []
|
|
732
|
+
self._identifiers = []
|
|
733
|
+
self._ident_to_idx = {}
|
|
734
|
+
current_index = 0
|
|
639
735
|
|
|
640
|
-
|
|
736
|
+
# Generate integer indices and tuple-based identifiers for all sites.
|
|
737
|
+
# e.g., identifier = (uc_x, uc_y, basis_idx)
|
|
738
|
+
size_ranges = [range(s) for s in self.size]
|
|
739
|
+
for cell_coord_tuple in itertools.product(*size_ranges):
|
|
740
|
+
for basis_index in range(self.num_basis):
|
|
741
|
+
identifier = cell_coord_tuple + (basis_index,)
|
|
641
742
|
self._indices.append(current_index)
|
|
642
743
|
self._identifiers.append(identifier)
|
|
643
|
-
self._coordinates.append(coord)
|
|
644
744
|
self._ident_to_idx[identifier] = current_index
|
|
645
745
|
current_index += 1
|
|
646
746
|
|
|
647
|
-
def
|
|
747
|
+
def _get_distance_matrix_with_mic_vectorized(self) -> Coordinates:
|
|
648
748
|
"""
|
|
649
|
-
Computes the full N x N distance matrix
|
|
650
|
-
Minimum Image Convention (MIC) for
|
|
749
|
+
Computes the full N x N distance matrix using a fully vectorized approach
|
|
750
|
+
that correctly applies the Minimum Image Convention (MIC) for periodic
|
|
751
|
+
boundary conditions.
|
|
752
|
+
|
|
753
|
+
This method uses full vectorization for optimal performance and compatibility
|
|
754
|
+
with JIT compilation frameworks like JAX. The implementation processes all
|
|
755
|
+
site pairs simultaneously rather than iterating row-by-row, which provides:
|
|
756
|
+
|
|
757
|
+
- Better performance through vectorized operations
|
|
758
|
+
- Full compatibility with automatic differentiation
|
|
759
|
+
- JIT compilation support (e.g., JAX, TensorFlow)
|
|
760
|
+
- Consistent tensor operations throughout
|
|
761
|
+
|
|
762
|
+
The trade-off is higher memory usage compared to iterative approaches,
|
|
763
|
+
as it computes all pairwise distances simultaneously. For very large
|
|
764
|
+
lattices (N > 10^4 sites), memory usage scales as O(N^2).
|
|
765
|
+
|
|
766
|
+
:return: Distance matrix with shape (N, N) where entry (i,j) is the
|
|
767
|
+
minimum distance between sites i and j under periodic boundary conditions.
|
|
768
|
+
:rtype: Coordinates
|
|
651
769
|
"""
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
#
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
if
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
translations.extend(pbc_translations)
|
|
671
|
-
|
|
672
|
-
translations_arr = np.array(translations, dtype=float)
|
|
673
|
-
|
|
674
|
-
# Calculate the distance matrix applying MIC
|
|
675
|
-
dist_matrix_sq = np.full((self.num_sites, self.num_sites), np.inf, dtype=float)
|
|
676
|
-
for i in range(self.num_sites):
|
|
677
|
-
displacements = all_coords - all_coords[i]
|
|
678
|
-
image_displacements = (
|
|
679
|
-
displacements[:, np.newaxis, :] - translations_arr[np.newaxis, :, :]
|
|
680
|
-
)
|
|
681
|
-
image_d_sq = np.sum(image_displacements**2, axis=2)
|
|
682
|
-
dist_matrix_sq[i, :] = np.min(image_d_sq, axis=1)
|
|
770
|
+
# Ensure dtype consistency across backends (especially torch) by explicitly
|
|
771
|
+
# casting size and lattice_vectors to the same floating dtype used internally.
|
|
772
|
+
# Strategy: prefer existing lattice_vectors dtype; if it's an unusual dtype,
|
|
773
|
+
# fall back to float32 to avoid mixed-precision issues in vectorized ops.
|
|
774
|
+
# Note: `self.lattice_vectors` is always created via `backend.convert_to_tensor`
|
|
775
|
+
# in __init__, so `backend.dtype(...)` is reliable here and doesn't need try/except.
|
|
776
|
+
target_dt = str(backend.dtype(self.lattice_vectors)) # type: ignore
|
|
777
|
+
if target_dt not in ("float32", "float64"):
|
|
778
|
+
# fallback for unusual dtypes
|
|
779
|
+
target_dt = "float32"
|
|
780
|
+
|
|
781
|
+
size_arr = backend.cast(backend.convert_to_tensor(self.size), target_dt)
|
|
782
|
+
lattice_vecs = backend.cast(
|
|
783
|
+
backend.convert_to_tensor(self.lattice_vectors), target_dt
|
|
784
|
+
)
|
|
785
|
+
system_vectors = lattice_vecs * backend.expand_dims(size_arr, axis=1)
|
|
786
|
+
|
|
787
|
+
pbc_mask = backend.convert_to_tensor(self.pbc)
|
|
683
788
|
|
|
684
|
-
|
|
789
|
+
# Generate all 3^d possible image shifts (-1, 0, 1) for all dimensions
|
|
790
|
+
shift_options = [
|
|
791
|
+
backend.convert_to_tensor([-1.0, 0.0, 1.0])
|
|
792
|
+
] * self.dimensionality
|
|
793
|
+
shifts_grid = backend.meshgrid(*shift_options, indexing="ij")
|
|
794
|
+
all_shifts = backend.reshape(
|
|
795
|
+
backend.stack(shifts_grid, axis=-1), (-1, self.dimensionality)
|
|
796
|
+
)
|
|
797
|
+
|
|
798
|
+
# Only apply shifts to periodic dimensions
|
|
799
|
+
masked_shifts = all_shifts * backend.cast(pbc_mask, all_shifts.dtype)
|
|
800
|
+
|
|
801
|
+
# Calculate all translation vectors due to PBC
|
|
802
|
+
translations_arr = backend.tensordot(
|
|
803
|
+
masked_shifts, system_vectors, axes=[[1], [0]]
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
# Vectorized computation of all displacements between any two sites
|
|
807
|
+
# Shape: (N, 1, D) - (1, N, D) -> (N, N, D)
|
|
808
|
+
displacements = backend.expand_dims(self._coordinates, 1) - backend.expand_dims(
|
|
809
|
+
self._coordinates, 0
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
# Consider all periodic images for each displacement
|
|
813
|
+
# Shape: (N, N, 1, D) - (1, 1, num_translations, D) -> (N, N, num_translations, D)
|
|
814
|
+
image_displacements = backend.expand_dims(
|
|
815
|
+
displacements, 2
|
|
816
|
+
) - backend.expand_dims(backend.expand_dims(translations_arr, 0), 0)
|
|
817
|
+
|
|
818
|
+
# Sum of squares for distances
|
|
819
|
+
image_d_sq = backend.sum(image_displacements**2, axis=3)
|
|
820
|
+
|
|
821
|
+
# Find the minimum distance among all images (Minimum Image Convention)
|
|
822
|
+
min_dist_sq = backend.min(image_d_sq, axis=2)
|
|
823
|
+
|
|
824
|
+
safe_dist_matrix_sq = backend.where(min_dist_sq > 0, min_dist_sq, 0.0)
|
|
825
|
+
return backend.sqrt(safe_dist_matrix_sq)
|
|
685
826
|
|
|
686
827
|
def _build_neighbors(self, max_k: int = 2, **kwargs: Any) -> None:
|
|
687
828
|
"""Calculates neighbor relationships for the periodic lattice.
|
|
688
829
|
|
|
689
|
-
This method
|
|
690
|
-
distance matrix
|
|
691
|
-
periodic
|
|
692
|
-
(
|
|
830
|
+
This method computes neighbor information by first calculating the full
|
|
831
|
+
distance matrix using the Minimum Image Convention (MIC) to correctly
|
|
832
|
+
handle periodic boundary conditions. It then identifies unique distance
|
|
833
|
+
shells (e.g., nearest, next-nearest) and populates the neighbor maps
|
|
834
|
+
accordingly. This approach is general and works for any periodic lattice
|
|
835
|
+
geometry defined by the TILattice class.
|
|
693
836
|
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
distance matrix is then cached for future use.
|
|
697
|
-
|
|
698
|
-
:param max_k: The maximum number of neighbor shells to
|
|
699
|
-
calculate. Defaults to 2.
|
|
837
|
+
:param max_k: The maximum order of neighbors to compute (e.g., k=1 for
|
|
838
|
+
nearest neighbors, k=2 for next-nearest, etc.). Defaults to 2.
|
|
700
839
|
:type max_k: int, optional
|
|
701
|
-
:param
|
|
702
|
-
|
|
703
|
-
|
|
840
|
+
:param kwargs: Additional keyword arguments. May include:
|
|
841
|
+
- ``tol`` (float): The numerical tolerance used to determine if two
|
|
842
|
+
distances are equal when identifying shells. Defaults to 1e-6.
|
|
704
843
|
"""
|
|
705
844
|
tol = kwargs.get("tol", 1e-6)
|
|
706
|
-
dist_matrix = self.
|
|
845
|
+
dist_matrix = self._get_distance_matrix_with_mic_vectorized()
|
|
707
846
|
dist_matrix_sq = dist_matrix**2
|
|
708
847
|
self._distance_matrix = dist_matrix
|
|
709
|
-
all_distances_sq =
|
|
848
|
+
all_distances_sq = backend.reshape(dist_matrix_sq, [-1])
|
|
710
849
|
dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
k = k_idx + 1
|
|
715
|
-
current_k_map: Dict[int, List[int]] = {}
|
|
716
|
-
match_indices = np.where(
|
|
717
|
-
np.isclose(dist_matrix_sq, target_d_sq, rtol=0, atol=tol**2)
|
|
718
|
-
)
|
|
719
|
-
for i, j in zip(*match_indices):
|
|
720
|
-
if i == j:
|
|
721
|
-
continue
|
|
722
|
-
if i not in current_k_map:
|
|
723
|
-
current_k_map[i] = []
|
|
724
|
-
current_k_map[i].append(j)
|
|
725
|
-
|
|
726
|
-
for i in current_k_map:
|
|
727
|
-
current_k_map[i].sort()
|
|
728
|
-
|
|
729
|
-
self._neighbor_maps[k] = current_k_map
|
|
850
|
+
self._neighbor_maps = self._build_neighbor_map_from_distances(
|
|
851
|
+
dist_matrix_sq, dist_shells_sq, tol
|
|
852
|
+
)
|
|
730
853
|
|
|
731
854
|
def _compute_distance_matrix(self) -> Coordinates:
|
|
732
855
|
"""Computes the distance matrix using the Minimum Image Convention."""
|
|
733
|
-
|
|
856
|
+
if self.num_sites == 0:
|
|
857
|
+
return backend.zeros((0, 0))
|
|
858
|
+
return self._get_distance_matrix_with_mic_vectorized()
|
|
734
859
|
|
|
735
860
|
|
|
736
861
|
class SquareLattice(TILattice):
|
|
@@ -758,20 +883,24 @@ class SquareLattice(TILattice):
|
|
|
758
883
|
def __init__(
|
|
759
884
|
self,
|
|
760
885
|
size: Tuple[int, int],
|
|
761
|
-
lattice_constant: float = 1.0,
|
|
886
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
762
887
|
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
763
888
|
precompute_neighbors: Optional[int] = None,
|
|
764
889
|
):
|
|
765
890
|
"""Initializes the SquareLattice."""
|
|
766
891
|
dimensionality = 2
|
|
892
|
+
# Define orthogonal lattice vectors for a square.
|
|
893
|
+
# Avoid mixing Python floats with backend Tensors (TF would error),
|
|
894
|
+
# so first convert inputs to tensors of a unified dtype, then stack.
|
|
895
|
+
lc = backend.convert_to_tensor(lattice_constant)
|
|
896
|
+
dt = backend.dtype(lc)
|
|
897
|
+
z = backend.cast(backend.convert_to_tensor(0.0), dt)
|
|
898
|
+
row1 = backend.stack([lc, z])
|
|
899
|
+
row2 = backend.stack([z, lc])
|
|
900
|
+
lattice_vectors = backend.stack([row1, row2])
|
|
901
|
+
# A square lattice is a Bravais lattice, so it has a single-site basis.
|
|
902
|
+
basis_coords = backend.stack([backend.stack([z, z])])
|
|
767
903
|
|
|
768
|
-
# Define lattice vectors for a square lattice
|
|
769
|
-
lattice_vectors = np.array([[lattice_constant, 0.0], [0.0, lattice_constant]])
|
|
770
|
-
|
|
771
|
-
# A square lattice has a single site in its basis
|
|
772
|
-
basis_coords = np.array([[0.0, 0.0]])
|
|
773
|
-
|
|
774
|
-
# Call the parent TILattice constructor with these parameters
|
|
775
904
|
super().__init__(
|
|
776
905
|
dimensionality=dimensionality,
|
|
777
906
|
lattice_vectors=lattice_vectors,
|
|
@@ -807,19 +936,28 @@ class HoneycombLattice(TILattice):
|
|
|
807
936
|
def __init__(
|
|
808
937
|
self,
|
|
809
938
|
size: Tuple[int, int],
|
|
810
|
-
lattice_constant: float = 1.0,
|
|
939
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
811
940
|
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
812
941
|
precompute_neighbors: Optional[int] = None,
|
|
813
942
|
):
|
|
814
943
|
"""Initializes the HoneycombLattice."""
|
|
815
944
|
dimensionality = 2
|
|
816
945
|
a = lattice_constant
|
|
946
|
+
a_t = backend.convert_to_tensor(a)
|
|
947
|
+
zero = a_t * 0.0
|
|
817
948
|
|
|
818
|
-
# Define the primitive lattice vectors for the underlying triangular lattice
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
949
|
+
# Define the two primitive lattice vectors for the underlying triangular Bravais lattice.
|
|
950
|
+
rt3_over_2 = math.sqrt(3.0) / 2.0
|
|
951
|
+
lattice_vectors = backend.stack(
|
|
952
|
+
[
|
|
953
|
+
backend.stack([a_t * 1.5, a_t * rt3_over_2]),
|
|
954
|
+
backend.stack([a_t * 1.5, -a_t * rt3_over_2]),
|
|
955
|
+
]
|
|
956
|
+
)
|
|
957
|
+
# Define the two basis sites (A and B) within the unit cell.
|
|
958
|
+
basis_coords = backend.stack(
|
|
959
|
+
[backend.stack([zero, zero]), backend.stack([a_t * 1.0, zero])]
|
|
960
|
+
)
|
|
823
961
|
|
|
824
962
|
super().__init__(
|
|
825
963
|
dimensionality=dimensionality,
|
|
@@ -854,19 +992,30 @@ class TriangularLattice(TILattice):
|
|
|
854
992
|
def __init__(
|
|
855
993
|
self,
|
|
856
994
|
size: Tuple[int, int],
|
|
857
|
-
lattice_constant: float = 1.0,
|
|
995
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
858
996
|
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
859
997
|
precompute_neighbors: Optional[int] = None,
|
|
860
998
|
):
|
|
861
999
|
"""Initializes the TriangularLattice."""
|
|
862
1000
|
dimensionality = 2
|
|
863
1001
|
a = lattice_constant
|
|
1002
|
+
a_t = backend.convert_to_tensor(a)
|
|
1003
|
+
zero = a_t * 0.0
|
|
864
1004
|
|
|
865
|
-
# Define the primitive lattice vectors for a triangular lattice
|
|
866
|
-
lattice_vectors =
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
1005
|
+
# Define the primitive lattice vectors for a triangular lattice.
|
|
1006
|
+
lattice_vectors = backend.stack(
|
|
1007
|
+
[
|
|
1008
|
+
backend.stack([a_t * 1.0, zero]),
|
|
1009
|
+
backend.stack(
|
|
1010
|
+
[
|
|
1011
|
+
a_t * 0.5,
|
|
1012
|
+
a_t * backend.sqrt(backend.convert_to_tensor(3.0)) / 2.0,
|
|
1013
|
+
]
|
|
1014
|
+
),
|
|
1015
|
+
]
|
|
1016
|
+
)
|
|
1017
|
+
# A triangular lattice is a Bravais lattice with a single-site basis.
|
|
1018
|
+
basis_coords = backend.stack([backend.stack([zero, zero])])
|
|
870
1019
|
|
|
871
1020
|
super().__init__(
|
|
872
1021
|
dimensionality=dimensionality,
|
|
@@ -895,13 +1044,18 @@ class ChainLattice(TILattice):
|
|
|
895
1044
|
def __init__(
|
|
896
1045
|
self,
|
|
897
1046
|
size: Tuple[int],
|
|
898
|
-
lattice_constant: float = 1.0,
|
|
1047
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
899
1048
|
pbc: bool = True,
|
|
900
1049
|
precompute_neighbors: Optional[int] = None,
|
|
901
1050
|
):
|
|
902
1051
|
dimensionality = 1
|
|
903
|
-
|
|
904
|
-
|
|
1052
|
+
# The lattice vector is just the lattice constant along one dimension.
|
|
1053
|
+
lc = backend.convert_to_tensor(lattice_constant)
|
|
1054
|
+
lattice_vectors = backend.stack([backend.stack([lc])])
|
|
1055
|
+
# A simple chain is a Bravais lattice with a single-site basis.
|
|
1056
|
+
zero = lc * 0.0
|
|
1057
|
+
basis_coords = backend.stack([backend.stack([zero])])
|
|
1058
|
+
|
|
905
1059
|
super().__init__(
|
|
906
1060
|
dimensionality=dimensionality,
|
|
907
1061
|
lattice_vectors=lattice_vectors,
|
|
@@ -933,15 +1087,17 @@ class DimerizedChainLattice(TILattice):
|
|
|
933
1087
|
def __init__(
|
|
934
1088
|
self,
|
|
935
1089
|
size: Tuple[int],
|
|
936
|
-
lattice_constant: float = 1.0,
|
|
1090
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
937
1091
|
pbc: bool = True,
|
|
938
1092
|
precompute_neighbors: Optional[int] = None,
|
|
939
1093
|
):
|
|
940
1094
|
dimensionality = 1
|
|
941
|
-
# The unit cell
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
1095
|
+
# The unit cell is twice the bond length, as it contains two sites.
|
|
1096
|
+
lc = backend.convert_to_tensor(lattice_constant)
|
|
1097
|
+
lattice_vectors = backend.stack([backend.stack([2 * lc])])
|
|
1098
|
+
# Two basis sites (A and B) separated by the bond length.
|
|
1099
|
+
zero = lc * 0.0
|
|
1100
|
+
basis_coords = backend.stack([backend.stack([zero]), backend.stack([lc])])
|
|
945
1101
|
|
|
946
1102
|
super().__init__(
|
|
947
1103
|
dimensionality=dimensionality,
|
|
@@ -974,14 +1130,22 @@ class RectangularLattice(TILattice):
|
|
|
974
1130
|
def __init__(
|
|
975
1131
|
self,
|
|
976
1132
|
size: Tuple[int, int],
|
|
977
|
-
lattice_constants: Tuple[float, float] = (1.0, 1.0),
|
|
1133
|
+
lattice_constants: Union[Tuple[float, float], Any] = (1.0, 1.0),
|
|
978
1134
|
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
979
1135
|
precompute_neighbors: Optional[int] = None,
|
|
980
1136
|
):
|
|
981
1137
|
dimensionality = 2
|
|
982
1138
|
ax, ay = lattice_constants
|
|
983
|
-
|
|
984
|
-
|
|
1139
|
+
ax_t = backend.convert_to_tensor(ax)
|
|
1140
|
+
dt = backend.dtype(ax_t)
|
|
1141
|
+
ay_t = backend.cast(backend.convert_to_tensor(ay), dt)
|
|
1142
|
+
z = backend.cast(backend.convert_to_tensor(0.0), dt)
|
|
1143
|
+
# Orthogonal lattice vectors with potentially different lengths.
|
|
1144
|
+
row1 = backend.stack([ax_t, z])
|
|
1145
|
+
row2 = backend.stack([z, ay_t])
|
|
1146
|
+
lattice_vectors = backend.stack([row1, row2])
|
|
1147
|
+
# A rectangular lattice is a Bravais lattice with a single-site basis.
|
|
1148
|
+
basis_coords = backend.stack([backend.stack([z, z])])
|
|
985
1149
|
|
|
986
1150
|
super().__init__(
|
|
987
1151
|
dimensionality=dimensionality,
|
|
@@ -1012,16 +1176,26 @@ class CheckerboardLattice(TILattice):
|
|
|
1012
1176
|
def __init__(
|
|
1013
1177
|
self,
|
|
1014
1178
|
size: Tuple[int, int],
|
|
1015
|
-
lattice_constant: float = 1.0,
|
|
1179
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
1016
1180
|
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
1017
1181
|
precompute_neighbors: Optional[int] = None,
|
|
1018
1182
|
):
|
|
1019
1183
|
dimensionality = 2
|
|
1020
1184
|
a = lattice_constant
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1185
|
+
a_t = backend.convert_to_tensor(a)
|
|
1186
|
+
# The unit cell is a square rotated by 45 degrees.
|
|
1187
|
+
lattice_vectors = backend.stack(
|
|
1188
|
+
[
|
|
1189
|
+
backend.stack([a_t * 1.0, a_t * 1.0]),
|
|
1190
|
+
backend.stack([a_t * 1.0, a_t * -1.0]),
|
|
1191
|
+
]
|
|
1192
|
+
)
|
|
1193
|
+
# Two basis sites (A and B) within the unit cell.
|
|
1194
|
+
zero = a_t * 0.0
|
|
1195
|
+
basis_coords = backend.stack(
|
|
1196
|
+
[backend.stack([zero, zero]), backend.stack([a_t * 1.0, zero])]
|
|
1197
|
+
)
|
|
1198
|
+
|
|
1025
1199
|
super().__init__(
|
|
1026
1200
|
dimensionality=dimensionality,
|
|
1027
1201
|
lattice_vectors=lattice_vectors,
|
|
@@ -1051,16 +1225,30 @@ class KagomeLattice(TILattice):
|
|
|
1051
1225
|
def __init__(
|
|
1052
1226
|
self,
|
|
1053
1227
|
size: Tuple[int, int],
|
|
1054
|
-
lattice_constant: float = 1.0,
|
|
1228
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
1055
1229
|
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
1056
1230
|
precompute_neighbors: Optional[int] = None,
|
|
1057
1231
|
):
|
|
1058
1232
|
dimensionality = 2
|
|
1059
1233
|
a = lattice_constant
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1234
|
+
a_t = backend.convert_to_tensor(a)
|
|
1235
|
+
# The Kagome lattice is based on a triangular Bravais lattice.
|
|
1236
|
+
lattice_vectors = backend.stack(
|
|
1237
|
+
[
|
|
1238
|
+
backend.stack([a_t * 2.0, a_t * 0.0]),
|
|
1239
|
+
backend.stack([a_t * 1.0, a_t * backend.sqrt(3.0)]),
|
|
1240
|
+
]
|
|
1241
|
+
)
|
|
1242
|
+
# It has a three-site basis, forming the corners of the triangles.
|
|
1243
|
+
zero = a_t * 0.0
|
|
1244
|
+
basis_coords = backend.stack(
|
|
1245
|
+
[
|
|
1246
|
+
backend.stack([zero, zero]),
|
|
1247
|
+
backend.stack([a_t * 1.0, zero]),
|
|
1248
|
+
backend.stack([a_t * 0.5, a_t * backend.sqrt(3.0) / 2.0]),
|
|
1249
|
+
]
|
|
1250
|
+
)
|
|
1251
|
+
|
|
1064
1252
|
super().__init__(
|
|
1065
1253
|
dimensionality=dimensionality,
|
|
1066
1254
|
lattice_vectors=lattice_vectors,
|
|
@@ -1091,29 +1279,26 @@ class LiebLattice(TILattice):
|
|
|
1091
1279
|
def __init__(
|
|
1092
1280
|
self,
|
|
1093
1281
|
size: Tuple[int, int],
|
|
1094
|
-
lattice_constant: float = 1.0,
|
|
1282
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
1095
1283
|
pbc: Union[bool, Tuple[bool, bool]] = True,
|
|
1096
1284
|
precompute_neighbors: Optional[int] = None,
|
|
1097
1285
|
):
|
|
1098
1286
|
"""Initializes the LiebLattice."""
|
|
1099
1287
|
dimensionality = 2
|
|
1100
|
-
# Use a more descriptive name for clarity. In a Lieb lattice,
|
|
1101
|
-
# the lattice_constant is the bond length between nearest neighbors.
|
|
1102
1288
|
bond_length = lattice_constant
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
#
|
|
1106
|
-
|
|
1107
|
-
lattice_vectors =
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
#
|
|
1111
|
-
|
|
1112
|
-
basis_coords = np.array(
|
|
1289
|
+
bl_t = backend.convert_to_tensor(bond_length)
|
|
1290
|
+
unit_cell_side_t = 2 * bl_t
|
|
1291
|
+
# The Lieb lattice is based on a square Bravais lattice.
|
|
1292
|
+
z = bl_t * 0.0
|
|
1293
|
+
lattice_vectors = backend.stack(
|
|
1294
|
+
[backend.stack([unit_cell_side_t, z]), backend.stack([z, unit_cell_side_t])]
|
|
1295
|
+
)
|
|
1296
|
+
# It has a three-site basis: one corner and two edge-centers.
|
|
1297
|
+
basis_coords = backend.stack(
|
|
1113
1298
|
[
|
|
1114
|
-
[
|
|
1115
|
-
[
|
|
1116
|
-
[
|
|
1299
|
+
backend.stack([z, z]), # Corner site
|
|
1300
|
+
backend.stack([bl_t, z]), # x-edge center
|
|
1301
|
+
backend.stack([z, bl_t]), # y-edge center
|
|
1117
1302
|
]
|
|
1118
1303
|
)
|
|
1119
1304
|
|
|
@@ -1146,14 +1331,24 @@ class CubicLattice(TILattice):
|
|
|
1146
1331
|
def __init__(
|
|
1147
1332
|
self,
|
|
1148
1333
|
size: Tuple[int, int, int],
|
|
1149
|
-
lattice_constant: float = 1.0,
|
|
1334
|
+
lattice_constant: Union[float, Any] = 1.0,
|
|
1150
1335
|
pbc: Union[bool, Tuple[bool, bool, bool]] = True,
|
|
1151
1336
|
precompute_neighbors: Optional[int] = None,
|
|
1152
1337
|
):
|
|
1153
1338
|
dimensionality = 3
|
|
1154
1339
|
a = lattice_constant
|
|
1155
|
-
|
|
1156
|
-
|
|
1340
|
+
a_t = backend.convert_to_tensor(a)
|
|
1341
|
+
# Orthogonal lattice vectors of equal length in 3D.
|
|
1342
|
+
z = a_t * 0.0
|
|
1343
|
+
lattice_vectors = backend.stack(
|
|
1344
|
+
[
|
|
1345
|
+
backend.stack([a_t, z, z]),
|
|
1346
|
+
backend.stack([z, a_t, z]),
|
|
1347
|
+
backend.stack([z, z, a_t]),
|
|
1348
|
+
]
|
|
1349
|
+
)
|
|
1350
|
+
# A simple cubic lattice is a Bravais lattice with a single-site basis.
|
|
1351
|
+
basis_coords = backend.stack([backend.stack([z, z, z])])
|
|
1157
1352
|
super().__init__(
|
|
1158
1353
|
dimensionality=dimensionality,
|
|
1159
1354
|
lattice_vectors=lattice_vectors,
|
|
@@ -1193,29 +1388,37 @@ class CustomizeLattice(AbstractLattice):
|
|
|
1193
1388
|
self,
|
|
1194
1389
|
dimensionality: int,
|
|
1195
1390
|
identifiers: List[SiteIdentifier],
|
|
1196
|
-
coordinates:
|
|
1391
|
+
coordinates: Any,
|
|
1197
1392
|
precompute_neighbors: Optional[int] = None,
|
|
1198
1393
|
):
|
|
1199
1394
|
"""Initializes the CustomizeLattice."""
|
|
1200
1395
|
super().__init__(dimensionality)
|
|
1201
|
-
|
|
1396
|
+
|
|
1397
|
+
self._coordinates = backend.convert_to_tensor(coordinates)
|
|
1398
|
+
if len(identifiers) == 0:
|
|
1399
|
+
self._coordinates = backend.reshape(
|
|
1400
|
+
self._coordinates, (0, self.dimensionality)
|
|
1401
|
+
)
|
|
1402
|
+
|
|
1403
|
+
if len(identifiers) != backend.shape_tuple(self._coordinates)[0]:
|
|
1202
1404
|
raise ValueError(
|
|
1203
|
-
"
|
|
1405
|
+
"The number of identifiers must match the number of coordinates. "
|
|
1406
|
+
f"Got {len(identifiers)} identifiers and "
|
|
1407
|
+
f"{backend.shape_tuple(self._coordinates)[0]} coordinates."
|
|
1204
1408
|
)
|
|
1205
1409
|
|
|
1206
|
-
# The _build_lattice logic is simple enough to be in __init__
|
|
1207
1410
|
self._identifiers = list(identifiers)
|
|
1208
|
-
self._coordinates = [np.array(c) for c in coordinates]
|
|
1209
1411
|
self._indices = list(range(len(identifiers)))
|
|
1210
1412
|
self._ident_to_idx = {ident: idx for idx, ident in enumerate(identifiers)}
|
|
1211
1413
|
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1414
|
+
if (
|
|
1415
|
+
self.num_sites > 0
|
|
1416
|
+
and backend.shape_tuple(self._coordinates)[1] != dimensionality
|
|
1417
|
+
):
|
|
1418
|
+
raise ValueError(
|
|
1419
|
+
f"Coordinates tensor has dimension {backend.shape_tuple(self._coordinates)[1]}, "
|
|
1420
|
+
f"but expected dimensionality is {dimensionality}."
|
|
1421
|
+
)
|
|
1219
1422
|
|
|
1220
1423
|
logger.info(f"CustomizeLattice with {self.num_sites} sites created.")
|
|
1221
1424
|
|
|
@@ -1227,95 +1430,170 @@ class CustomizeLattice(AbstractLattice):
|
|
|
1227
1430
|
pass
|
|
1228
1431
|
|
|
1229
1432
|
def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None:
|
|
1230
|
-
"""
|
|
1231
|
-
|
|
1232
|
-
This method uses a memory-efficient approach to identify neighbors without
|
|
1233
|
-
initially computing the full N x N distance matrix. It leverages
|
|
1234
|
-
`scipy.spatial.distance.pdist` to find unique distance shells and then
|
|
1235
|
-
a `scipy.spatial.KDTree` for fast radius queries. This approach is
|
|
1236
|
-
significantly more memory-efficient during the neighbor identification phase.
|
|
1433
|
+
"""
|
|
1434
|
+
Calculates neighbor relationships using either KDTree or distance matrix methods.
|
|
1237
1435
|
|
|
1238
|
-
|
|
1239
|
-
|
|
1436
|
+
This method supports two modes:
|
|
1437
|
+
1. KDTree mode (use_kdtree=True): Fast, O(N log N) performance for large lattices
|
|
1438
|
+
but breaks differentiability due to scipy dependency
|
|
1439
|
+
2. Distance matrix mode (use_kdtree=False): Slower O(N^2) but fully differentiable
|
|
1440
|
+
and backend-agnostic
|
|
1240
1441
|
|
|
1241
|
-
:param max_k:
|
|
1242
|
-
|
|
1243
|
-
:
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
:type tol: float, optional
|
|
1442
|
+
:param max_k: Maximum number of neighbor shells to compute
|
|
1443
|
+
:type max_k: int
|
|
1444
|
+
:param kwargs: Additional arguments including:
|
|
1445
|
+
- use_kdtree (bool): Whether to use KDTree optimization. Defaults to False.
|
|
1446
|
+
- tol (float): Distance tolerance for neighbor identification. Defaults to 1e-6.
|
|
1247
1447
|
"""
|
|
1248
1448
|
tol = kwargs.get("tol", 1e-6)
|
|
1249
|
-
|
|
1449
|
+
# Reviewer suggestion: prefer differentiable method by default
|
|
1450
|
+
use_kdtree = kwargs.get("use_kdtree", False)
|
|
1451
|
+
|
|
1250
1452
|
if self.num_sites < 2:
|
|
1251
1453
|
return
|
|
1252
1454
|
|
|
1253
|
-
|
|
1455
|
+
# Choose algorithm based on user preference
|
|
1456
|
+
if use_kdtree:
|
|
1457
|
+
logger.info(
|
|
1458
|
+
f"Using KDTree method for {self.num_sites} sites up to k={max_k}"
|
|
1459
|
+
)
|
|
1460
|
+
self._build_neighbors_kdtree(max_k, tol)
|
|
1461
|
+
else:
|
|
1462
|
+
logger.info(
|
|
1463
|
+
f"Using differentiable distance matrix method for {self.num_sites} sites up to k={max_k}"
|
|
1464
|
+
)
|
|
1254
1465
|
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
all_distances_sq = pdist(all_coords, metric="sqeuclidean")
|
|
1258
|
-
dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
|
|
1466
|
+
# Use the existing distance matrix method
|
|
1467
|
+
self._build_neighbors_by_distance_matrix(max_k, tol)
|
|
1259
1468
|
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1469
|
+
def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None:
|
|
1470
|
+
"""
|
|
1471
|
+
Build neighbors using KDTree for optimal performance.
|
|
1263
1472
|
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1473
|
+
This method provides O(N log N) performance for neighbor finding but breaks
|
|
1474
|
+
differentiability due to scipy dependency. Use this method when:
|
|
1475
|
+
- Performance is critical
|
|
1476
|
+
- Differentiability is not required
|
|
1477
|
+
- Large lattices (N > 1000)
|
|
1267
1478
|
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1479
|
+
Note: This method uses numpy arrays directly and may not be compatible
|
|
1480
|
+
with all backend types (JAX, TensorFlow, etc.).
|
|
1481
|
+
"""
|
|
1482
|
+
|
|
1483
|
+
# For small lattices or cases with potential duplicate coordinates,
|
|
1484
|
+
# fall back to distance matrix method for robustness
|
|
1485
|
+
if self.num_sites < 200:
|
|
1486
|
+
logger.info(
|
|
1487
|
+
"Small lattice detected, falling back to distance matrix method for robustness"
|
|
1276
1488
|
)
|
|
1489
|
+
self._build_neighbors_by_distance_matrix(max_k, tol)
|
|
1490
|
+
return
|
|
1277
1491
|
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
current_k_map: Dict[int, List[int]] = {}
|
|
1281
|
-
for i in range(self.num_sites):
|
|
1492
|
+
# Convert coordinates to numpy for KDTree
|
|
1493
|
+
coords_np = backend.numpy(self._coordinates)
|
|
1282
1494
|
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1495
|
+
# Build KDTree
|
|
1496
|
+
logger.info("Building KDTree...")
|
|
1497
|
+
tree = KDTree(coords_np)
|
|
1498
|
+
# Find all distances for shell identification - use comprehensive sampling
|
|
1499
|
+
logger.info("Identifying distance shells...")
|
|
1500
|
+
distances_for_shells: List[float] = []
|
|
1288
1501
|
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1502
|
+
# For robust shell identification, query all pairwise distances for smaller lattices
|
|
1503
|
+
# or use dense sampling for larger ones
|
|
1504
|
+
if self.num_sites <= 100:
|
|
1505
|
+
# For small lattices, compute all pairwise distances for accuracy
|
|
1506
|
+
for i in range(self.num_sites):
|
|
1507
|
+
query_k = min(self.num_sites - 1, max_k * 20)
|
|
1508
|
+
if query_k > 0:
|
|
1509
|
+
dists, _ = tree.query(
|
|
1510
|
+
coords_np[i], k=query_k + 1
|
|
1511
|
+
) # +1 to exclude self
|
|
1512
|
+
if isinstance(dists, np.ndarray):
|
|
1513
|
+
distances_for_shells.extend(dists[1:]) # Skip distance to self
|
|
1514
|
+
else:
|
|
1515
|
+
distances_for_shells.append(dists) # Single distance
|
|
1516
|
+
else:
|
|
1517
|
+
# For larger lattices, use adaptive sampling but ensure we capture all shells
|
|
1518
|
+
sample_size = min(1000, self.num_sites // 2) # More conservative sampling
|
|
1519
|
+
for i in range(0, self.num_sites, max(1, self.num_sites // sample_size)):
|
|
1520
|
+
query_k = min(max_k * 20 + 50, self.num_sites - 1)
|
|
1521
|
+
if query_k > 0:
|
|
1522
|
+
dists, _ = tree.query(
|
|
1523
|
+
coords_np[i], k=query_k + 1
|
|
1524
|
+
) # +1 to exclude self
|
|
1525
|
+
if isinstance(dists, np.ndarray):
|
|
1526
|
+
distances_for_shells.extend(dists[1:]) # Skip distance to self
|
|
1527
|
+
else:
|
|
1528
|
+
distances_for_shells.append(dists) # Single distance
|
|
1292
1529
|
|
|
1293
|
-
|
|
1294
|
-
|
|
1530
|
+
# Filter out zero distances (duplicate coordinates) before shell identification
|
|
1531
|
+
ZERO_THRESHOLD = 1e-12
|
|
1532
|
+
distances_for_shells = [d for d in distances_for_shells if d > ZERO_THRESHOLD]
|
|
1295
1533
|
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
self._distance_matrix = np.sqrt(squareform(all_distances_sq))
|
|
1534
|
+
if not distances_for_shells:
|
|
1535
|
+
logger.warning("No valid distances found for shell identification")
|
|
1536
|
+
self._neighbor_maps = {}
|
|
1537
|
+
return
|
|
1301
1538
|
|
|
1302
|
-
|
|
1539
|
+
# Use the same shell identification logic as distance matrix method
|
|
1540
|
+
distances_for_shells_sq = [d * d for d in distances_for_shells]
|
|
1541
|
+
dist_shells_sq = self._identify_distance_shells(
|
|
1542
|
+
distances_for_shells_sq, max_k, tol
|
|
1543
|
+
)
|
|
1544
|
+
dist_shells = [np.sqrt(d_sq) for d_sq in dist_shells_sq]
|
|
1303
1545
|
|
|
1304
|
-
|
|
1305
|
-
"""Computes the distance matrix from the stored coordinates.
|
|
1546
|
+
logger.info(f"Found {len(dist_shells)} distance shells: {dist_shells[:5]}...")
|
|
1306
1547
|
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
full square matrix.
|
|
1310
|
-
"""
|
|
1311
|
-
if self.num_sites < 2:
|
|
1312
|
-
return cast(Coordinates, np.empty((self.num_sites, self.num_sites)))
|
|
1548
|
+
# Initialize neighbor maps
|
|
1549
|
+
self._neighbor_maps = {k: {} for k in range(1, len(dist_shells) + 1)}
|
|
1313
1550
|
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
|
|
1551
|
+
# Build neighbor lists for each site
|
|
1552
|
+
for i in range(self.num_sites):
|
|
1553
|
+
# Query enough neighbors to capture all shells
|
|
1554
|
+
query_k = min(max_k * 20 + 50, self.num_sites - 1)
|
|
1555
|
+
if query_k > 0:
|
|
1556
|
+
distances, indices = tree.query(
|
|
1557
|
+
coords_np[i], k=query_k + 1
|
|
1558
|
+
) # +1 for self
|
|
1559
|
+
|
|
1560
|
+
# Skip the first entry (distance to self)
|
|
1561
|
+
# Handle both single value and array cases
|
|
1562
|
+
if isinstance(distances, np.ndarray) and len(distances) > 1:
|
|
1563
|
+
distances_slice = distances[1:]
|
|
1564
|
+
indices_slice = (
|
|
1565
|
+
indices[1:]
|
|
1566
|
+
if isinstance(indices, np.ndarray)
|
|
1567
|
+
else np.array([], dtype=int)
|
|
1568
|
+
)
|
|
1569
|
+
else:
|
|
1570
|
+
# Single value or empty case - no neighbors to process
|
|
1571
|
+
distances_slice = np.array([])
|
|
1572
|
+
indices_slice = np.array([], dtype=int)
|
|
1573
|
+
|
|
1574
|
+
# Filter out zero distances (duplicate coordinates)
|
|
1575
|
+
valid_pairs = [
|
|
1576
|
+
(d, idx)
|
|
1577
|
+
for d, idx in zip(distances_slice, indices_slice)
|
|
1578
|
+
if d > ZERO_THRESHOLD
|
|
1579
|
+
]
|
|
1580
|
+
|
|
1581
|
+
# Assign neighbors to shells
|
|
1582
|
+
for shell_idx, shell_dist in enumerate(dist_shells):
|
|
1583
|
+
k = shell_idx + 1
|
|
1584
|
+
shell_neighbors = []
|
|
1585
|
+
|
|
1586
|
+
for dist, neighbor_idx in valid_pairs:
|
|
1587
|
+
if abs(dist - shell_dist) <= tol:
|
|
1588
|
+
shell_neighbors.append(int(neighbor_idx))
|
|
1589
|
+
elif dist > shell_dist + tol:
|
|
1590
|
+
break # Distances are sorted, no more matches
|
|
1591
|
+
|
|
1592
|
+
if shell_neighbors:
|
|
1593
|
+
self._neighbor_maps[k][i] = sorted(shell_neighbors)
|
|
1594
|
+
|
|
1595
|
+
# Set distance matrix to None - will compute on demand
|
|
1596
|
+
self._distance_matrix = None
|
|
1319
1597
|
|
|
1320
1598
|
def _reset_computations(self) -> None:
|
|
1321
1599
|
"""Resets all cached data that depends on the lattice structure."""
|
|
@@ -1343,18 +1621,28 @@ class CustomizeLattice(AbstractLattice):
|
|
|
1343
1621
|
)
|
|
1344
1622
|
|
|
1345
1623
|
# Unzip the list of tuples into separate lists of identifiers and coordinates
|
|
1346
|
-
_, identifiers,
|
|
1624
|
+
_, identifiers, _ = zip(*all_sites_info)
|
|
1625
|
+
|
|
1626
|
+
# Detach-and-copy coordinates while remaining in tensor form to avoid
|
|
1627
|
+
# host roundtrips and device/dtype changes; this keeps CustomizeLattice
|
|
1628
|
+
# decoupled from the original graph but backend-friendly.
|
|
1629
|
+
# Some backends (e.g., NumPy) don't implement stop_gradient; fall back.
|
|
1630
|
+
try:
|
|
1631
|
+
coords_detached = backend.stop_gradient(lattice._coordinates)
|
|
1632
|
+
except NotImplementedError:
|
|
1633
|
+
coords_detached = lattice._coordinates
|
|
1634
|
+
coords_tensor = backend.copy(coords_detached)
|
|
1347
1635
|
|
|
1348
1636
|
return cls(
|
|
1349
1637
|
dimensionality=lattice.dimensionality,
|
|
1350
1638
|
identifiers=list(identifiers),
|
|
1351
|
-
coordinates=
|
|
1639
|
+
coordinates=coords_tensor,
|
|
1352
1640
|
)
|
|
1353
1641
|
|
|
1354
1642
|
def add_sites(
|
|
1355
1643
|
self,
|
|
1356
1644
|
identifiers: List[SiteIdentifier],
|
|
1357
|
-
coordinates:
|
|
1645
|
+
coordinates: Any,
|
|
1358
1646
|
) -> None:
|
|
1359
1647
|
"""Adds new sites to the lattice.
|
|
1360
1648
|
|
|
@@ -1362,21 +1650,29 @@ class CustomizeLattice(AbstractLattice):
|
|
|
1362
1650
|
previously computed neighbor information is cleared and must be
|
|
1363
1651
|
recalculated.
|
|
1364
1652
|
|
|
1365
|
-
:param identifiers: A list of unique
|
|
1653
|
+
:param identifiers: A list of unique identifiers for the new sites.
|
|
1366
1654
|
:type identifiers: List[SiteIdentifier]
|
|
1367
|
-
:param coordinates:
|
|
1368
|
-
|
|
1369
|
-
:
|
|
1370
|
-
identifier already exists in the lattice.
|
|
1655
|
+
:param coordinates: The coordinates for the new sites. Can be a list of lists,
|
|
1656
|
+
a NumPy array, or a backend-compatible tensor (e.g., jax.numpy.ndarray).
|
|
1657
|
+
:type coordinates: Any
|
|
1371
1658
|
"""
|
|
1372
|
-
if
|
|
1659
|
+
if not identifiers:
|
|
1660
|
+
return
|
|
1661
|
+
|
|
1662
|
+
new_coords_tensor = backend.convert_to_tensor(coordinates)
|
|
1663
|
+
|
|
1664
|
+
if len(identifiers) != backend.shape_tuple(new_coords_tensor)[0]:
|
|
1373
1665
|
raise ValueError(
|
|
1374
1666
|
"Identifiers and coordinates lists must have the same length."
|
|
1375
1667
|
)
|
|
1376
|
-
if not identifiers:
|
|
1377
|
-
return # Nothing to add
|
|
1378
1668
|
|
|
1379
|
-
|
|
1669
|
+
if backend.shape_tuple(new_coords_tensor)[1] != self.dimensionality:
|
|
1670
|
+
raise ValueError(
|
|
1671
|
+
f"New coordinate tensor has dimension {backend.shape_tuple(new_coords_tensor)[1]}, "
|
|
1672
|
+
f"but expected dimensionality is {self.dimensionality}."
|
|
1673
|
+
)
|
|
1674
|
+
|
|
1675
|
+
# Ensure that the new identifiers are unique and do not already exist.
|
|
1380
1676
|
existing_ids = set(self._identifiers)
|
|
1381
1677
|
new_ids = set(identifiers)
|
|
1382
1678
|
if not new_ids.isdisjoint(existing_ids):
|
|
@@ -1384,21 +1680,14 @@ class CustomizeLattice(AbstractLattice):
|
|
|
1384
1680
|
f"Duplicate identifiers found: {new_ids.intersection(existing_ids)}"
|
|
1385
1681
|
)
|
|
1386
1682
|
|
|
1387
|
-
|
|
1388
|
-
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
f"New coordinate at index {i} has shape {coord_arr.shape}, "
|
|
1392
|
-
f"expected ({self.dimensionality},)"
|
|
1393
|
-
)
|
|
1394
|
-
self._coordinates.append(coord_arr)
|
|
1395
|
-
self._identifiers.append(identifiers[i])
|
|
1683
|
+
self._coordinates = backend.concat(
|
|
1684
|
+
[self._coordinates, new_coords_tensor], axis=0
|
|
1685
|
+
)
|
|
1686
|
+
self._identifiers.extend(identifiers)
|
|
1396
1687
|
|
|
1397
|
-
# Rebuild index mappings from scratch
|
|
1398
1688
|
self._indices = list(range(len(self._identifiers)))
|
|
1399
1689
|
self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)}
|
|
1400
1690
|
|
|
1401
|
-
# Invalidate any previously computed neighbors or distance matrices
|
|
1402
1691
|
self._reset_computations()
|
|
1403
1692
|
logger.info(
|
|
1404
1693
|
f"{len(identifiers)} sites added. Lattice now has {self.num_sites} sites."
|
|
@@ -1413,10 +1702,9 @@ class CustomizeLattice(AbstractLattice):
|
|
|
1413
1702
|
|
|
1414
1703
|
:param identifiers: A list of identifiers for the sites to be removed.
|
|
1415
1704
|
:type identifiers: List[SiteIdentifier]
|
|
1416
|
-
:raises ValueError: If any of the specified identifiers do not exist.
|
|
1417
1705
|
"""
|
|
1418
1706
|
if not identifiers:
|
|
1419
|
-
return
|
|
1707
|
+
return
|
|
1420
1708
|
|
|
1421
1709
|
ids_to_remove = set(identifiers)
|
|
1422
1710
|
current_ids = set(self._identifiers)
|
|
@@ -1425,24 +1713,77 @@ class CustomizeLattice(AbstractLattice):
|
|
|
1425
1713
|
f"Non-existent identifiers provided for removal: {ids_to_remove - current_ids}"
|
|
1426
1714
|
)
|
|
1427
1715
|
|
|
1428
|
-
#
|
|
1429
|
-
|
|
1430
|
-
|
|
1431
|
-
|
|
1432
|
-
if ident not in ids_to_remove
|
|
1433
|
-
|
|
1434
|
-
|
|
1716
|
+
# Find the indices of the sites that we want to keep.
|
|
1717
|
+
indices_to_keep = [
|
|
1718
|
+
idx
|
|
1719
|
+
for idx, ident in enumerate(self._identifiers)
|
|
1720
|
+
if ident not in ids_to_remove
|
|
1721
|
+
]
|
|
1722
|
+
|
|
1723
|
+
new_identifiers = [self._identifiers[i] for i in indices_to_keep]
|
|
1724
|
+
|
|
1725
|
+
self._coordinates = backend.gather1d(
|
|
1726
|
+
self._coordinates,
|
|
1727
|
+
backend.cast(backend.convert_to_tensor(indices_to_keep), "int32"),
|
|
1728
|
+
)
|
|
1435
1729
|
|
|
1436
|
-
# Replace old data with the new, filtered data
|
|
1437
1730
|
self._identifiers = new_identifiers
|
|
1438
|
-
self._coordinates = new_coordinates
|
|
1439
1731
|
|
|
1440
|
-
# Rebuild index mappings
|
|
1441
1732
|
self._indices = list(range(len(self._identifiers)))
|
|
1442
1733
|
self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)}
|
|
1443
1734
|
|
|
1444
|
-
# Invalidate caches
|
|
1445
1735
|
self._reset_computations()
|
|
1446
1736
|
logger.info(
|
|
1447
1737
|
f"{len(ids_to_remove)} sites removed. Lattice now has {self.num_sites} sites."
|
|
1448
1738
|
)
|
|
1739
|
+
|
|
1740
|
+
|
|
1741
|
+
def get_compatible_layers(bonds: List[Tuple[int, int]]) -> List[List[Tuple[int, int]]]:
|
|
1742
|
+
"""
|
|
1743
|
+
Partitions a list of pairs (bonds) into compatible layers for parallel
|
|
1744
|
+
gate application using a greedy edge-coloring algorithm.
|
|
1745
|
+
|
|
1746
|
+
This function takes a list of pairs, representing connections like
|
|
1747
|
+
nearest-neighbor (NN) or next-nearest-neighbor (NNN) bonds, and
|
|
1748
|
+
partitions them into the minimum number of sets ("layers") where no two
|
|
1749
|
+
pairs in a set share an index. This is a general utility for scheduling
|
|
1750
|
+
non-overlapping operations.
|
|
1751
|
+
|
|
1752
|
+
:Example:
|
|
1753
|
+
|
|
1754
|
+
>>> from tensorcircuit.templates.lattice import SquareLattice
|
|
1755
|
+
>>> sq_lattice = SquareLattice(size=(2, 2), pbc=False)
|
|
1756
|
+
>>> nn_bonds = sq_lattice.get_neighbor_pairs(k=1, unique=True)
|
|
1757
|
+
|
|
1758
|
+
>>> gate_layers = get_compatible_layers(nn_bonds)
|
|
1759
|
+
>>> print(gate_layers)
|
|
1760
|
+
[[[0, 1], [2, 3]], [[0, 2], [1, 3]]]
|
|
1761
|
+
|
|
1762
|
+
:param bonds: A list of tuples, where each tuple represents a bond (i, j)
|
|
1763
|
+
of site indices to be scheduled.
|
|
1764
|
+
:type bonds: List[Tuple[int, int]]
|
|
1765
|
+
:return: A list of layers. Each layer is a list of tuples, where each
|
|
1766
|
+
tuple represents a bond. All bonds within a layer are non-overlapping.
|
|
1767
|
+
:rtype: List[List[Tuple[int, int]]]
|
|
1768
|
+
"""
|
|
1769
|
+
uncolored_edges: Set[Tuple[int, int]] = {(min(bond), max(bond)) for bond in bonds}
|
|
1770
|
+
|
|
1771
|
+
layers: List[List[Tuple[int, int]]] = []
|
|
1772
|
+
|
|
1773
|
+
while uncolored_edges:
|
|
1774
|
+
current_layer: List[Tuple[int, int]] = []
|
|
1775
|
+
qubits_in_this_layer: Set[int] = set()
|
|
1776
|
+
|
|
1777
|
+
edges_to_process = sorted(list(uncolored_edges))
|
|
1778
|
+
|
|
1779
|
+
for edge in edges_to_process:
|
|
1780
|
+
i, j = edge
|
|
1781
|
+
if i not in qubits_in_this_layer and j not in qubits_in_this_layer:
|
|
1782
|
+
current_layer.append(edge)
|
|
1783
|
+
qubits_in_this_layer.add(i)
|
|
1784
|
+
qubits_in_this_layer.add(j)
|
|
1785
|
+
|
|
1786
|
+
uncolored_edges -= set(current_layer)
|
|
1787
|
+
layers.append(sorted(current_layer))
|
|
1788
|
+
|
|
1789
|
+
return layers
|