tensorcircuit-nightly 1.3.0.dev20250815__py3-none-any.whl → 1.3.0.dev20250817__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

@@ -18,11 +18,12 @@ from typing import (
18
18
  Set,
19
19
  )
20
20
 
21
- logger = logging.getLogger(__name__)
21
+ import itertools
22
+ import math
22
23
  import numpy as np
23
-
24
24
  from scipy.spatial import KDTree
25
- from scipy.spatial.distance import pdist, squareform
25
+
26
+ from .. import backend
26
27
 
27
28
 
28
29
  # This block resolves a name resolution issue for the static type checker (mypy).
@@ -41,9 +42,13 @@ if TYPE_CHECKING:
41
42
  import matplotlib.axes
42
43
  from mpl_toolkits.mplot3d import Axes3D
43
44
 
45
+ logger = logging.getLogger(__name__)
46
+
47
+ Tensor = Any
44
48
  SiteIndex = int
45
49
  SiteIdentifier = Hashable
46
- Coordinates = np.ndarray[Any, Any]
50
+ Coordinates = Tensor
51
+
47
52
  NeighborMap = Dict[SiteIndex, List[SiteIndex]]
48
53
 
49
54
 
@@ -64,13 +69,27 @@ class AbstractLattice(abc.ABC):
64
69
  """Initializes the base lattice class."""
65
70
  self._dimensionality = dimensionality
66
71
 
67
- # --- Internal Data Structures (to be populated by subclasses) ---
68
- self._indices: List[SiteIndex] = []
69
- self._identifiers: List[SiteIdentifier] = []
70
- self._coordinates: List[Coordinates] = []
71
- self._ident_to_idx: Dict[SiteIdentifier, SiteIndex] = {}
72
- self._neighbor_maps: Dict[int, NeighborMap] = {}
73
- self._distance_matrix: Optional[Coordinates] = None
72
+ # Core data structures for storing site information.
73
+ self._indices: List[SiteIndex] = [] # List of integer indices [0, 1, ..., N-1]
74
+ self._identifiers: List[SiteIdentifier] = (
75
+ []
76
+ ) # List of unique, hashable site identifiers
77
+ # Always initialize to an empty coordinate tensor with correct dimensionality
78
+ # so that type checkers know this is indexable and not Optional.
79
+ self._coordinates: Coordinates = backend.zeros((0, dimensionality))
80
+
81
+ # Mappings for efficient lookups.
82
+ self._ident_to_idx: Dict[SiteIdentifier, SiteIndex] = (
83
+ {}
84
+ ) # Maps identifiers to indices
85
+
86
+ # Cached properties, computed on demand.
87
+ self._neighbor_maps: Dict[int, NeighborMap] = (
88
+ {}
89
+ ) # Caches neighbor info for different k
90
+ self._distance_matrix: Optional[Coordinates] = (
91
+ None # Caches the full N x N distance matrix
92
+ )
74
93
 
75
94
  @property
76
95
  def num_sites(self) -> int:
@@ -95,7 +114,6 @@ class AbstractLattice(abc.ABC):
95
114
  subsequent calls. This computation can be expensive for large lattices.
96
115
  """
97
116
  if self._distance_matrix is None:
98
- logger.info("Distance matrix not cached. Computing now...")
99
117
  self._distance_matrix = self._compute_distance_matrix()
100
118
  return self._distance_matrix
101
119
 
@@ -116,7 +134,8 @@ class AbstractLattice(abc.ABC):
116
134
  :rtype: Coordinates
117
135
  """
118
136
  self._validate_index(index)
119
- return self._coordinates[index]
137
+ coords = self._coordinates[index]
138
+ return coords
120
139
 
121
140
  def get_identifier(self, index: SiteIndex) -> SiteIdentifier:
122
141
  """Gets the abstract identifier of a site by its integer index.
@@ -140,7 +159,8 @@ class AbstractLattice(abc.ABC):
140
159
  :rtype: SiteIndex
141
160
  """
142
161
  try:
143
- return self._ident_to_idx[identifier]
162
+ index = self._ident_to_idx[identifier]
163
+ return index
144
164
  except KeyError as e:
145
165
  raise ValueError(
146
166
  f"Identifier {identifier} not found in the lattice."
@@ -170,7 +190,7 @@ class AbstractLattice(abc.ABC):
170
190
  idx = index_or_identifier
171
191
  self._validate_index(idx)
172
192
  return idx, self._identifiers[idx], self._coordinates[idx]
173
- else: # Identifier
193
+ else:
174
194
  ident = index_or_identifier
175
195
  idx = self.get_index(ident)
176
196
  return idx, ident, self._coordinates[idx]
@@ -237,7 +257,6 @@ class AbstractLattice(abc.ABC):
237
257
  )
238
258
  self._build_neighbors(max_k=k)
239
259
 
240
- # After attempting to build, check again. If still not found, return empty.
241
260
  if k not in self._neighbor_maps:
242
261
  return []
243
262
 
@@ -251,8 +270,28 @@ class AbstractLattice(abc.ABC):
251
270
  pairs.append((i, j))
252
271
  return sorted(pairs)
253
272
 
254
- # Sorting provides a deterministic output order
255
- # --- Abstract Methods for Subclass Implementation ---
273
+ def get_all_pairs(self) -> List[Tuple[SiteIndex, SiteIndex]]:
274
+ """
275
+ Returns a list of all unique pairs of site indices (i, j) where i < j.
276
+
277
+ This method provides all-to-all connectivity, useful for Hamiltonians
278
+ where every site interacts with every other site.
279
+
280
+ Note on Differentiability:
281
+ This method provides a static list of index pairs and is not differentiable
282
+ itself. However, it is designed to be used in combination with the fully
283
+ differentiable ``distance_matrix`` property. By using the pairs from this
284
+ method to index into the ``distance_matrix``, one can construct differentiable
285
+ objective functions based on all-pair interactions, effectively bypassing the
286
+ non-differentiable ``get_neighbor_pairs`` method for geometry optimization tasks.
287
+
288
+ :return: A list of tuples, where each tuple is a unique pair of site indices.
289
+ :rtype: List[Tuple[SiteIndex, SiteIndex]]
290
+ """
291
+ if self.num_sites < 2:
292
+ return []
293
+ # Use itertools.combinations to efficiently generate all unique pairs (i, j) with i < j.
294
+ return sorted(list(itertools.combinations(range(self.num_sites), 2)))
256
295
 
257
296
  @abc.abstractmethod
258
297
  def _build_lattice(self, *args: Any, **kwargs: Any) -> None:
@@ -281,14 +320,24 @@ class AbstractLattice(abc.ABC):
281
320
  """
282
321
  pass
283
322
 
284
- @abc.abstractmethod
285
323
  def _compute_distance_matrix(self) -> Coordinates:
286
324
  """
