tensorcircuit-nightly 1.3.0.dev20250728__py3-none-any.whl → 1.4.0.dev20251103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

Files changed (72) hide show
  1. tensorcircuit/__init__.py +5 -1
  2. tensorcircuit/abstractcircuit.py +4 -0
  3. tensorcircuit/analogcircuit.py +413 -0
  4. tensorcircuit/applications/layers.py +1 -1
  5. tensorcircuit/applications/van.py +1 -1
  6. tensorcircuit/backends/abstract_backend.py +312 -5
  7. tensorcircuit/backends/cupy_backend.py +3 -1
  8. tensorcircuit/backends/jax_backend.py +92 -3
  9. tensorcircuit/backends/jax_ops.py +108 -0
  10. tensorcircuit/backends/numpy_backend.py +49 -3
  11. tensorcircuit/backends/pytorch_backend.py +92 -3
  12. tensorcircuit/backends/tensorflow_backend.py +102 -3
  13. tensorcircuit/basecircuit.py +123 -82
  14. tensorcircuit/circuit.py +67 -57
  15. tensorcircuit/cloud/local.py +1 -1
  16. tensorcircuit/cloud/quafu_provider.py +1 -1
  17. tensorcircuit/cloud/tencent.py +1 -1
  18. tensorcircuit/compiler/simple_compiler.py +2 -2
  19. tensorcircuit/cons.py +1 -0
  20. tensorcircuit/densitymatrix.py +16 -11
  21. tensorcircuit/experimental.py +7 -152
  22. tensorcircuit/fgs.py +5 -6
  23. tensorcircuit/gates.py +66 -22
  24. tensorcircuit/keras.py +3 -3
  25. tensorcircuit/mpscircuit.py +109 -61
  26. tensorcircuit/quantum.py +697 -133
  27. tensorcircuit/quditcircuit.py +733 -0
  28. tensorcircuit/quditgates.py +618 -0
  29. tensorcircuit/results/counts.py +45 -31
  30. tensorcircuit/shadows.py +1 -1
  31. tensorcircuit/simplify.py +3 -1
  32. tensorcircuit/stabilizercircuit.py +4 -2
  33. tensorcircuit/templates/blocks.py +2 -2
  34. tensorcircuit/templates/hamiltonians.py +29 -8
  35. tensorcircuit/templates/lattice.py +676 -335
  36. tensorcircuit/timeevol.py +896 -0
  37. {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/METADATA +50 -25
  38. tensorcircuit_nightly-1.4.0.dev20251103.dist-info/RECORD +96 -0
  39. {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/top_level.txt +0 -1
  40. tensorcircuit_nightly-1.3.0.dev20250728.dist-info/RECORD +0 -122
  41. tests/__init__.py +0 -0
  42. tests/conftest.py +0 -67
  43. tests/test_backends.py +0 -1035
  44. tests/test_calibrating.py +0 -149
  45. tests/test_channels.py +0 -409
  46. tests/test_circuit.py +0 -1713
  47. tests/test_cloud.py +0 -219
  48. tests/test_compiler.py +0 -147
  49. tests/test_dmcircuit.py +0 -555
  50. tests/test_ensemble.py +0 -72
  51. tests/test_fgs.py +0 -318
  52. tests/test_gates.py +0 -156
  53. tests/test_hamiltonians.py +0 -159
  54. tests/test_interfaces.py +0 -557
  55. tests/test_keras.py +0 -160
  56. tests/test_lattice.py +0 -1666
  57. tests/test_miscs.py +0 -334
  58. tests/test_mpscircuit.py +0 -341
  59. tests/test_noisemodel.py +0 -156
  60. tests/test_qaoa.py +0 -86
  61. tests/test_qem.py +0 -152
  62. tests/test_quantum.py +0 -549
  63. tests/test_quantum_attr.py +0 -42
  64. tests/test_results.py +0 -379
  65. tests/test_shadows.py +0 -160
  66. tests/test_simplify.py +0 -46
  67. tests/test_stabilizer.py +0 -226
  68. tests/test_templates.py +0 -218
  69. tests/test_torchnn.py +0 -99
  70. tests/test_van.py +0 -102
  71. {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/WHEEL +0 -0
  72. {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/licenses/LICENSE +0 -0
@@ -15,13 +15,15 @@ from typing import (
15
15
  Union,
16
16
  TYPE_CHECKING,
17
17
  cast,
18
+ Set,
18
19
  )
19
20
 
20
- logger = logging.getLogger(__name__)
21
+ import itertools
22
+ import math
21
23
  import numpy as np
22
-
23
24
  from scipy.spatial import KDTree
24
- from scipy.spatial.distance import pdist, squareform
25
+
26
+ from .. import backend
25
27
 
26
28
 
27
29
  # This block resolves a name resolution issue for the static type checker (mypy).
@@ -40,9 +42,13 @@ if TYPE_CHECKING:
40
42
  import matplotlib.axes
41
43
  from mpl_toolkits.mplot3d import Axes3D
42
44
 
45
+ logger = logging.getLogger(__name__)
46
+
47
+ Tensor = Any
43
48
  SiteIndex = int
44
49
  SiteIdentifier = Hashable
45
- Coordinates = np.ndarray[Any, Any]
50
+ Coordinates = Tensor
51
+
46
52
  NeighborMap = Dict[SiteIndex, List[SiteIndex]]
47
53
 
48
54
 
@@ -63,13 +69,27 @@ class AbstractLattice(abc.ABC):
63
69
  """Initializes the base lattice class."""
64
70
  self._dimensionality = dimensionality
65
71
 
66
- # --- Internal Data Structures (to be populated by subclasses) ---
67
- self._indices: List[SiteIndex] = []
68
- self._identifiers: List[SiteIdentifier] = []
69
- self._coordinates: List[Coordinates] = []
70
- self._ident_to_idx: Dict[SiteIdentifier, SiteIndex] = {}
71
- self._neighbor_maps: Dict[int, NeighborMap] = {}
72
- self._distance_matrix: Optional[Coordinates] = None
72
+ # Core data structures for storing site information.
73
+ self._indices: List[SiteIndex] = [] # List of integer indices [0, 1, ..., N-1]
74
+ self._identifiers: List[SiteIdentifier] = (
75
+ []
76
+ ) # List of unique, hashable site identifiers
77
+ # Always initialize to an empty coordinate tensor with correct dimensionality
78
+ # so that type checkers know this is indexable and not Optional.
79
+ self._coordinates: Coordinates = backend.zeros((0, dimensionality))
80
+
81
+ # Mappings for efficient lookups.
82
+ self._ident_to_idx: Dict[SiteIdentifier, SiteIndex] = (
83
+ {}
84
+ ) # Maps identifiers to indices
85
+
86
+ # Cached properties, computed on demand.
87
+ self._neighbor_maps: Dict[int, NeighborMap] = (
88
+ {}
89
+ ) # Caches neighbor info for different k
90
+ self._distance_matrix: Optional[Coordinates] = (
91
+ None # Caches the full N x N distance matrix
92
+ )
73
93
 
74
94
  @property
75
95
  def num_sites(self) -> int:
@@ -94,7 +114,6 @@ class AbstractLattice(abc.ABC):
94
114
  subsequent calls. This computation can be expensive for large lattices.
95
115
  """
96
116
  if self._distance_matrix is None:
97
- logger.info("Distance matrix not cached. Computing now...")
98
117
  self._distance_matrix = self._compute_distance_matrix()
99
118
  return self._distance_matrix
100
119
 
@@ -115,7 +134,8 @@ class AbstractLattice(abc.ABC):
115
134
  :rtype: Coordinates
116
135
  """
117
136
  self._validate_index(index)
118
- return self._coordinates[index]
137
+ coords = self._coordinates[index]
138
+ return coords
119
139
 
120
140
  def get_identifier(self, index: SiteIndex) -> SiteIdentifier:
121
141
  """Gets the abstract identifier of a site by its integer index.
@@ -139,7 +159,8 @@ class AbstractLattice(abc.ABC):
139
159
  :rtype: SiteIndex
140
160
  """
141
161
  try:
142
- return self._ident_to_idx[identifier]
162
+ index = self._ident_to_idx[identifier]
163
+ return index
143
164
  except KeyError as e:
144
165
  raise ValueError(
145
166
  f"Identifier {identifier} not found in the lattice."
@@ -169,7 +190,7 @@ class AbstractLattice(abc.ABC):
169
190
  idx = index_or_identifier
170
191
  self._validate_index(idx)
171
192
  return idx, self._identifiers[idx], self._coordinates[idx]
172
- else: # Identifier
193
+ else:
173
194
  ident = index_or_identifier
174
195
  idx = self.get_index(ident)
175
196
  return idx, ident, self._coordinates[idx]
@@ -236,7 +257,6 @@ class AbstractLattice(abc.ABC):
236
257
  )
237
258
  self._build_neighbors(max_k=k)
238
259
 
239
- # After attempting to build, check again. If still not found, return empty.
240
260
  if k not in self._neighbor_maps:
241
261
  return []
242
262
 
@@ -250,8 +270,28 @@ class AbstractLattice(abc.ABC):
250
270
  pairs.append((i, j))
251
271
  return sorted(pairs)
252
272
 
253
- # Sorting provides a deterministic output order
254
- # --- Abstract Methods for Subclass Implementation ---
273
+ def get_all_pairs(self) -> List[Tuple[SiteIndex, SiteIndex]]:
274
+ """
275
+ Returns a list of all unique pairs of site indices (i, j) where i < j.
276
+
277
+ This method provides all-to-all connectivity, useful for Hamiltonians
278
+ where every site interacts with every other site.
279
+
280
+ Note on Differentiability:
281
+ This method provides a static list of index pairs and is not differentiable
282
+ itself. However, it is designed to be used in combination with the fully
283
+ differentiable ``distance_matrix`` property. By using the pairs from this
284
+ method to index into the ``distance_matrix``, one can construct differentiable
285
+ objective functions based on all-pair interactions, effectively bypassing the
286
+ non-differentiable ``get_neighbor_pairs`` method for geometry optimization tasks.
287
+
288
+ :return: A list of tuples, where each tuple is a unique pair of site indices.
289
+ :rtype: List[Tuple[SiteIndex, SiteIndex]]
290
+ """
291
+ if self.num_sites < 2:
292
+ return []
293
+ # Use itertools.combinations to efficiently generate all unique pairs (i, j) with i < j.
294
+ return sorted(list(itertools.combinations(range(self.num_sites), 2)))
255
295
 
256
296
  @abc.abstractmethod
257
297
  def _build_lattice(self, *args: Any, **kwargs: Any) -> None:
@@ -280,14 +320,24 @@ class AbstractLattice(abc.ABC):
280
320
  """
281
321
  pass
282
322
 
283
- @abc.abstractmethod
284
323
  def _compute_distance_matrix(self) -> Coordinates:
285
324
  """
286
- Abstract method for subclasses to implement the actual matrix calculation.
287
- This method is called by the `distance_matrix` property when the matrix
288
- needs to be computed for the first time.
325
+ Default generic distance matrix computation (no periodic images).
326
+
327
+ Subclasses can override this when a specialized rule is required
328
+ (e.g., applying Minimum Image Convention for PBC in TILattice).
289
329
  """
290
- pass
330
+ # Handle empty lattices and trivial 1-site lattices
331
+ if self.num_sites == 0:
332
+ return backend.zeros((0, 0))
333
+
334
+ # Vectorized pairwise Euclidean distances
335
+ all_coords = self._coordinates
336
+ displacements = backend.expand_dims(all_coords, 1) - backend.expand_dims(
337
+ all_coords, 0
338
+ )
339
+ dist_matrix_sq = backend.sum(displacements**2, axis=-1)
340
+ return backend.sqrt(dist_matrix_sq)
291
341
 
292
342
  def show(
293
343
  self,
@@ -327,13 +377,14 @@ class AbstractLattice(abc.ABC):
327
377
  try:
328
378
  import matplotlib.pyplot as plt
329
379
  except ImportError:
330
- logger.error(
380
+ logger.warning(
331
381
  "Matplotlib is required for visualization. "
332
382
  "Please install it using 'pip install matplotlib'."
333
383
  )
334
384
  return
335
385
 
336
- # creat "fig_created_internally" as flag
386
+ # Flag to track if the Matplotlib figure was created by this method.
387
+ # This prevents calling plt.show() if the user provided their own Axes.
337
388
  fig_created_internally = False
338
389
 
339
390
  if self.num_sites == 0:
@@ -346,7 +397,7 @@ class AbstractLattice(abc.ABC):
346
397
  return
347
398
 
348
399
  if ax is None:
349
- # when ax is none, make fig_created_internally true
400
+ # If no Axes object is provided, create a new figure and axes.
350
401
  fig_created_internally = True
351
402
  if self.dimensionality == 3:
352
403
  fig = plt.figure(figsize=(8, 8))
@@ -357,6 +408,7 @@ class AbstractLattice(abc.ABC):
357
408
  fig = ax.figure # type: ignore
358
409
 
359
410
  coords = np.array(self._coordinates)
411
+ # Prepare arguments for the scatter plot, allowing user overrides.
360
412
  scatter_args = {"s": 100, "zorder": 2}
361
413
  scatter_args.update(kwargs)
362
414
  if self.dimensionality == 1:
@@ -370,12 +422,11 @@ class AbstractLattice(abc.ABC):
370
422
  if show_indices or show_identifiers:
371
423
  for i in range(self.num_sites):
372
424
  label = str(self._identifiers[i]) if show_identifiers else str(i)
425
+ # Calculate a small offset for placing text labels to avoid overlap with sites.
373
426
  offset = (
374
427
  0.02 * np.max(np.ptp(coords, axis=0)) if coords.size > 0 else 0.1
375
428
  )
376
429
 
377
- # Robust Logic: Decide plotting strategy based on known dimensionality.
378
-
379
430
  if self.dimensionality == 1:
380
431
  ax.text(coords[i, 0], offset, label, fontsize=9, ha="center")
381
432
  elif self.dimensionality == 2:
@@ -397,8 +448,6 @@ class AbstractLattice(abc.ABC):
397
448
  zorder=3,
398
449
  )
399
450
 
400
- # Note: No 'else' needed as we already check dimensionality at the start.
401
-
402
451
  if show_bonds_k is not None:
403
452
  if show_bonds_k not in self._neighbor_maps:
404
453
  logger.warning(
@@ -432,7 +481,7 @@ class AbstractLattice(abc.ABC):
432
481
  if self.dimensionality == 1: # type: ignore
433
482
 
434
483
  ax.plot([p1[0], p2[0]], [0, 0], **plot_bond_kwargs) # type: ignore
435
- else: # dimensionality == 2
484
+ else:
436
485
  ax.plot([p1[0], p2[0]], [p1[1], p2[1]], **plot_bond_kwargs) # type: ignore
437
486
 
438
487
  except ValueError as e:
@@ -448,7 +497,7 @@ class AbstractLattice(abc.ABC):
448
497
  ax.set_zlabel("z")
449
498
  ax.grid(True)
450
499
 
451
- # 3. whether plt.show()
500
+ # Display the plot only if the figure was created within this function.
452
501
  if fig_created_internally:
453
502
  plt.show()
454
503
 
@@ -474,26 +523,28 @@ class AbstractLattice(abc.ABC):
474
523
  :return: A sorted list of squared distances representing the shells.
475
524
  :rtype: List[float]
476
525
  """
526
+ # A small threshold to filter out zero distances (site to itself).
477
527
  ZERO_THRESHOLD_SQ = 1e-12
478
528
 
479
- all_distances_sq = np.asarray(all_distances_sq)
529
+ all_distances_sq = backend.convert_to_tensor(all_distances_sq)
480
530
  # Now, the .size call below is guaranteed to be safe.
481
- if all_distances_sq.size == 0:
531
+ if backend.sizen(all_distances_sq) == 0:
482
532
  return []
483
533
 
484
- sorted_dist = np.sort(all_distances_sq[all_distances_sq > ZERO_THRESHOLD_SQ])
534
+ # Filter out self-distances and sort the remaining squared distances.
535
+ sorted_dist = backend.sort(
536
+ all_distances_sq[all_distances_sq > ZERO_THRESHOLD_SQ]
537
+ )
485
538
 
486
- if sorted_dist.size == 0:
539
+ if backend.sizen(sorted_dist) == 0:
487
540
  return []
488
541
 
489
- # Identify shells using the user-provided tolerance.
490
542
  dist_shells = [sorted_dist[0]]
491
543
 
492
544
  for d_sq in sorted_dist[1:]:
493
545
  if len(dist_shells) >= max_k:
494
546
  break
495
- # If the current distance is notably larger than the last shell's distance
496
- if d_sq > dist_shells[-1] + tol**2:
547
+ if backend.sqrt(d_sq) - backend.sqrt(dist_shells[-1]) > tol:
497
548
  dist_shells.append(d_sq)
498
549
 
499
550
  return dist_shells
@@ -502,11 +553,9 @@ class AbstractLattice(abc.ABC):
502
553
  self, max_k: int = 2, tol: float = 1e-6
503
554
  ) -> None:
504
555
  """A generic, distance-based neighbor finding method.
505
-
506
556
  This method calculates the full N x N distance matrix to find neighbor
507
557
  shells. It is computationally expensive for large N (O(N^2)) and is
508
558
  best suited for non-periodic or custom-defined lattices.
509
-
510
559
  :param max_k: The maximum number of neighbor shells to
511
560
  calculate. Defaults to 2.
512
561
  :type max_k: int, optional
@@ -517,26 +566,55 @@ class AbstractLattice(abc.ABC):
517
566
  if self.num_sites < 2:
518
567
  return
519
568
 
520
- all_coords = np.array(self._coordinates)
521
- dist_matrix_sq = np.sum(
522
- (all_coords[:, np.newaxis, :] - all_coords[np.newaxis, :, :]) ** 2, axis=-1
569
+ all_coords = self._coordinates
570
+ # Vectorized computation of the squared distance matrix:
571
+ # (N, 1, D) - (1, N, D) -> (N, N, D) -> (N, N)
572
+ displacements = backend.expand_dims(all_coords, 1) - backend.expand_dims(
573
+ all_coords, 0
523
574
  )
575
+ dist_matrix_sq = backend.sum(displacements**2, axis=-1)
524
576
 
525
- all_distances_sq = dist_matrix_sq.flatten()
577
+ # Flatten the matrix to a list of all squared distances to identify shells.
578
+ all_distances_sq = backend.reshape(dist_matrix_sq, [-1])
526
579
  dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
527
580
 
528
- self._neighbor_maps = {k: {} for k in range(1, len(dist_shells_sq) + 1)}
581
+ self._neighbor_maps = self._build_neighbor_map_from_distances(
582
+ dist_matrix_sq, dist_shells_sq, tol
583
+ )
584
+ self._distance_matrix = backend.sqrt(dist_matrix_sq)
585
+
586
+ def _build_neighbor_map_from_distances(
587
+ self,
588
+ dist_matrix_sq: Coordinates,
589
+ dist_shells_sq: List[float],
590
+ tol: float = 1e-6,
591
+ ) -> Dict[int, NeighborMap]:
592
+ """
593
+ Builds a neighbor map from a squared distance matrix and identified shells.
594
+ This is a generic helper function to reduce code duplication.
595
+ """
596
+ neighbor_maps: Dict[int, NeighborMap] = {
597
+ k: {} for k in range(1, len(dist_shells_sq) + 1)
598
+ }
529
599
  for k_idx, target_d_sq in enumerate(dist_shells_sq):
530
600
  k = k_idx + 1
531
601
  current_k_map: Dict[int, List[int]] = {}
532
- for i in range(self.num_sites):
533
- neighbor_indices = np.where(
534
- np.isclose(dist_matrix_sq[i], target_d_sq, rtol=0, atol=tol**2)
535
- )[0]
536
- if len(neighbor_indices) > 0:
537
- current_k_map[i] = sorted(neighbor_indices.tolist())
538
- self._neighbor_maps[k] = current_k_map
539
- self._distance_matrix = np.sqrt(dist_matrix_sq)
602
+ # For each shell, find all pairs of sites (i, j) with that distance.
603
+ is_close_matrix = backend.abs(dist_matrix_sq - target_d_sq) < tol
604
+ rows, cols = backend.where(is_close_matrix)
605
+
606
+ for i, j in zip(backend.numpy(rows), backend.numpy(cols)):
607
+ if i == j:
608
+ continue
609
+ if i not in current_k_map:
610
+ current_k_map[i] = []
611
+ current_k_map[i].append(j)
612
+
613
+ for i in current_k_map:
614
+ current_k_map[i].sort()
615
+
616
+ neighbor_maps[k] = current_k_map
617
+ return neighbor_maps
540
618
 
541
619
 
542
620
  class TILattice(AbstractLattice):
@@ -587,150 +665,197 @@ class TILattice(AbstractLattice):
587
665
  ):
588
666
  """Initializes the Translationally Invariant Lattice."""
589
667
  super().__init__(dimensionality)
590
- assert lattice_vectors.shape == (
591
- dimensionality,
592
- dimensionality,
593
- ), "Lattice vectors shape mismatch"
594
- assert (
595
- basis_coords.shape[1] == dimensionality
596
- ), "Basis coordinates dimension mismatch"
597
- assert len(size) == dimensionality, "Size tuple length mismatch"
598
-
599
- self.lattice_vectors = lattice_vectors
600
- self.basis_coords = basis_coords
601
- self.num_basis = basis_coords.shape[0]
668
+
669
+ self.lattice_vectors = backend.convert_to_tensor(lattice_vectors)
670
+ self.basis_coords = backend.convert_to_tensor(basis_coords)
671
+
672
+ if self.lattice_vectors.shape != (dimensionality, dimensionality):
673
+ raise ValueError(
674
+ f"Lattice vectors shape {self.lattice_vectors.shape} does not match "
675
+ f"expected ({dimensionality}, {dimensionality})"
676
+ )
677
+ if self.basis_coords.shape[1] != dimensionality:
678
+ raise ValueError(
679
+ f"Basis coordinates dimension {self.basis_coords.shape[1]} does not "
680
+ f"match lattice dimensionality {dimensionality}"
681
+ )
682
+ if len(size) != dimensionality:
683
+ raise ValueError(
684
+ f"Size tuple length {len(size)} does not match dimensionality {dimensionality}"
685
+ )
686
+
687
+ self.num_basis = self.basis_coords.shape[0]
602
688
  self.size = size
603
689
  if isinstance(pbc, bool):
604
690
  self.pbc = tuple([pbc] * dimensionality)
605
691
  else:
606
- assert len(pbc) == dimensionality, "PBC tuple length mismatch"
692
+ if len(pbc) != dimensionality:
693
+ raise ValueError(
694
+ f"PBC tuple length {len(pbc)} does not match dimensionality {dimensionality}"
695
+ )
607
696
  self.pbc = tuple(pbc)
608
697
 
609
- # Build the lattice sites and their neighbor relationships
610
698
  self._build_lattice()
611
699
  if precompute_neighbors is not None and precompute_neighbors > 0:
612
700
  logger.info(f"Pre-computing neighbors up to k={precompute_neighbors}...")
613
701
  self._build_neighbors(max_k=precompute_neighbors)
614
702
 
615
703
  def _build_lattice(self) -> None:
616
- """Generates all site information for the periodic lattice.
617
-
618
- This method iterates through each unit cell defined by `self.size`,
619
- and for each unit cell, it iterates through all basis sites. It then
620
- calculates the real-space coordinates and creates a unique identifier
621
- for each site, populating the internal lattice data structures.
622
704
  """
623
- current_index = 0
705
+ Generates all site information for the periodic lattice in a vectorized manner.
706
+ """
707
+ ranges = [backend.arange(s) for s in self.size]
624
708
 
625
- # Iterate over all unit cell coordinates elegantly using np.ndindex
626
- for cell_coord in np.ndindex(self.size):
627
- cell_coord_arr = np.array(cell_coord)
628
- # R = n1*a1 + n2*a2 + ...
629
- cell_vector = np.dot(cell_coord_arr, self.lattice_vectors)
709
+ # Generate a grid of all integer unit cell coordinates.
710
+ grid = backend.meshgrid(*ranges, indexing="ij")
711
+ all_cell_coords = backend.reshape(
712
+ backend.stack(grid, axis=-1), (-1, self.dimensionality)
713
+ )
630
714
 
631
- # Iterate over the basis sites within the unit cell
632
- for basis_index in range(self.num_basis):
633
- basis_vec = self.basis_coords[basis_index]
715
+ all_cell_coords = backend.cast(all_cell_coords, self.lattice_vectors.dtype)
634
716
 
635
- # Calculate the real-space coordinate
636
- coord = cell_vector + basis_vec
637
- # Create a structured identifier
638
- identifier = cell_coord + (basis_index,)
717
+ cell_vectors = backend.tensordot(
718
+ all_cell_coords, self.lattice_vectors, axes=[[1], [0]]
719
+ )
720
+
721
+ cell_vectors = backend.cast(cell_vectors, self.basis_coords.dtype)
722
+
723
+ # Combine cell vectors with basis coordinates to get all site positions
724
+ # via broadcasting: (num_cells, 1, D) + (1, num_basis, D) -> (num_cells, num_basis, D)
725
+ all_coords = backend.expand_dims(cell_vectors, 1) + backend.expand_dims(
726
+ self.basis_coords, 0
727
+ )
728
+
729
+ self._coordinates = backend.reshape(all_coords, (-1, self.dimensionality))
730
+
731
+ self._indices = []
732
+ self._identifiers = []
733
+ self._ident_to_idx = {}
734
+ current_index = 0
639
735
 
640
- # Store site information
736
+ # Generate integer indices and tuple-based identifiers for all sites.
737
+ # e.g., identifier = (uc_x, uc_y, basis_idx)
738
+ size_ranges = [range(s) for s in self.size]
739
+ for cell_coord_tuple in itertools.product(*size_ranges):
740
+ for basis_index in range(self.num_basis):
741
+ identifier = cell_coord_tuple + (basis_index,)
641
742
  self._indices.append(current_index)
642
743
  self._identifiers.append(identifier)
643
- self._coordinates.append(coord)
644
744
  self._ident_to_idx[identifier] = current_index
645
745
  current_index += 1
646
746
 
647
- def _get_distance_matrix_with_mic(self) -> Coordinates:
747
+ def _get_distance_matrix_with_mic_vectorized(self) -> Coordinates:
648
748
  """
649
- Computes the full N x N distance matrix, correctly applying the
650
- Minimum Image Convention (MIC) for all periodic dimensions.
749
+ Computes the full N x N distance matrix using a fully vectorized approach
750
+ that correctly applies the Minimum Image Convention (MIC) for periodic
751
+ boundary conditions.
752
+
753
+ This method uses full vectorization for optimal performance and compatibility
754
+ with JIT compilation frameworks like JAX. The implementation processes all
755
+ site pairs simultaneously rather than iterating row-by-row, which provides:
756
+
757
+ - Better performance through vectorized operations
758
+ - Full compatibility with automatic differentiation
759
+ - JIT compilation support (e.g., JAX, TensorFlow)
760
+ - Consistent tensor operations throughout
761
+
762
+ The trade-off is higher memory usage compared to iterative approaches,
763
+ as it computes all pairwise distances simultaneously. For very large
764
+ lattices (N > 10^4 sites), memory usage scales as O(N^2).
765
+
766
+ :return: Distance matrix with shape (N, N) where entry (i,j) is the
767
+ minimum distance between sites i and j under periodic boundary conditions.
768
+ :rtype: Coordinates
651
769
  """
652
- all_coords = np.array(self._coordinates)
653
- size_arr = np.array(self.size)
654
- system_vectors = self.lattice_vectors * size_arr[:, np.newaxis]
655
-
656
- # Generate translation vectors ONLY for periodic dimensions
657
- pbc_dims = [d for d in range(self.dimensionality) if self.pbc[d]]
658
- translations = [np.zeros(self.dimensionality)]
659
- if pbc_dims:
660
- num_pbc_dims = len(pbc_dims)
661
- pbc_system_vectors = system_vectors[pbc_dims, :]
662
-
663
- # Create all 3^k - 1 non-zero shifts for k periodic dimensions
664
- shift_options = [np.array([-1, 0, 1])] * num_pbc_dims
665
- shifts_grid = np.meshgrid(*shift_options, indexing="ij")
666
- all_shifts = np.stack(shifts_grid, axis=-1).reshape(-1, num_pbc_dims)
667
- all_shifts = all_shifts[np.any(all_shifts != 0, axis=1)]
668
-
669
- pbc_translations = all_shifts @ pbc_system_vectors
670
- translations.extend(pbc_translations)
671
-
672
- translations_arr = np.array(translations, dtype=float)
673
-
674
- # Calculate the distance matrix applying MIC
675
- dist_matrix_sq = np.full((self.num_sites, self.num_sites), np.inf, dtype=float)
676
- for i in range(self.num_sites):
677
- displacements = all_coords - all_coords[i]
678
- image_displacements = (
679
- displacements[:, np.newaxis, :] - translations_arr[np.newaxis, :, :]
680
- )
681
- image_d_sq = np.sum(image_displacements**2, axis=2)
682
- dist_matrix_sq[i, :] = np.min(image_d_sq, axis=1)
770
+ # Ensure dtype consistency across backends (especially torch) by explicitly
771
+ # casting size and lattice_vectors to the same floating dtype used internally.
772
+ # Strategy: prefer existing lattice_vectors dtype; if it's an unusual dtype,
773
+ # fall back to float32 to avoid mixed-precision issues in vectorized ops.
774
+ # Note: `self.lattice_vectors` is always created via `backend.convert_to_tensor`
775
+ # in __init__, so `backend.dtype(...)` is reliable here and doesn't need try/except.
776
+ target_dt = str(backend.dtype(self.lattice_vectors)) # type: ignore
777
+ if target_dt not in ("float32", "float64"):
778
+ # fallback for unusual dtypes
779
+ target_dt = "float32"
780
+
781
+ size_arr = backend.cast(backend.convert_to_tensor(self.size), target_dt)
782
+ lattice_vecs = backend.cast(
783
+ backend.convert_to_tensor(self.lattice_vectors), target_dt
784
+ )
785
+ system_vectors = lattice_vecs * backend.expand_dims(size_arr, axis=1)
786
+
787
+ pbc_mask = backend.convert_to_tensor(self.pbc)
683
788
 
684
- return cast(Coordinates, np.sqrt(dist_matrix_sq))
789
+ # Generate all 3^d possible image shifts (-1, 0, 1) for all dimensions
790
+ shift_options = [
791
+ backend.convert_to_tensor([-1.0, 0.0, 1.0])
792
+ ] * self.dimensionality
793
+ shifts_grid = backend.meshgrid(*shift_options, indexing="ij")
794
+ all_shifts = backend.reshape(
795
+ backend.stack(shifts_grid, axis=-1), (-1, self.dimensionality)
796
+ )
797
+
798
+ # Only apply shifts to periodic dimensions
799
+ masked_shifts = all_shifts * backend.cast(pbc_mask, all_shifts.dtype)
800
+
801
+ # Calculate all translation vectors due to PBC
802
+ translations_arr = backend.tensordot(
803
+ masked_shifts, system_vectors, axes=[[1], [0]]
804
+ )
805
+
806
+ # Vectorized computation of all displacements between any two sites
807
+ # Shape: (N, 1, D) - (1, N, D) -> (N, N, D)
808
+ displacements = backend.expand_dims(self._coordinates, 1) - backend.expand_dims(
809
+ self._coordinates, 0
810
+ )
811
+
812
+ # Consider all periodic images for each displacement
813
+ # Shape: (N, N, 1, D) - (1, 1, num_translations, D) -> (N, N, num_translations, D)
814
+ image_displacements = backend.expand_dims(
815
+ displacements, 2
816
+ ) - backend.expand_dims(backend.expand_dims(translations_arr, 0), 0)
817
+
818
+ # Sum of squares for distances
819
+ image_d_sq = backend.sum(image_displacements**2, axis=3)
820
+
821
+ # Find the minimum distance among all images (Minimum Image Convention)
822
+ min_dist_sq = backend.min(image_d_sq, axis=2)
823
+
824
+ safe_dist_matrix_sq = backend.where(min_dist_sq > 0, min_dist_sq, 0.0)
825
+ return backend.sqrt(safe_dist_matrix_sq)
685
826
 
686
827
  def _build_neighbors(self, max_k: int = 2, **kwargs: Any) -> None:
687
828
  """Calculates neighbor relationships for the periodic lattice.
688
829
 
689
- This method calculates neighbor relationships by computing the full N x N
690
- distance matrix. It robustly handles all boundary conditions (fully
691
- periodic, open, or mixed) by applying the Minimum Image Convention
692
- (MIC) only to the periodic dimensions.
830
+ This method computes neighbor information by first calculating the full
831
+ distance matrix using the Minimum Image Convention (MIC) to correctly
832
+ handle periodic boundary conditions. It then identifies unique distance
833
+ shells (e.g., nearest, next-nearest) and populates the neighbor maps
834
+ accordingly. This approach is general and works for any periodic lattice
835
+ geometry defined by the TILattice class.
693
836
 
694
- From this distance matrix, it identifies unique neighbor shells up to
695
- the specified `max_k` and populates the neighbor maps. The computed
696
- distance matrix is then cached for future use.
697
-
698
- :param max_k: The maximum number of neighbor shells to
699
- calculate. Defaults to 2.
837
+ :param max_k: The maximum order of neighbors to compute (e.g., k=1 for
838
+ nearest neighbors, k=2 for next-nearest, etc.). Defaults to 2.
700
839
  :type max_k: int, optional
701
- :param tol: The numerical tolerance for distance
702
- comparisons. Defaults to 1e-6.
703
- :type tol: float, optional
840
+ :param kwargs: Additional keyword arguments. May include:
841
+ - ``tol`` (float): The numerical tolerance used to determine if two
842
+ distances are equal when identifying shells. Defaults to 1e-6.
704
843
  """
705
844
  tol = kwargs.get("tol", 1e-6)
706
- dist_matrix = self._get_distance_matrix_with_mic()
845
+ dist_matrix = self._get_distance_matrix_with_mic_vectorized()
707
846
  dist_matrix_sq = dist_matrix**2
708
847
  self._distance_matrix = dist_matrix
709
- all_distances_sq = dist_matrix_sq.flatten()
848
+ all_distances_sq = backend.reshape(dist_matrix_sq, [-1])
710
849
  dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
711
-
712
- self._neighbor_maps = {k: {} for k in range(1, len(dist_shells_sq) + 1)}
713
- for k_idx, target_d_sq in enumerate(dist_shells_sq):
714
- k = k_idx + 1
715
- current_k_map: Dict[int, List[int]] = {}
716
- match_indices = np.where(
717
- np.isclose(dist_matrix_sq, target_d_sq, rtol=0, atol=tol**2)
718
- )
719
- for i, j in zip(*match_indices):
720
- if i == j:
721
- continue
722
- if i not in current_k_map:
723
- current_k_map[i] = []
724
- current_k_map[i].append(j)
725
-
726
- for i in current_k_map:
727
- current_k_map[i].sort()
728
-
729
- self._neighbor_maps[k] = current_k_map
850
+ self._neighbor_maps = self._build_neighbor_map_from_distances(
851
+ dist_matrix_sq, dist_shells_sq, tol
852
+ )
730
853
 
731
854
  def _compute_distance_matrix(self) -> Coordinates:
732
855
  """Computes the distance matrix using the Minimum Image Convention."""
733
- return self._get_distance_matrix_with_mic()
856
+ if self.num_sites == 0:
857
+ return backend.zeros((0, 0))
858
+ return self._get_distance_matrix_with_mic_vectorized()
734
859
 
735
860
 
736
861
  class SquareLattice(TILattice):
@@ -758,20 +883,24 @@ class SquareLattice(TILattice):
758
883
  def __init__(
759
884
  self,
760
885
  size: Tuple[int, int],
761
- lattice_constant: float = 1.0,
886
+ lattice_constant: Union[float, Any] = 1.0,
762
887
  pbc: Union[bool, Tuple[bool, bool]] = True,
763
888
  precompute_neighbors: Optional[int] = None,
764
889
  ):
765
890
  """Initializes the SquareLattice."""
766
891
  dimensionality = 2
892
+ # Define orthogonal lattice vectors for a square.
893
+ # Avoid mixing Python floats with backend Tensors (TF would error),
894
+ # so first convert inputs to tensors of a unified dtype, then stack.
895
+ lc = backend.convert_to_tensor(lattice_constant)
896
+ dt = backend.dtype(lc)
897
+ z = backend.cast(backend.convert_to_tensor(0.0), dt)
898
+ row1 = backend.stack([lc, z])
899
+ row2 = backend.stack([z, lc])
900
+ lattice_vectors = backend.stack([row1, row2])
901
+ # A square lattice is a Bravais lattice, so it has a single-site basis.
902
+ basis_coords = backend.stack([backend.stack([z, z])])
767
903
 
768
- # Define lattice vectors for a square lattice
769
- lattice_vectors = np.array([[lattice_constant, 0.0], [0.0, lattice_constant]])
770
-
771
- # A square lattice has a single site in its basis
772
- basis_coords = np.array([[0.0, 0.0]])
773
-
774
- # Call the parent TILattice constructor with these parameters
775
904
  super().__init__(
776
905
  dimensionality=dimensionality,
777
906
  lattice_vectors=lattice_vectors,
@@ -807,19 +936,28 @@ class HoneycombLattice(TILattice):
807
936
  def __init__(
808
937
  self,
809
938
  size: Tuple[int, int],
810
- lattice_constant: float = 1.0,
939
+ lattice_constant: Union[float, Any] = 1.0,
811
940
  pbc: Union[bool, Tuple[bool, bool]] = True,
812
941
  precompute_neighbors: Optional[int] = None,
813
942
  ):
814
943
  """Initializes the HoneycombLattice."""
815
944
  dimensionality = 2
816
945
  a = lattice_constant
946
+ a_t = backend.convert_to_tensor(a)
947
+ zero = a_t * 0.0
817
948
 
818
- # Define the primitive lattice vectors for the underlying triangular lattice
819
- lattice_vectors = a * np.array([[1.5, np.sqrt(3) / 2], [1.5, -np.sqrt(3) / 2]])
820
-
821
- # Define the coordinates of the two basis sites (A and B)
822
- basis_coords = a * np.array([[0.0, 0.0], [1.0, 0.0]]) # Site A # Site B
949
+ # Define the two primitive lattice vectors for the underlying triangular Bravais lattice.
950
+ rt3_over_2 = math.sqrt(3.0) / 2.0
951
+ lattice_vectors = backend.stack(
952
+ [
953
+ backend.stack([a_t * 1.5, a_t * rt3_over_2]),
954
+ backend.stack([a_t * 1.5, -a_t * rt3_over_2]),
955
+ ]
956
+ )
957
+ # Define the two basis sites (A and B) within the unit cell.
958
+ basis_coords = backend.stack(
959
+ [backend.stack([zero, zero]), backend.stack([a_t * 1.0, zero])]
960
+ )
823
961
 
824
962
  super().__init__(
825
963
  dimensionality=dimensionality,
@@ -854,19 +992,30 @@ class TriangularLattice(TILattice):
854
992
  def __init__(
855
993
  self,
856
994
  size: Tuple[int, int],
857
- lattice_constant: float = 1.0,
995
+ lattice_constant: Union[float, Any] = 1.0,
858
996
  pbc: Union[bool, Tuple[bool, bool]] = True,
859
997
  precompute_neighbors: Optional[int] = None,
860
998
  ):
861
999
  """Initializes the TriangularLattice."""
862
1000
  dimensionality = 2
863
1001
  a = lattice_constant
1002
+ a_t = backend.convert_to_tensor(a)
1003
+ zero = a_t * 0.0
864
1004
 
865
- # Define the primitive lattice vectors for a triangular lattice
866
- lattice_vectors = a * np.array([[1.0, 0.0], [0.5, np.sqrt(3) / 2]])
867
-
868
- # A triangular lattice is a Bravais lattice, with a single site in its basis
869
- basis_coords = np.array([[0.0, 0.0]])
1005
+ # Define the primitive lattice vectors for a triangular lattice.
1006
+ lattice_vectors = backend.stack(
1007
+ [
1008
+ backend.stack([a_t * 1.0, zero]),
1009
+ backend.stack(
1010
+ [
1011
+ a_t * 0.5,
1012
+ a_t * backend.sqrt(backend.convert_to_tensor(3.0)) / 2.0,
1013
+ ]
1014
+ ),
1015
+ ]
1016
+ )
1017
+ # A triangular lattice is a Bravais lattice with a single-site basis.
1018
+ basis_coords = backend.stack([backend.stack([zero, zero])])
870
1019
 
871
1020
  super().__init__(
872
1021
  dimensionality=dimensionality,
@@ -895,13 +1044,18 @@ class ChainLattice(TILattice):
895
1044
  def __init__(
896
1045
  self,
897
1046
  size: Tuple[int],
898
- lattice_constant: float = 1.0,
1047
+ lattice_constant: Union[float, Any] = 1.0,
899
1048
  pbc: bool = True,
900
1049
  precompute_neighbors: Optional[int] = None,
901
1050
  ):
902
1051
  dimensionality = 1
903
- lattice_vectors = np.array([[lattice_constant]])
904
- basis_coords = np.array([[0.0]])
1052
+ # The lattice vector is just the lattice constant along one dimension.
1053
+ lc = backend.convert_to_tensor(lattice_constant)
1054
+ lattice_vectors = backend.stack([backend.stack([lc])])
1055
+ # A simple chain is a Bravais lattice with a single-site basis.
1056
+ zero = lc * 0.0
1057
+ basis_coords = backend.stack([backend.stack([zero])])
1058
+
905
1059
  super().__init__(
906
1060
  dimensionality=dimensionality,
907
1061
  lattice_vectors=lattice_vectors,
@@ -933,15 +1087,17 @@ class DimerizedChainLattice(TILattice):
933
1087
  def __init__(
934
1088
  self,
935
1089
  size: Tuple[int],
936
- lattice_constant: float = 1.0,
1090
+ lattice_constant: Union[float, Any] = 1.0,
937
1091
  pbc: bool = True,
938
1092
  precompute_neighbors: Optional[int] = None,
939
1093
  ):
940
1094
  dimensionality = 1
941
- # The unit cell vector connects two A sites, spanning length 2*a
942
- lattice_vectors = np.array([[2 * lattice_constant]])
943
- # Basis has site A at origin, site B at distance 'a'
944
- basis_coords = np.array([[0.0], [lattice_constant]])
1095
+ # The unit cell is twice the bond length, as it contains two sites.
1096
+ lc = backend.convert_to_tensor(lattice_constant)
1097
+ lattice_vectors = backend.stack([backend.stack([2 * lc])])
1098
+ # Two basis sites (A and B) separated by the bond length.
1099
+ zero = lc * 0.0
1100
+ basis_coords = backend.stack([backend.stack([zero]), backend.stack([lc])])
945
1101
 
946
1102
  super().__init__(
947
1103
  dimensionality=dimensionality,
@@ -974,14 +1130,22 @@ class RectangularLattice(TILattice):
974
1130
  def __init__(
975
1131
  self,
976
1132
  size: Tuple[int, int],
977
- lattice_constants: Tuple[float, float] = (1.0, 1.0),
1133
+ lattice_constants: Union[Tuple[float, float], Any] = (1.0, 1.0),
978
1134
  pbc: Union[bool, Tuple[bool, bool]] = True,
979
1135
  precompute_neighbors: Optional[int] = None,
980
1136
  ):
981
1137
  dimensionality = 2
982
1138
  ax, ay = lattice_constants
983
- lattice_vectors = np.array([[ax, 0.0], [0.0, ay]])
984
- basis_coords = np.array([[0.0, 0.0]])
1139
+ ax_t = backend.convert_to_tensor(ax)
1140
+ dt = backend.dtype(ax_t)
1141
+ ay_t = backend.cast(backend.convert_to_tensor(ay), dt)
1142
+ z = backend.cast(backend.convert_to_tensor(0.0), dt)
1143
+ # Orthogonal lattice vectors with potentially different lengths.
1144
+ row1 = backend.stack([ax_t, z])
1145
+ row2 = backend.stack([z, ay_t])
1146
+ lattice_vectors = backend.stack([row1, row2])
1147
+ # A rectangular lattice is a Bravais lattice with a single-site basis.
1148
+ basis_coords = backend.stack([backend.stack([z, z])])
985
1149
 
986
1150
  super().__init__(
987
1151
  dimensionality=dimensionality,
@@ -1012,16 +1176,26 @@ class CheckerboardLattice(TILattice):
1012
1176
  def __init__(
1013
1177
  self,
1014
1178
  size: Tuple[int, int],
1015
- lattice_constant: float = 1.0,
1179
+ lattice_constant: Union[float, Any] = 1.0,
1016
1180
  pbc: Union[bool, Tuple[bool, bool]] = True,
1017
1181
  precompute_neighbors: Optional[int] = None,
1018
1182
  ):
1019
1183
  dimensionality = 2
1020
1184
  a = lattice_constant
1021
- # Primitive vectors for a square lattice rotated by 45 degrees.
1022
- lattice_vectors = a * np.array([[1.0, 1.0], [1.0, -1.0]])
1023
- # Two-site basis
1024
- basis_coords = a * np.array([[0.0, 0.0], [1.0, 0.0]])
1185
+ a_t = backend.convert_to_tensor(a)
1186
+ # The unit cell is a square rotated by 45 degrees.
1187
+ lattice_vectors = backend.stack(
1188
+ [
1189
+ backend.stack([a_t * 1.0, a_t * 1.0]),
1190
+ backend.stack([a_t * 1.0, a_t * -1.0]),
1191
+ ]
1192
+ )
1193
+ # Two basis sites (A and B) within the unit cell.
1194
+ zero = a_t * 0.0
1195
+ basis_coords = backend.stack(
1196
+ [backend.stack([zero, zero]), backend.stack([a_t * 1.0, zero])]
1197
+ )
1198
+
1025
1199
  super().__init__(
1026
1200
  dimensionality=dimensionality,
1027
1201
  lattice_vectors=lattice_vectors,
@@ -1051,16 +1225,30 @@ class KagomeLattice(TILattice):
1051
1225
  def __init__(
1052
1226
  self,
1053
1227
  size: Tuple[int, int],
1054
- lattice_constant: float = 1.0,
1228
+ lattice_constant: Union[float, Any] = 1.0,
1055
1229
  pbc: Union[bool, Tuple[bool, bool]] = True,
1056
1230
  precompute_neighbors: Optional[int] = None,
1057
1231
  ):
1058
1232
  dimensionality = 2
1059
1233
  a = lattice_constant
1060
- # Using a rectangular unit cell definition for simplicity
1061
- lattice_vectors = a * np.array([[2.0, 0.0], [1.0, np.sqrt(3)]])
1062
- # Three-site basis
1063
- basis_coords = a * np.array([[0.0, 0.0], [1.0, 0.0], [0.5, np.sqrt(3) / 2.0]])
1234
+ a_t = backend.convert_to_tensor(a)
1235
+ # The Kagome lattice is based on a triangular Bravais lattice.
1236
+ lattice_vectors = backend.stack(
1237
+ [
1238
+ backend.stack([a_t * 2.0, a_t * 0.0]),
1239
+ backend.stack([a_t * 1.0, a_t * backend.sqrt(3.0)]),
1240
+ ]
1241
+ )
1242
+ # It has a three-site basis, forming the corners of the triangles.
1243
+ zero = a_t * 0.0
1244
+ basis_coords = backend.stack(
1245
+ [
1246
+ backend.stack([zero, zero]),
1247
+ backend.stack([a_t * 1.0, zero]),
1248
+ backend.stack([a_t * 0.5, a_t * backend.sqrt(3.0) / 2.0]),
1249
+ ]
1250
+ )
1251
+
1064
1252
  super().__init__(
1065
1253
  dimensionality=dimensionality,
1066
1254
  lattice_vectors=lattice_vectors,
@@ -1091,29 +1279,26 @@ class LiebLattice(TILattice):
1091
1279
  def __init__(
1092
1280
  self,
1093
1281
  size: Tuple[int, int],
1094
- lattice_constant: float = 1.0,
1282
+ lattice_constant: Union[float, Any] = 1.0,
1095
1283
  pbc: Union[bool, Tuple[bool, bool]] = True,
1096
1284
  precompute_neighbors: Optional[int] = None,
1097
1285
  ):
1098
1286
  """Initializes the LiebLattice."""
1099
1287
  dimensionality = 2
1100
- # Use a more descriptive name for clarity. In a Lieb lattice,
1101
- # the lattice_constant is the bond length between nearest neighbors.
1102
1288
  bond_length = lattice_constant
1103
-
1104
- # The unit cell of a Lieb lattice is a square with side length
1105
- # equal to twice the bond length.
1106
- unit_cell_side = 2 * bond_length
1107
- lattice_vectors = np.array([[unit_cell_side, 0.0], [0.0, unit_cell_side]])
1108
-
1109
- # The three-site basis consists of a corner site, a site on the
1110
- # center of the horizontal edge, and a site on the center of the vertical edge.
1111
- # Their coordinates are defined directly in terms of the physical bond length.
1112
- basis_coords = np.array(
1289
+ bl_t = backend.convert_to_tensor(bond_length)
1290
+ unit_cell_side_t = 2 * bl_t
1291
+ # The Lieb lattice is based on a square Bravais lattice.
1292
+ z = bl_t * 0.0
1293
+ lattice_vectors = backend.stack(
1294
+ [backend.stack([unit_cell_side_t, z]), backend.stack([z, unit_cell_side_t])]
1295
+ )
1296
+ # It has a three-site basis: one corner and two edge-centers.
1297
+ basis_coords = backend.stack(
1113
1298
  [
1114
- [0.0, 0.0], # Corner site
1115
- [bond_length, 0.0], # Horizontal edge center
1116
- [0.0, bond_length], # Vertical edge center
1299
+ backend.stack([z, z]), # Corner site
1300
+ backend.stack([bl_t, z]), # x-edge center
1301
+ backend.stack([z, bl_t]), # y-edge center
1117
1302
  ]
1118
1303
  )
1119
1304
 
@@ -1146,14 +1331,24 @@ class CubicLattice(TILattice):
1146
1331
  def __init__(
1147
1332
  self,
1148
1333
  size: Tuple[int, int, int],
1149
- lattice_constant: float = 1.0,
1334
+ lattice_constant: Union[float, Any] = 1.0,
1150
1335
  pbc: Union[bool, Tuple[bool, bool, bool]] = True,
1151
1336
  precompute_neighbors: Optional[int] = None,
1152
1337
  ):
1153
1338
  dimensionality = 3
1154
1339
  a = lattice_constant
1155
- lattice_vectors = np.array([[a, 0, 0], [0, a, 0], [0, 0, a]])
1156
- basis_coords = np.array([[0.0, 0.0, 0.0]])
1340
+ a_t = backend.convert_to_tensor(a)
1341
+ # Orthogonal lattice vectors of equal length in 3D.
1342
+ z = a_t * 0.0
1343
+ lattice_vectors = backend.stack(
1344
+ [
1345
+ backend.stack([a_t, z, z]),
1346
+ backend.stack([z, a_t, z]),
1347
+ backend.stack([z, z, a_t]),
1348
+ ]
1349
+ )
1350
+ # A simple cubic lattice is a Bravais lattice with a single-site basis.
1351
+ basis_coords = backend.stack([backend.stack([z, z, z])])
1157
1352
  super().__init__(
1158
1353
  dimensionality=dimensionality,
1159
1354
  lattice_vectors=lattice_vectors,
@@ -1193,29 +1388,37 @@ class CustomizeLattice(AbstractLattice):
1193
1388
  self,
1194
1389
  dimensionality: int,
1195
1390
  identifiers: List[SiteIdentifier],
1196
- coordinates: List[Union[List[float], Coordinates]],
1391
+ coordinates: Any,
1197
1392
  precompute_neighbors: Optional[int] = None,
1198
1393
  ):
1199
1394
  """Initializes the CustomizeLattice."""
1200
1395
  super().__init__(dimensionality)
1201
- if len(identifiers) != len(coordinates):
1396
+
1397
+ self._coordinates = backend.convert_to_tensor(coordinates)
1398
+ if len(identifiers) == 0:
1399
+ self._coordinates = backend.reshape(
1400
+ self._coordinates, (0, self.dimensionality)
1401
+ )
1402
+
1403
+ if len(identifiers) != backend.shape_tuple(self._coordinates)[0]:
1202
1404
  raise ValueError(
1203
- "Identifiers and coordinates lists must have the same length."
1405
+ "The number of identifiers must match the number of coordinates. "
1406
+ f"Got {len(identifiers)} identifiers and "
1407
+ f"{backend.shape_tuple(self._coordinates)[0]} coordinates."
1204
1408
  )
1205
1409
 
1206
- # The _build_lattice logic is simple enough to be in __init__
1207
1410
  self._identifiers = list(identifiers)
1208
- self._coordinates = [np.array(c) for c in coordinates]
1209
1411
  self._indices = list(range(len(identifiers)))
1210
1412
  self._ident_to_idx = {ident: idx for idx, ident in enumerate(identifiers)}
1211
1413
 
1212
- # Validate coordinate dimensions
1213
- for i, coord in enumerate(self._coordinates):
1214
- if coord.shape != (dimensionality,):
1215
- raise ValueError(
1216
- f"Coordinate at index {i} has shape {coord.shape}, "
1217
- f"expected ({dimensionality},)"
1218
- )
1414
+ if (
1415
+ self.num_sites > 0
1416
+ and backend.shape_tuple(self._coordinates)[1] != dimensionality
1417
+ ):
1418
+ raise ValueError(
1419
+ f"Coordinates tensor has dimension {backend.shape_tuple(self._coordinates)[1]}, "
1420
+ f"but expected dimensionality is {dimensionality}."
1421
+ )
1219
1422
 
1220
1423
  logger.info(f"CustomizeLattice with {self.num_sites} sites created.")
1221
1424
 
@@ -1227,95 +1430,170 @@ class CustomizeLattice(AbstractLattice):
1227
1430
  pass
1228
1431
 
1229
1432
  def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None:
1230
- """Calculates neighbors using a KDTree for efficiency.
1231
-
1232
- This method uses a memory-efficient approach to identify neighbors without
1233
- initially computing the full N x N distance matrix. It leverages
1234
- `scipy.spatial.distance.pdist` to find unique distance shells and then
1235
- a `scipy.spatial.KDTree` for fast radius queries. This approach is
1236
- significantly more memory-efficient during the neighbor identification phase.
1433
+ """
1434
+ Calculates neighbor relationships using either KDTree or distance matrix methods.
1237
1435
 
1238
- After the neighbors are identified, the full distance matrix is computed
1239
- from the pairwise distances and cached for potential future use.
1436
+ This method supports two modes:
1437
+ 1. KDTree mode (use_kdtree=True): Fast, O(N log N) performance for large lattices
1438
+ but breaks differentiability due to scipy dependency
1439
+ 2. Distance matrix mode (use_kdtree=False): Slower O(N^2) but fully differentiable
1440
+ and backend-agnostic
1240
1441
 
1241
- :param max_k: The maximum number of neighbor shells to
1242
- calculate. Defaults to 1.
1243
- :type max_k: int, optional
1244
- :param tol: The numerical tolerance for distance
1245
- comparisons. Defaults to 1e-6.
1246
- :type tol: float, optional
1442
+ :param max_k: Maximum number of neighbor shells to compute
1443
+ :type max_k: int
1444
+ :param kwargs: Additional arguments including:
1445
+ - use_kdtree (bool): Whether to use KDTree optimization. Defaults to False.
1446
+ - tol (float): Distance tolerance for neighbor identification. Defaults to 1e-6.
1247
1447
  """
1248
1448
  tol = kwargs.get("tol", 1e-6)
1249
- logger.info(f"Building neighbors for CustomizeLattice up to k={max_k}...")
1449
+ # Reviewer suggestion: prefer differentiable method by default
1450
+ use_kdtree = kwargs.get("use_kdtree", False)
1451
+
1250
1452
  if self.num_sites < 2:
1251
1453
  return
1252
1454
 
1253
- all_coords = np.array(self._coordinates)
1455
+ # Choose algorithm based on user preference
1456
+ if use_kdtree:
1457
+ logger.info(
1458
+ f"Using KDTree method for {self.num_sites} sites up to k={max_k}"
1459
+ )
1460
+ self._build_neighbors_kdtree(max_k, tol)
1461
+ else:
1462
+ logger.info(
1463
+ f"Using differentiable distance matrix method for {self.num_sites} sites up to k={max_k}"
1464
+ )
1254
1465
 
1255
- # 1. Use pdist for memory-efficient calculation of pairwise distances
1256
- # to robustly identify the distance shells.
1257
- all_distances_sq = pdist(all_coords, metric="sqeuclidean")
1258
- dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
1466
+ # Use the existing distance matrix method
1467
+ self._build_neighbors_by_distance_matrix(max_k, tol)
1259
1468
 
1260
- if not dist_shells_sq:
1261
- logger.info("No distinct neighbor shells found.")
1262
- return
1469
+ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None:
1470
+ """
1471
+ Build neighbors using KDTree for optimal performance.
1263
1472
 
1264
- # 2. Build the KDTree for efficient querying.
1265
- tree = KDTree(all_coords)
1266
- self._neighbor_maps = {k: {} for k in range(1, len(dist_shells_sq) + 1)}
1473
+ This method provides O(N log N) performance for neighbor finding but breaks
1474
+ differentiability due to scipy dependency. Use this method when:
1475
+ - Performance is critical
1476
+ - Differentiability is not required
1477
+ - Large lattices (N > 1000)
1267
1478
 
1268
- # 3. Find neighbors by isolating shells using inclusion-exclusion.
1269
- # `found_indices` will store all neighbors within a given radius.
1270
- found_indices: List[set[int]] = []
1271
- for k_idx, target_d_sq in enumerate(dist_shells_sq):
1272
- radius = np.sqrt(target_d_sq) + tol
1273
- # Query for all points within the new, larger radius.
1274
- current_shell_indices = tree.query_ball_point(
1275
- all_coords, r=radius, return_sorted=True
1479
+ Note: This method uses numpy arrays directly and may not be compatible
1480
+ with all backend types (JAX, TensorFlow, etc.).
1481
+ """
1482
+
1483
+ # For small lattices or cases with potential duplicate coordinates,
1484
+ # fall back to distance matrix method for robustness
1485
+ if self.num_sites < 200:
1486
+ logger.info(
1487
+ "Small lattice detected, falling back to distance matrix method for robustness"
1276
1488
  )
1489
+ self._build_neighbors_by_distance_matrix(max_k, tol)
1490
+ return
1277
1491
 
1278
- # Now, isolate the neighbors for the current shell k
1279
- k = k_idx + 1
1280
- current_k_map: Dict[int, List[int]] = {}
1281
- for i in range(self.num_sites):
1492
+ # Convert coordinates to numpy for KDTree
1493
+ coords_np = backend.numpy(self._coordinates)
1282
1494
 
1283
- if k_idx == 0:
1284
- co_located_indices = tree.query_ball_point(all_coords[i], r=1e-12)
1285
- prev_found = set(co_located_indices)
1286
- else:
1287
- prev_found = found_indices[i]
1495
+ # Build KDTree
1496
+ logger.info("Building KDTree...")
1497
+ tree = KDTree(coords_np)
1498
+ # Find all distances for shell identification - use comprehensive sampling
1499
+ logger.info("Identifying distance shells...")
1500
+ distances_for_shells: List[float] = []
1288
1501
 
1289
- # The new neighbors are those in the current radius shell,
1290
- # excluding those already found in smaller shells.
1291
- new_neighbors = set(current_shell_indices[i]) - prev_found
1502
+ # For robust shell identification, query all pairwise distances for smaller lattices
1503
+ # or use dense sampling for larger ones
1504
+ if self.num_sites <= 100:
1505
+ # For small lattices, compute all pairwise distances for accuracy
1506
+ for i in range(self.num_sites):
1507
+ query_k = min(self.num_sites - 1, max_k * 20)
1508
+ if query_k > 0:
1509
+ dists, _ = tree.query(
1510
+ coords_np[i], k=query_k + 1
1511
+ ) # +1 to exclude self
1512
+ if isinstance(dists, np.ndarray):
1513
+ distances_for_shells.extend(dists[1:]) # Skip distance to self
1514
+ else:
1515
+ distances_for_shells.append(dists) # Single distance
1516
+ else:
1517
+ # For larger lattices, use adaptive sampling but ensure we capture all shells
1518
+ sample_size = min(1000, self.num_sites // 2) # More conservative sampling
1519
+ for i in range(0, self.num_sites, max(1, self.num_sites // sample_size)):
1520
+ query_k = min(max_k * 20 + 50, self.num_sites - 1)
1521
+ if query_k > 0:
1522
+ dists, _ = tree.query(
1523
+ coords_np[i], k=query_k + 1
1524
+ ) # +1 to exclude self
1525
+ if isinstance(dists, np.ndarray):
1526
+ distances_for_shells.extend(dists[1:]) # Skip distance to self
1527
+ else:
1528
+ distances_for_shells.append(dists) # Single distance
1292
1529
 
1293
- if new_neighbors:
1294
- current_k_map[i] = sorted(list(new_neighbors))
1530
+ # Filter out zero distances (duplicate coordinates) before shell identification
1531
+ ZERO_THRESHOLD = 1e-12
1532
+ distances_for_shells = [d for d in distances_for_shells if d > ZERO_THRESHOLD]
1295
1533
 
1296
- self._neighbor_maps[k] = current_k_map
1297
- found_indices = [
1298
- set(l) for l in current_shell_indices
1299
- ] # Update for next iteration
1300
- self._distance_matrix = np.sqrt(squareform(all_distances_sq))
1534
+ if not distances_for_shells:
1535
+ logger.warning("No valid distances found for shell identification")
1536
+ self._neighbor_maps = {}
1537
+ return
1301
1538
 
1302
- logger.info("Neighbor building complete using KDTree.")
1539
+ # Use the same shell identification logic as distance matrix method
1540
+ distances_for_shells_sq = [d * d for d in distances_for_shells]
1541
+ dist_shells_sq = self._identify_distance_shells(
1542
+ distances_for_shells_sq, max_k, tol
1543
+ )
1544
+ dist_shells = [np.sqrt(d_sq) for d_sq in dist_shells_sq]
1303
1545
 
1304
- def _compute_distance_matrix(self) -> Coordinates:
1305
- """Computes the distance matrix from the stored coordinates.
1546
+ logger.info(f"Found {len(dist_shells)} distance shells: {dist_shells[:5]}...")
1306
1547
 
1307
- This implementation uses scipy.pdist for a memory-efficient
1308
- calculation of pairwise distances, which is then converted to a
1309
- full square matrix.
1310
- """
1311
- if self.num_sites < 2:
1312
- return cast(Coordinates, np.empty((self.num_sites, self.num_sites)))
1548
+ # Initialize neighbor maps
1549
+ self._neighbor_maps = {k: {} for k in range(1, len(dist_shells) + 1)}
1313
1550
 
1314
- all_coords = np.array(self._coordinates)
1315
- # Use pdist for memory-efficiency, then build the full matrix.
1316
- all_distances_sq = pdist(all_coords, metric="sqeuclidean")
1317
- dist_matrix_sq = squareform(all_distances_sq)
1318
- return cast(Coordinates, np.sqrt(dist_matrix_sq))
1551
+ # Build neighbor lists for each site
1552
+ for i in range(self.num_sites):
1553
+ # Query enough neighbors to capture all shells
1554
+ query_k = min(max_k * 20 + 50, self.num_sites - 1)
1555
+ if query_k > 0:
1556
+ distances, indices = tree.query(
1557
+ coords_np[i], k=query_k + 1
1558
+ ) # +1 for self
1559
+
1560
+ # Skip the first entry (distance to self)
1561
+ # Handle both single value and array cases
1562
+ if isinstance(distances, np.ndarray) and len(distances) > 1:
1563
+ distances_slice = distances[1:]
1564
+ indices_slice = (
1565
+ indices[1:]
1566
+ if isinstance(indices, np.ndarray)
1567
+ else np.array([], dtype=int)
1568
+ )
1569
+ else:
1570
+ # Single value or empty case - no neighbors to process
1571
+ distances_slice = np.array([])
1572
+ indices_slice = np.array([], dtype=int)
1573
+
1574
+ # Filter out zero distances (duplicate coordinates)
1575
+ valid_pairs = [
1576
+ (d, idx)
1577
+ for d, idx in zip(distances_slice, indices_slice)
1578
+ if d > ZERO_THRESHOLD
1579
+ ]
1580
+
1581
+ # Assign neighbors to shells
1582
+ for shell_idx, shell_dist in enumerate(dist_shells):
1583
+ k = shell_idx + 1
1584
+ shell_neighbors = []
1585
+
1586
+ for dist, neighbor_idx in valid_pairs:
1587
+ if abs(dist - shell_dist) <= tol:
1588
+ shell_neighbors.append(int(neighbor_idx))
1589
+ elif dist > shell_dist + tol:
1590
+ break # Distances are sorted, no more matches
1591
+
1592
+ if shell_neighbors:
1593
+ self._neighbor_maps[k][i] = sorted(shell_neighbors)
1594
+
1595
+ # Set distance matrix to None - will compute on demand
1596
+ self._distance_matrix = None
1319
1597
 
1320
1598
  def _reset_computations(self) -> None:
1321
1599
  """Resets all cached data that depends on the lattice structure."""
@@ -1343,18 +1621,28 @@ class CustomizeLattice(AbstractLattice):
1343
1621
  )
1344
1622
 
1345
1623
  # Unzip the list of tuples into separate lists of identifiers and coordinates
1346
- _, identifiers, coordinates = zip(*all_sites_info)
1624
+ _, identifiers, _ = zip(*all_sites_info)
1625
+
1626
+ # Detach-and-copy coordinates while remaining in tensor form to avoid
1627
+ # host roundtrips and device/dtype changes; this keeps CustomizeLattice
1628
+ # decoupled from the original graph but backend-friendly.
1629
+ # Some backends (e.g., NumPy) don't implement stop_gradient; fall back.
1630
+ try:
1631
+ coords_detached = backend.stop_gradient(lattice._coordinates)
1632
+ except NotImplementedError:
1633
+ coords_detached = lattice._coordinates
1634
+ coords_tensor = backend.copy(coords_detached)
1347
1635
 
1348
1636
  return cls(
1349
1637
  dimensionality=lattice.dimensionality,
1350
1638
  identifiers=list(identifiers),
1351
- coordinates=list(coordinates),
1639
+ coordinates=coords_tensor,
1352
1640
  )
1353
1641
 
1354
1642
  def add_sites(
1355
1643
  self,
1356
1644
  identifiers: List[SiteIdentifier],
1357
- coordinates: List[Union[List[float], Coordinates]],
1645
+ coordinates: Any,
1358
1646
  ) -> None:
1359
1647
  """Adds new sites to the lattice.
1360
1648
 
@@ -1362,21 +1650,29 @@ class CustomizeLattice(AbstractLattice):
1362
1650
  previously computed neighbor information is cleared and must be
1363
1651
  recalculated.
1364
1652
 
1365
- :param identifiers: A list of unique, hashable identifiers for the new sites.
1653
+ :param identifiers: A list of unique identifiers for the new sites.
1366
1654
  :type identifiers: List[SiteIdentifier]
1367
- :param coordinates: A list of coordinates for the new sites.
1368
- :type coordinates: List[Union[List[float], np.ndarray]]
1369
- :raises ValueError: If input lists have mismatched lengths, or if any new
1370
- identifier already exists in the lattice.
1655
+ :param coordinates: The coordinates for the new sites. Can be a list of lists,
1656
+ a NumPy array, or a backend-compatible tensor (e.g., jax.numpy.ndarray).
1657
+ :type coordinates: Any
1371
1658
  """
1372
- if len(identifiers) != len(coordinates):
1659
+ if not identifiers:
1660
+ return
1661
+
1662
+ new_coords_tensor = backend.convert_to_tensor(coordinates)
1663
+
1664
+ if len(identifiers) != backend.shape_tuple(new_coords_tensor)[0]:
1373
1665
  raise ValueError(
1374
1666
  "Identifiers and coordinates lists must have the same length."
1375
1667
  )
1376
- if not identifiers:
1377
- return # Nothing to add
1378
1668
 
1379
- # Check for duplicate identifiers before making any changes
1669
+ if backend.shape_tuple(new_coords_tensor)[1] != self.dimensionality:
1670
+ raise ValueError(
1671
+ f"New coordinate tensor has dimension {backend.shape_tuple(new_coords_tensor)[1]}, "
1672
+ f"but expected dimensionality is {self.dimensionality}."
1673
+ )
1674
+
1675
+ # Ensure that the new identifiers are unique and do not already exist.
1380
1676
  existing_ids = set(self._identifiers)
1381
1677
  new_ids = set(identifiers)
1382
1678
  if not new_ids.isdisjoint(existing_ids):
@@ -1384,21 +1680,14 @@ class CustomizeLattice(AbstractLattice):
1384
1680
  f"Duplicate identifiers found: {new_ids.intersection(existing_ids)}"
1385
1681
  )
1386
1682
 
1387
- for i, coord in enumerate(coordinates):
1388
- coord_arr = np.asarray(coord)
1389
- if coord_arr.shape != (self.dimensionality,):
1390
- raise ValueError(
1391
- f"New coordinate at index {i} has shape {coord_arr.shape}, "
1392
- f"expected ({self.dimensionality},)"
1393
- )
1394
- self._coordinates.append(coord_arr)
1395
- self._identifiers.append(identifiers[i])
1683
+ self._coordinates = backend.concat(
1684
+ [self._coordinates, new_coords_tensor], axis=0
1685
+ )
1686
+ self._identifiers.extend(identifiers)
1396
1687
 
1397
- # Rebuild index mappings from scratch
1398
1688
  self._indices = list(range(len(self._identifiers)))
1399
1689
  self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)}
1400
1690
 
1401
- # Invalidate any previously computed neighbors or distance matrices
1402
1691
  self._reset_computations()
1403
1692
  logger.info(
1404
1693
  f"{len(identifiers)} sites added. Lattice now has {self.num_sites} sites."
@@ -1413,10 +1702,9 @@ class CustomizeLattice(AbstractLattice):
1413
1702
 
1414
1703
  :param identifiers: A list of identifiers for the sites to be removed.
1415
1704
  :type identifiers: List[SiteIdentifier]
1416
- :raises ValueError: If any of the specified identifiers do not exist.
1417
1705
  """
1418
1706
  if not identifiers:
1419
- return # Nothing to remove
1707
+ return
1420
1708
 
1421
1709
  ids_to_remove = set(identifiers)
1422
1710
  current_ids = set(self._identifiers)
@@ -1425,24 +1713,77 @@ class CustomizeLattice(AbstractLattice):
1425
1713
  f"Non-existent identifiers provided for removal: {ids_to_remove - current_ids}"
1426
1714
  )
1427
1715
 
1428
- # Create new lists containing only the sites to keep
1429
- new_identifiers: List[SiteIdentifier] = []
1430
- new_coordinates: List[Coordinates] = []
1431
- for ident, coord in zip(self._identifiers, self._coordinates):
1432
- if ident not in ids_to_remove:
1433
- new_identifiers.append(ident)
1434
- new_coordinates.append(coord)
1716
+ # Find the indices of the sites that we want to keep.
1717
+ indices_to_keep = [
1718
+ idx
1719
+ for idx, ident in enumerate(self._identifiers)
1720
+ if ident not in ids_to_remove
1721
+ ]
1722
+
1723
+ new_identifiers = [self._identifiers[i] for i in indices_to_keep]
1724
+
1725
+ self._coordinates = backend.gather1d(
1726
+ self._coordinates,
1727
+ backend.cast(backend.convert_to_tensor(indices_to_keep), "int32"),
1728
+ )
1435
1729
 
1436
- # Replace old data with the new, filtered data
1437
1730
  self._identifiers = new_identifiers
1438
- self._coordinates = new_coordinates
1439
1731
 
1440
- # Rebuild index mappings
1441
1732
  self._indices = list(range(len(self._identifiers)))
1442
1733
  self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)}
1443
1734
 
1444
- # Invalidate caches
1445
1735
  self._reset_computations()
1446
1736
  logger.info(
1447
1737
  f"{len(ids_to_remove)} sites removed. Lattice now has {self.num_sites} sites."
1448
1738
  )
1739
+
1740
+
1741
+ def get_compatible_layers(bonds: List[Tuple[int, int]]) -> List[List[Tuple[int, int]]]:
1742
+ """
1743
+ Partitions a list of pairs (bonds) into compatible layers for parallel
1744
+ gate application using a greedy edge-coloring algorithm.
1745
+
1746
+ This function takes a list of pairs, representing connections like
1747
+ nearest-neighbor (NN) or next-nearest-neighbor (NNN) bonds, and
1748
+ partitions them into the minimum number of sets ("layers") where no two
1749
+ pairs in a set share an index. This is a general utility for scheduling
1750
+ non-overlapping operations.
1751
+
1752
+ :Example:
1753
+
1754
+ >>> from tensorcircuit.templates.lattice import SquareLattice
1755
+ >>> sq_lattice = SquareLattice(size=(2, 2), pbc=False)
1756
+ >>> nn_bonds = sq_lattice.get_neighbor_pairs(k=1, unique=True)
1757
+
1758
+ >>> gate_layers = get_compatible_layers(nn_bonds)
1759
+ >>> print(gate_layers)
1760
+ [[[0, 1], [2, 3]], [[0, 2], [1, 3]]]
1761
+
1762
+ :param bonds: A list of tuples, where each tuple represents a bond (i, j)
1763
+ of site indices to be scheduled.
1764
+ :type bonds: List[Tuple[int, int]]
1765
+ :return: A list of layers. Each layer is a list of tuples, where each
1766
+ tuple represents a bond. All bonds within a layer are non-overlapping.
1767
+ :rtype: List[List[Tuple[int, int]]]
1768
+ """
1769
+ uncolored_edges: Set[Tuple[int, int]] = {(min(bond), max(bond)) for bond in bonds}
1770
+
1771
+ layers: List[List[Tuple[int, int]]] = []
1772
+
1773
+ while uncolored_edges:
1774
+ current_layer: List[Tuple[int, int]] = []
1775
+ qubits_in_this_layer: Set[int] = set()
1776
+
1777
+ edges_to_process = sorted(list(uncolored_edges))
1778
+
1779
+ for edge in edges_to_process:
1780
+ i, j = edge
1781
+ if i not in qubits_in_this_layer and j not in qubits_in_this_layer:
1782
+ current_layer.append(edge)
1783
+ qubits_in_this_layer.add(i)
1784
+ qubits_in_this_layer.add(j)
1785
+
1786
+ uncolored_edges -= set(current_layer)
1787
+ layers.append(sorted(current_layer))
1788
+
1789
+ return layers