tensorcircuit-nightly 1.3.0.dev20250728__py3-none-any.whl → 1.4.0.dev20251103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

Files changed (72) hide show
  1. tensorcircuit/__init__.py +5 -1
  2. tensorcircuit/abstractcircuit.py +4 -0
  3. tensorcircuit/analogcircuit.py +413 -0
  4. tensorcircuit/applications/layers.py +1 -1
  5. tensorcircuit/applications/van.py +1 -1
  6. tensorcircuit/backends/abstract_backend.py +312 -5
  7. tensorcircuit/backends/cupy_backend.py +3 -1
  8. tensorcircuit/backends/jax_backend.py +92 -3
  9. tensorcircuit/backends/jax_ops.py +108 -0
  10. tensorcircuit/backends/numpy_backend.py +49 -3
  11. tensorcircuit/backends/pytorch_backend.py +92 -3
  12. tensorcircuit/backends/tensorflow_backend.py +102 -3
  13. tensorcircuit/basecircuit.py +123 -82
  14. tensorcircuit/circuit.py +67 -57
  15. tensorcircuit/cloud/local.py +1 -1
  16. tensorcircuit/cloud/quafu_provider.py +1 -1
  17. tensorcircuit/cloud/tencent.py +1 -1
  18. tensorcircuit/compiler/simple_compiler.py +2 -2
  19. tensorcircuit/cons.py +1 -0
  20. tensorcircuit/densitymatrix.py +16 -11
  21. tensorcircuit/experimental.py +7 -152
  22. tensorcircuit/fgs.py +5 -6
  23. tensorcircuit/gates.py +66 -22
  24. tensorcircuit/keras.py +3 -3
  25. tensorcircuit/mpscircuit.py +109 -61
  26. tensorcircuit/quantum.py +697 -133
  27. tensorcircuit/quditcircuit.py +733 -0
  28. tensorcircuit/quditgates.py +618 -0
  29. tensorcircuit/results/counts.py +45 -31
  30. tensorcircuit/shadows.py +1 -1
  31. tensorcircuit/simplify.py +3 -1
  32. tensorcircuit/stabilizercircuit.py +4 -2
  33. tensorcircuit/templates/blocks.py +2 -2
  34. tensorcircuit/templates/hamiltonians.py +29 -8
  35. tensorcircuit/templates/lattice.py +676 -335
  36. tensorcircuit/timeevol.py +896 -0
  37. {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/METADATA +50 -25
  38. tensorcircuit_nightly-1.4.0.dev20251103.dist-info/RECORD +96 -0
  39. {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/top_level.txt +0 -1
  40. tensorcircuit_nightly-1.3.0.dev20250728.dist-info/RECORD +0 -122
  41. tests/__init__.py +0 -0
  42. tests/conftest.py +0 -67
  43. tests/test_backends.py +0 -1035
  44. tests/test_calibrating.py +0 -149
  45. tests/test_channels.py +0 -409
  46. tests/test_circuit.py +0 -1713
  47. tests/test_cloud.py +0 -219
  48. tests/test_compiler.py +0 -147
  49. tests/test_dmcircuit.py +0 -555
  50. tests/test_ensemble.py +0 -72
  51. tests/test_fgs.py +0 -318
  52. tests/test_gates.py +0 -156
  53. tests/test_hamiltonians.py +0 -159
  54. tests/test_interfaces.py +0 -557
  55. tests/test_keras.py +0 -160
  56. tests/test_lattice.py +0 -1666
  57. tests/test_miscs.py +0 -334
  58. tests/test_mpscircuit.py +0 -341
  59. tests/test_noisemodel.py +0 -156
  60. tests/test_qaoa.py +0 -86
  61. tests/test_qem.py +0 -152
  62. tests/test_quantum.py +0 -549
  63. tests/test_quantum_attr.py +0 -42
  64. tests/test_results.py +0 -379
  65. tests/test_shadows.py +0 -160
  66. tests/test_simplify.py +0 -46
  67. tests/test_stabilizer.py +0 -226
  68. tests/test_templates.py +0 -218
  69. tests/test_torchnn.py +0 -99
  70. tests/test_van.py +0 -102
  71. {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/WHEEL +0 -0
  72. {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/licenses/LICENSE +0 -0
tests/test_lattice.py DELETED
@@ -1,1666 +0,0 @@
1
- from unittest.mock import patch
2
- import logging
3
-
4
- # import time
5
-
6
- import matplotlib
7
-
8
- matplotlib.use("Agg")
9
-
10
-
11
- import pytest
12
- import numpy as np
13
-
14
- from tensorcircuit.templates.lattice import (
15
- ChainLattice,
16
- CheckerboardLattice,
17
- CubicLattice,
18
- CustomizeLattice,
19
- DimerizedChainLattice,
20
- HoneycombLattice,
21
- KagomeLattice,
22
- LiebLattice,
23
- RectangularLattice,
24
- SquareLattice,
25
- TriangularLattice,
26
- )
27
-
28
-
29
- @pytest.fixture
30
- def simple_square_lattice() -> CustomizeLattice:
31
- """
32
- Provides a simple 2x2 square CustomizeLattice instance for neighbor tests.
33
- The sites are indexed as follows:
34
- 2--3
35
- | |
36
- 0--1
37
- """
38
- coords = [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]
39
- ids = list(range(len(coords)))
40
- lattice = CustomizeLattice(dimensionality=2, identifiers=ids, coordinates=coords)
41
- # Pre-calculate neighbors up to the 2nd shell for use in tests.
42
- lattice._build_neighbors(max_k=2)
43
- return lattice
44
-
45
-
46
- @pytest.fixture
47
- def kagome_lattice_fragment() -> CustomizeLattice:
48
- """
49
- Pytest fixture to provide a standard CustomizeLattice instance.
50
- This represents the Kagome fragment from the project requirements,
51
- making it a reusable object for multiple tests.
52
- """
53
- kag_coords = [
54
- [0.0, 0.0],
55
- [1.0, 0.0],
56
- [0.5, np.sqrt(3) / 2], # Triangle 1
57
- [2, 0],
58
- [1.5, np.sqrt(3) / 2], # Triangle 2 (shifted basis)
59
- [1.0, np.sqrt(3)], # Top site
60
- ]
61
- kag_ids = list(range(len(kag_coords)))
62
- return CustomizeLattice(
63
- dimensionality=2, identifiers=kag_ids, coordinates=kag_coords
64
- )
65
-
66
-
67
- class TestCustomizeLattice:
68
- """
69
- A test class to group all tests related to the CustomizeLattice.
70
- This helps in organizing the test suite.
71
- """
72
-
73
- def test_initialization_and_properties(self, kagome_lattice_fragment):
74
- """
75
- Test case for successful initialization and verification of basic properties.
76
- This test function receives the 'kagome_lattice_fragment' fixture as an argument.
77
- """
78
- # Arrange: The fixture has already prepared the 'lattice' object for us.
79
- lattice = kagome_lattice_fragment
80
-
81
- # Assert: Check if the object's properties match our expectations.
82
- assert lattice.dimensionality == 2
83
- assert lattice.num_sites == 6
84
- assert len(lattice) == 6 # This also tests the __len__ dunder method
85
-
86
- # Verify that coordinates are correctly stored as numpy arrays.
87
- # It's important to use np.testing.assert_array_equal for numpy array comparison.
88
- expected_coord = np.array([0.5, np.sqrt(3) / 2])
89
- np.testing.assert_array_equal(lattice.get_coordinates(2), expected_coord)
90
-
91
- # Verify that the mapping between identifiers and indices is correct.
92
- assert lattice.get_identifier(4) == 4
93
- assert lattice.get_index(4) == 4
94
-
95
- def test_input_validation_mismatched_lengths(self):
96
- """
97
- Tests that a ValueError is raised if identifiers and coordinates
98
- lists have mismatched lengths.
99
- """
100
- # Arrange: Prepare invalid inputs.
101
- coords = [[0.0, 0.0], [1.0, 0.0]] # 2 coordinates
102
- ids = [0, 1, 2] # 3 identifiers
103
-
104
- # Act & Assert: Use pytest.raises as a context manager to ensure
105
- # the specified exception is raised within the 'with' block.
106
- with pytest.raises(
107
- ValueError,
108
- match="Identifiers and coordinates lists must have the same length.",
109
- ):
110
- CustomizeLattice(dimensionality=2, identifiers=ids, coordinates=coords)
111
-
112
- def test_input_validation_wrong_dimension(self):
113
- """
114
- Tests that a ValueError is raised if a coordinate's dimension
115
- does not match the lattice's specified dimensionality.
116
- """
117
- # Arrange: Prepare coordinates with mixed dimensions for a 2D lattice.
118
- coords_wrong_dim = [[0.0, 0.0], [1.0, 0.0, 0.0]] # A mix of 2D and 3D
119
- ids_ok = [0, 1]
120
-
121
- # Act & Assert: Check for the specific error message. The 'r' before the string
122
- # indicates a raw string, which is good practice for regex patterns.
123
- with pytest.raises(
124
- ValueError, match=r"Coordinate at index 1 has shape \(3,\), expected \(2,\)"
125
- ):
126
- CustomizeLattice(
127
- dimensionality=2, identifiers=ids_ok, coordinates=coords_wrong_dim
128
- )
129
-
130
- def test_neighbor_finding(self, simple_square_lattice):
131
- """
132
- Tests the k-th nearest neighbor finding functionality (_build_neighbors
133
- and get_neighbors).
134
- """
135
- # Arrange: The fixture provides the lattice with pre-built neighbors.
136
- lattice = simple_square_lattice
137
-
138
- # --- Assertions for k=1 (Nearest Neighbors) ---
139
- # We use set() for comparison to ignore the order of neighbors.
140
- assert set(lattice.get_neighbors(0, k=1)) == {1, 2}
141
- assert set(lattice.get_neighbors(1, k=1)) == {0, 3}
142
- assert set(lattice.get_neighbors(2, k=1)) == {0, 3}
143
- assert set(lattice.get_neighbors(3, k=1)) == {1, 2}
144
-
145
- # --- Assertions for k=2 (Next-Nearest Neighbors) ---
146
- # These should be the diagonal sites.
147
- assert set(lattice.get_neighbors(0, k=2)) == {3}
148
- assert set(lattice.get_neighbors(1, k=2)) == {2}
149
-
150
- def test_neighbor_pairs(self, simple_square_lattice):
151
- """
152
- Tests the retrieval of unique neighbor pairs (bonds) using
153
- get_neighbor_pairs.
154
- """
155
- # Arrange: Use the same fixture.
156
- lattice = simple_square_lattice
157
-
158
- # --- Test for k=1 (Nearest Neighbor bonds) ---
159
- # Act: Get unique nearest neighbor pairs.
160
- nn_pairs = lattice.get_neighbor_pairs(k=1, unique=True)
161
-
162
- # Assert: The set of pairs should match the expected bonds.
163
- # We convert the list of pairs to a set of tuples for order-independent comparison.
164
- expected_nn_pairs = {(0, 1), (0, 2), (1, 3), (2, 3)}
165
- assert set(map(tuple, nn_pairs)) == expected_nn_pairs
166
-
167
- # --- Test for k=2 (Next-Nearest Neighbor bonds) ---
168
- # Act: Get unique next-nearest neighbor pairs.
169
- nnn_pairs = lattice.get_neighbor_pairs(k=2, unique=True)
170
-
171
- # Assert:
172
- expected_nnn_pairs = {(0, 3), (1, 2)}
173
- assert set(map(tuple, nnn_pairs)) == expected_nnn_pairs
174
-
175
- def test_neighbor_pairs_non_unique(self, simple_square_lattice):
176
- """
177
- Tests get_neighbor_pairs with unique=False to ensure all
178
- directed pairs (bonds) are returned.
179
- """
180
- # Arrange: Use the same 2x2 square lattice fixture.
181
- # 2--3
182
- # | |
183
- # 0--1
184
- lattice = simple_square_lattice
185
-
186
- # Act: Get NON-unique nearest neighbor pairs.
187
- nn_pairs = lattice.get_neighbor_pairs(k=1, unique=False)
188
-
189
- # Assert:
190
- # There are 4 bonds, so we expect 4 * 2 = 8 directed pairs.
191
- assert len(nn_pairs) == 8
192
-
193
- # Your source code sorts the output, so we can compare against a
194
- # sorted list for a precise match.
195
- expected_pairs = sorted(
196
- [(0, 1), (1, 0), (0, 2), (2, 0), (1, 3), (3, 1), (2, 3), (3, 2)]
197
- )
198
-
199
- assert nn_pairs == expected_pairs
200
-
201
- @patch("matplotlib.pyplot.show")
202
- def test_show_method_runs_and_calls_plt_show(
203
- self, mock_show, simple_square_lattice
204
- ):
205
- """
206
- Smoke test for the .show() method.
207
- It verifies that the method runs without raising an exception and that it
208
- triggers a call to matplotlib's show() function.
209
- We use @patch to "mock" the show function, preventing a plot window
210
- from actually appearing during tests.
211
- """
212
- # Arrange: Get the lattice instance from the fixture
213
- lattice = simple_square_lattice
214
-
215
- # Act: Call the .show() method.
216
- # We wrap it in a try...except block to give a more specific error
217
- # if the method fails for any reason.
218
- try:
219
- lattice.show()
220
- except Exception as e:
221
- pytest.fail(f".show() method raised an unexpected exception: {e}")
222
-
223
- # Assert: Check that our mocked matplotlib.pyplot.show was called exactly once.
224
- mock_show.assert_called_once()
225
-
226
- def test_sites_iterator(self, simple_square_lattice):
227
- """
228
- Tests the sites() iterator to ensure it yields all sites correctly.
229
- """
230
- # Arrange
231
- lattice = simple_square_lattice
232
- expected_num_sites = 4
233
-
234
- # Act
235
- # The sites() method returns an iterator, we convert it to a list to check its length.
236
- all_sites = list(lattice.sites())
237
-
238
- # Assert
239
- assert len(all_sites) == expected_num_sites
240
-
241
- # For a more thorough check, verify the content of one of the yielded tuples.
242
- # For the simple_square_lattice fixture, site 3 has identifier 3 and coords [1, 1].
243
- idx, ident, coords = all_sites[3]
244
- assert idx == 3
245
- assert ident == 3
246
- np.testing.assert_array_equal(coords, np.array([1, 1]))
247
-
248
- def test_get_site_info_with_identifier(self, simple_square_lattice):
249
- """
250
- Tests the get_site_info() method using a site identifier instead of an index.
251
- This covers the 'else' branch of the type check in the method.
252
- """
253
- # Arrange
254
- lattice = simple_square_lattice
255
- # In this fixture, the identifier for the site at index 2 is also the integer 2.
256
- identifier_to_test = 2
257
- expected_index = 2
258
- expected_coords = np.array([0, 1])
259
-
260
- # Act
261
- idx, ident, coords = lattice.get_site_info(identifier_to_test)
262
-
263
- # Assert
264
- assert idx == expected_index
265
- assert ident == identifier_to_test
266
- np.testing.assert_array_equal(coords, expected_coords)
267
-
268
- @patch("matplotlib.pyplot.show")
269
- def test_show_method_with_labels(self, mock_show, simple_square_lattice):
270
- """
271
- Tests that the .show() method runs without error when label-related
272
- options are enabled. This covers the logic inside the
273
- 'if show_indices or show_identifiers:' block.
274
- """
275
- # Arrange
276
- lattice = simple_square_lattice
277
-
278
- # Act & Assert
279
- try:
280
- # Call .show() with options to display indices and identifiers.
281
- lattice.show(show_indices=True, show_identifiers=True)
282
- except Exception as e:
283
- pytest.fail(
284
- f".show() with label options raised an unexpected exception: {e}"
285
- )
286
-
287
- # Ensure the plotting function is still called.
288
- mock_show.assert_called_once()
289
-
290
- def test_get_neighbors_logs_info_for_uncached_k(
291
- self, simple_square_lattice, caplog
292
- ):
293
- """
294
- Tests that an INFO message is logged when get_neighbors is called for a 'k'
295
- that has not been pre-calculated, triggering on-demand computation.
296
- """
297
- # Arrange
298
- lattice = simple_square_lattice # This fixture builds neighbors up to k=2
299
- k_to_test = 99 # A value that is clearly not cached
300
- caplog.set_level(logging.INFO) # Ensure INFO logs are captured
301
-
302
- # Act
303
- # This will now trigger the on-demand computation
304
- _ = lattice.get_neighbors(0, k=k_to_test)
305
-
306
- # Assert
307
- # Check that the correct INFO message about on-demand building was logged.
308
- expected_log = (
309
- f"Neighbors for k={k_to_test} not pre-computed. "
310
- f"Building now up to max_k={k_to_test}."
311
- )
312
- assert expected_log in caplog.text
313
-
314
- @patch("matplotlib.pyplot.show")
315
- def test_show_prints_warning_for_uncached_bonds(
316
- self, mock_show, simple_square_lattice, caplog
317
- ):
318
- """
319
- Tests that a warning is printed when .show() is asked to draw a bond layer 'k'
320
- that has not been pre-calculated.
321
- """
322
- # Arrange
323
- lattice = simple_square_lattice # This fixture builds neighbors up to k=2
324
- k_to_test = 99 # A value that is clearly not cached
325
-
326
- # Act
327
- lattice.show(show_bonds_k=k_to_test)
328
-
329
- # Assert
330
- assert (
331
- f"Cannot draw bonds. k={k_to_test} neighbors have not been calculated"
332
- in caplog.text
333
- )
334
-
335
- @patch("matplotlib.pyplot.show")
336
- def test_show_method_for_3d_lattice(self, mock_show):
337
- """
338
- Tests that the .show() method can handle a 3D lattice without
339
- crashing. This covers the 'if self.dimensionality == 3:' branches.
340
- """
341
- # Arrange: Create a simple 2-site lattice in 3D space.
342
- coords_3d = [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]
343
- ids_3d = [0, 1]
344
- lattice_3d = CustomizeLattice(
345
- dimensionality=3, identifiers=ids_3d, coordinates=coords_3d
346
- )
347
-
348
- # Assert basic property
349
- assert lattice_3d.dimensionality == 3
350
-
351
- # Act & Assert
352
- # We just need to ensure that calling .show() on a 3D object
353
- # executes the 3D plotting logic without raising an exception.
354
- try:
355
- lattice_3d.show(show_indices=True, show_bonds_k=None)
356
- except Exception as e:
357
- pytest.fail(f".show() for 3D lattice raised an unexpected exception: {e}")
358
-
359
- # Verify that the plotting pipeline was completed.
360
- mock_show.assert_called_once()
361
-
362
- @patch("matplotlib.pyplot.subplots")
363
- def test_show_method_actually_draws_2d_labels(
364
- self, mock_subplots, simple_square_lattice
365
- ):
366
- """
367
- Tests if ax.text is actually called for a 2D lattice when labels are enabled.
368
- """
369
- # Arrange:
370
- # 1. Prepare mock Figure and Axes objects that `matplotlib.pyplot.subplots` will return.
371
- # This allows us to inspect calls to the `ax` object.
372
- mock_fig = matplotlib.figure.Figure()
373
- mock_ax = matplotlib.axes.Axes(mock_fig, [0.0, 0.0, 1.0, 1.0])
374
- mock_subplots.return_value = (mock_fig, mock_ax)
375
-
376
- # 2. Mock the text method on our mock Axes object to monitor its calls.
377
- with patch.object(mock_ax, "text") as mock_text_method:
378
- lattice = simple_square_lattice
379
-
380
- # Act:
381
- # Call the show method. It will now operate on our mock_ax object.
382
- lattice.show(show_indices=True)
383
-
384
- # Assert:
385
- # Check if the ax.text method was called. For a 4-site lattice, it should be called 4 times.
386
- assert mock_text_method.call_count == lattice.num_sites
387
-
388
- def test_custom_irregular_geometry_neighbors(self):
389
- """
390
- Tests neighbor finding on a more complex, non-grid-like custom geometry
391
- to stress-test the distance shell and KDTree logic.
392
- """
393
- # Arrange: A "star-shaped" lattice with a central point,
394
- # an inner shell, and an outer shell.
395
- coords = [
396
- [0.0, 0.0], # Site 0: Center
397
- [1.0, 0.0],
398
- [0.0, 1.0],
399
- [-1.0, 0.0],
400
- [0.0, -1.0], # Sites 1-4: Inner shell (dist=1)
401
- [2.0, 0.0],
402
- [0.0, 2.0],
403
- [-2.0, 0.0],
404
- [0.0, -2.0], # Sites 5-8: Outer shell (dist=2)
405
- ]
406
- ids = list(range(len(coords)))
407
- lattice = CustomizeLattice(
408
- dimensionality=2, identifiers=ids, coordinates=coords
409
- )
410
- lattice._build_neighbors(max_k=3)
411
-
412
- # Assert 1: Neighbors of the central point (0) should be the distinct shells.
413
- assert set(lattice.get_neighbors(0, k=1)) == {1, 2, 3, 4}
414
- # The shell at dist=2.0 (d_sq=4.0) is the 3rd global shell, so we check k=3.
415
- assert set(lattice.get_neighbors(0, k=3)) == {5, 6, 7, 8}
416
-
417
- assert lattice.get_neighbors(0, k=2) == []
418
-
419
- # Assert 2: Neighbors of a point on the inner shell, e.g., site 1 ([1.0, 0.0]).
420
- # Its nearest neighbors (k=1) are the center (0) and the closest point on the outer shell (5).
421
- # Both are at distance 1.0.
422
- assert set(lattice.get_neighbors(1, k=1)) == {0, 5}
423
-
424
- # Its next-nearest neighbors (k=2) are the other two points on the inner shell (2 and 4),
425
- # both at distance sqrt(2).
426
- assert set(lattice.get_neighbors(1, k=2)) == {2, 4}
427
-
428
- def test_customizelattice_max_k_precomputation_and_ondemand(self):
429
- """
430
- A robust test to verify `precompute_neighbors` (max_k) for CustomizeLattice.
431
- This test is designed to FAIL on the buggy code.
432
- """
433
- coords = [
434
- [0.0, 0.0],
435
- [1.0, 0.0],
436
- [0.0, 1.0],
437
- [-1.0, 0.0],
438
- [0.0, -1.0],
439
- [1.0, 1.0],
440
- [-1.0, 1.0],
441
- [-1.0, -1.0],
442
- [1.0, -1.0],
443
- [2.0, 0.0],
444
- [0.0, 2.0],
445
- [-2.0, 0.0],
446
- [0.0, -2.0],
447
- ]
448
- ids = list(range(len(coords)))
449
- k_precompute = 2
450
-
451
- lattice = CustomizeLattice(
452
- dimensionality=2,
453
- identifiers=ids,
454
- coordinates=coords,
455
- precompute_neighbors=k_precompute,
456
- )
457
-
458
- computed_shells = sorted(list(lattice._neighbor_maps.keys()))
459
- expected_shells = list(range(1, k_precompute + 1))
460
-
461
- assert computed_shells == expected_shells, (
462
- f"TEST FAILED for CustomizeLattice with k={k_precompute}. "
463
- f"Expected shells {expected_shells}, but found {computed_shells}."
464
- )
465
-
466
- k_ondemand = 3
467
- _ = lattice.get_neighbors(0, k=k_ondemand)
468
-
469
- computed_shells_after = sorted(list(lattice._neighbor_maps.keys()))
470
- expected_shells_after = list(range(1, k_ondemand + 1))
471
-
472
- assert computed_shells_after == expected_shells_after, (
473
- f"ON-DEMAND TEST FAILED for CustomizeLattice. "
474
- f"Expected shells {expected_shells_after} after demanding k={k_ondemand}, "
475
- f"but found {computed_shells_after}."
476
- )
477
-
478
-
479
- @pytest.fixture
480
- def obc_square_lattice() -> SquareLattice:
481
- """Provides a 3x3 SquareLattice with Open Boundary Conditions."""
482
- return SquareLattice(size=(3, 3), pbc=False)
483
-
484
-
485
- @pytest.fixture
486
- def pbc_square_lattice() -> SquareLattice:
487
- """Provides a 3x3 SquareLattice with Periodic Boundary Conditions."""
488
- return SquareLattice(size=(3, 3), pbc=True)
489
-
490
-
491
- class TestSquareLattice:
492
- """
493
- Groups all tests for the SquareLattice class, which implicitly tests
494
- the core functionality of its parent, TILattice.
495
- """
496
-
497
- def test_initialization_and_properties(self, obc_square_lattice):
498
- """
499
- Tests the basic properties of a SquareLattice instance.
500
- """
501
- lattice = obc_square_lattice
502
- assert lattice.dimensionality == 2
503
- assert lattice.num_sites == 9 # A 3x3 lattice should have 9 sites.
504
- assert len(lattice) == 9
505
-
506
- def test_site_info_and_identifiers(self, obc_square_lattice):
507
- """
508
- Tests that site information (coordinates, identifiers) is correct.
509
- """
510
- lattice = obc_square_lattice
511
- center_idx = lattice.get_index((1, 1, 0))
512
- assert center_idx == 4
513
-
514
- _, ident, coords = lattice.get_site_info(center_idx)
515
- assert ident == (1, 1, 0)
516
- np.testing.assert_array_equal(coords, np.array([1.0, 1.0]))
517
-
518
- corner_idx = 0
519
- _, ident, coords = lattice.get_site_info(corner_idx)
520
- assert ident == (0, 0, 0)
521
- np.testing.assert_array_equal(coords, np.array([0.0, 0.0]))
522
-
523
- def test_neighbors_with_open_boundaries(self, obc_square_lattice):
524
- """
525
- Tests neighbor finding with Open Boundary Conditions (OBC) using specific
526
- neighbor identities.
527
- """
528
- lattice = obc_square_lattice
529
- # Site indices for a 3x3 grid (row-major order):
530
- # 0 1 2
531
- # 3 4 5
532
- # 6 7 8
533
- center_idx = 4 # (1, 1, 0)
534
- corner_idx = 0 # (0, 0, 0)
535
- edge_idx = 3 # (1, 0, 0)
536
-
537
- # Assert center site (4) has neighbors 1, 3, 5, 7
538
- assert set(lattice.get_neighbors(center_idx, k=1)) == {1, 3, 5, 7}
539
- # Assert corner site (0) has neighbors 1, 3
540
- assert set(lattice.get_neighbors(corner_idx, k=1)) == {1, 3}
541
- # Assert edge site (3) has neighbors 0, 4, 6
542
- assert set(lattice.get_neighbors(edge_idx, k=1)) == {0, 4, 6}
543
-
544
- def test_neighbors_with_periodic_boundaries(self, pbc_square_lattice):
545
- """
546
- Tests neighbor finding with Periodic Boundary Conditions (PBC).
547
- """
548
- lattice = pbc_square_lattice
549
- corner_idx = lattice.get_index((0, 0, 0))
550
-
551
- neighbors = lattice.get_neighbors(corner_idx, k=1)
552
- neighbor_idents = {lattice.get_identifier(i) for i in neighbors}
553
- expected_neighbor_idents = {(1, 0, 0), (0, 1, 0), (2, 0, 0), (0, 2, 0)}
554
- assert neighbor_idents == expected_neighbor_idents
555
-
556
- nnn_neighbors = lattice.get_neighbors(corner_idx, k=2)
557
- nnn_neighbor_idents = {lattice.get_identifier(i) for i in nnn_neighbors}
558
- expected_nnn_idents = {(1, 1, 0), (2, 1, 0), (1, 2, 0), (2, 2, 0)}
559
- assert nnn_neighbor_idents == expected_nnn_idents
560
-
561
-
562
- # --- Tests for HoneycombLattice ---
563
-
564
-
565
- @pytest.fixture
566
- def pbc_honeycomb_lattice() -> HoneycombLattice:
567
- """Provides a 2x2 HoneycombLattice with Periodic Boundary Conditions."""
568
- return HoneycombLattice(size=(2, 2), pbc=True)
569
-
570
-
571
- class TestHoneycombLattice:
572
- """
573
- Tests the HoneycombLattice class, focusing on its two-site basis.
574
- """
575
-
576
- def test_initialization_and_properties(self, pbc_honeycomb_lattice):
577
- """
578
- Tests that the total number of sites is correct for a composite lattice.
579
- """
580
- lattice = pbc_honeycomb_lattice
581
- assert lattice.num_sites == 8
582
- assert lattice.num_basis == 2
583
-
584
- def test_honeycomb_neighbors(self, pbc_honeycomb_lattice):
585
- """
586
- Tests that every site in a honeycomb lattice has 3 nearest neighbors.
587
- """
588
- lattice = pbc_honeycomb_lattice
589
- site_a_idx = lattice.get_index((0, 0, 0))
590
- assert len(lattice.get_neighbors(site_a_idx, k=1)) == 3
591
-
592
- site_b_idx = lattice.get_index((0, 0, 1))
593
- assert len(lattice.get_neighbors(site_b_idx, k=1)) == 3
594
-
595
-
596
- # --- Tests for TriangularLattice ---
597
-
598
-
599
- @pytest.fixture
600
- def pbc_triangular_lattice() -> TriangularLattice:
601
- """
602
- Provides a 3x3 TriangularLattice with Periodic Boundary Conditions.
603
- A 3x3 size is used to ensure all 6 nearest neighbors are unique sites.
604
- """
605
- return TriangularLattice(size=(3, 3), pbc=True)
606
-
607
-
608
- class TestTriangularLattice:
609
- """
610
- Tests the TriangularLattice class, focusing on its coordination number.
611
- """
612
-
613
- def test_initialization_and_properties(self, pbc_triangular_lattice):
614
- """
615
- Tests the basic properties of the triangular lattice.
616
- """
617
- lattice = pbc_triangular_lattice
618
- assert lattice.num_sites == 9 # 3 * 3 = 9 sites for a 3x3 grid
619
-
620
- def test_triangular_neighbors(self, pbc_triangular_lattice):
621
- """
622
- Tests that every site in a triangular lattice has 6 nearest neighbors.
623
- """
624
- lattice = pbc_triangular_lattice
625
- site_idx = 0
626
- assert len(lattice.get_neighbors(site_idx, k=1)) == 6
627
-
628
-
629
- # --- Tests for New TILattice Implementations ---
630
-
631
-
632
- class TestRectangularLattice:
633
- """Tests for the 2D RectangularLattice."""
634
-
635
- def test_rectangular_properties_and_neighbors(self):
636
- """Tests neighbor counts for an OBC rectangular lattice."""
637
- lattice = RectangularLattice(size=(3, 4), pbc=False)
638
- assert lattice.num_sites == 12
639
- assert lattice.dimensionality == 2
640
-
641
- # Test neighbor counts for different site types
642
- center_idx = lattice.get_index((1, 1, 0))
643
- corner_idx = lattice.get_index((0, 0, 0))
644
- edge_idx = lattice.get_index((0, 1, 0))
645
-
646
- assert len(lattice.get_neighbors(center_idx, k=1)) == 4
647
- assert len(lattice.get_neighbors(corner_idx, k=1)) == 2
648
- assert len(lattice.get_neighbors(edge_idx, k=1)) == 3
649
-
650
-
651
- class TestTILatticeEdgeCases:
652
- """
653
- A dedicated class for testing the behavior of TILattice and its
654
- subclasses under less common, "edge-case" conditions.
655
- """
656
-
657
- @pytest.fixture
658
- def obc_1d_chain(self) -> ChainLattice:
659
- """
660
- Provides a 5-site 1D chain with Open Boundary Conditions.
661
- """
662
- # 0--1--2--3--4
663
- return ChainLattice(size=(5,), pbc=False)
664
-
665
- def test_1d_chain_properties_and_neighbors(self, obc_1d_chain):
666
- # Arrange
667
- lattice = obc_1d_chain
668
-
669
- # Assert basic properties
670
- assert lattice.num_sites == 5
671
- assert lattice.dimensionality == 1
672
-
673
- # Assert neighbor counts for different positions
674
- # Endpoint (site 0) should have 1 neighbor (site 1)
675
- endpoint_idx = lattice.get_index((0, 0))
676
- assert lattice.get_neighbors(endpoint_idx, k=1) == [1]
677
-
678
- # Middle point (site 2) should have 2 neighbors (sites 1 and 3)
679
- middle_idx = lattice.get_index((2, 0))
680
- assert len(lattice.get_neighbors(middle_idx, k=1)) == 2
681
- assert set(lattice.get_neighbors(middle_idx, k=1)) == {1, 3}
682
-
683
- @pytest.fixture
684
- def nonsquare_lattice(self) -> SquareLattice:
685
- """Provides a non-square 2x3 lattice to test indexing."""
686
- return SquareLattice(size=(2, 3), pbc=False)
687
-
688
- def test_nonsquare_lattice_indexing(self, nonsquare_lattice):
689
- """
690
- Tests site indexing and coordinate generation on a non-square (2x3) lattice.
691
- This ensures the logic correctly handles different dimension lengths.
692
- The lattice sites are indexed row by row:
693
- (0,0) (0,1) (0,2) -> indices 0, 1, 2
694
- (1,0) (1,1) (1,2) -> indices 3, 4, 5
695
- """
696
- # Arrange
697
- lattice = nonsquare_lattice
698
-
699
- # Assert properties
700
- assert lattice.num_sites == 6 # 2 * 3 = 6
701
-
702
- # Act & Assert: Check a non-trivial site, e.g., the last one.
703
- # The identifier for the site in the last row and last column.
704
- ident = (1, 2, 0)
705
- expected_idx = 5
706
- expected_coords = np.array([1.0, 2.0])
707
-
708
- # Get index from identifier
709
- idx = lattice.get_index(ident)
710
- assert idx == expected_idx
711
-
712
- # Get info from index
713
- _, _, coords = lattice.get_site_info(idx)
714
- np.testing.assert_array_equal(coords, expected_coords)
715
-
716
- @patch("matplotlib.pyplot.show")
717
- def test_show_method_for_1d_lattice(self, mock_show, obc_1d_chain):
718
- """
719
- Tests that the .show() method can handle a 1D lattice (chain)
720
- without crashing. This covers the 'if self.dimensionality == 1:' branches.
721
- """
722
- # Arrange
723
- lattice_1d = obc_1d_chain
724
-
725
- # Assert basic property
726
- assert lattice_1d.num_sites == 5
727
-
728
- # Act & Assert
729
- try:
730
- # Call .show() on the 1D lattice to execute the 1D plotting logic.
731
- lattice_1d.show(show_indices=True)
732
- except Exception as e:
733
- pytest.fail(f".show() for 1D lattice raised an unexpected exception: {e}")
734
-
735
- # Verify that the plotting pipeline was completed.
736
- mock_show.assert_called_once()
737
-
738
-
739
- # --- Tests for API Robustness / Negative Cases ---
740
-
741
-
742
- class TestApiRobustness:
743
- """
744
- Groups tests that verify the API's behavior with invalid inputs.
745
- This ensures the lattice classes fail gracefully and predictably.
746
- """
747
-
748
- def test_access_with_out_of_bounds_index(self, simple_square_lattice):
749
- """
750
- Tests that an IndexError is raised when accessing a site index
751
- that is out of the valid range (0 to num_sites-1).
752
- """
753
- # Arrange
754
- lattice = simple_square_lattice # This lattice has 4 sites (indices 0, 1, 2, 3)
755
- invalid_index = 999
756
-
757
- # Act & Assert
758
- # We use pytest.raises to confirm that the expected exception is thrown.
759
- with pytest.raises(IndexError):
760
- lattice.get_coordinates(invalid_index)
761
-
762
- with pytest.raises(IndexError):
763
- lattice.get_identifier(invalid_index)
764
-
765
- with pytest.raises(IndexError):
766
- # get_site_info should also raise IndexError for an invalid index
767
- lattice.get_site_info(invalid_index)
768
-
769
- def test_empty_lattice_handles_gracefully(self, caplog):
770
- """
771
- Tests that an empty lattice initializes correctly and that methods
772
- like .show() and ._build_neighbors() handle the zero-site case
773
- gracefully without crashing.
774
- """
775
- # Arrange: Create an empty CustomizeLattice instance.
776
- empty_lattice = CustomizeLattice(
777
- dimensionality=2, identifiers=[], coordinates=[]
778
- )
779
-
780
- # Assert: Verify basic properties.
781
- assert empty_lattice.num_sites == 0
782
- assert len(empty_lattice) == 0
783
-
784
- # Act & Assert for .show(): Verify it prints the expected message without crashing.
785
- caplog.set_level(logging.INFO)
786
-
787
- empty_lattice.show()
788
- assert "Lattice is empty, nothing to show." in caplog.text
789
-
790
- # Act & Assert for neighbor finding: Verify these calls run without errors.
791
- empty_lattice._build_neighbors()
792
- assert empty_lattice.get_neighbor_pairs(k=1) == []
793
-
794
- def test_single_site_lattice_handles_gracefully(self):
795
- """
796
- Tests that a lattice with a single site correctly handles neighbor
797
- finding (i.e., returns no neighbors).
798
- """
799
- # Arrange: Create a CustomizeLattice with a single site.
800
- single_site_lattice = CustomizeLattice(
801
- dimensionality=2, identifiers=[0], coordinates=[[0.0, 0.0]]
802
- )
803
-
804
- # Assert: Verify basic properties.
805
- assert single_site_lattice.num_sites == 1
806
-
807
- # Act: Attempt to build neighbor relationships.
808
- single_site_lattice._build_neighbors(max_k=1)
809
-
810
- # Assert: The single site should have no neighbors.
811
- assert single_site_lattice.get_neighbors(0, k=1) == []
812
-
813
- def test_access_with_non_existent_identifier(self, simple_square_lattice):
814
- """
815
- Tests that a ValueError is raised when accessing a site
816
- with an identifier that does not exist in the lattice.
817
- """
818
- # Arrange
819
- lattice = simple_square_lattice
820
- invalid_identifier = "non_existent_site"
821
-
822
- # Act & Assert
823
- # Your code raises a ValueError with a specific message. We can even
824
- # use the 'match' parameter to check if the error message is correct.
825
- with pytest.raises(ValueError, match="not found in the lattice"):
826
- lattice.get_index(invalid_identifier)
827
-
828
- with pytest.raises(ValueError, match="not found in the lattice"):
829
- lattice.get_site_info(invalid_identifier)
830
-
831
- def test_show_warning_for_unsupported_dimension(self, caplog):
832
- """
833
- Tests that .show() prints a warning when called on a lattice with a
834
- dimensionality that it does not support for plotting (e.g., 4D).
835
- """
836
- # Arrange: Create a simple lattice with an unsupported dimension.
837
- lattice_4d = CustomizeLattice(
838
- dimensionality=4, identifiers=[0], coordinates=[[0, 0, 0, 0]]
839
- )
840
-
841
- # Act
842
- lattice_4d.show()
843
-
844
- # Assert: Check that the appropriate warning was printed to stdout.
845
- assert "show() is not implemented for 4D lattices." in caplog.text
846
-
847
- def test_disconnected_lattice_neighbor_finding(self):
848
- """
849
- Tests that neighbor finding algorithms work correctly for a lattice
850
- composed of multiple, physically disconnected components.
851
- """
852
- # Arrange: Create a lattice with two disconnected 2x2 squares,
853
- # separated by a large distance.
854
- # Component 1: sites with indices 0, 1, 2, 3
855
- # Component 2: sites with indices 4, 5, 6, 7
856
- coords = [
857
- [0.0, 0.0],
858
- [1.0, 0.0],
859
- [0.0, 1.0],
860
- [1.0, 1.0], # Square 1
861
- [100.0, 0.0],
862
- [101.0, 0.0],
863
- [100.0, 1.0],
864
- [101.0, 1.0], # Square 2
865
- ]
866
- ids = list(range(len(coords)))
867
- lattice = CustomizeLattice(
868
- dimensionality=2, identifiers=ids, coordinates=coords
869
- )
870
- lattice._build_neighbors(max_k=1) # Explicitly build neighbors
871
-
872
- # --- Test 1: get_neighbors() ---
873
- # Act: Get neighbors for a site in the first component.
874
- neighbors_of_site_0 = lattice.get_neighbors(0, k=1)
875
-
876
- # Assert: Its neighbors must only be within the first component.
877
- assert set(neighbors_of_site_0) == {1, 2}
878
-
879
- # --- Test 2: get_neighbor_pairs() ---
880
- # Act: Get all unique bonds for the entire lattice.
881
- all_bonds = lattice.get_neighbor_pairs(k=1, unique=True)
882
-
883
- # Assert: No bond should connect a site from Component 1 to Component 2.
884
- for i, j in all_bonds:
885
- # A bond is valid only if both its sites are in the same component.
886
- # We check this by seeing if their indices fall in the same range.
887
- is_in_first_component = i < 4 and j < 4
888
- is_in_second_component = i >= 4 and j >= 4
889
-
890
- assert is_in_first_component or is_in_second_component, (
891
- f"Found an invalid bond { (i,j) } that incorrectly connects "
892
- "two separate components of the lattice."
893
- )
894
-
895
- def test_lattice_with_duplicate_coordinates(self):
896
- """
897
- Tests a pathological case where multiple sites share the exact same coordinates.
898
- The neighbor-finding logic must still treat them as distinct sites and
899
- correctly identify neighbors based on other non-overlapping sites.
900
- """
901
- # Arrange
902
- # Site 'A' and 'B' are at the same position (0,0).
903
- # Site 'C' is at (1,0), which should be a neighbor to both 'A' and 'B'.
904
- ids = ["A", "B", "C"]
905
- coords = [[0.0, 0.0], [0.0, 0.0], [1.0, 0.0]]
906
-
907
- lattice = CustomizeLattice(
908
- dimensionality=2, identifiers=ids, coordinates=coords
909
- )
910
- lattice._build_neighbors(max_k=1) # Build nearest neighbors
911
-
912
- # Act
913
- idx_A = lattice.get_index("A")
914
- idx_B = lattice.get_index("B")
915
- idx_C = lattice.get_index("C")
916
-
917
- neighbors_A = lattice.get_neighbors(idx_A, k=1)
918
- neighbors_B = lattice.get_neighbors(idx_B, k=1)
919
-
920
- # Assert
921
- # 1. The distance between the overlapping points 'A' and 'B' is 0,
922
- # so they should NOT be considered neighbors of each other.
923
- assert (
924
- idx_B not in neighbors_A
925
- ), "Overlapping sites should not be their own neighbors."
926
- assert (
927
- idx_A not in neighbors_B
928
- ), "Overlapping sites should not be their own neighbors."
929
-
930
- # 2. Both 'A' and 'B' should correctly identify 'C' as their neighbor.
931
- # This is the key test of robustness.
932
- assert neighbors_A == [
933
- idx_C
934
- ], "Site 'A' failed to find its correct neighbor 'C'."
935
- assert neighbors_B == [
936
- idx_C
937
- ], "Site 'B' failed to find its correct neighbor 'C'."
938
-
939
- # 3. Conversely, 'C' should identify both 'A' and 'B' as its neighbors.
940
- neighbors_C = lattice.get_neighbors(idx_C, k=1)
941
- assert set(neighbors_C) == {
942
- idx_A,
943
- idx_B,
944
- }, "Site 'C' failed to find both overlapping neighbors."
945
-
946
- def test_neighbor_shells_with_tiny_separation(self):
947
- """
948
- Tests the numerical stability of neighbor shell identification.
949
- Creates a lattice where the k=1 and k=2 shells are separated by a
950
- distance much smaller than the default tolerance, and verifies that they
951
- are still correctly identified as distinct shells.
952
- """
953
- # Arrange
954
- # Let d1 be the distance to the first neighbor shell.
955
- d1 = 1.0
956
- # Let d2 be the distance to the second shell, which is extremely close to d1.
957
- epsilon = 1e-8 # A tiny separation
958
- d2 = d1 + epsilon
959
-
960
- # Create a 1D lattice with these specific distances.
961
- # Site 0 is origin. Site 1 is at d1. Site 2 is at d2.
962
- ids = [0, 1, 2]
963
- coords = [[0.0], [d1], [d2]]
964
-
965
- # We explicitly use a tolerance LARGER than the separation,
966
- # which SHOULD cause the shells to merge.
967
- lattice_merged = CustomizeLattice(
968
- dimensionality=1, identifiers=ids, coordinates=coords
969
- )
970
- # Use a tolerance that cannot distinguish d1 and d2.
971
- lattice_merged._build_neighbors(max_k=2, tol=1e-7)
972
-
973
- # Now, use a tolerance SMALLER than the separation,
974
- # which SHOULD correctly distinguish the shells.
975
- lattice_distinct = CustomizeLattice(
976
- dimensionality=1, identifiers=ids, coordinates=coords
977
- )
978
- lattice_distinct._build_neighbors(max_k=2, tol=1e-9)
979
-
980
- # Assert for the merged case
981
- # With a large tolerance, site 1 and 2 should both be in the k=1 shell.
982
- merged_neighbors_k1 = lattice_merged.get_neighbors(0, k=1)
983
- assert set(merged_neighbors_k1) == {
984
- 1,
985
- 2,
986
- }, "Shells were not merged with a large tolerance."
987
- # There should be no k=2 shell.
988
- merged_neighbors_k2 = lattice_merged.get_neighbors(0, k=2)
989
- assert (
990
- merged_neighbors_k2 == []
991
- ), "A k=2 shell should not exist when shells are merged."
992
-
993
- # Assert for the distinct case
994
- # With a small tolerance, only site 1 should be in the k=1 shell.
995
- distinct_neighbors_k1 = lattice_distinct.get_neighbors(0, k=1)
996
- assert distinct_neighbors_k1 == [
997
- 1
998
- ], "k=1 shell is incorrect with a small tolerance."
999
- # Site 2 should now be in its own k=2 shell.
1000
- distinct_neighbors_k2 = lattice_distinct.get_neighbors(0, k=2)
1001
- assert distinct_neighbors_k2 == [
1002
- 2
1003
- ], "k=2 shell is incorrect with a small tolerance."
1004
-
1005
-
1006
- class TestTILattice:
1007
- """
1008
- A dedicated class for testing the Translationally Invariant Lattice (TILattice)
1009
- and its subclasses like SquareLattice.
1010
- """
1011
-
1012
- def test_init_with_mismatched_shapes_raises_error(self):
1013
- """
1014
- Tests that TILattice raises AssertionError if the 'size' parameter's
1015
- length does not match the dimensionality.
1016
- """
1017
- # Act & Assert:
1018
- # Pass a 'size' tuple with 3 elements to a 2D SquareLattice.
1019
- # This should trigger the AssertionError from the parent TILattice class.
1020
- with pytest.raises(AssertionError, match="Size tuple length mismatch"):
1021
- SquareLattice(size=(2, 2, 2))
1022
-
1023
- def test_init_with_tuple_pbc(self):
1024
- """
1025
- Tests that TILattice correctly handles a tuple input for the 'pbc'
1026
- (periodic boundary conditions) parameter. This covers the 'else' branch.
1027
- """
1028
- # Arrange
1029
- pbc_tuple = (True, False)
1030
-
1031
- # Act
1032
- # Initialize a lattice with a tuple for pbc.
1033
- lattice = SquareLattice(size=(3, 3), pbc=pbc_tuple)
1034
-
1035
- # Assert
1036
- # The public 'pbc' attribute should be identical to the tuple we passed.
1037
- assert lattice.pbc == pbc_tuple
1038
-
1039
- @pytest.mark.parametrize(
1040
- "LatticeClass, init_args, k_precompute",
1041
- [
1042
- (HoneycombLattice, {"size": (4, 5), "pbc": True}, 1),
1043
- (SquareLattice, {"size": (5, 5), "pbc": True}, 2),
1044
- (SquareLattice, {"size": (5, 5), "pbc": False}, 1),
1045
- (KagomeLattice, {"size": (3, 3), "pbc": True}, 1),
1046
- ],
1047
- )
1048
- def test_tilattice_max_k_precomputation_and_ondemand(
1049
- self, LatticeClass, init_args, k_precompute
1050
- ):
1051
- """
1052
- A robust, parameterized test to verify that `precompute_neighbors` (max_k)
1053
- works correctly across various TILattice types and conditions.
1054
- This test is designed to FAIL on the buggy code.
1055
- """
1056
- lattice = LatticeClass(**init_args, precompute_neighbors=k_precompute)
1057
-
1058
- computed_shells = sorted(list(lattice._neighbor_maps.keys()))
1059
- expected_shells = list(range(1, k_precompute + 1))
1060
-
1061
- assert computed_shells == expected_shells, (
1062
- f"TEST FAILED for {LatticeClass.__name__} with k={k_precompute}. "
1063
- f"Expected shells {expected_shells}, but found {computed_shells}."
1064
- )
1065
-
1066
- k_ondemand = k_precompute + 1
1067
-
1068
- _ = lattice.get_neighbors(0, k=k_ondemand)
1069
-
1070
- computed_shells_after = sorted(list(lattice._neighbor_maps.keys()))
1071
- expected_shells_after = list(range(1, k_ondemand + 1))
1072
-
1073
- assert computed_shells_after == expected_shells_after, (
1074
- f"ON-DEMAND TEST FAILED for {LatticeClass.__name__}. "
1075
- f"Expected shells {expected_shells_after} after demanding k={k_ondemand}, "
1076
- f"but found {computed_shells_after}."
1077
- )
1078
-
1079
-
1080
- class TestLongRangeNeighborFinding:
1081
- """
1082
- Tests neighbor finding on larger lattices and for longer-range interactions (large k),
1083
- addressing suggestions from code review.
1084
- """
1085
-
1086
- @pytest.fixture(scope="class")
1087
- def large_pbc_square_lattice(self) -> SquareLattice:
1088
- """
1089
- Provides a single 6x8 SquareLattice with PBC for all tests in this class.
1090
- Using scope="class" makes it more efficient as it's created only once.
1091
- """
1092
- # We choose a non-square size to catch potential bugs with non-uniform dimensions.
1093
- return SquareLattice(size=(7, 9), pbc=True)
1094
-
1095
- def test_neighbor_shell_structure_on_large_lattice(self, large_pbc_square_lattice):
1096
- """
1097
- Tests the coordination number of various neighbor shells (k) on a large
1098
- periodic lattice. In a PBC square lattice, every site is identical, so
1099
- the number of neighbors for each shell k should be the same for all sites.
1100
-
1101
- Shell distances squared and their coordination numbers for a 2D square lattice:
1102
- - k=1: dist_sq=1 (e.g., (1,0)) -> 4 neighbors
1103
- - k=2: dist_sq=2 (e.g., (1,1)) -> 4 neighbors
1104
- - k=3: dist_sq=4 (e.g., (2,0)) -> 4 neighbors
1105
- - k=4: dist_sq=5 (e.g., (2,1)) -> 8 neighbors
1106
- - k=5: dist_sq=8 (e.g., (2,2)) -> 4 neighbors
1107
- - k=6: dist_sq=9 (e.g., (3,0)) -> 4 neighbors
1108
- - k=7: dist_sq=10 (e.g., (3,1)) -> 8 neighbors
1109
- """
1110
- lattice = large_pbc_square_lattice
1111
- # Pick an arbitrary site, e.g., index 0.
1112
- site_idx = 0
1113
-
1114
- # Expected coordination numbers for the first few shells.
1115
- expected_coordinations = {1: 4, 2: 4, 3: 4, 4: 8, 5: 4, 6: 4, 7: 8}
1116
-
1117
- for k, expected_count in expected_coordinations.items():
1118
- neighbors = lattice.get_neighbors(site_idx, k=k)
1119
- assert (
1120
- len(neighbors) == expected_count
1121
- ), f"Failed for k={k}. Expected {expected_count}, got {len(neighbors)}"
1122
-
1123
- def test_requesting_k_beyond_max_possible_shell(self, large_pbc_square_lattice):
1124
- """
1125
- Tests that requesting a neighbor shell 'k' that is larger than any
1126
- possible shell in the finite lattice returns an empty list, and does
1127
- not raise an error.
1128
- """
1129
- lattice = large_pbc_square_lattice
1130
- site_idx = 0
1131
-
1132
- # 1. First, find out the maximum number of shells that *do* exist.
1133
- # We do this by calling _build_neighbors with a very large max_k.
1134
- # This is a bit of "white-box" testing but necessary to find the true max k.
1135
- lattice._build_neighbors(max_k=100)
1136
- max_k_found = len(lattice._neighbor_maps)
1137
-
1138
- # 2. Assert that the last valid shell is not empty.
1139
- last_shell_neighbors = lattice.get_neighbors(site_idx, k=max_k_found)
1140
- assert len(last_shell_neighbors) > 0
1141
-
1142
- # 3. Assert that requesting a shell just beyond the last valid one returns empty.
1143
- # This is the core of the test.
1144
- non_existent_shell_neighbors = lattice.get_neighbors(
1145
- site_idx, k=max_k_found + 1
1146
- )
1147
- assert non_existent_shell_neighbors == []
1148
-
1149
- @patch("matplotlib.pyplot.subplots")
1150
- def test_show_method_with_custom_bond_kwargs(
1151
- self, mock_subplots, simple_square_lattice
1152
- ):
1153
- """
1154
- Tests that .show() correctly uses the `bond_kwargs` parameter
1155
- to customize the appearance of neighbor bonds.
1156
- """
1157
- # Arrange:
1158
- # 1. Set up mock Figure and Axes objects, similar to other show() tests.
1159
- mock_fig = matplotlib.figure.Figure()
1160
- mock_ax = matplotlib.axes.Axes(mock_fig, [0.0, 0.0, 1.0, 1.0])
1161
- mock_subplots.return_value = (mock_fig, mock_ax)
1162
-
1163
- # 2. Define our custom styles and the expected final styles.
1164
- lattice = simple_square_lattice
1165
- custom_bond_kwargs = {"color": "red", "linestyle": ":", "linewidth": 2}
1166
-
1167
- # The final dictionary should contain the defaults updated by our custom arguments.
1168
- expected_plot_kwargs = {
1169
- "color": "red", # Overridden
1170
- "linestyle": ":", # Overridden
1171
- "linewidth": 2, # A new key
1172
- "alpha": 0.6, # From default
1173
- "zorder": 1, # From default
1174
- }
1175
-
1176
- # 3. We specifically mock the `plot` method on our mock `ax` object.
1177
- with patch.object(mock_ax, "plot") as mock_plot_method:
1178
- # Act:
1179
- # Call the show method with our custom bond styles.
1180
- lattice.show(show_bonds_k=1, bond_kwargs=custom_bond_kwargs)
1181
-
1182
- # Assert:
1183
- # Check that the plot method was called. For a 2x2 square, there are 4 NN bonds.
1184
- assert mock_plot_method.call_count == 4
1185
-
1186
- # Get the keyword arguments from the very first call to plot().
1187
- # Note: call_args is a tuple (positional_args, keyword_args). We need the second element.
1188
- actual_kwargs = mock_plot_method.call_args[1]
1189
-
1190
- # Verify that the keyword arguments used for plotting match our expectations.
1191
- assert actual_kwargs == expected_plot_kwargs
1192
-
1193
- def test_mixed_boundary_conditions(self):
1194
- """
1195
- Tests neighbor finding with mixed PBC (periodic in x, open in y).
1196
- This verifies that the neighbor finding logic correctly handles
1197
- anisotropy in periodic boundary conditions and returns sorted indices.
1198
- """
1199
- # Arrange: Create a 3x3 square lattice, periodic in x, open in y.
1200
- lattice = SquareLattice(size=(3, 3), pbc=(True, False))
1201
-
1202
- # We will test a site on the corner of the open boundary: (0, 0)
1203
- corner_site_idx = lattice.get_index((0, 0, 0))
1204
-
1205
- # --- Test corner site (0, 0, 0), which is index 0 ---
1206
- # Act
1207
- corner_neighbors = lattice.get_neighbors(corner_site_idx, k=1)
1208
-
1209
- # Assert: The expected neighbors are (1,0,0), (2,0,0) [periodic], and (0,1,0)
1210
- # We get their indices and sort them to create the expected output.
1211
- expected_indices = sorted(
1212
- [
1213
- lattice.get_index((1, 0, 0)), # Right neighbor
1214
- lattice.get_index((2, 0, 0)), # "Left" neighbor (wraps around)
1215
- lattice.get_index((0, 1, 0)), # "Up" neighbor
1216
- ]
1217
- )
1218
-
1219
- # The list returned by get_neighbors should be identical to our sorted list.
1220
- assert (
1221
- corner_neighbors == expected_indices
1222
- ), "Failed for corner site with mixed BC."
1223
-
1224
- # --- Test middle site on the edge (1, 0, 0), which is index 1 ---
1225
- edge_site_idx = lattice.get_index((1, 0, 0))
1226
-
1227
- # Act
1228
- edge_neighbors = lattice.get_neighbors(edge_site_idx, k=1)
1229
-
1230
- # Assert
1231
- expected_edge_indices = sorted(
1232
- [
1233
- lattice.get_index((0, 0, 0)), # Left neighbor
1234
- lattice.get_index((2, 0, 0)), # Right neighbor
1235
- lattice.get_index((1, 1, 0)), # "Up" neighbor
1236
- ]
1237
- )
1238
- assert (
1239
- edge_neighbors == expected_edge_indices
1240
- ), "Failed for edge site with mixed BC."
1241
-
1242
-
1243
- class TestAllTILattices:
1244
- """
1245
- A parameterized test class to verify the basic properties and coordination
1246
- numbers for all implemented TILattice subclasses. This avoids code duplication.
1247
- """
1248
-
1249
- # --- Test data in a structured and readable format ---
1250
- # Format:
1251
- # (
1252
- # LatticeClass, # The lattice class to test
1253
- # {"size": ..., ...}, # Arguments for the constructor
1254
- # expected_num_sites, # Expected total number of sites
1255
- # expected_num_basis, # Expected number of sites in the basis
1256
- # {site_repr: count} # Dict of {representative_site: neighbor_count}
1257
- # )
1258
- # For `site_repr`:
1259
- # - For simple lattices (basis=1), it's the integer index of the site.
1260
- # - For composite lattices (basis>1), it's the *basis index* to test.
1261
- lattice_test_cases = [
1262
- # 1D Lattices
1263
- (ChainLattice, {"size": (5,), "pbc": True}, 5, 1, {0: 2, 2: 2}),
1264
- (ChainLattice, {"size": (5,), "pbc": False}, 5, 1, {0: 1, 2: 2}),
1265
- (DimerizedChainLattice, {"size": (3,), "pbc": True}, 6, 2, {0: 2, 1: 2}),
1266
- # 2D Lattices
1267
- (
1268
- RectangularLattice,
1269
- {"size": (3, 4), "pbc": False},
1270
- 12,
1271
- 1,
1272
- {5: 4, 0: 2, 4: 3},
1273
- ), # center, corner, edge
1274
- (HoneycombLattice, {"size": (2, 2), "pbc": True}, 8, 2, {0: 3, 1: 3}),
1275
- (TriangularLattice, {"size": (3, 3), "pbc": True}, 9, 1, {0: 6}),
1276
- (CheckerboardLattice, {"size": (2, 2), "pbc": True}, 8, 2, {0: 4, 1: 4}),
1277
- (KagomeLattice, {"size": (2, 2), "pbc": True}, 12, 3, {0: 4, 1: 4, 2: 4}),
1278
- (LiebLattice, {"size": (2, 2), "pbc": True}, 12, 3, {0: 4, 1: 2, 2: 2}),
1279
- # 3D Lattices
1280
- (CubicLattice, {"size": (3, 3, 3), "pbc": True}, 27, 1, {0: 6, 13: 6}),
1281
- ]
1282
-
1283
- @pytest.mark.parametrize(
1284
- "LatticeClass, init_args, num_sites, num_basis, coordination_numbers",
1285
- lattice_test_cases,
1286
- )
1287
- def test_lattice_properties_and_coordination(
1288
- self,
1289
- LatticeClass,
1290
- init_args,
1291
- num_sites,
1292
- num_basis,
1293
- coordination_numbers,
1294
- ):
1295
- """
1296
- A single, parameterized test to validate all TILattice types.
1297
- """
1298
- # --- Arrange ---
1299
- # Create the lattice instance dynamically from the test data.
1300
- lattice = LatticeClass(**init_args)
1301
-
1302
- # --- Assert: Basic properties ---
1303
- assert lattice.num_sites == num_sites
1304
- assert lattice.num_basis == num_basis
1305
- assert lattice.dimensionality == len(init_args["size"])
1306
-
1307
- # --- Assert: Coordination numbers (nearest neighbors, k=1) ---
1308
- for site_repr, expected_count in coordination_numbers.items():
1309
- # This logic correctly gets the site index to test,
1310
- # whether it's a simple or composite lattice.
1311
- if lattice.num_basis > 1:
1312
- # For composite lattices, site_repr is the basis_index.
1313
- # We find the index of this basis site in the first unit cell.
1314
- uc_coord = (0,) * lattice.dimensionality
1315
- test_site_idx = lattice.get_index(uc_coord + (site_repr,))
1316
- else:
1317
- # For simple lattices, site_repr is the absolute site index.
1318
- test_site_idx = site_repr
1319
-
1320
- neighbors = lattice.get_neighbors(test_site_idx, k=1)
1321
- assert len(neighbors) == expected_count
1322
- if isinstance(LatticeClass, ChainLattice) and not init_args.get("pbc"):
1323
- if test_site_idx == 0:
1324
- assert 1 in neighbors
1325
-
1326
-
1327
- class TestCustomizeLatticeDynamic:
1328
- """Tests the dynamic modification capabilities of CustomizeLattice."""
1329
-
1330
- @pytest.fixture
1331
- def initial_lattice(self) -> CustomizeLattice:
1332
- """Provides a basic 3-site lattice for modification tests."""
1333
- return CustomizeLattice(
1334
- dimensionality=2,
1335
- identifiers=["A", "B", "C"],
1336
- coordinates=[[0, 0], [1, 0], [0, 1]],
1337
- )
1338
-
1339
- def test_from_lattice_conversion(self):
1340
- """Tests creating a CustomizeLattice from a TILattice."""
1341
- # Arrange
1342
- sq_lattice = SquareLattice(size=(2, 2), pbc=False)
1343
-
1344
- # Act
1345
- custom_lattice = CustomizeLattice.from_lattice(sq_lattice)
1346
-
1347
- # Assert
1348
- assert isinstance(custom_lattice, CustomizeLattice)
1349
- assert custom_lattice.num_sites == sq_lattice.num_sites
1350
- assert custom_lattice.dimensionality == sq_lattice.dimensionality
1351
- # Verify a site to be sure
1352
- np.testing.assert_array_equal(
1353
- custom_lattice.get_coordinates(3), sq_lattice.get_coordinates(3)
1354
- )
1355
- assert custom_lattice.get_identifier(3) == sq_lattice.get_identifier(3)
1356
-
1357
- def test_add_sites_successfully(self, initial_lattice):
1358
- """Tests adding new, valid sites to the lattice."""
1359
- # Arrange
1360
- lat = initial_lattice
1361
- assert lat.num_sites == 3
1362
-
1363
- # Act
1364
- lat.add_sites(identifiers=["D", "E"], coordinates=[[1, 1], [2, 2]])
1365
-
1366
- # Assert
1367
- assert lat.num_sites == 5
1368
- assert lat.get_identifier(4) == "E"
1369
- np.testing.assert_array_equal(lat.get_coordinates(3), np.array([1, 1]))
1370
- assert "E" in lat._ident_to_idx
1371
-
1372
- def test_remove_sites_successfully(self, initial_lattice):
1373
- """Tests removing existing sites from the lattice."""
1374
- # Arrange
1375
- lat = initial_lattice
1376
- assert lat.num_sites == 3
1377
-
1378
- # Act
1379
- lat.remove_sites(identifiers=["A", "C"])
1380
-
1381
- # Assert
1382
- assert lat.num_sites == 1
1383
- assert lat.get_identifier(0) == "B" # Site 'B' is now at index 0
1384
- assert "A" not in lat._ident_to_idx
1385
- np.testing.assert_array_equal(lat.get_coordinates(0), np.array([1, 0]))
1386
-
1387
- def test_add_duplicate_identifier_raises_error(self, initial_lattice):
1388
- """Tests that adding a site with an existing identifier fails."""
1389
- with pytest.raises(ValueError, match="Duplicate identifiers found"):
1390
- initial_lattice.add_sites(identifiers=["A"], coordinates=[[9, 9]])
1391
-
1392
- def test_remove_nonexistent_identifier_raises_error(self, initial_lattice):
1393
- """Tests that removing a non-existent site fails."""
1394
- with pytest.raises(ValueError, match="Non-existent identifiers provided"):
1395
- initial_lattice.remove_sites(identifiers=["Z"])
1396
-
1397
- def test_modification_clears_neighbor_cache(self, initial_lattice):
1398
- """
1399
- Tests that add_sites and remove_sites correctly invalidate the
1400
- pre-computed neighbor map.
1401
- """
1402
- # Arrange: Pre-compute neighbors on the initial lattice
1403
- initial_lattice._build_neighbors(max_k=1)
1404
- assert 0 in initial_lattice._neighbor_maps[1] # Check that neighbors exist
1405
-
1406
- # Act 1: Add a site
1407
- initial_lattice.add_sites(identifiers=["D"], coordinates=[[5, 5]])
1408
-
1409
- # Assert 1: The neighbor map should now be empty
1410
- assert not initial_lattice._neighbor_maps
1411
-
1412
- # Arrange 2: Re-compute neighbors and then remove a site
1413
- initial_lattice._build_neighbors(max_k=1)
1414
- assert 0 in initial_lattice._neighbor_maps[1]
1415
-
1416
- # Act 2: Remove a site
1417
- initial_lattice.remove_sites(identifiers=["A"])
1418
-
1419
- # Assert 2: The neighbor map should be empty again
1420
- assert not initial_lattice._neighbor_maps
1421
-
1422
- def test_modification_clears_distance_matrix_cache(self, initial_lattice):
1423
- """
1424
- Tests that add_sites and remove_sites correctly invalidate the
1425
- cached distance matrix and that the recomputed matrix is correct.
1426
- """
1427
- # Arrange 1: Compute, cache, and perform a meaningful check on the original matrix.
1428
- lat = initial_lattice
1429
- original_matrix = lat.distance_matrix
1430
- assert lat._distance_matrix is not None
1431
- assert original_matrix.shape == (3, 3)
1432
- # Meaningful check: distance from 'A'(idx 0) to 'B'(idx 1) should be 1.0
1433
- np.testing.assert_allclose(original_matrix[0, 1], 1.0)
1434
-
1435
- # Act 1: Add a site. This should invalidate the cache.
1436
- lat.add_sites(identifiers=["D"], coordinates=[[1, 1]])
1437
-
1438
- # Assert 1: Check cache is cleared and the new matrix is correct.
1439
- assert lat._distance_matrix is None # Verify cache invalidation
1440
- new_matrix_added = lat.distance_matrix
1441
- assert new_matrix_added.shape == (4, 4)
1442
- # Meaningful check: distance from 'B'(idx 1) to new site 'D'(idx 3) should be 1.0
1443
- # Coords: B=[1,0], D=[1,1]
1444
- np.testing.assert_allclose(new_matrix_added[1, 3], 1.0)
1445
-
1446
- # Act 2: Remove a site. This should also invalidate the cache.
1447
- lat.remove_sites(identifiers=["A"])
1448
-
1449
- # Assert 2: Check cache is cleared again and the final matrix is correct.
1450
- assert lat._distance_matrix is None # Verify cache invalidation
1451
- final_matrix = lat.distance_matrix
1452
- assert final_matrix.shape == (3, 3) # Now has 3 sites again
1453
- # Meaningful check: After removing 'A', the sites are B, C, D.
1454
- # 'B' is now at index 0 (coords [1,0])
1455
- # 'C' is now at index 1 (coords [0,1])
1456
- # 'D' is now at index 2 (coords [1,1])
1457
- # Distance from new 'B' (idx 0) to new 'D' (idx 2) should be 1.0
1458
- np.testing.assert_allclose(final_matrix[0, 2], 1.0)
1459
-
1460
- def test_neighbor_finding_returns_sorted_list(self, simple_square_lattice):
1461
- """
1462
- Ensures that the list of neighbors returned by get_neighbors is always sorted.
1463
- This provides a stricter check than set-based comparisons.
1464
- """
1465
- # Arrange
1466
- lattice = simple_square_lattice
1467
-
1468
- # Act
1469
- # Get neighbors for the central site (index 1 in a 2x2 grid)
1470
- # Expected neighbors are 0, 3.
1471
- neighbors = lattice.get_neighbors(1, k=1)
1472
-
1473
- # Assert
1474
- # We compare directly against a pre-sorted list, not a set.
1475
- # This will fail if the implementation returns [3, 0] instead of [0, 3].
1476
- assert neighbors == [
1477
- 0,
1478
- 3,
1479
- ], "The neighbor list should be sorted in ascending order."
1480
-
1481
-
1482
- class TestDistanceMatrix:
1483
-
1484
- # This is the upgraded, parameterized test.
1485
- @pytest.mark.parametrize(
1486
- # We define test scenarios as tuples:
1487
- # (build_k, check_site_identifier, expected_dist_sq)
1488
- # build_k: The number of neighbor shells to pre-build.
1489
- # check_site_identifier: The identifier of a site whose distance from the origin we will check.
1490
- # expected_dist_sq: The expected squared distance to that site.
1491
- "build_k, check_site_identifier, expected_dist_sq",
1492
- [
1493
- # Scenario 1: The most common case. Build only NN (k=1), but check a NNN (k=2) distance.
1494
- # A buggy cache would fail this.
1495
- (1, (1, 1, 0), 2.0),
1496
- # Scenario 2: Build up to k=2, but check a k=3 distance.
1497
- (2, (2, 0, 0), 4.0),
1498
- # Scenario 3: Build up to k=3, but check a k=4 distance.
1499
- (3, (2, 1, 0), 5.0),
1500
- # Scenario 4: A more complex, higher-order neighbor.
1501
- (5, (3, 1, 0), 10.0),
1502
- ],
1503
- )
1504
- def test_tilattice_full_pbc_distance_matrix_is_correct_regardless_of_build_k(
1505
- self, build_k, check_site_identifier, expected_dist_sq
1506
- ):
1507
- """
1508
- Tests that the distance matrix for a fully periodic TILattice is
1509
- always fully correct, no matter how many neighbor shells were pre-calculated.
1510
- This is a high-strength test designed to catch subtle caching bugs where
1511
- the cached matrix might only contain partial information.
1512
- """
1513
- # Arrange
1514
- # Using a larger, non-square lattice to avoid accidental symmetries
1515
- lat = SquareLattice(size=(7, 9), pbc=True)
1516
-
1517
- # Act
1518
- # Step 1: Pre-build neighbors. This is where a faulty caching
1519
- # mechanism in the source code might be triggered.
1520
- lat._build_neighbors(max_k=build_k)
1521
-
1522
- # Step 2: Access the distance_matrix property. A correct implementation
1523
- # will return a fully valid matrix.
1524
- dist_matrix = lat.distance_matrix
1525
-
1526
- # Assert
1527
- # Find the indices for the sites we want to check.
1528
- origin_idx = lat.get_index((0, 0, 0))
1529
- check_site_idx = lat.get_index(check_site_identifier)
1530
-
1531
- # The core assertion: check the distance.
1532
- actual_dist_sq = dist_matrix[origin_idx, check_site_idx] ** 2
1533
-
1534
- error_message = (
1535
- f"Distance matrix failed when building k={build_k}. "
1536
- f"Checking distance to site {check_site_identifier} (expected sq={expected_dist_sq}) "
1537
- f"but got sq={actual_dist_sq} instead."
1538
- )
1539
-
1540
- np.testing.assert_allclose(
1541
- actual_dist_sq, expected_dist_sq, err_msg=error_message
1542
- )
1543
-
1544
- def test_tilattice_mixed_bc_distance_matrix_is_correct(self):
1545
- """
1546
- Tests that the distance matrix is correctly calculated for a TILattice
1547
- with mixed boundary conditions (e.g., periodic in x, open in y).
1548
- """
1549
- # Arrange
1550
- # pbc=(True, False) means periodic along x-axis, open along y-axis.
1551
- lat = SquareLattice(size=(5, 5), pbc=(True, False))
1552
-
1553
- # Pre-build neighbors to engage the caching logic.
1554
- lat._build_neighbors(max_k=2)
1555
- dist_matrix = lat.distance_matrix
1556
-
1557
- # Assert
1558
- origin_idx = lat.get_index((0, 0, 0))
1559
-
1560
- # 1. Test a distance affected by the periodic boundary (x-direction)
1561
- # The distance between (0,0) and (4,0) should be 1.0 due to PBC wrap-around.
1562
- pbc_neighbor_idx = lat.get_index((4, 0, 0))
1563
- np.testing.assert_allclose(dist_matrix[origin_idx, pbc_neighbor_idx], 1.0)
1564
-
1565
- # 2. Test a distance affected by the open boundary (y-direction)
1566
- # The distance between (0,0) and (0,4) should be 4.0 as there's no wrap-around.
1567
- obc_neighbor_idx = lat.get_index((0, 4, 0))
1568
- np.testing.assert_allclose(dist_matrix[origin_idx, obc_neighbor_idx], 4.0)
1569
-
1570
- # 3. Test a general, off-axis point.
1571
- # Distance from (0,0) to (3,3) with x-pbc. The x-distance is min(3, 5-3=2) = 2.
1572
- # The y-distance is 3. So total distance is sqrt(2^2 + 3^2) = sqrt(13).
1573
- general_neighbor_idx = lat.get_index((3, 3, 0))
1574
- np.testing.assert_allclose(
1575
- dist_matrix[origin_idx, general_neighbor_idx], np.sqrt(13)
1576
- )
1577
-
1578
- # --- This list and the following test are now at the correct indentation level ---
1579
- lattice_instances_for_invariant_test = [
1580
- SquareLattice(size=(4, 4), pbc=True),
1581
- SquareLattice(size=(4, 3), pbc=(True, False)), # Mixed BC, non-square
1582
- HoneycombLattice(size=(3, 3), pbc=True),
1583
- TriangularLattice(size=(4, 4), pbc=False),
1584
- CustomizeLattice(
1585
- dimensionality=2,
1586
- identifiers=list(range(4)),
1587
- coordinates=[[0, 0], [1, 1], [0, 1], [1, 0]],
1588
- ),
1589
- ]
1590
-
1591
- @pytest.mark.parametrize("lattice", lattice_instances_for_invariant_test)
1592
- def test_distance_matrix_invariants_for_all_lattice_types(self, lattice):
1593
- """
1594
- Tests that the distance matrix for any lattice type adheres to
1595
- fundamental mathematical properties (invariants): symmetry, zero diagonal,
1596
- and positive off-diagonal elements.
1597
- """
1598
- # Arrange
1599
- n = lattice.num_sites
1600
- if n < 2:
1601
- pytest.skip("Invariant test requires at least 2 sites.")
1602
-
1603
- # Act
1604
- # We call the property directly, without building neighbors first,
1605
- # to test the on-demand computation path.
1606
- matrix = lattice.distance_matrix
1607
-
1608
- # Assert
1609
- # 1. Symmetry: The matrix must be equal to its transpose.
1610
- np.testing.assert_allclose(
1611
- matrix,
1612
- matrix.T,
1613
- err_msg=f"Distance matrix for {type(lattice).__name__} is not symmetric.",
1614
- )
1615
-
1616
- # 2. Zero Diagonal: All diagonal elements must be zero.
1617
- np.testing.assert_allclose(
1618
- np.diag(matrix),
1619
- np.zeros(n),
1620
- err_msg=f"Diagonal of distance matrix for {type(lattice).__name__} is not zero.",
1621
- )
1622
-
1623
- # 3. Positive Off-diagonal: All non-diagonal elements must be > 0.
1624
- # We create a boolean mask for the off-diagonal elements.
1625
- off_diagonal_mask = ~np.eye(n, dtype=bool)
1626
- assert np.all(
1627
- matrix[off_diagonal_mask] > 1e-9
1628
- ), f"Found non-positive off-diagonal elements in distance matrix for {type(lattice).__name__}."
1629
-
1630
-
1631
- # @pytest.mark.slow
1632
- # class TestPerformance:
1633
- # def test_pbc_implementation_is_not_significantly_slower_than_obc(self):
1634
- # """
1635
- # A performance regression test.
1636
- # It ensures that the specialized implementation for fully periodic
1637
- # lattices (pbc=True) is not substantially slower than the general
1638
- # implementation used for open boundaries (pbc=False).
1639
- # This test will FAIL with the current code, exposing the performance bug.
1640
- # """
1641
- # # Arrange: Use a large-enough lattice to make performance differences apparent
1642
- # size = (30, 30)
1643
- # k = 1
1644
-
1645
- # # Act 1: Measure the execution time of the general (OBC) implementation
1646
- # start_time_obc = time.time()
1647
- # _ = SquareLattice(size=size, pbc=False, precompute_neighbors=k)
1648
- # duration_obc = time.time() - start_time_obc
1649
-
1650
- # # Act 2: Measure the execution time of the specialized (PBC) implementation
1651
- # start_time_pbc = time.time()
1652
- # _ = SquareLattice(size=size, pbc=True, precompute_neighbors=k)
1653
- # duration_pbc = time.time() - start_time_pbc
1654
-
1655
- # print(
1656
- # f"\n[Performance] OBC ({size}): {duration_obc:.4f}s | PBC ({size}): {duration_pbc:.4f}s"
1657
- # )
1658
-
1659
- # # Assert: The PBC implementation should not be drastically slower.
1660
- # # We allow it to be up to 3 times slower to account for minor overheads,
1661
- # # but this will catch the current 10x+ regression.
1662
- # # THIS ASSERTION WILL FAIL with the current buggy code.
1663
- # assert duration_pbc < duration_obc * 5, (
1664
- # "The specialized PBC implementation is significantly slower "
1665
- # "than the general-purpose implementation."
1666
- # )