287
- Abstract method for subclasses to implement the actual matrix calculation.
288
- This method is called by the `distance_matrix` property when the matrix
289
- needs to be computed for the first time.
325
+ Default generic distance matrix computation (no periodic images).
326
+
327
+ Subclasses can override this when a specialized rule is required
328
+ (e.g., applying Minimum Image Convention for PBC in TILattice).
290
329
  """
291
- pass
330
+ # Handle empty lattices and trivial 1-site lattices
331
+ if self.num_sites == 0:
332
+ return backend.zeros((0, 0))
333
+
334
+ # Vectorized pairwise Euclidean distances
335
+ all_coords = self._coordinates
336
+ displacements = backend.expand_dims(all_coords, 1) - backend.expand_dims(
337
+ all_coords, 0
338
+ )
339
+ dist_matrix_sq = backend.sum(displacements**2, axis=-1)
340
+ return backend.sqrt(dist_matrix_sq)
292
341
 
293
342
  def show(
294
343
  self,
@@ -328,13 +377,14 @@ class AbstractLattice(abc.ABC):
328
377
  try:
329
378
  import matplotlib.pyplot as plt
330
379
  except ImportError:
331
- logger.error(
380
+ logger.warning(
332
381
  "Matplotlib is required for visualization. "
333
382
  "Please install it using 'pip install matplotlib'."
334
383
  )
335
384
  return
336
385
 
337
- # creat "fig_created_internally" as flag
386
+ # Flag to track if the Matplotlib figure was created by this method.
387
+ # This prevents calling plt.show() if the user provided their own Axes.
338
388
  fig_created_internally = False
339
389
 
340
390
  if self.num_sites == 0:
@@ -347,7 +397,7 @@ class AbstractLattice(abc.ABC):
347
397
  return
348
398
 
349
399
  if ax is None:
350
- # when ax is none, make fig_created_internally true
400
+ # If no Axes object is provided, create a new figure and axes.
351
401
  fig_created_internally = True
352
402
  if self.dimensionality == 3:
353
403
  fig = plt.figure(figsize=(8, 8))
@@ -358,6 +408,7 @@ class AbstractLattice(abc.ABC):
358
408
  fig = ax.figure # type: ignore
359
409
 
360
410
  coords = np.array(self._coordinates)
411
+ # Prepare arguments for the scatter plot, allowing user overrides.
361
412
  scatter_args = {"s": 100, "zorder": 2}
362
413
  scatter_args.update(kwargs)
363
414
  if self.dimensionality == 1:
@@ -371,12 +422,11 @@ class AbstractLattice(abc.ABC):
371
422
  if show_indices or show_identifiers:
372
423
  for i in range(self.num_sites):
373
424
  label = str(self._identifiers[i]) if show_identifiers else str(i)
425
+ # Calculate a small offset for placing text labels to avoid overlap with sites.
374
426
  offset = (
375
427
  0.02 * np.max(np.ptp(coords, axis=0)) if coords.size > 0 else 0.1
376
428
  )
377
429
 
378
- # Robust Logic: Decide plotting strategy based on known dimensionality.
379
-
380
430
  if self.dimensionality == 1:
381
431
  ax.text(coords[i, 0], offset, label, fontsize=9, ha="center")
382
432
  elif self.dimensionality == 2:
@@ -398,8 +448,6 @@ class AbstractLattice(abc.ABC):
398
448
  zorder=3,
399
449
  )
400
450
 
401
- # Note: No 'else' needed as we already check dimensionality at the start.
402
-
403
451
  if show_bonds_k is not None:
404
452
  if show_bonds_k not in self._neighbor_maps:
405
453
  logger.warning(
@@ -433,7 +481,7 @@ class AbstractLattice(abc.ABC):
433
481
  if self.dimensionality == 1: # type: ignore
434
482
 
435
483
  ax.plot([p1[0], p2[0]], [0, 0], **plot_bond_kwargs) # type: ignore
436
- else: # dimensionality == 2
484
+ else:
437
485
  ax.plot([p1[0], p2[0]], [p1[1], p2[1]], **plot_bond_kwargs) # type: ignore
438
486
 
439
487
  except ValueError as e:
@@ -449,7 +497,7 @@ class AbstractLattice(abc.ABC):
449
497
  ax.set_zlabel("z")
450
498
  ax.grid(True)
451
499
 
452
- # 3. whether plt.show()
500
+ # Display the plot only if the figure was created within this function.
453
501
  if fig_created_internally:
454
502
  plt.show()
455
503
 
@@ -475,26 +523,28 @@ class AbstractLattice(abc.ABC):
475
523
  :return: A sorted list of squared distances representing the shells.
476
524
  :rtype: List[float]
477
525
  """
526
+ # A small threshold to filter out zero distances (site to itself).
478
527
  ZERO_THRESHOLD_SQ = 1e-12
479
528
 
480
- all_distances_sq = np.asarray(all_distances_sq)
529
+ all_distances_sq = backend.convert_to_tensor(all_distances_sq)
481
530
  # Now, the .size call below is guaranteed to be safe.
482
- if all_distances_sq.size == 0:
531
+ if backend.sizen(all_distances_sq) == 0:
483
532
  return []
484
533
 
485
- sorted_dist = np.sort(all_distances_sq[all_distances_sq > ZERO_THRESHOLD_SQ])
534
+ # Filter out self-distances and sort the remaining squared distances.
535
+ sorted_dist = backend.sort(
536
+ all_distances_sq[all_distances_sq > ZERO_THRESHOLD_SQ]
537
+ )
486
538
 
487
- if sorted_dist.size == 0:
539
+ if backend.sizen(sorted_dist) == 0:
488
540
  return []
489
541
 
490
- # Identify shells using the user-provided tolerance.
491
542
  dist_shells = [sorted_dist[0]]
492
543
 
493
544
  for d_sq in sorted_dist[1:]:
494
545
  if len(dist_shells) >= max_k:
495
546
  break
496
- # If the current distance is notably larger than the last shell's distance
497
- if d_sq > dist_shells[-1] + tol**2:
547
+ if backend.sqrt(d_sq) - backend.sqrt(dist_shells[-1]) > tol:
498
548
  dist_shells.append(d_sq)
499
549
 
500
550
  return dist_shells
@@ -503,11 +553,9 @@ class AbstractLattice(abc.ABC):
503
553
  self, max_k: int = 2, tol: float = 1e-6
504
554
  ) -> None:
505
555
  """A generic, distance-based neighbor finding method.
506
-
507
556
  This method calculates the full N x N distance matrix to find neighbor
508
557
  shells. It is computationally expensive for large N (O(N^2)) and is
509
558
  best suited for non-periodic or custom-defined lattices.
510
-
511
559
  :param max_k: The maximum number of neighbor shells to
512
560
  calculate. Defaults to 2.
513
561
  :type max_k: int, optional
@@ -518,26 +566,55 @@ class AbstractLattice(abc.ABC):
518
566
  if self.num_sites < 2:
519
567
  return
520
568
 
521
- all_coords = np.array(self._coordinates)
522
- dist_matrix_sq = np.sum(
523
- (all_coords[:, np.newaxis, :] - all_coords[np.newaxis, :, :]) ** 2, axis=-1
569
+ all_coords = self._coordinates
570
+ # Vectorized computation of the squared distance matrix:
571
+ # (N, 1, D) - (1, N, D) -> (N, N, D) -> (N, N)
572
+ displacements = backend.expand_dims(all_coords, 1) - backend.expand_dims(
573
+ all_coords, 0
524
574
  )
575
+ dist_matrix_sq = backend.sum(displacements**2, axis=-1)
525
576
 
526
- all_distances_sq = dist_matrix_sq.flatten()
577
+ # Flatten the matrix to a list of all squared distances to identify shells.
578
+ all_distances_sq = backend.reshape(dist_matrix_sq, [-1])
527
579
  dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
528
580
 
529
- self._neighbor_maps = {k: {} for k in range(1, len(dist_shells_sq) + 1)}
581
+ self._neighbor_maps = self._build_neighbor_map_from_distances(
582
+ dist_matrix_sq, dist_shells_sq, tol
583
+ )
584
+ self._distance_matrix = backend.sqrt(dist_matrix_sq)
585
+
586
+ def _build_neighbor_map_from_distances(
587
+ self,
588
+ dist_matrix_sq: Coordinates,
589
+ dist_shells_sq: List[float],
590
+ tol: float = 1e-6,
591
+ ) -> Dict[int, NeighborMap]:
592
+ """
593
+ Builds a neighbor map from a squared distance matrix and identified shells.
594
+ This is a generic helper function to reduce code duplication.
595
+ """
596
+ neighbor_maps: Dict[int, NeighborMap] = {
597
+ k: {} for k in range(1, len(dist_shells_sq) + 1)
598
+ }
530
599
  for k_idx, target_d_sq in enumerate(dist_shells_sq):
531
600
  k = k_idx + 1
532
601
  current_k_map: Dict[int, List[int]] = {}
533
- for i in range(self.num_sites):
534
- neighbor_indices = np.where(
535
- np.isclose(dist_matrix_sq[i], target_d_sq, rtol=0, atol=tol**2)
536
- )[0]
537
- if len(neighbor_indices) > 0:
538
- current_k_map[i] = sorted(neighbor_indices.tolist())
539
- self._neighbor_maps[k] = current_k_map
540
- self._distance_matrix = np.sqrt(dist_matrix_sq)
602
+ # For each shell, find all pairs of sites (i, j) with that distance.
603
+ is_close_matrix = backend.abs(dist_matrix_sq - target_d_sq) < tol
604
+ rows, cols = backend.where(is_close_matrix)
605
+
606
+ for i, j in zip(backend.numpy(rows), backend.numpy(cols)):
607
+ if i == j:
608
+ continue
609
+ if i not in current_k_map:
610
+ current_k_map[i] = []
611
+ current_k_map[i].append(j)
612
+
613
+ for i in current_k_map:
614
+ current_k_map[i].sort()
615
+
616
+ neighbor_maps[k] = current_k_map
617
+ return neighbor_maps
541
618
 
542
619
 
543
620
  class TILattice(AbstractLattice):
@@ -588,150 +665,197 @@ class TILattice(AbstractLattice):
588
665
  ):
589
666
  """Initializes the Translationally Invariant Lattice."""
590
667
  super().__init__(dimensionality)
591
- assert lattice_vectors.shape == (
592
- dimensionality,
593
- dimensionality,
594
- ), "Lattice vectors shape mismatch"
595
- assert (
596
- basis_coords.shape[1] == dimensionality
597
- ), "Basis coordinates dimension mismatch"
598
- assert len(size) == dimensionality, "Size tuple length mismatch"
599
-
600
- self.lattice_vectors = lattice_vectors
601
- self.basis_coords = basis_coords
602
- self.num_basis = basis_coords.shape[0]
668
+
669
+ self.lattice_vectors = backend.convert_to_tensor(lattice_vectors)
670
+ self.basis_coords = backend.convert_to_tensor(basis_coords)
671
+
672
+ if self.lattice_vectors.shape != (dimensionality, dimensionality):
673
+ raise ValueError(
674
+ f"Lattice vectors shape {self.lattice_vectors.shape} does not match "
675
+ f"expected ({dimensionality}, {dimensionality})"
676
+ )
677
+ if self.basis_coords.shape[1] != dimensionality:
678
+ raise ValueError(
679
+ f"Basis coordinates dimension {self.basis_coords.shape[1]} does not "
680
+ f"match lattice dimensionality {dimensionality}"
681
+ )
682
+ if len(size) != dimensionality:
683
+ raise ValueError(
684
+ f"Size tuple length {len(size)} does not match dimensionality {dimensionality}"
685
+ )
686
+
687
+ self.num_basis = self.basis_coords.shape[0]
603
688
  self.size = size
604
689
  if isinstance(pbc, bool):
605
690
  self.pbc = tuple([pbc] * dimensionality)
606
691
  else:
607
- assert len(pbc) == dimensionality, "PBC tuple length mismatch"
692
+ if len(pbc) != dimensionality:
693
+ raise ValueError(
694
+ f"PBC tuple length {len(pbc)} does not match dimensionality {dimensionality}"
695
+ )
608
696
  self.pbc = tuple(pbc)
609
697
 
610
- # Build the lattice sites and their neighbor relationships
611
698
  self._build_lattice()
612
699
  if precompute_neighbors is not None and precompute_neighbors > 0:
613
700
  logger.info(f"Pre-computing neighbors up to k={precompute_neighbors}...")
614
701
  self._build_neighbors(max_k=precompute_neighbors)
615
702
 
616
703
  def _build_lattice(self) -> None:
617
- """Generates all site information for the periodic lattice.
618
-
619
- This method iterates through each unit cell defined by `self.size`,
620
- and for each unit cell, it iterates through all basis sites. It then
621
- calculates the real-space coordinates and creates a unique identifier
622
- for each site, populating the internal lattice data structures.
623
704
  """
624
- current_index = 0
705
+ Generates all site information for the periodic lattice in a vectorized manner.
706
+ """
707
+ ranges = [backend.arange(s) for s in self.size]
625
708
 
626
- # Iterate over all unit cell coordinates elegantly using np.ndindex
627
- for cell_coord in np.ndindex(self.size):
628
- cell_coord_arr = np.array(cell_coord)
629
- # R = n1*a1 + n2*a2 + ...
630
- cell_vector = np.dot(cell_coord_arr, self.lattice_vectors)
709
+ # Generate a grid of all integer unit cell coordinates.
710
+ grid = backend.meshgrid(*ranges, indexing="ij")
711
+ all_cell_coords = backend.reshape(
712
+ backend.stack(grid, axis=-1), (-1, self.dimensionality)
713
+ )
631
714
 
632
- # Iterate over the basis sites within the unit cell
633
- for basis_index in range(self.num_basis):
634
- basis_vec = self.basis_coords[basis_index]
715
+ all_cell_coords = backend.cast(all_cell_coords, self.lattice_vectors.dtype)
716
+
717
+ cell_vectors = backend.tensordot(
718
+ all_cell_coords, self.lattice_vectors, axes=[[1], [0]]
719
+ )
720
+
721
+ cell_vectors = backend.cast(cell_vectors, self.basis_coords.dtype)
635
722
 
636
- # Calculate the real-space coordinate
637
- coord = cell_vector + basis_vec
638
- # Create a structured identifier
639
- identifier = cell_coord + (basis_index,)
723
+ # Combine cell vectors with basis coordinates to get all site positions
724
+ # via broadcasting: (num_cells, 1, D) + (1, num_basis, D) -> (num_cells, num_basis, D)
725
+ all_coords = backend.expand_dims(cell_vectors, 1) + backend.expand_dims(
726
+ self.basis_coords, 0
727
+ )
728
+
729
+ self._coordinates = backend.reshape(all_coords, (-1, self.dimensionality))
730
+
731
+ self._indices = []
732
+ self._identifiers = []
733
+ self._ident_to_idx = {}
734
+ current_index = 0
640
735
 
641
- # Store site information
736
+ # Generate integer indices and tuple-based identifiers for all sites.
737
+ # e.g., identifier = (uc_x, uc_y, basis_idx)
738
+ size_ranges = [range(s) for s in self.size]
739
+ for cell_coord_tuple in itertools.product(*size_ranges):
740
+ for basis_index in range(self.num_basis):
741
+ identifier = cell_coord_tuple + (basis_index,)
642
742
  self._indices.append(current_index)
643
743
  self._identifiers.append(identifier)
644
- self._coordinates.append(coord)
645
744
  self._ident_to_idx[identifier] = current_index
646
745
  current_index += 1
647
746
 
648
- def _get_distance_matrix_with_mic(self) -> Coordinates:
747
+ def _get_distance_matrix_with_mic_vectorized(self) -> Coordinates:
649
748
  """
650
- Computes the full N x N distance matrix, correctly applying the
651
- Minimum Image Convention (MIC) for all periodic dimensions.
749
+ Computes the full N x N distance matrix using a fully vectorized approach
750
+ that correctly applies the Minimum Image Convention (MIC) for periodic
751
+ boundary conditions.
752
+
753
+ This method uses full vectorization for optimal performance and compatibility
754
+ with JIT compilation frameworks like JAX. The implementation processes all
755
+ site pairs simultaneously rather than iterating row-by-row, which provides:
756
+
757
+ - Better performance through vectorized operations
758
+ - Full compatibility with automatic differentiation
759
+ - JIT compilation support (e.g., JAX, TensorFlow)
760
+ - Consistent tensor operations throughout
761
+
762
+ The trade-off is higher memory usage compared to iterative approaches,
763
+ as it computes all pairwise distances simultaneously. For very large
764
+ lattices (N > 10^4 sites), memory usage scales as O(N^2).
765
+
766
+ :return: Distance matrix with shape (N, N) where entry (i,j) is the
767
+ minimum distance between sites i and j under periodic boundary conditions.
768
+ :rtype: Coordinates
652
769
  """
653
- all_coords = np.array(self._coordinates)
654
- size_arr = np.array(self.size)
655
- system_vectors = self.lattice_vectors * size_arr[:, np.newaxis]
656
-
657
- # Generate translation vectors ONLY for periodic dimensions
658
- pbc_dims = [d for d in range(self.dimensionality) if self.pbc[d]]
659
- translations = [np.zeros(self.dimensionality)]
660
- if pbc_dims:
661
- num_pbc_dims = len(pbc_dims)
662
- pbc_system_vectors = system_vectors[pbc_dims, :]
663
-
664
- # Create all 3^k - 1 non-zero shifts for k periodic dimensions
665
- shift_options = [np.array([-1, 0, 1])] * num_pbc_dims
666
- shifts_grid = np.meshgrid(*shift_options, indexing="ij")
667
- all_shifts = np.stack(shifts_grid, axis=-1).reshape(-1, num_pbc_dims)
668
- all_shifts = all_shifts[np.any(all_shifts != 0, axis=1)]
669
-
670
- pbc_translations = all_shifts @ pbc_system_vectors
671
- translations.extend(pbc_translations)
672
-
673
- translations_arr = np.array(translations, dtype=float)
674
-
675
- # Calculate the distance matrix applying MIC
676
- dist_matrix_sq = np.full((self.num_sites, self.num_sites), np.inf, dtype=float)
677
- for i in range(self.num_sites):
678
- displacements = all_coords - all_coords[i]
679
- image_displacements = (
680
- displacements[:, np.newaxis, :] - translations_arr[np.newaxis, :, :]
681
- )
682
- image_d_sq = np.sum(image_displacements**2, axis=2)
683
- dist_matrix_sq[i, :] = np.min(image_d_sq, axis=1)
770
+ # Ensure dtype consistency across backends (especially torch) by explicitly
771
+ # casting size and lattice_vectors to the same floating dtype used internally.
772
+ # Strategy: prefer existing lattice_vectors dtype; if it's an unusual dtype,
773
+ # fall back to float32 to avoid mixed-precision issues in vectorized ops.
774
+ # Note: `self.lattice_vectors` is always created via `backend.convert_to_tensor`
775
+ # in __init__, so `backend.dtype(...)` is reliable here and doesn't need try/except.
776
+ target_dt = str(backend.dtype(self.lattice_vectors)) # type: ignore
777
+ if target_dt not in ("float32", "float64"):
778
+ # fallback for unusual dtypes
779
+ target_dt = "float32"
780
+
781
+ size_arr = backend.cast(backend.convert_to_tensor(self.size), target_dt)
782
+ lattice_vecs = backend.cast(
783
+ backend.convert_to_tensor(self.lattice_vectors), target_dt
784
+ )
785
+ system_vectors = lattice_vecs * backend.expand_dims(size_arr, axis=1)
684
786
 
685
- return cast(Coordinates, np.sqrt(dist_matrix_sq))
787
+ pbc_mask = backend.convert_to_tensor(self.pbc)
788
+
789
+ # Generate all 3^d possible image shifts (-1, 0, 1) for all dimensions
790
+ shift_options = [
791
+ backend.convert_to_tensor([-1.0, 0.0, 1.0])
792
+ ] * self.dimensionality
793
+ shifts_grid = backend.meshgrid(*shift_options, indexing="ij")
794
+ all_shifts = backend.reshape(
795
+ backend.stack(shifts_grid, axis=-1), (-1, self.dimensionality)
796
+ )
797
+
798
+ # Only apply shifts to periodic dimensions
799
+ masked_shifts = all_shifts * backend.cast(pbc_mask, all_shifts.dtype)
800
+
801
+ # Calculate all translation vectors due to PBC
802
+ translations_arr = backend.tensordot(
803
+ masked_shifts, system_vectors, axes=[[1], [0]]
804
+ )
805
+
806
+ # Vectorized computation of all displacements between any two sites
807
+ # Shape: (N, 1, D) - (1, N, D) -> (N, N, D)
808
+ displacements = backend.expand_dims(self._coordinates, 1) - backend.expand_dims(
809
+ self._coordinates, 0
810
+ )
811
+
812
+ # Consider all periodic images for each displacement
813
+ # Shape: (N, N, 1, D) - (1, 1, num_translations, D) -> (N, N, num_translations, D)
814
+ image_displacements = backend.expand_dims(
815
+ displacements, 2
816
+ ) - backend.expand_dims(backend.expand_dims(translations_arr, 0), 0)
817
+
818
+ # Sum of squares for distances
819
+ image_d_sq = backend.sum(image_displacements**2, axis=3)
820
+
821
+ # Find the minimum distance among all images (Minimum Image Convention)
822
+ min_dist_sq = backend.min(image_d_sq, axis=2)
823
+
824
+ safe_dist_matrix_sq = backend.where(min_dist_sq > 0, min_dist_sq, 0.0)
825
+ return backend.sqrt(safe_dist_matrix_sq)
686
826
 
687
827
  def _build_neighbors(self, max_k: int = 2, **kwargs: Any) -> None:
688
828
  """Calculates neighbor relationships for the periodic lattice.
689
829
 
690
- This method calculates neighbor relationships by computing the full N x N
691
- distance matrix. It robustly handles all boundary conditions (fully
692
- periodic, open, or mixed) by applying the Minimum Image Convention
693
- (MIC) only to the periodic dimensions.
694
-
695
- From this distance matrix, it identifies unique neighbor shells up to
696
- the specified `max_k` and populates the neighbor maps. The computed
697
- distance matrix is then cached for future use.
830
+ This method computes neighbor information by first calculating the full
831
+ distance matrix using the Minimum Image Convention (MIC) to correctly
832
+ handle periodic boundary conditions. It then identifies unique distance
833
+ shells (e.g., nearest, next-nearest) and populates the neighbor maps
834
+ accordingly. This approach is general and works for any periodic lattice
835
+ geometry defined by the TILattice class.
698
836
 
699
- :param max_k: The maximum number of neighbor shells to
700
- calculate. Defaults to 2.
837
+ :param max_k: The maximum order of neighbors to compute (e.g., k=1 for
838
+ nearest neighbors, k=2 for next-nearest, etc.). Defaults to 2.
701
839
  :type max_k: int, optional
702
- :param tol: The numerical tolerance for distance
703
- comparisons. Defaults to 1e-6.
704
- :type tol: float, optional
840
+ :param kwargs: Additional keyword arguments. May include:
841
+ - ``tol`` (float): The numerical tolerance used to determine if two
842
+ distances are equal when identifying shells. Defaults to 1e-6.
705
843
  """
706
844
  tol = kwargs.get("tol", 1e-6)
707
- dist_matrix = self._get_distance_matrix_with_mic()
845
+ dist_matrix = self._get_distance_matrix_with_mic_vectorized()
708
846
  dist_matrix_sq = dist_matrix**2
709
847
  self._distance_matrix = dist_matrix
710
- all_distances_sq = dist_matrix_sq.flatten()
848
+ all_distances_sq = backend.reshape(dist_matrix_sq, [-1])
711
849
  dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
712
-
713
- self._neighbor_maps = {k: {} for k in range(1, len(dist_shells_sq) + 1)}
714
- for k_idx, target_d_sq in enumerate(dist_shells_sq):
715
- k = k_idx + 1
716
- current_k_map: Dict[int, List[int]] = {}
717
- match_indices = np.where(
718
- np.isclose(dist_matrix_sq, target_d_sq, rtol=0, atol=tol**2)
719
- )
720
- for i, j in zip(*match_indices):
721
- if i == j:
722
- continue
723
- if i not in current_k_map:
724
- current_k_map[i] = []
725
- current_k_map[i].append(j)
726
-
727
- for i in current_k_map:
728
- current_k_map[i].sort()
729
-
730
- self._neighbor_maps[k] = current_k_map
850
+ self._neighbor_maps = self._build_neighbor_map_from_distances(
851
+ dist_matrix_sq, dist_shells_sq, tol
852
+ )
731
853
 
732
854
  def _compute_distance_matrix(self) -> Coordinates:
733
855
  """Computes the distance matrix using the Minimum Image Convention."""
734
- return self._get_distance_matrix_with_mic()
856
+ if self.num_sites == 0:
857
+ return backend.zeros((0, 0))
858
+ return self._get_distance_matrix_with_mic_vectorized()
735
859
 
736
860
 
737
861
  class SquareLattice(TILattice):
@@ -759,20 +883,24 @@ class SquareLattice(TILattice):
759
883
  def __init__(
760
884
  self,
761
885
  size: Tuple[int, int],
762
- lattice_constant: float = 1.0,
886
+ lattice_constant: Union[float, Any] = 1.0,
763
887
  pbc: Union[bool, Tuple[bool, bool]] = True,
764
888
  precompute_neighbors: Optional[int] = None,
765
889
  ):
766
890
  """Initializes the SquareLattice."""
767
891
  dimensionality = 2
892
+ # Define orthogonal lattice vectors for a square.
893
+ # Avoid mixing Python floats with backend Tensors (TF would error),
894
+ # so first convert inputs to tensors of a unified dtype, then stack.
895
+ lc = backend.convert_to_tensor(lattice_constant)
896
+ dt = backend.dtype(lc)
897
+ z = backend.cast(backend.convert_to_tensor(0.0), dt)
898
+ row1 = backend.stack([lc, z])
899
+ row2 = backend.stack([z, lc])
900
+ lattice_vectors = backend.stack([row1, row2])
901
+ # A square lattice is a Bravais lattice, so it has a single-site basis.
902
+ basis_coords = backend.stack([backend.stack([z, z])])
768
903
 
769
- # Define lattice vectors for a square lattice
770
- lattice_vectors = np.array([[lattice_constant, 0.0], [0.0, lattice_constant]])
771
-
772
- # A square lattice has a single site in its basis
773
- basis_coords = np.array([[0.0, 0.0]])
774
-
775
- # Call the parent TILattice constructor with these parameters
776
904
  super().__init__(
777
905
  dimensionality=dimensionality,
778
906
  lattice_vectors=lattice_vectors,
@@ -808,19 +936,28 @@ class HoneycombLattice(TILattice):
808
936
  def __init__(
809
937
  self,
810
938
  size: Tuple[int, int],
811
- lattice_constant: float = 1.0,
939
+ lattice_constant: Union[float, Any] = 1.0,
812
940
  pbc: Union[bool, Tuple[bool, bool]] = True,
813
941
  precompute_neighbors: Optional[int] = None,
814
942
  ):
815
943
  """Initializes the HoneycombLattice."""
816
944
  dimensionality = 2
817
945
  a = lattice_constant
946
+ a_t = backend.convert_to_tensor(a)
947
+ zero = a_t * 0.0
818
948
 
819
- # Define the primitive lattice vectors for the underlying triangular lattice
820
- lattice_vectors = a * np.array([[1.5, np.sqrt(3) / 2], [1.5, -np.sqrt(3) / 2]])
821
-
822
- # Define the coordinates of the two basis sites (A and B)
823
- basis_coords = a * np.array([[0.0, 0.0], [1.0, 0.0]]) # Site A # Site B
949
+ # Define the two primitive lattice vectors for the underlying triangular Bravais lattice.
950
+ rt3_over_2 = math.sqrt(3.0) / 2.0
951
+ lattice_vectors = backend.stack(
952
+ [
953
+ backend.stack([a_t * 1.5, a_t * rt3_over_2]),
954
+ backend.stack([a_t * 1.5, -a_t * rt3_over_2]),
955
+ ]
956
+ )
957
+ # Define the two basis sites (A and B) within the unit cell.
958
+ basis_coords = backend.stack(
959
+ [backend.stack([zero, zero]), backend.stack([a_t * 1.0, zero])]
960
+ )
824
961
 
825
962
  super().__init__(
826
963
  dimensionality=dimensionality,
@@ -855,19 +992,30 @@ class TriangularLattice(TILattice):
855
992
  def __init__(
856
993
  self,
857
994
  size: Tuple[int, int],
858
- lattice_constant: float = 1.0,
995
+ lattice_constant: Union[float, Any] = 1.0,
859
996
  pbc: Union[bool, Tuple[bool, bool]] = True,
860
997
  precompute_neighbors: Optional[int] = None,
861
998
  ):
862
999
  """Initializes the TriangularLattice."""
863
1000
  dimensionality = 2
864
1001
  a = lattice_constant
1002
+ a_t = backend.convert_to_tensor(a)
1003
+ zero = a_t * 0.0
865
1004
 
866
- # Define the primitive lattice vectors for a triangular lattice
867
- lattice_vectors = a * np.array([[1.0, 0.0], [0.5, np.sqrt(3) / 2]])
868
-
869
- # A triangular lattice is a Bravais lattice, with a single site in its basis
870
- basis_coords = np.array([[0.0, 0.0]])
1005
+ # Define the primitive lattice vectors for a triangular lattice.
1006
+ lattice_vectors = backend.stack(
1007
+ [
1008
+ backend.stack([a_t * 1.0, zero]),
1009
+ backend.stack(
1010
+ [
1011
+ a_t * 0.5,
1012
+ a_t * backend.sqrt(backend.convert_to_tensor(3.0)) / 2.0,
1013
+ ]
1014
+ ),
1015
+ ]
1016
+ )
1017
+ # A triangular lattice is a Bravais lattice with a single-site basis.
1018
+ basis_coords = backend.stack([backend.stack([zero, zero])])
871
1019
 
872
1020
  super().__init__(
873
1021
  dimensionality=dimensionality,
@@ -896,13 +1044,18 @@ class ChainLattice(TILattice):
896
1044
  def __init__(
897
1045
  self,
898
1046
  size: Tuple[int],
899
- lattice_constant: float = 1.0,
1047
+ lattice_constant: Union[float, Any] = 1.0,
900
1048
  pbc: bool = True,
901
1049
  precompute_neighbors: Optional[int] = None,
902
1050
  ):
903
1051
  dimensionality = 1
904
- lattice_vectors = np.array([[lattice_constant]])
905
- basis_coords = np.array([[0.0]])
1052
+ # The lattice vector is just the lattice constant along one dimension.
1053
+ lc = backend.convert_to_tensor(lattice_constant)
1054
+ lattice_vectors = backend.stack([backend.stack([lc])])
1055
+ # A simple chain is a Bravais lattice with a single-site basis.
1056
+ zero = lc * 0.0
1057
+ basis_coords = backend.stack([backend.stack([zero])])
1058
+
906
1059
  super().__init__(
907
1060
  dimensionality=dimensionality,
908
1061
  lattice_vectors=lattice_vectors,
@@ -934,15 +1087,17 @@ class DimerizedChainLattice(TILattice):
934
1087
  def __init__(
935
1088
  self,
936
1089
  size: Tuple[int],
937
- lattice_constant: float = 1.0,
1090
+ lattice_constant: Union[float, Any] = 1.0,
938
1091
  pbc: bool = True,
939
1092
  precompute_neighbors: Optional[int] = None,
940
1093
  ):
941
1094
  dimensionality = 1
942
- # The unit cell vector connects two A sites, spanning length 2*a
943
- lattice_vectors = np.array([[2 * lattice_constant]])
944
- # Basis has site A at origin, site B at distance 'a'
945
- basis_coords = np.array([[0.0], [lattice_constant]])
1095
+ # The unit cell is twice the bond length, as it contains two sites.
1096
+ lc = backend.convert_to_tensor(lattice_constant)
1097
+ lattice_vectors = backend.stack([backend.stack([2 * lc])])
1098
+ # Two basis sites (A and B) separated by the bond length.
1099
+ zero = lc * 0.0
1100
+ basis_coords = backend.stack([backend.stack([zero]), backend.stack([lc])])
946
1101
 
947
1102
  super().__init__(
948
1103
  dimensionality=dimensionality,
@@ -975,14 +1130,22 @@ class RectangularLattice(TILattice):
975
1130
  def __init__(
976
1131
  self,
977
1132
  size: Tuple[int, int],
978
- lattice_constants: Tuple[float, float] = (1.0, 1.0),
1133
+ lattice_constants: Union[Tuple[float, float], Any] = (1.0, 1.0),
979
1134
  pbc: Union[bool, Tuple[bool, bool]] = True,
980
1135
  precompute_neighbors: Optional[int] = None,
981
1136
  ):
982
1137
  dimensionality = 2
983
1138
  ax, ay = lattice_constants
984
- lattice_vectors = np.array([[ax, 0.0], [0.0, ay]])
985
- basis_coords = np.array([[0.0, 0.0]])
1139
+ ax_t = backend.convert_to_tensor(ax)
1140
+ dt = backend.dtype(ax_t)
1141
+ ay_t = backend.cast(backend.convert_to_tensor(ay), dt)
1142
+ z = backend.cast(backend.convert_to_tensor(0.0), dt)
1143
+ # Orthogonal lattice vectors with potentially different lengths.
1144
+ row1 = backend.stack([ax_t, z])
1145
+ row2 = backend.stack([z, ay_t])
1146
+ lattice_vectors = backend.stack([row1, row2])
1147
+ # A rectangular lattice is a Bravais lattice with a single-site basis.
1148
+ basis_coords = backend.stack([backend.stack([z, z])])
986
1149
 
987
1150
  super().__init__(
988
1151
  dimensionality=dimensionality,
@@ -1013,16 +1176,26 @@ class CheckerboardLattice(TILattice):
1013
1176
  def __init__(
1014
1177
  self,
1015
1178
  size: Tuple[int, int],
1016
- lattice_constant: float = 1.0,
1179
+ lattice_constant: Union[float, Any] = 1.0,
1017
1180
  pbc: Union[bool, Tuple[bool, bool]] = True,
1018
1181
  precompute_neighbors: Optional[int] = None,
1019
1182
  ):
1020
1183
  dimensionality = 2
1021
1184
  a = lattice_constant
1022
- # Primitive vectors for a square lattice rotated by 45 degrees.
1023
- lattice_vectors = a * np.array([[1.0, 1.0], [1.0, -1.0]])
1024
- # Two-site basis
1025
- basis_coords = a * np.array([[0.0, 0.0], [1.0, 0.0]])
1185
+ a_t = backend.convert_to_tensor(a)
1186
+ # The unit cell is a square rotated by 45 degrees.
1187
+ lattice_vectors = backend.stack(
1188
+ [
1189
+ backend.stack([a_t * 1.0, a_t * 1.0]),
1190
+ backend.stack([a_t * 1.0, a_t * -1.0]),
1191
+ ]
1192
+ )
1193
+ # Two basis sites (A and B) within the unit cell.
1194
+ zero = a_t * 0.0
1195
+ basis_coords = backend.stack(
1196
+ [backend.stack([zero, zero]), backend.stack([a_t * 1.0, zero])]
1197
+ )
1198
+
1026
1199
  super().__init__(
1027
1200
  dimensionality=dimensionality,
1028
1201
  lattice_vectors=lattice_vectors,
@@ -1052,16 +1225,30 @@ class KagomeLattice(TILattice):
1052
1225
  def __init__(
1053
1226
  self,
1054
1227
  size: Tuple[int, int],
1055
- lattice_constant: float = 1.0,
1228
+ lattice_constant: Union[float, Any] = 1.0,
1056
1229
  pbc: Union[bool, Tuple[bool, bool]] = True,
1057
1230
  precompute_neighbors: Optional[int] = None,
1058
1231
  ):
1059
1232
  dimensionality = 2
1060
1233
  a = lattice_constant
1061
- # Using a rectangular unit cell definition for simplicity
1062
- lattice_vectors = a * np.array([[2.0, 0.0], [1.0, np.sqrt(3)]])
1063
- # Three-site basis
1064
- basis_coords = a * np.array([[0.0, 0.0], [1.0, 0.0], [0.5, np.sqrt(3) / 2.0]])
1234
+ a_t = backend.convert_to_tensor(a)
1235
+ # The Kagome lattice is based on a triangular Bravais lattice.
1236
+ lattice_vectors = backend.stack(
1237
+ [
1238
+ backend.stack([a_t * 2.0, a_t * 0.0]),
1239
+ backend.stack([a_t * 1.0, a_t * backend.sqrt(3.0)]),
1240
+ ]
1241
+ )
1242
+ # It has a three-site basis, forming the corners of the triangles.
1243
+ zero = a_t * 0.0
1244
+ basis_coords = backend.stack(
1245
+ [
1246
+ backend.stack([zero, zero]),
1247
+ backend.stack([a_t * 1.0, zero]),
1248
+ backend.stack([a_t * 0.5, a_t * backend.sqrt(3.0) / 2.0]),
1249
+ ]
1250
+ )
1251
+
1065
1252
  super().__init__(
1066
1253
  dimensionality=dimensionality,
1067
1254
  lattice_vectors=lattice_vectors,
@@ -1092,29 +1279,26 @@ class LiebLattice(TILattice):
1092
1279
  def __init__(
1093
1280
  self,
1094
1281
  size: Tuple[int, int],
1095
- lattice_constant: float = 1.0,
1282
+ lattice_constant: Union[float, Any] = 1.0,
1096
1283
  pbc: Union[bool, Tuple[bool, bool]] = True,
1097
1284
  precompute_neighbors: Optional[int] = None,
1098
1285
  ):
1099
1286
  """Initializes the LiebLattice."""
1100
1287
  dimensionality = 2
1101
- # Use a more descriptive name for clarity. In a Lieb lattice,
1102
- # the lattice_constant is the bond length between nearest neighbors.
1103
1288
  bond_length = lattice_constant
1104
-
1105
- # The unit cell of a Lieb lattice is a square with side length
1106
- # equal to twice the bond length.
1107
- unit_cell_side = 2 * bond_length
1108
- lattice_vectors = np.array([[unit_cell_side, 0.0], [0.0, unit_cell_side]])
1109
-
1110
- # The three-site basis consists of a corner site, a site on the
1111
- # center of the horizontal edge, and a site on the center of the vertical edge.
1112
- # Their coordinates are defined directly in terms of the physical bond length.
1113
- basis_coords = np.array(
1289
+ bl_t = backend.convert_to_tensor(bond_length)
1290
+ unit_cell_side_t = 2 * bl_t
1291
+ # The Lieb lattice is based on a square Bravais lattice.
1292
+ z = bl_t * 0.0
1293
+ lattice_vectors = backend.stack(
1294
+ [backend.stack([unit_cell_side_t, z]), backend.stack([z, unit_cell_side_t])]
1295
+ )
1296
+ # It has a three-site basis: one corner and two edge-centers.
1297
+ basis_coords = backend.stack(
1114
1298
  [
1115
- [0.0, 0.0], # Corner site
1116
- [bond_length, 0.0], # Horizontal edge center
1117
- [0.0, bond_length], # Vertical edge center
1299
+ backend.stack([z, z]), # Corner site
1300
+ backend.stack([bl_t, z]), # x-edge center
1301
+ backend.stack([z, bl_t]), # y-edge center
1118
1302
  ]
1119
1303
  )
1120
1304
 
@@ -1147,14 +1331,24 @@ class CubicLattice(TILattice):
1147
1331
  def __init__(
1148
1332
  self,
1149
1333
  size: Tuple[int, int, int],
1150
- lattice_constant: float = 1.0,
1334
+ lattice_constant: Union[float, Any] = 1.0,
1151
1335
  pbc: Union[bool, Tuple[bool, bool, bool]] = True,
1152
1336
  precompute_neighbors: Optional[int] = None,
1153
1337
  ):
1154
1338
  dimensionality = 3
1155
1339
  a = lattice_constant
1156
- lattice_vectors = np.array([[a, 0, 0], [0, a, 0], [0, 0, a]])
1157
- basis_coords = np.array([[0.0, 0.0, 0.0]])
1340
+ a_t = backend.convert_to_tensor(a)
1341
+ # Orthogonal lattice vectors of equal length in 3D.
1342
+ z = a_t * 0.0
1343
+ lattice_vectors = backend.stack(
1344
+ [
1345
+ backend.stack([a_t, z, z]),
1346
+ backend.stack([z, a_t, z]),
1347
+ backend.stack([z, z, a_t]),
1348
+ ]
1349
+ )
1350
+ # A simple cubic lattice is a Bravais lattice with a single-site basis.
1351
+ basis_coords = backend.stack([backend.stack([z, z, z])])
1158
1352
  super().__init__(
1159
1353
  dimensionality=dimensionality,
1160
1354
  lattice_vectors=lattice_vectors,
@@ -1194,29 +1388,37 @@ class CustomizeLattice(AbstractLattice):
1194
1388
  self,
1195
1389
  dimensionality: int,
1196
1390
  identifiers: List[SiteIdentifier],
1197
- coordinates: List[Union[List[float], Coordinates]],
1391
+ coordinates: Any,
1198
1392
  precompute_neighbors: Optional[int] = None,
1199
1393
  ):
1200
1394
  """Initializes the CustomizeLattice."""
1201
1395
  super().__init__(dimensionality)
1202
- if len(identifiers) != len(coordinates):
1396
+
1397
+ self._coordinates = backend.convert_to_tensor(coordinates)
1398
+ if len(identifiers) == 0:
1399
+ self._coordinates = backend.reshape(
1400
+ self._coordinates, (0, self.dimensionality)
1401
+ )
1402
+
1403
+ if len(identifiers) != backend.shape_tuple(self._coordinates)[0]:
1203
1404
  raise ValueError(
1204
- "Identifiers and coordinates lists must have the same length."
1405
+ "The number of identifiers must match the number of coordinates. "
1406
+ f"Got {len(identifiers)} identifiers and "
1407
+ f"{backend.shape_tuple(self._coordinates)[0]} coordinates."
1205
1408
  )
1206
1409
 
1207
- # The _build_lattice logic is simple enough to be in __init__
1208
1410
  self._identifiers = list(identifiers)
1209
- self._coordinates = [np.array(c) for c in coordinates]
1210
1411
  self._indices = list(range(len(identifiers)))
1211
1412
  self._ident_to_idx = {ident: idx for idx, ident in enumerate(identifiers)}
1212
1413
 
1213
- # Validate coordinate dimensions
1214
- for i, coord in enumerate(self._coordinates):
1215
- if coord.shape != (dimensionality,):
1216
- raise ValueError(
1217
- f"Coordinate at index {i} has shape {coord.shape}, "
1218
- f"expected ({dimensionality},)"
1219
- )
1414
+ if (
1415
+ self.num_sites > 0
1416
+ and backend.shape_tuple(self._coordinates)[1] != dimensionality
1417
+ ):
1418
+ raise ValueError(
1419
+ f"Coordinates tensor has dimension {backend.shape_tuple(self._coordinates)[1]}, "
1420
+ f"but expected dimensionality is {dimensionality}."
1421
+ )
1220
1422
 
1221
1423
  logger.info(f"CustomizeLattice with {self.num_sites} sites created.")
1222
1424
 
@@ -1228,95 +1430,170 @@ class CustomizeLattice(AbstractLattice):
1228
1430
  pass
1229
1431
 
1230
1432
  def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None:
1231
- """Calculates neighbors using a KDTree for efficiency.
1232
-
1233
- This method uses a memory-efficient approach to identify neighbors without
1234
- initially computing the full N x N distance matrix. It leverages
1235
- `scipy.spatial.distance.pdist` to find unique distance shells and then
1236
- a `scipy.spatial.KDTree` for fast radius queries. This approach is
1237
- significantly more memory-efficient during the neighbor identification phase.
1433
+ """
1434
+ Calculates neighbor relationships using either KDTree or distance matrix methods.
1238
1435
 
1239
- After the neighbors are identified, the full distance matrix is computed
1240
- from the pairwise distances and cached for potential future use.
1436
+ This method supports two modes:
1437
+ 1. KDTree mode (use_kdtree=True): Fast, O(N log N) performance for large lattices
1438
+ but breaks differentiability due to scipy dependency
1439
+ 2. Distance matrix mode (use_kdtree=False): Slower O(N²) but fully differentiable
1440
+ and backend-agnostic
1241
1441
 
1242
- :param max_k: The maximum number of neighbor shells to
1243
- calculate. Defaults to 1.
1244
- :type max_k: int, optional
1245
- :param tol: The numerical tolerance for distance
1246
- comparisons. Defaults to 1e-6.
1247
- :type tol: float, optional
1442
+ :param max_k: Maximum number of neighbor shells to compute
1443
+ :type max_k: int
1444
+ :param kwargs: Additional arguments including:
1445
+ - use_kdtree (bool): Whether to use KDTree optimization. Defaults to False.
1446
+ - tol (float): Distance tolerance for neighbor identification. Defaults to 1e-6.
1248
1447
  """
1249
1448
  tol = kwargs.get("tol", 1e-6)
1250
- logger.info(f"Building neighbors for CustomizeLattice up to k={max_k}...")
1449
+ # Reviewer suggestion: prefer differentiable method by default
1450
+ use_kdtree = kwargs.get("use_kdtree", False)
1451
+
1251
1452
  if self.num_sites < 2:
1252
1453
  return
1253
1454
 
1254
- all_coords = np.array(self._coordinates)
1455
+ # Choose algorithm based on user preference
1456
+ if use_kdtree:
1457
+ logger.info(
1458
+ f"Using KDTree method for {self.num_sites} sites up to k={max_k}"
1459
+ )
1460
+ self._build_neighbors_kdtree(max_k, tol)
1461
+ else:
1462
+ logger.info(
1463
+ f"Using differentiable distance matrix method for {self.num_sites} sites up to k={max_k}"
1464
+ )
1255
1465
 
1256
- # 1. Use pdist for memory-efficient calculation of pairwise distances
1257
- # to robustly identify the distance shells.
1258
- all_distances_sq = pdist(all_coords, metric="sqeuclidean")
1259
- dist_shells_sq = self._identify_distance_shells(all_distances_sq, max_k, tol)
1466
+ # Use the existing distance matrix method
1467
+ self._build_neighbors_by_distance_matrix(max_k, tol)
1260
1468
 
1261
- if not dist_shells_sq:
1262
- logger.info("No distinct neighbor shells found.")
1263
- return
1469
+ def _build_neighbors_kdtree(self, max_k: int, tol: float) -> None:
1470
+ """
1471
+ Build neighbors using KDTree for optimal performance.
1264
1472
 
1265
- # 2. Build the KDTree for efficient querying.
1266
- tree = KDTree(all_coords)
1267
- self._neighbor_maps = {k: {} for k in range(1, len(dist_shells_sq) + 1)}
1473
+ This method provides O(N log N) performance for neighbor finding but breaks
1474
+ differentiability due to scipy dependency. Use this method when:
1475
+ - Performance is critical
1476
+ - Differentiability is not required
1477
+ - Large lattices (N > 1000)
1268
1478
 
1269
- # 3. Find neighbors by isolating shells using inclusion-exclusion.
1270
- # `found_indices` will store all neighbors within a given radius.
1271
- found_indices: List[set[int]] = []
1272
- for k_idx, target_d_sq in enumerate(dist_shells_sq):
1273
- radius = np.sqrt(target_d_sq) + tol
1274
- # Query for all points within the new, larger radius.
1275
- current_shell_indices = tree.query_ball_point(
1276
- all_coords, r=radius, return_sorted=True
1479
+ Note: This method uses numpy arrays directly and may not be compatible
1480
+ with all backend types (JAX, TensorFlow, etc.).
1481
+ """
1482
+
1483
+ # For small lattices or cases with potential duplicate coordinates,
1484
+ # fall back to distance matrix method for robustness
1485
+ if self.num_sites < 200:
1486
+ logger.info(
1487
+ "Small lattice detected, falling back to distance matrix method for robustness"
1277
1488
  )
1489
+ self._build_neighbors_by_distance_matrix(max_k, tol)
1490
+ return
1278
1491
 
1279
- # Now, isolate the neighbors for the current shell k
1280
- k = k_idx + 1
1281
- current_k_map: Dict[int, List[int]] = {}
1282
- for i in range(self.num_sites):
1492
+ # Convert coordinates to numpy for KDTree
1493
+ coords_np = backend.numpy(self._coordinates)
1283
1494
 
1284
- if k_idx == 0:
1285
- co_located_indices = tree.query_ball_point(all_coords[i], r=1e-12)
1286
- prev_found = set(co_located_indices)
1287
- else:
1288
- prev_found = found_indices[i]
1495
+ # Build KDTree
1496
+ logger.info("Building KDTree...")
1497
+ tree = KDTree(coords_np)
1498
+ # Find all distances for shell identification - use comprehensive sampling
1499
+ logger.info("Identifying distance shells...")
1500
+ distances_for_shells: List[float] = []
1289
1501
 
1290
- # The new neighbors are those in the current radius shell,
1291
- # excluding those already found in smaller shells.
1292
- new_neighbors = set(current_shell_indices[i]) - prev_found
1502
+ # For robust shell identification, query all pairwise distances for smaller lattices
1503
+ # or use dense sampling for larger ones
1504
+ if self.num_sites <= 100:
1505
+ # For small lattices, compute all pairwise distances for accuracy
1506
+ for i in range(self.num_sites):
1507
+ query_k = min(self.num_sites - 1, max_k * 20)
1508
+ if query_k > 0:
1509
+ dists, _ = tree.query(
1510
+ coords_np[i], k=query_k + 1
1511
+ ) # +1 to exclude self
1512
+ if isinstance(dists, np.ndarray):
1513
+ distances_for_shells.extend(dists[1:]) # Skip distance to self
1514
+ else:
1515
+ distances_for_shells.append(dists) # Single distance
1516
+ else:
1517
+ # For larger lattices, use adaptive sampling but ensure we capture all shells
1518
+ sample_size = min(1000, self.num_sites // 2) # More conservative sampling
1519
+ for i in range(0, self.num_sites, max(1, self.num_sites // sample_size)):
1520
+ query_k = min(max_k * 20 + 50, self.num_sites - 1)
1521
+ if query_k > 0:
1522
+ dists, _ = tree.query(
1523
+ coords_np[i], k=query_k + 1
1524
+ ) # +1 to exclude self
1525
+ if isinstance(dists, np.ndarray):
1526
+ distances_for_shells.extend(dists[1:]) # Skip distance to self
1527
+ else:
1528
+ distances_for_shells.append(dists) # Single distance
1293
1529
 
1294
- if new_neighbors:
1295
- current_k_map[i] = sorted(list(new_neighbors))
1530
+ # Filter out zero distances (duplicate coordinates) before shell identification
1531
+ ZERO_THRESHOLD = 1e-12
1532
+ distances_for_shells = [d for d in distances_for_shells if d > ZERO_THRESHOLD]
1296
1533
 
1297
- self._neighbor_maps[k] = current_k_map
1298
- found_indices = [
1299
- set(l) for l in current_shell_indices
1300
- ] # Update for next iteration
1301
- self._distance_matrix = np.sqrt(squareform(all_distances_sq))
1534
+ if not distances_for_shells:
1535
+ logger.warning("No valid distances found for shell identification")
1536
+ self._neighbor_maps = {}
1537
+ return
1302
1538
 
1303
- logger.info("Neighbor building complete using KDTree.")
1539
+ # Use the same shell identification logic as distance matrix method
1540
+ distances_for_shells_sq = [d * d for d in distances_for_shells]
1541
+ dist_shells_sq = self._identify_distance_shells(
1542
+ distances_for_shells_sq, max_k, tol
1543
+ )
1544
+ dist_shells = [np.sqrt(d_sq) for d_sq in dist_shells_sq]
1304
1545
 
1305
- def _compute_distance_matrix(self) -> Coordinates:
1306
- """Computes the distance matrix from the stored coordinates.
1546
+ logger.info(f"Found {len(dist_shells)} distance shells: {dist_shells[:5]}...")
1307
1547
 
1308
- This implementation uses scipy.pdist for a memory-efficient
1309
- calculation of pairwise distances, which is then converted to a
1310
- full square matrix.
1311
- """
1312
- if self.num_sites < 2:
1313
- return cast(Coordinates, np.empty((self.num_sites, self.num_sites)))
1548
+ # Initialize neighbor maps
1549
+ self._neighbor_maps = {k: {} for k in range(1, len(dist_shells) + 1)}
1314
1550
 
1315
- all_coords = np.array(self._coordinates)
1316
- # Use pdist for memory-efficiency, then build the full matrix.
1317
- all_distances_sq = pdist(all_coords, metric="sqeuclidean")
1318
- dist_matrix_sq = squareform(all_distances_sq)
1319
- return cast(Coordinates, np.sqrt(dist_matrix_sq))
1551
+ # Build neighbor lists for each site
1552
+ for i in range(self.num_sites):
1553
+ # Query enough neighbors to capture all shells
1554
+ query_k = min(max_k * 20 + 50, self.num_sites - 1)
1555
+ if query_k > 0:
1556
+ distances, indices = tree.query(
1557
+ coords_np[i], k=query_k + 1
1558
+ ) # +1 for self
1559
+
1560
+ # Skip the first entry (distance to self)
1561
+ # Handle both single value and array cases
1562
+ if isinstance(distances, np.ndarray) and len(distances) > 1:
1563
+ distances_slice = distances[1:]
1564
+ indices_slice = (
1565
+ indices[1:]
1566
+ if isinstance(indices, np.ndarray)
1567
+ else np.array([], dtype=int)
1568
+ )
1569
+ else:
1570
+ # Single value or empty case - no neighbors to process
1571
+ distances_slice = np.array([])
1572
+ indices_slice = np.array([], dtype=int)
1573
+
1574
+ # Filter out zero distances (duplicate coordinates)
1575
+ valid_pairs = [
1576
+ (d, idx)
1577
+ for d, idx in zip(distances_slice, indices_slice)
1578
+ if d > ZERO_THRESHOLD
1579
+ ]
1580
+
1581
+ # Assign neighbors to shells
1582
+ for shell_idx, shell_dist in enumerate(dist_shells):
1583
+ k = shell_idx + 1
1584
+ shell_neighbors = []
1585
+
1586
+ for dist, neighbor_idx in valid_pairs:
1587
+ if abs(dist - shell_dist) <= tol:
1588
+ shell_neighbors.append(int(neighbor_idx))
1589
+ elif dist > shell_dist + tol:
1590
+ break # Distances are sorted, no more matches
1591
+
1592
+ if shell_neighbors:
1593
+ self._neighbor_maps[k][i] = sorted(shell_neighbors)
1594
+
1595
+ # Set distance matrix to None - will compute on demand
1596
+ self._distance_matrix = None
1320
1597
 
1321
1598
  def _reset_computations(self) -> None:
1322
1599
  """Resets all cached data that depends on the lattice structure."""
@@ -1344,18 +1621,28 @@ class CustomizeLattice(AbstractLattice):
1344
1621
  )
1345
1622
 
1346
1623
  # Unzip the list of tuples into separate lists of identifiers and coordinates
1347
- _, identifiers, coordinates = zip(*all_sites_info)
1624
+ _, identifiers, _ = zip(*all_sites_info)
1625
+
1626
+ # Detach-and-copy coordinates while remaining in tensor form to avoid
1627
+ # host roundtrips and device/dtype changes; this keeps CustomizeLattice
1628
+ # decoupled from the original graph but backend-friendly.
1629
+ # Some backends (e.g., NumPy) don't implement stop_gradient; fall back.
1630
+ try:
1631
+ coords_detached = backend.stop_gradient(lattice._coordinates)
1632
+ except NotImplementedError:
1633
+ coords_detached = lattice._coordinates
1634
+ coords_tensor = backend.copy(coords_detached)
1348
1635
 
1349
1636
  return cls(
1350
1637
  dimensionality=lattice.dimensionality,
1351
1638
  identifiers=list(identifiers),
1352
- coordinates=list(coordinates),
1639
+ coordinates=coords_tensor,
1353
1640
  )
1354
1641
 
1355
1642
  def add_sites(
1356
1643
  self,
1357
1644
  identifiers: List[SiteIdentifier],
1358
- coordinates: List[Union[List[float], Coordinates]],
1645
+ coordinates: Any,
1359
1646
  ) -> None:
1360
1647
  """Adds new sites to the lattice.
1361
1648
 
@@ -1363,21 +1650,29 @@ class CustomizeLattice(AbstractLattice):
1363
1650
  previously computed neighbor information is cleared and must be
1364
1651
  recalculated.
1365
1652
 
1366
- :param identifiers: A list of unique, hashable identifiers for the new sites.
1653
+ :param identifiers: A list of unique identifiers for the new sites.
1367
1654
  :type identifiers: List[SiteIdentifier]
1368
- :param coordinates: A list of coordinates for the new sites.
1369
- :type coordinates: List[Union[List[float], np.ndarray]]
1370
- :raises ValueError: If input lists have mismatched lengths, or if any new
1371
- identifier already exists in the lattice.
1655
+ :param coordinates: The coordinates for the new sites. Can be a list of lists,
1656
+ a NumPy array, or a backend-compatible tensor (e.g., jax.numpy.ndarray).
1657
+ :type coordinates: Any
1372
1658
  """
1373
- if len(identifiers) != len(coordinates):
1659
+ if not identifiers:
1660
+ return
1661
+
1662
+ new_coords_tensor = backend.convert_to_tensor(coordinates)
1663
+
1664
+ if len(identifiers) != backend.shape_tuple(new_coords_tensor)[0]:
1374
1665
  raise ValueError(
1375
1666
  "Identifiers and coordinates lists must have the same length."
1376
1667
  )
1377
- if not identifiers:
1378
- return # Nothing to add
1379
1668
 
1380
- # Check for duplicate identifiers before making any changes
1669
+ if backend.shape_tuple(new_coords_tensor)[1] != self.dimensionality:
1670
+ raise ValueError(
1671
+ f"New coordinate tensor has dimension {backend.shape_tuple(new_coords_tensor)[1]}, "
1672
+ f"but expected dimensionality is {self.dimensionality}."
1673
+ )
1674
+
1675
+ # Ensure that the new identifiers are unique and do not already exist.
1381
1676
  existing_ids = set(self._identifiers)
1382
1677
  new_ids = set(identifiers)
1383
1678
  if not new_ids.isdisjoint(existing_ids):
@@ -1385,21 +1680,14 @@ class CustomizeLattice(AbstractLattice):
1385
1680
  f"Duplicate identifiers found: {new_ids.intersection(existing_ids)}"
1386
1681
  )
1387
1682
 
1388
- for i, coord in enumerate(coordinates):
1389
- coord_arr = np.asarray(coord)
1390
- if coord_arr.shape != (self.dimensionality,):
1391
- raise ValueError(
1392
- f"New coordinate at index {i} has shape {coord_arr.shape}, "
1393
- f"expected ({self.dimensionality},)"
1394
- )
1395
- self._coordinates.append(coord_arr)
1396
- self._identifiers.append(identifiers[i])
1683
+ self._coordinates = backend.concat(
1684
+ [self._coordinates, new_coords_tensor], axis=0
1685
+ )
1686
+ self._identifiers.extend(identifiers)
1397
1687
 
1398
- # Rebuild index mappings from scratch
1399
1688
  self._indices = list(range(len(self._identifiers)))
1400
1689
  self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)}
1401
1690
 
1402
- # Invalidate any previously computed neighbors or distance matrices
1403
1691
  self._reset_computations()
1404
1692
  logger.info(
1405
1693
  f"{len(identifiers)} sites added. Lattice now has {self.num_sites} sites."
@@ -1414,10 +1702,9 @@ class CustomizeLattice(AbstractLattice):
1414
1702
 
1415
1703
  :param identifiers: A list of identifiers for the sites to be removed.
1416
1704
  :type identifiers: List[SiteIdentifier]
1417
- :raises ValueError: If any of the specified identifiers do not exist.
1418
1705
  """
1419
1706
  if not identifiers:
1420
- return # Nothing to remove
1707
+ return
1421
1708
 
1422
1709
  ids_to_remove = set(identifiers)
1423
1710
  current_ids = set(self._identifiers)
@@ -1426,23 +1713,25 @@ class CustomizeLattice(AbstractLattice):
1426
1713
  f"Non-existent identifiers provided for removal: {ids_to_remove - current_ids}"
1427
1714
  )
1428
1715
 
1429
- # Create new lists containing only the sites to keep
1430
- new_identifiers: List[SiteIdentifier] = []
1431
- new_coordinates: List[Coordinates] = []
1432
- for ident, coord in zip(self._identifiers, self._coordinates):
1433
- if ident not in ids_to_remove:
1434
- new_identifiers.append(ident)
1435
- new_coordinates.append(coord)
1716
+ # Find the indices of the sites that we want to keep.
1717
+ indices_to_keep = [
1718
+ idx
1719
+ for idx, ident in enumerate(self._identifiers)
1720
+ if ident not in ids_to_remove
1721
+ ]
1722
+
1723
+ new_identifiers = [self._identifiers[i] for i in indices_to_keep]
1724
+
1725
+ self._coordinates = backend.gather1d(
1726
+ self._coordinates,
1727
+ backend.cast(backend.convert_to_tensor(indices_to_keep), "int32"),
1728
+ )
1436
1729
 
1437
- # Replace old data with the new, filtered data
1438
1730
  self._identifiers = new_identifiers
1439
- self._coordinates = new_coordinates
1440
1731
 
1441
- # Rebuild index mappings
1442
1732
  self._indices = list(range(len(self._identifiers)))
1443
1733
  self._ident_to_idx = {ident: idx for idx, ident in enumerate(self._identifiers)}
1444
1734
 
1445
- # Invalidate caches
1446
1735
  self._reset_computations()
1447
1736
  logger.info(
1448
1737
  f"{len(ids_to_remove)} sites removed. Lattice now has {self.num_sites} sites."