tensorcircuit-nightly 1.3.0.dev20250728__py3-none-any.whl → 1.4.0.dev20251103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tensorcircuit-nightly might be problematic. Click here for more details.
- tensorcircuit/__init__.py +5 -1
- tensorcircuit/abstractcircuit.py +4 -0
- tensorcircuit/analogcircuit.py +413 -0
- tensorcircuit/applications/layers.py +1 -1
- tensorcircuit/applications/van.py +1 -1
- tensorcircuit/backends/abstract_backend.py +312 -5
- tensorcircuit/backends/cupy_backend.py +3 -1
- tensorcircuit/backends/jax_backend.py +92 -3
- tensorcircuit/backends/jax_ops.py +108 -0
- tensorcircuit/backends/numpy_backend.py +49 -3
- tensorcircuit/backends/pytorch_backend.py +92 -3
- tensorcircuit/backends/tensorflow_backend.py +102 -3
- tensorcircuit/basecircuit.py +123 -82
- tensorcircuit/circuit.py +67 -57
- tensorcircuit/cloud/local.py +1 -1
- tensorcircuit/cloud/quafu_provider.py +1 -1
- tensorcircuit/cloud/tencent.py +1 -1
- tensorcircuit/compiler/simple_compiler.py +2 -2
- tensorcircuit/cons.py +1 -0
- tensorcircuit/densitymatrix.py +16 -11
- tensorcircuit/experimental.py +7 -152
- tensorcircuit/fgs.py +5 -6
- tensorcircuit/gates.py +66 -22
- tensorcircuit/keras.py +3 -3
- tensorcircuit/mpscircuit.py +109 -61
- tensorcircuit/quantum.py +697 -133
- tensorcircuit/quditcircuit.py +733 -0
- tensorcircuit/quditgates.py +618 -0
- tensorcircuit/results/counts.py +45 -31
- tensorcircuit/shadows.py +1 -1
- tensorcircuit/simplify.py +3 -1
- tensorcircuit/stabilizercircuit.py +4 -2
- tensorcircuit/templates/blocks.py +2 -2
- tensorcircuit/templates/hamiltonians.py +29 -8
- tensorcircuit/templates/lattice.py +676 -335
- tensorcircuit/timeevol.py +896 -0
- {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/METADATA +50 -25
- tensorcircuit_nightly-1.4.0.dev20251103.dist-info/RECORD +96 -0
- {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/top_level.txt +0 -1
- tensorcircuit_nightly-1.3.0.dev20250728.dist-info/RECORD +0 -122
- tests/__init__.py +0 -0
- tests/conftest.py +0 -67
- tests/test_backends.py +0 -1035
- tests/test_calibrating.py +0 -149
- tests/test_channels.py +0 -409
- tests/test_circuit.py +0 -1713
- tests/test_cloud.py +0 -219
- tests/test_compiler.py +0 -147
- tests/test_dmcircuit.py +0 -555
- tests/test_ensemble.py +0 -72
- tests/test_fgs.py +0 -318
- tests/test_gates.py +0 -156
- tests/test_hamiltonians.py +0 -159
- tests/test_interfaces.py +0 -557
- tests/test_keras.py +0 -160
- tests/test_lattice.py +0 -1666
- tests/test_miscs.py +0 -334
- tests/test_mpscircuit.py +0 -341
- tests/test_noisemodel.py +0 -156
- tests/test_qaoa.py +0 -86
- tests/test_qem.py +0 -152
- tests/test_quantum.py +0 -549
- tests/test_quantum_attr.py +0 -42
- tests/test_results.py +0 -379
- tests/test_shadows.py +0 -160
- tests/test_simplify.py +0 -46
- tests/test_stabilizer.py +0 -226
- tests/test_templates.py +0 -218
- tests/test_torchnn.py +0 -99
- tests/test_van.py +0 -102
- {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/WHEEL +0 -0
- {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/licenses/LICENSE +0 -0
tests/test_lattice.py
DELETED
|
@@ -1,1666 +0,0 @@
|
|
|
1
|
-
from unittest.mock import patch
|
|
2
|
-
import logging
|
|
3
|
-
|
|
4
|
-
# import time
|
|
5
|
-
|
|
6
|
-
import matplotlib
|
|
7
|
-
|
|
8
|
-
matplotlib.use("Agg")
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
import pytest
|
|
12
|
-
import numpy as np
|
|
13
|
-
|
|
14
|
-
from tensorcircuit.templates.lattice import (
|
|
15
|
-
ChainLattice,
|
|
16
|
-
CheckerboardLattice,
|
|
17
|
-
CubicLattice,
|
|
18
|
-
CustomizeLattice,
|
|
19
|
-
DimerizedChainLattice,
|
|
20
|
-
HoneycombLattice,
|
|
21
|
-
KagomeLattice,
|
|
22
|
-
LiebLattice,
|
|
23
|
-
RectangularLattice,
|
|
24
|
-
SquareLattice,
|
|
25
|
-
TriangularLattice,
|
|
26
|
-
)
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
@pytest.fixture
|
|
30
|
-
def simple_square_lattice() -> CustomizeLattice:
|
|
31
|
-
"""
|
|
32
|
-
Provides a simple 2x2 square CustomizeLattice instance for neighbor tests.
|
|
33
|
-
The sites are indexed as follows:
|
|
34
|
-
2--3
|
|
35
|
-
| |
|
|
36
|
-
0--1
|
|
37
|
-
"""
|
|
38
|
-
coords = [[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]
|
|
39
|
-
ids = list(range(len(coords)))
|
|
40
|
-
lattice = CustomizeLattice(dimensionality=2, identifiers=ids, coordinates=coords)
|
|
41
|
-
# Pre-calculate neighbors up to the 2nd shell for use in tests.
|
|
42
|
-
lattice._build_neighbors(max_k=2)
|
|
43
|
-
return lattice
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
@pytest.fixture
|
|
47
|
-
def kagome_lattice_fragment() -> CustomizeLattice:
|
|
48
|
-
"""
|
|
49
|
-
Pytest fixture to provide a standard CustomizeLattice instance.
|
|
50
|
-
This represents the Kagome fragment from the project requirements,
|
|
51
|
-
making it a reusable object for multiple tests.
|
|
52
|
-
"""
|
|
53
|
-
kag_coords = [
|
|
54
|
-
[0.0, 0.0],
|
|
55
|
-
[1.0, 0.0],
|
|
56
|
-
[0.5, np.sqrt(3) / 2], # Triangle 1
|
|
57
|
-
[2, 0],
|
|
58
|
-
[1.5, np.sqrt(3) / 2], # Triangle 2 (shifted basis)
|
|
59
|
-
[1.0, np.sqrt(3)], # Top site
|
|
60
|
-
]
|
|
61
|
-
kag_ids = list(range(len(kag_coords)))
|
|
62
|
-
return CustomizeLattice(
|
|
63
|
-
dimensionality=2, identifiers=kag_ids, coordinates=kag_coords
|
|
64
|
-
)
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
class TestCustomizeLattice:
|
|
68
|
-
"""
|
|
69
|
-
A test class to group all tests related to the CustomizeLattice.
|
|
70
|
-
This helps in organizing the test suite.
|
|
71
|
-
"""
|
|
72
|
-
|
|
73
|
-
def test_initialization_and_properties(self, kagome_lattice_fragment):
|
|
74
|
-
"""
|
|
75
|
-
Test case for successful initialization and verification of basic properties.
|
|
76
|
-
This test function receives the 'kagome_lattice_fragment' fixture as an argument.
|
|
77
|
-
"""
|
|
78
|
-
# Arrange: The fixture has already prepared the 'lattice' object for us.
|
|
79
|
-
lattice = kagome_lattice_fragment
|
|
80
|
-
|
|
81
|
-
# Assert: Check if the object's properties match our expectations.
|
|
82
|
-
assert lattice.dimensionality == 2
|
|
83
|
-
assert lattice.num_sites == 6
|
|
84
|
-
assert len(lattice) == 6 # This also tests the __len__ dunder method
|
|
85
|
-
|
|
86
|
-
# Verify that coordinates are correctly stored as numpy arrays.
|
|
87
|
-
# It's important to use np.testing.assert_array_equal for numpy array comparison.
|
|
88
|
-
expected_coord = np.array([0.5, np.sqrt(3) / 2])
|
|
89
|
-
np.testing.assert_array_equal(lattice.get_coordinates(2), expected_coord)
|
|
90
|
-
|
|
91
|
-
# Verify that the mapping between identifiers and indices is correct.
|
|
92
|
-
assert lattice.get_identifier(4) == 4
|
|
93
|
-
assert lattice.get_index(4) == 4
|
|
94
|
-
|
|
95
|
-
def test_input_validation_mismatched_lengths(self):
|
|
96
|
-
"""
|
|
97
|
-
Tests that a ValueError is raised if identifiers and coordinates
|
|
98
|
-
lists have mismatched lengths.
|
|
99
|
-
"""
|
|
100
|
-
# Arrange: Prepare invalid inputs.
|
|
101
|
-
coords = [[0.0, 0.0], [1.0, 0.0]] # 2 coordinates
|
|
102
|
-
ids = [0, 1, 2] # 3 identifiers
|
|
103
|
-
|
|
104
|
-
# Act & Assert: Use pytest.raises as a context manager to ensure
|
|
105
|
-
# the specified exception is raised within the 'with' block.
|
|
106
|
-
with pytest.raises(
|
|
107
|
-
ValueError,
|
|
108
|
-
match="Identifiers and coordinates lists must have the same length.",
|
|
109
|
-
):
|
|
110
|
-
CustomizeLattice(dimensionality=2, identifiers=ids, coordinates=coords)
|
|
111
|
-
|
|
112
|
-
def test_input_validation_wrong_dimension(self):
|
|
113
|
-
"""
|
|
114
|
-
Tests that a ValueError is raised if a coordinate's dimension
|
|
115
|
-
does not match the lattice's specified dimensionality.
|
|
116
|
-
"""
|
|
117
|
-
# Arrange: Prepare coordinates with mixed dimensions for a 2D lattice.
|
|
118
|
-
coords_wrong_dim = [[0.0, 0.0], [1.0, 0.0, 0.0]] # A mix of 2D and 3D
|
|
119
|
-
ids_ok = [0, 1]
|
|
120
|
-
|
|
121
|
-
# Act & Assert: Check for the specific error message. The 'r' before the string
|
|
122
|
-
# indicates a raw string, which is good practice for regex patterns.
|
|
123
|
-
with pytest.raises(
|
|
124
|
-
ValueError, match=r"Coordinate at index 1 has shape \(3,\), expected \(2,\)"
|
|
125
|
-
):
|
|
126
|
-
CustomizeLattice(
|
|
127
|
-
dimensionality=2, identifiers=ids_ok, coordinates=coords_wrong_dim
|
|
128
|
-
)
|
|
129
|
-
|
|
130
|
-
def test_neighbor_finding(self, simple_square_lattice):
|
|
131
|
-
"""
|
|
132
|
-
Tests the k-th nearest neighbor finding functionality (_build_neighbors
|
|
133
|
-
and get_neighbors).
|
|
134
|
-
"""
|
|
135
|
-
# Arrange: The fixture provides the lattice with pre-built neighbors.
|
|
136
|
-
lattice = simple_square_lattice
|
|
137
|
-
|
|
138
|
-
# --- Assertions for k=1 (Nearest Neighbors) ---
|
|
139
|
-
# We use set() for comparison to ignore the order of neighbors.
|
|
140
|
-
assert set(lattice.get_neighbors(0, k=1)) == {1, 2}
|
|
141
|
-
assert set(lattice.get_neighbors(1, k=1)) == {0, 3}
|
|
142
|
-
assert set(lattice.get_neighbors(2, k=1)) == {0, 3}
|
|
143
|
-
assert set(lattice.get_neighbors(3, k=1)) == {1, 2}
|
|
144
|
-
|
|
145
|
-
# --- Assertions for k=2 (Next-Nearest Neighbors) ---
|
|
146
|
-
# These should be the diagonal sites.
|
|
147
|
-
assert set(lattice.get_neighbors(0, k=2)) == {3}
|
|
148
|
-
assert set(lattice.get_neighbors(1, k=2)) == {2}
|
|
149
|
-
|
|
150
|
-
def test_neighbor_pairs(self, simple_square_lattice):
|
|
151
|
-
"""
|
|
152
|
-
Tests the retrieval of unique neighbor pairs (bonds) using
|
|
153
|
-
get_neighbor_pairs.
|
|
154
|
-
"""
|
|
155
|
-
# Arrange: Use the same fixture.
|
|
156
|
-
lattice = simple_square_lattice
|
|
157
|
-
|
|
158
|
-
# --- Test for k=1 (Nearest Neighbor bonds) ---
|
|
159
|
-
# Act: Get unique nearest neighbor pairs.
|
|
160
|
-
nn_pairs = lattice.get_neighbor_pairs(k=1, unique=True)
|
|
161
|
-
|
|
162
|
-
# Assert: The set of pairs should match the expected bonds.
|
|
163
|
-
# We convert the list of pairs to a set of tuples for order-independent comparison.
|
|
164
|
-
expected_nn_pairs = {(0, 1), (0, 2), (1, 3), (2, 3)}
|
|
165
|
-
assert set(map(tuple, nn_pairs)) == expected_nn_pairs
|
|
166
|
-
|
|
167
|
-
# --- Test for k=2 (Next-Nearest Neighbor bonds) ---
|
|
168
|
-
# Act: Get unique next-nearest neighbor pairs.
|
|
169
|
-
nnn_pairs = lattice.get_neighbor_pairs(k=2, unique=True)
|
|
170
|
-
|
|
171
|
-
# Assert:
|
|
172
|
-
expected_nnn_pairs = {(0, 3), (1, 2)}
|
|
173
|
-
assert set(map(tuple, nnn_pairs)) == expected_nnn_pairs
|
|
174
|
-
|
|
175
|
-
def test_neighbor_pairs_non_unique(self, simple_square_lattice):
|
|
176
|
-
"""
|
|
177
|
-
Tests get_neighbor_pairs with unique=False to ensure all
|
|
178
|
-
directed pairs (bonds) are returned.
|
|
179
|
-
"""
|
|
180
|
-
# Arrange: Use the same 2x2 square lattice fixture.
|
|
181
|
-
# 2--3
|
|
182
|
-
# | |
|
|
183
|
-
# 0--1
|
|
184
|
-
lattice = simple_square_lattice
|
|
185
|
-
|
|
186
|
-
# Act: Get NON-unique nearest neighbor pairs.
|
|
187
|
-
nn_pairs = lattice.get_neighbor_pairs(k=1, unique=False)
|
|
188
|
-
|
|
189
|
-
# Assert:
|
|
190
|
-
# There are 4 bonds, so we expect 4 * 2 = 8 directed pairs.
|
|
191
|
-
assert len(nn_pairs) == 8
|
|
192
|
-
|
|
193
|
-
# Your source code sorts the output, so we can compare against a
|
|
194
|
-
# sorted list for a precise match.
|
|
195
|
-
expected_pairs = sorted(
|
|
196
|
-
[(0, 1), (1, 0), (0, 2), (2, 0), (1, 3), (3, 1), (2, 3), (3, 2)]
|
|
197
|
-
)
|
|
198
|
-
|
|
199
|
-
assert nn_pairs == expected_pairs
|
|
200
|
-
|
|
201
|
-
@patch("matplotlib.pyplot.show")
|
|
202
|
-
def test_show_method_runs_and_calls_plt_show(
|
|
203
|
-
self, mock_show, simple_square_lattice
|
|
204
|
-
):
|
|
205
|
-
"""
|
|
206
|
-
Smoke test for the .show() method.
|
|
207
|
-
It verifies that the method runs without raising an exception and that it
|
|
208
|
-
triggers a call to matplotlib's show() function.
|
|
209
|
-
We use @patch to "mock" the show function, preventing a plot window
|
|
210
|
-
from actually appearing during tests.
|
|
211
|
-
"""
|
|
212
|
-
# Arrange: Get the lattice instance from the fixture
|
|
213
|
-
lattice = simple_square_lattice
|
|
214
|
-
|
|
215
|
-
# Act: Call the .show() method.
|
|
216
|
-
# We wrap it in a try...except block to give a more specific error
|
|
217
|
-
# if the method fails for any reason.
|
|
218
|
-
try:
|
|
219
|
-
lattice.show()
|
|
220
|
-
except Exception as e:
|
|
221
|
-
pytest.fail(f".show() method raised an unexpected exception: {e}")
|
|
222
|
-
|
|
223
|
-
# Assert: Check that our mocked matplotlib.pyplot.show was called exactly once.
|
|
224
|
-
mock_show.assert_called_once()
|
|
225
|
-
|
|
226
|
-
def test_sites_iterator(self, simple_square_lattice):
|
|
227
|
-
"""
|
|
228
|
-
Tests the sites() iterator to ensure it yields all sites correctly.
|
|
229
|
-
"""
|
|
230
|
-
# Arrange
|
|
231
|
-
lattice = simple_square_lattice
|
|
232
|
-
expected_num_sites = 4
|
|
233
|
-
|
|
234
|
-
# Act
|
|
235
|
-
# The sites() method returns an iterator, we convert it to a list to check its length.
|
|
236
|
-
all_sites = list(lattice.sites())
|
|
237
|
-
|
|
238
|
-
# Assert
|
|
239
|
-
assert len(all_sites) == expected_num_sites
|
|
240
|
-
|
|
241
|
-
# For a more thorough check, verify the content of one of the yielded tuples.
|
|
242
|
-
# For the simple_square_lattice fixture, site 3 has identifier 3 and coords [1, 1].
|
|
243
|
-
idx, ident, coords = all_sites[3]
|
|
244
|
-
assert idx == 3
|
|
245
|
-
assert ident == 3
|
|
246
|
-
np.testing.assert_array_equal(coords, np.array([1, 1]))
|
|
247
|
-
|
|
248
|
-
def test_get_site_info_with_identifier(self, simple_square_lattice):
|
|
249
|
-
"""
|
|
250
|
-
Tests the get_site_info() method using a site identifier instead of an index.
|
|
251
|
-
This covers the 'else' branch of the type check in the method.
|
|
252
|
-
"""
|
|
253
|
-
# Arrange
|
|
254
|
-
lattice = simple_square_lattice
|
|
255
|
-
# In this fixture, the identifier for the site at index 2 is also the integer 2.
|
|
256
|
-
identifier_to_test = 2
|
|
257
|
-
expected_index = 2
|
|
258
|
-
expected_coords = np.array([0, 1])
|
|
259
|
-
|
|
260
|
-
# Act
|
|
261
|
-
idx, ident, coords = lattice.get_site_info(identifier_to_test)
|
|
262
|
-
|
|
263
|
-
# Assert
|
|
264
|
-
assert idx == expected_index
|
|
265
|
-
assert ident == identifier_to_test
|
|
266
|
-
np.testing.assert_array_equal(coords, expected_coords)
|
|
267
|
-
|
|
268
|
-
@patch("matplotlib.pyplot.show")
|
|
269
|
-
def test_show_method_with_labels(self, mock_show, simple_square_lattice):
|
|
270
|
-
"""
|
|
271
|
-
Tests that the .show() method runs without error when label-related
|
|
272
|
-
options are enabled. This covers the logic inside the
|
|
273
|
-
'if show_indices or show_identifiers:' block.
|
|
274
|
-
"""
|
|
275
|
-
# Arrange
|
|
276
|
-
lattice = simple_square_lattice
|
|
277
|
-
|
|
278
|
-
# Act & Assert
|
|
279
|
-
try:
|
|
280
|
-
# Call .show() with options to display indices and identifiers.
|
|
281
|
-
lattice.show(show_indices=True, show_identifiers=True)
|
|
282
|
-
except Exception as e:
|
|
283
|
-
pytest.fail(
|
|
284
|
-
f".show() with label options raised an unexpected exception: {e}"
|
|
285
|
-
)
|
|
286
|
-
|
|
287
|
-
# Ensure the plotting function is still called.
|
|
288
|
-
mock_show.assert_called_once()
|
|
289
|
-
|
|
290
|
-
def test_get_neighbors_logs_info_for_uncached_k(
|
|
291
|
-
self, simple_square_lattice, caplog
|
|
292
|
-
):
|
|
293
|
-
"""
|
|
294
|
-
Tests that an INFO message is logged when get_neighbors is called for a 'k'
|
|
295
|
-
that has not been pre-calculated, triggering on-demand computation.
|
|
296
|
-
"""
|
|
297
|
-
# Arrange
|
|
298
|
-
lattice = simple_square_lattice # This fixture builds neighbors up to k=2
|
|
299
|
-
k_to_test = 99 # A value that is clearly not cached
|
|
300
|
-
caplog.set_level(logging.INFO) # Ensure INFO logs are captured
|
|
301
|
-
|
|
302
|
-
# Act
|
|
303
|
-
# This will now trigger the on-demand computation
|
|
304
|
-
_ = lattice.get_neighbors(0, k=k_to_test)
|
|
305
|
-
|
|
306
|
-
# Assert
|
|
307
|
-
# Check that the correct INFO message about on-demand building was logged.
|
|
308
|
-
expected_log = (
|
|
309
|
-
f"Neighbors for k={k_to_test} not pre-computed. "
|
|
310
|
-
f"Building now up to max_k={k_to_test}."
|
|
311
|
-
)
|
|
312
|
-
assert expected_log in caplog.text
|
|
313
|
-
|
|
314
|
-
@patch("matplotlib.pyplot.show")
|
|
315
|
-
def test_show_prints_warning_for_uncached_bonds(
|
|
316
|
-
self, mock_show, simple_square_lattice, caplog
|
|
317
|
-
):
|
|
318
|
-
"""
|
|
319
|
-
Tests that a warning is printed when .show() is asked to draw a bond layer 'k'
|
|
320
|
-
that has not been pre-calculated.
|
|
321
|
-
"""
|
|
322
|
-
# Arrange
|
|
323
|
-
lattice = simple_square_lattice # This fixture builds neighbors up to k=2
|
|
324
|
-
k_to_test = 99 # A value that is clearly not cached
|
|
325
|
-
|
|
326
|
-
# Act
|
|
327
|
-
lattice.show(show_bonds_k=k_to_test)
|
|
328
|
-
|
|
329
|
-
# Assert
|
|
330
|
-
assert (
|
|
331
|
-
f"Cannot draw bonds. k={k_to_test} neighbors have not been calculated"
|
|
332
|
-
in caplog.text
|
|
333
|
-
)
|
|
334
|
-
|
|
335
|
-
@patch("matplotlib.pyplot.show")
|
|
336
|
-
def test_show_method_for_3d_lattice(self, mock_show):
|
|
337
|
-
"""
|
|
338
|
-
Tests that the .show() method can handle a 3D lattice without
|
|
339
|
-
crashing. This covers the 'if self.dimensionality == 3:' branches.
|
|
340
|
-
"""
|
|
341
|
-
# Arrange: Create a simple 2-site lattice in 3D space.
|
|
342
|
-
coords_3d = [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]
|
|
343
|
-
ids_3d = [0, 1]
|
|
344
|
-
lattice_3d = CustomizeLattice(
|
|
345
|
-
dimensionality=3, identifiers=ids_3d, coordinates=coords_3d
|
|
346
|
-
)
|
|
347
|
-
|
|
348
|
-
# Assert basic property
|
|
349
|
-
assert lattice_3d.dimensionality == 3
|
|
350
|
-
|
|
351
|
-
# Act & Assert
|
|
352
|
-
# We just need to ensure that calling .show() on a 3D object
|
|
353
|
-
# executes the 3D plotting logic without raising an exception.
|
|
354
|
-
try:
|
|
355
|
-
lattice_3d.show(show_indices=True, show_bonds_k=None)
|
|
356
|
-
except Exception as e:
|
|
357
|
-
pytest.fail(f".show() for 3D lattice raised an unexpected exception: {e}")
|
|
358
|
-
|
|
359
|
-
# Verify that the plotting pipeline was completed.
|
|
360
|
-
mock_show.assert_called_once()
|
|
361
|
-
|
|
362
|
-
@patch("matplotlib.pyplot.subplots")
|
|
363
|
-
def test_show_method_actually_draws_2d_labels(
|
|
364
|
-
self, mock_subplots, simple_square_lattice
|
|
365
|
-
):
|
|
366
|
-
"""
|
|
367
|
-
Tests if ax.text is actually called for a 2D lattice when labels are enabled.
|
|
368
|
-
"""
|
|
369
|
-
# Arrange:
|
|
370
|
-
# 1. Prepare mock Figure and Axes objects that `matplotlib.pyplot.subplots` will return.
|
|
371
|
-
# This allows us to inspect calls to the `ax` object.
|
|
372
|
-
mock_fig = matplotlib.figure.Figure()
|
|
373
|
-
mock_ax = matplotlib.axes.Axes(mock_fig, [0.0, 0.0, 1.0, 1.0])
|
|
374
|
-
mock_subplots.return_value = (mock_fig, mock_ax)
|
|
375
|
-
|
|
376
|
-
# 2. Mock the text method on our mock Axes object to monitor its calls.
|
|
377
|
-
with patch.object(mock_ax, "text") as mock_text_method:
|
|
378
|
-
lattice = simple_square_lattice
|
|
379
|
-
|
|
380
|
-
# Act:
|
|
381
|
-
# Call the show method. It will now operate on our mock_ax object.
|
|
382
|
-
lattice.show(show_indices=True)
|
|
383
|
-
|
|
384
|
-
# Assert:
|
|
385
|
-
# Check if the ax.text method was called. For a 4-site lattice, it should be called 4 times.
|
|
386
|
-
assert mock_text_method.call_count == lattice.num_sites
|
|
387
|
-
|
|
388
|
-
def test_custom_irregular_geometry_neighbors(self):
|
|
389
|
-
"""
|
|
390
|
-
Tests neighbor finding on a more complex, non-grid-like custom geometry
|
|
391
|
-
to stress-test the distance shell and KDTree logic.
|
|
392
|
-
"""
|
|
393
|
-
# Arrange: A "star-shaped" lattice with a central point,
|
|
394
|
-
# an inner shell, and an outer shell.
|
|
395
|
-
coords = [
|
|
396
|
-
[0.0, 0.0], # Site 0: Center
|
|
397
|
-
[1.0, 0.0],
|
|
398
|
-
[0.0, 1.0],
|
|
399
|
-
[-1.0, 0.0],
|
|
400
|
-
[0.0, -1.0], # Sites 1-4: Inner shell (dist=1)
|
|
401
|
-
[2.0, 0.0],
|
|
402
|
-
[0.0, 2.0],
|
|
403
|
-
[-2.0, 0.0],
|
|
404
|
-
[0.0, -2.0], # Sites 5-8: Outer shell (dist=2)
|
|
405
|
-
]
|
|
406
|
-
ids = list(range(len(coords)))
|
|
407
|
-
lattice = CustomizeLattice(
|
|
408
|
-
dimensionality=2, identifiers=ids, coordinates=coords
|
|
409
|
-
)
|
|
410
|
-
lattice._build_neighbors(max_k=3)
|
|
411
|
-
|
|
412
|
-
# Assert 1: Neighbors of the central point (0) should be the distinct shells.
|
|
413
|
-
assert set(lattice.get_neighbors(0, k=1)) == {1, 2, 3, 4}
|
|
414
|
-
# The shell at dist=2.0 (d_sq=4.0) is the 3rd global shell, so we check k=3.
|
|
415
|
-
assert set(lattice.get_neighbors(0, k=3)) == {5, 6, 7, 8}
|
|
416
|
-
|
|
417
|
-
assert lattice.get_neighbors(0, k=2) == []
|
|
418
|
-
|
|
419
|
-
# Assert 2: Neighbors of a point on the inner shell, e.g., site 1 ([1.0, 0.0]).
|
|
420
|
-
# Its nearest neighbors (k=1) are the center (0) and the closest point on the outer shell (5).
|
|
421
|
-
# Both are at distance 1.0.
|
|
422
|
-
assert set(lattice.get_neighbors(1, k=1)) == {0, 5}
|
|
423
|
-
|
|
424
|
-
# Its next-nearest neighbors (k=2) are the other two points on the inner shell (2 and 4),
|
|
425
|
-
# both at distance sqrt(2).
|
|
426
|
-
assert set(lattice.get_neighbors(1, k=2)) == {2, 4}
|
|
427
|
-
|
|
428
|
-
def test_customizelattice_max_k_precomputation_and_ondemand(self):
|
|
429
|
-
"""
|
|
430
|
-
A robust test to verify `precompute_neighbors` (max_k) for CustomizeLattice.
|
|
431
|
-
This test is designed to FAIL on the buggy code.
|
|
432
|
-
"""
|
|
433
|
-
coords = [
|
|
434
|
-
[0.0, 0.0],
|
|
435
|
-
[1.0, 0.0],
|
|
436
|
-
[0.0, 1.0],
|
|
437
|
-
[-1.0, 0.0],
|
|
438
|
-
[0.0, -1.0],
|
|
439
|
-
[1.0, 1.0],
|
|
440
|
-
[-1.0, 1.0],
|
|
441
|
-
[-1.0, -1.0],
|
|
442
|
-
[1.0, -1.0],
|
|
443
|
-
[2.0, 0.0],
|
|
444
|
-
[0.0, 2.0],
|
|
445
|
-
[-2.0, 0.0],
|
|
446
|
-
[0.0, -2.0],
|
|
447
|
-
]
|
|
448
|
-
ids = list(range(len(coords)))
|
|
449
|
-
k_precompute = 2
|
|
450
|
-
|
|
451
|
-
lattice = CustomizeLattice(
|
|
452
|
-
dimensionality=2,
|
|
453
|
-
identifiers=ids,
|
|
454
|
-
coordinates=coords,
|
|
455
|
-
precompute_neighbors=k_precompute,
|
|
456
|
-
)
|
|
457
|
-
|
|
458
|
-
computed_shells = sorted(list(lattice._neighbor_maps.keys()))
|
|
459
|
-
expected_shells = list(range(1, k_precompute + 1))
|
|
460
|
-
|
|
461
|
-
assert computed_shells == expected_shells, (
|
|
462
|
-
f"TEST FAILED for CustomizeLattice with k={k_precompute}. "
|
|
463
|
-
f"Expected shells {expected_shells}, but found {computed_shells}."
|
|
464
|
-
)
|
|
465
|
-
|
|
466
|
-
k_ondemand = 3
|
|
467
|
-
_ = lattice.get_neighbors(0, k=k_ondemand)
|
|
468
|
-
|
|
469
|
-
computed_shells_after = sorted(list(lattice._neighbor_maps.keys()))
|
|
470
|
-
expected_shells_after = list(range(1, k_ondemand + 1))
|
|
471
|
-
|
|
472
|
-
assert computed_shells_after == expected_shells_after, (
|
|
473
|
-
f"ON-DEMAND TEST FAILED for CustomizeLattice. "
|
|
474
|
-
f"Expected shells {expected_shells_after} after demanding k={k_ondemand}, "
|
|
475
|
-
f"but found {computed_shells_after}."
|
|
476
|
-
)
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
@pytest.fixture
|
|
480
|
-
def obc_square_lattice() -> SquareLattice:
|
|
481
|
-
"""Provides a 3x3 SquareLattice with Open Boundary Conditions."""
|
|
482
|
-
return SquareLattice(size=(3, 3), pbc=False)
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
@pytest.fixture
|
|
486
|
-
def pbc_square_lattice() -> SquareLattice:
|
|
487
|
-
"""Provides a 3x3 SquareLattice with Periodic Boundary Conditions."""
|
|
488
|
-
return SquareLattice(size=(3, 3), pbc=True)
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
class TestSquareLattice:
|
|
492
|
-
"""
|
|
493
|
-
Groups all tests for the SquareLattice class, which implicitly tests
|
|
494
|
-
the core functionality of its parent, TILattice.
|
|
495
|
-
"""
|
|
496
|
-
|
|
497
|
-
def test_initialization_and_properties(self, obc_square_lattice):
|
|
498
|
-
"""
|
|
499
|
-
Tests the basic properties of a SquareLattice instance.
|
|
500
|
-
"""
|
|
501
|
-
lattice = obc_square_lattice
|
|
502
|
-
assert lattice.dimensionality == 2
|
|
503
|
-
assert lattice.num_sites == 9 # A 3x3 lattice should have 9 sites.
|
|
504
|
-
assert len(lattice) == 9
|
|
505
|
-
|
|
506
|
-
def test_site_info_and_identifiers(self, obc_square_lattice):
|
|
507
|
-
"""
|
|
508
|
-
Tests that site information (coordinates, identifiers) is correct.
|
|
509
|
-
"""
|
|
510
|
-
lattice = obc_square_lattice
|
|
511
|
-
center_idx = lattice.get_index((1, 1, 0))
|
|
512
|
-
assert center_idx == 4
|
|
513
|
-
|
|
514
|
-
_, ident, coords = lattice.get_site_info(center_idx)
|
|
515
|
-
assert ident == (1, 1, 0)
|
|
516
|
-
np.testing.assert_array_equal(coords, np.array([1.0, 1.0]))
|
|
517
|
-
|
|
518
|
-
corner_idx = 0
|
|
519
|
-
_, ident, coords = lattice.get_site_info(corner_idx)
|
|
520
|
-
assert ident == (0, 0, 0)
|
|
521
|
-
np.testing.assert_array_equal(coords, np.array([0.0, 0.0]))
|
|
522
|
-
|
|
523
|
-
def test_neighbors_with_open_boundaries(self, obc_square_lattice):
|
|
524
|
-
"""
|
|
525
|
-
Tests neighbor finding with Open Boundary Conditions (OBC) using specific
|
|
526
|
-
neighbor identities.
|
|
527
|
-
"""
|
|
528
|
-
lattice = obc_square_lattice
|
|
529
|
-
# Site indices for a 3x3 grid (row-major order):
|
|
530
|
-
# 0 1 2
|
|
531
|
-
# 3 4 5
|
|
532
|
-
# 6 7 8
|
|
533
|
-
center_idx = 4 # (1, 1, 0)
|
|
534
|
-
corner_idx = 0 # (0, 0, 0)
|
|
535
|
-
edge_idx = 3 # (1, 0, 0)
|
|
536
|
-
|
|
537
|
-
# Assert center site (4) has neighbors 1, 3, 5, 7
|
|
538
|
-
assert set(lattice.get_neighbors(center_idx, k=1)) == {1, 3, 5, 7}
|
|
539
|
-
# Assert corner site (0) has neighbors 1, 3
|
|
540
|
-
assert set(lattice.get_neighbors(corner_idx, k=1)) == {1, 3}
|
|
541
|
-
# Assert edge site (3) has neighbors 0, 4, 6
|
|
542
|
-
assert set(lattice.get_neighbors(edge_idx, k=1)) == {0, 4, 6}
|
|
543
|
-
|
|
544
|
-
def test_neighbors_with_periodic_boundaries(self, pbc_square_lattice):
|
|
545
|
-
"""
|
|
546
|
-
Tests neighbor finding with Periodic Boundary Conditions (PBC).
|
|
547
|
-
"""
|
|
548
|
-
lattice = pbc_square_lattice
|
|
549
|
-
corner_idx = lattice.get_index((0, 0, 0))
|
|
550
|
-
|
|
551
|
-
neighbors = lattice.get_neighbors(corner_idx, k=1)
|
|
552
|
-
neighbor_idents = {lattice.get_identifier(i) for i in neighbors}
|
|
553
|
-
expected_neighbor_idents = {(1, 0, 0), (0, 1, 0), (2, 0, 0), (0, 2, 0)}
|
|
554
|
-
assert neighbor_idents == expected_neighbor_idents
|
|
555
|
-
|
|
556
|
-
nnn_neighbors = lattice.get_neighbors(corner_idx, k=2)
|
|
557
|
-
nnn_neighbor_idents = {lattice.get_identifier(i) for i in nnn_neighbors}
|
|
558
|
-
expected_nnn_idents = {(1, 1, 0), (2, 1, 0), (1, 2, 0), (2, 2, 0)}
|
|
559
|
-
assert nnn_neighbor_idents == expected_nnn_idents
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
# --- Tests for HoneycombLattice ---
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
@pytest.fixture
|
|
566
|
-
def pbc_honeycomb_lattice() -> HoneycombLattice:
|
|
567
|
-
"""Provides a 2x2 HoneycombLattice with Periodic Boundary Conditions."""
|
|
568
|
-
return HoneycombLattice(size=(2, 2), pbc=True)
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
class TestHoneycombLattice:
|
|
572
|
-
"""
|
|
573
|
-
Tests the HoneycombLattice class, focusing on its two-site basis.
|
|
574
|
-
"""
|
|
575
|
-
|
|
576
|
-
def test_initialization_and_properties(self, pbc_honeycomb_lattice):
|
|
577
|
-
"""
|
|
578
|
-
Tests that the total number of sites is correct for a composite lattice.
|
|
579
|
-
"""
|
|
580
|
-
lattice = pbc_honeycomb_lattice
|
|
581
|
-
assert lattice.num_sites == 8
|
|
582
|
-
assert lattice.num_basis == 2
|
|
583
|
-
|
|
584
|
-
def test_honeycomb_neighbors(self, pbc_honeycomb_lattice):
|
|
585
|
-
"""
|
|
586
|
-
Tests that every site in a honeycomb lattice has 3 nearest neighbors.
|
|
587
|
-
"""
|
|
588
|
-
lattice = pbc_honeycomb_lattice
|
|
589
|
-
site_a_idx = lattice.get_index((0, 0, 0))
|
|
590
|
-
assert len(lattice.get_neighbors(site_a_idx, k=1)) == 3
|
|
591
|
-
|
|
592
|
-
site_b_idx = lattice.get_index((0, 0, 1))
|
|
593
|
-
assert len(lattice.get_neighbors(site_b_idx, k=1)) == 3
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
# --- Tests for TriangularLattice ---
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
@pytest.fixture
|
|
600
|
-
def pbc_triangular_lattice() -> TriangularLattice:
|
|
601
|
-
"""
|
|
602
|
-
Provides a 3x3 TriangularLattice with Periodic Boundary Conditions.
|
|
603
|
-
A 3x3 size is used to ensure all 6 nearest neighbors are unique sites.
|
|
604
|
-
"""
|
|
605
|
-
return TriangularLattice(size=(3, 3), pbc=True)
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
class TestTriangularLattice:
|
|
609
|
-
"""
|
|
610
|
-
Tests the TriangularLattice class, focusing on its coordination number.
|
|
611
|
-
"""
|
|
612
|
-
|
|
613
|
-
def test_initialization_and_properties(self, pbc_triangular_lattice):
|
|
614
|
-
"""
|
|
615
|
-
Tests the basic properties of the triangular lattice.
|
|
616
|
-
"""
|
|
617
|
-
lattice = pbc_triangular_lattice
|
|
618
|
-
assert lattice.num_sites == 9 # 3 * 3 = 9 sites for a 3x3 grid
|
|
619
|
-
|
|
620
|
-
def test_triangular_neighbors(self, pbc_triangular_lattice):
|
|
621
|
-
"""
|
|
622
|
-
Tests that every site in a triangular lattice has 6 nearest neighbors.
|
|
623
|
-
"""
|
|
624
|
-
lattice = pbc_triangular_lattice
|
|
625
|
-
site_idx = 0
|
|
626
|
-
assert len(lattice.get_neighbors(site_idx, k=1)) == 6
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
# --- Tests for New TILattice Implementations ---
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
class TestRectangularLattice:
|
|
633
|
-
"""Tests for the 2D RectangularLattice."""
|
|
634
|
-
|
|
635
|
-
def test_rectangular_properties_and_neighbors(self):
|
|
636
|
-
"""Tests neighbor counts for an OBC rectangular lattice."""
|
|
637
|
-
lattice = RectangularLattice(size=(3, 4), pbc=False)
|
|
638
|
-
assert lattice.num_sites == 12
|
|
639
|
-
assert lattice.dimensionality == 2
|
|
640
|
-
|
|
641
|
-
# Test neighbor counts for different site types
|
|
642
|
-
center_idx = lattice.get_index((1, 1, 0))
|
|
643
|
-
corner_idx = lattice.get_index((0, 0, 0))
|
|
644
|
-
edge_idx = lattice.get_index((0, 1, 0))
|
|
645
|
-
|
|
646
|
-
assert len(lattice.get_neighbors(center_idx, k=1)) == 4
|
|
647
|
-
assert len(lattice.get_neighbors(corner_idx, k=1)) == 2
|
|
648
|
-
assert len(lattice.get_neighbors(edge_idx, k=1)) == 3
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
class TestTILatticeEdgeCases:
|
|
652
|
-
"""
|
|
653
|
-
A dedicated class for testing the behavior of TILattice and its
|
|
654
|
-
subclasses under less common, "edge-case" conditions.
|
|
655
|
-
"""
|
|
656
|
-
|
|
657
|
-
@pytest.fixture
|
|
658
|
-
def obc_1d_chain(self) -> ChainLattice:
|
|
659
|
-
"""
|
|
660
|
-
Provides a 5-site 1D chain with Open Boundary Conditions.
|
|
661
|
-
"""
|
|
662
|
-
# 0--1--2--3--4
|
|
663
|
-
return ChainLattice(size=(5,), pbc=False)
|
|
664
|
-
|
|
665
|
-
def test_1d_chain_properties_and_neighbors(self, obc_1d_chain):
|
|
666
|
-
# Arrange
|
|
667
|
-
lattice = obc_1d_chain
|
|
668
|
-
|
|
669
|
-
# Assert basic properties
|
|
670
|
-
assert lattice.num_sites == 5
|
|
671
|
-
assert lattice.dimensionality == 1
|
|
672
|
-
|
|
673
|
-
# Assert neighbor counts for different positions
|
|
674
|
-
# Endpoint (site 0) should have 1 neighbor (site 1)
|
|
675
|
-
endpoint_idx = lattice.get_index((0, 0))
|
|
676
|
-
assert lattice.get_neighbors(endpoint_idx, k=1) == [1]
|
|
677
|
-
|
|
678
|
-
# Middle point (site 2) should have 2 neighbors (sites 1 and 3)
|
|
679
|
-
middle_idx = lattice.get_index((2, 0))
|
|
680
|
-
assert len(lattice.get_neighbors(middle_idx, k=1)) == 2
|
|
681
|
-
assert set(lattice.get_neighbors(middle_idx, k=1)) == {1, 3}
|
|
682
|
-
|
|
683
|
-
@pytest.fixture
|
|
684
|
-
def nonsquare_lattice(self) -> SquareLattice:
|
|
685
|
-
"""Provides a non-square 2x3 lattice to test indexing."""
|
|
686
|
-
return SquareLattice(size=(2, 3), pbc=False)
|
|
687
|
-
|
|
688
|
-
def test_nonsquare_lattice_indexing(self, nonsquare_lattice):
|
|
689
|
-
"""
|
|
690
|
-
Tests site indexing and coordinate generation on a non-square (2x3) lattice.
|
|
691
|
-
This ensures the logic correctly handles different dimension lengths.
|
|
692
|
-
The lattice sites are indexed row by row:
|
|
693
|
-
(0,0) (0,1) (0,2) -> indices 0, 1, 2
|
|
694
|
-
(1,0) (1,1) (1,2) -> indices 3, 4, 5
|
|
695
|
-
"""
|
|
696
|
-
# Arrange
|
|
697
|
-
lattice = nonsquare_lattice
|
|
698
|
-
|
|
699
|
-
# Assert properties
|
|
700
|
-
assert lattice.num_sites == 6 # 2 * 3 = 6
|
|
701
|
-
|
|
702
|
-
# Act & Assert: Check a non-trivial site, e.g., the last one.
|
|
703
|
-
# The identifier for the site in the last row and last column.
|
|
704
|
-
ident = (1, 2, 0)
|
|
705
|
-
expected_idx = 5
|
|
706
|
-
expected_coords = np.array([1.0, 2.0])
|
|
707
|
-
|
|
708
|
-
# Get index from identifier
|
|
709
|
-
idx = lattice.get_index(ident)
|
|
710
|
-
assert idx == expected_idx
|
|
711
|
-
|
|
712
|
-
# Get info from index
|
|
713
|
-
_, _, coords = lattice.get_site_info(idx)
|
|
714
|
-
np.testing.assert_array_equal(coords, expected_coords)
|
|
715
|
-
|
|
716
|
-
@patch("matplotlib.pyplot.show")
|
|
717
|
-
def test_show_method_for_1d_lattice(self, mock_show, obc_1d_chain):
|
|
718
|
-
"""
|
|
719
|
-
Tests that the .show() method can handle a 1D lattice (chain)
|
|
720
|
-
without crashing. This covers the 'if self.dimensionality == 1:' branches.
|
|
721
|
-
"""
|
|
722
|
-
# Arrange
|
|
723
|
-
lattice_1d = obc_1d_chain
|
|
724
|
-
|
|
725
|
-
# Assert basic property
|
|
726
|
-
assert lattice_1d.num_sites == 5
|
|
727
|
-
|
|
728
|
-
# Act & Assert
|
|
729
|
-
try:
|
|
730
|
-
# Call .show() on the 1D lattice to execute the 1D plotting logic.
|
|
731
|
-
lattice_1d.show(show_indices=True)
|
|
732
|
-
except Exception as e:
|
|
733
|
-
pytest.fail(f".show() for 1D lattice raised an unexpected exception: {e}")
|
|
734
|
-
|
|
735
|
-
# Verify that the plotting pipeline was completed.
|
|
736
|
-
mock_show.assert_called_once()
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
# --- Tests for API Robustness / Negative Cases ---
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
class TestApiRobustness:
|
|
743
|
-
"""
|
|
744
|
-
Groups tests that verify the API's behavior with invalid inputs.
|
|
745
|
-
This ensures the lattice classes fail gracefully and predictably.
|
|
746
|
-
"""
|
|
747
|
-
|
|
748
|
-
def test_access_with_out_of_bounds_index(self, simple_square_lattice):
|
|
749
|
-
"""
|
|
750
|
-
Tests that an IndexError is raised when accessing a site index
|
|
751
|
-
that is out of the valid range (0 to num_sites-1).
|
|
752
|
-
"""
|
|
753
|
-
# Arrange
|
|
754
|
-
lattice = simple_square_lattice # This lattice has 4 sites (indices 0, 1, 2, 3)
|
|
755
|
-
invalid_index = 999
|
|
756
|
-
|
|
757
|
-
# Act & Assert
|
|
758
|
-
# We use pytest.raises to confirm that the expected exception is thrown.
|
|
759
|
-
with pytest.raises(IndexError):
|
|
760
|
-
lattice.get_coordinates(invalid_index)
|
|
761
|
-
|
|
762
|
-
with pytest.raises(IndexError):
|
|
763
|
-
lattice.get_identifier(invalid_index)
|
|
764
|
-
|
|
765
|
-
with pytest.raises(IndexError):
|
|
766
|
-
# get_site_info should also raise IndexError for an invalid index
|
|
767
|
-
lattice.get_site_info(invalid_index)
|
|
768
|
-
|
|
769
|
-
def test_empty_lattice_handles_gracefully(self, caplog):
|
|
770
|
-
"""
|
|
771
|
-
Tests that an empty lattice initializes correctly and that methods
|
|
772
|
-
like .show() and ._build_neighbors() handle the zero-site case
|
|
773
|
-
gracefully without crashing.
|
|
774
|
-
"""
|
|
775
|
-
# Arrange: Create an empty CustomizeLattice instance.
|
|
776
|
-
empty_lattice = CustomizeLattice(
|
|
777
|
-
dimensionality=2, identifiers=[], coordinates=[]
|
|
778
|
-
)
|
|
779
|
-
|
|
780
|
-
# Assert: Verify basic properties.
|
|
781
|
-
assert empty_lattice.num_sites == 0
|
|
782
|
-
assert len(empty_lattice) == 0
|
|
783
|
-
|
|
784
|
-
# Act & Assert for .show(): Verify it prints the expected message without crashing.
|
|
785
|
-
caplog.set_level(logging.INFO)
|
|
786
|
-
|
|
787
|
-
empty_lattice.show()
|
|
788
|
-
assert "Lattice is empty, nothing to show." in caplog.text
|
|
789
|
-
|
|
790
|
-
# Act & Assert for neighbor finding: Verify these calls run without errors.
|
|
791
|
-
empty_lattice._build_neighbors()
|
|
792
|
-
assert empty_lattice.get_neighbor_pairs(k=1) == []
|
|
793
|
-
|
|
794
|
-
def test_single_site_lattice_handles_gracefully(self):
|
|
795
|
-
"""
|
|
796
|
-
Tests that a lattice with a single site correctly handles neighbor
|
|
797
|
-
finding (i.e., returns no neighbors).
|
|
798
|
-
"""
|
|
799
|
-
# Arrange: Create a CustomizeLattice with a single site.
|
|
800
|
-
single_site_lattice = CustomizeLattice(
|
|
801
|
-
dimensionality=2, identifiers=[0], coordinates=[[0.0, 0.0]]
|
|
802
|
-
)
|
|
803
|
-
|
|
804
|
-
# Assert: Verify basic properties.
|
|
805
|
-
assert single_site_lattice.num_sites == 1
|
|
806
|
-
|
|
807
|
-
# Act: Attempt to build neighbor relationships.
|
|
808
|
-
single_site_lattice._build_neighbors(max_k=1)
|
|
809
|
-
|
|
810
|
-
# Assert: The single site should have no neighbors.
|
|
811
|
-
assert single_site_lattice.get_neighbors(0, k=1) == []
|
|
812
|
-
|
|
813
|
-
def test_access_with_non_existent_identifier(self, simple_square_lattice):
|
|
814
|
-
"""
|
|
815
|
-
Tests that a ValueError is raised when accessing a site
|
|
816
|
-
with an identifier that does not exist in the lattice.
|
|
817
|
-
"""
|
|
818
|
-
# Arrange
|
|
819
|
-
lattice = simple_square_lattice
|
|
820
|
-
invalid_identifier = "non_existent_site"
|
|
821
|
-
|
|
822
|
-
# Act & Assert
|
|
823
|
-
# Your code raises a ValueError with a specific message. We can even
|
|
824
|
-
# use the 'match' parameter to check if the error message is correct.
|
|
825
|
-
with pytest.raises(ValueError, match="not found in the lattice"):
|
|
826
|
-
lattice.get_index(invalid_identifier)
|
|
827
|
-
|
|
828
|
-
with pytest.raises(ValueError, match="not found in the lattice"):
|
|
829
|
-
lattice.get_site_info(invalid_identifier)
|
|
830
|
-
|
|
831
|
-
def test_show_warning_for_unsupported_dimension(self, caplog):
|
|
832
|
-
"""
|
|
833
|
-
Tests that .show() prints a warning when called on a lattice with a
|
|
834
|
-
dimensionality that it does not support for plotting (e.g., 4D).
|
|
835
|
-
"""
|
|
836
|
-
# Arrange: Create a simple lattice with an unsupported dimension.
|
|
837
|
-
lattice_4d = CustomizeLattice(
|
|
838
|
-
dimensionality=4, identifiers=[0], coordinates=[[0, 0, 0, 0]]
|
|
839
|
-
)
|
|
840
|
-
|
|
841
|
-
# Act
|
|
842
|
-
lattice_4d.show()
|
|
843
|
-
|
|
844
|
-
# Assert: Check that the appropriate warning was printed to stdout.
|
|
845
|
-
assert "show() is not implemented for 4D lattices." in caplog.text
|
|
846
|
-
|
|
847
|
-
def test_disconnected_lattice_neighbor_finding(self):
|
|
848
|
-
"""
|
|
849
|
-
Tests that neighbor finding algorithms work correctly for a lattice
|
|
850
|
-
composed of multiple, physically disconnected components.
|
|
851
|
-
"""
|
|
852
|
-
# Arrange: Create a lattice with two disconnected 2x2 squares,
|
|
853
|
-
# separated by a large distance.
|
|
854
|
-
# Component 1: sites with indices 0, 1, 2, 3
|
|
855
|
-
# Component 2: sites with indices 4, 5, 6, 7
|
|
856
|
-
coords = [
|
|
857
|
-
[0.0, 0.0],
|
|
858
|
-
[1.0, 0.0],
|
|
859
|
-
[0.0, 1.0],
|
|
860
|
-
[1.0, 1.0], # Square 1
|
|
861
|
-
[100.0, 0.0],
|
|
862
|
-
[101.0, 0.0],
|
|
863
|
-
[100.0, 1.0],
|
|
864
|
-
[101.0, 1.0], # Square 2
|
|
865
|
-
]
|
|
866
|
-
ids = list(range(len(coords)))
|
|
867
|
-
lattice = CustomizeLattice(
|
|
868
|
-
dimensionality=2, identifiers=ids, coordinates=coords
|
|
869
|
-
)
|
|
870
|
-
lattice._build_neighbors(max_k=1) # Explicitly build neighbors
|
|
871
|
-
|
|
872
|
-
# --- Test 1: get_neighbors() ---
|
|
873
|
-
# Act: Get neighbors for a site in the first component.
|
|
874
|
-
neighbors_of_site_0 = lattice.get_neighbors(0, k=1)
|
|
875
|
-
|
|
876
|
-
# Assert: Its neighbors must only be within the first component.
|
|
877
|
-
assert set(neighbors_of_site_0) == {1, 2}
|
|
878
|
-
|
|
879
|
-
# --- Test 2: get_neighbor_pairs() ---
|
|
880
|
-
# Act: Get all unique bonds for the entire lattice.
|
|
881
|
-
all_bonds = lattice.get_neighbor_pairs(k=1, unique=True)
|
|
882
|
-
|
|
883
|
-
# Assert: No bond should connect a site from Component 1 to Component 2.
|
|
884
|
-
for i, j in all_bonds:
|
|
885
|
-
# A bond is valid only if both its sites are in the same component.
|
|
886
|
-
# We check this by seeing if their indices fall in the same range.
|
|
887
|
-
is_in_first_component = i < 4 and j < 4
|
|
888
|
-
is_in_second_component = i >= 4 and j >= 4
|
|
889
|
-
|
|
890
|
-
assert is_in_first_component or is_in_second_component, (
|
|
891
|
-
f"Found an invalid bond { (i,j) } that incorrectly connects "
|
|
892
|
-
"two separate components of the lattice."
|
|
893
|
-
)
|
|
894
|
-
|
|
895
|
-
def test_lattice_with_duplicate_coordinates(self):
|
|
896
|
-
"""
|
|
897
|
-
Tests a pathological case where multiple sites share the exact same coordinates.
|
|
898
|
-
The neighbor-finding logic must still treat them as distinct sites and
|
|
899
|
-
correctly identify neighbors based on other non-overlapping sites.
|
|
900
|
-
"""
|
|
901
|
-
# Arrange
|
|
902
|
-
# Site 'A' and 'B' are at the same position (0,0).
|
|
903
|
-
# Site 'C' is at (1,0), which should be a neighbor to both 'A' and 'B'.
|
|
904
|
-
ids = ["A", "B", "C"]
|
|
905
|
-
coords = [[0.0, 0.0], [0.0, 0.0], [1.0, 0.0]]
|
|
906
|
-
|
|
907
|
-
lattice = CustomizeLattice(
|
|
908
|
-
dimensionality=2, identifiers=ids, coordinates=coords
|
|
909
|
-
)
|
|
910
|
-
lattice._build_neighbors(max_k=1) # Build nearest neighbors
|
|
911
|
-
|
|
912
|
-
# Act
|
|
913
|
-
idx_A = lattice.get_index("A")
|
|
914
|
-
idx_B = lattice.get_index("B")
|
|
915
|
-
idx_C = lattice.get_index("C")
|
|
916
|
-
|
|
917
|
-
neighbors_A = lattice.get_neighbors(idx_A, k=1)
|
|
918
|
-
neighbors_B = lattice.get_neighbors(idx_B, k=1)
|
|
919
|
-
|
|
920
|
-
# Assert
|
|
921
|
-
# 1. The distance between the overlapping points 'A' and 'B' is 0,
|
|
922
|
-
# so they should NOT be considered neighbors of each other.
|
|
923
|
-
assert (
|
|
924
|
-
idx_B not in neighbors_A
|
|
925
|
-
), "Overlapping sites should not be their own neighbors."
|
|
926
|
-
assert (
|
|
927
|
-
idx_A not in neighbors_B
|
|
928
|
-
), "Overlapping sites should not be their own neighbors."
|
|
929
|
-
|
|
930
|
-
# 2. Both 'A' and 'B' should correctly identify 'C' as their neighbor.
|
|
931
|
-
# This is the key test of robustness.
|
|
932
|
-
assert neighbors_A == [
|
|
933
|
-
idx_C
|
|
934
|
-
], "Site 'A' failed to find its correct neighbor 'C'."
|
|
935
|
-
assert neighbors_B == [
|
|
936
|
-
idx_C
|
|
937
|
-
], "Site 'B' failed to find its correct neighbor 'C'."
|
|
938
|
-
|
|
939
|
-
# 3. Conversely, 'C' should identify both 'A' and 'B' as its neighbors.
|
|
940
|
-
neighbors_C = lattice.get_neighbors(idx_C, k=1)
|
|
941
|
-
assert set(neighbors_C) == {
|
|
942
|
-
idx_A,
|
|
943
|
-
idx_B,
|
|
944
|
-
}, "Site 'C' failed to find both overlapping neighbors."
|
|
945
|
-
|
|
946
|
-
def test_neighbor_shells_with_tiny_separation(self):
|
|
947
|
-
"""
|
|
948
|
-
Tests the numerical stability of neighbor shell identification.
|
|
949
|
-
Creates a lattice where the k=1 and k=2 shells are separated by a
|
|
950
|
-
distance much smaller than the default tolerance, and verifies that they
|
|
951
|
-
are still correctly identified as distinct shells.
|
|
952
|
-
"""
|
|
953
|
-
# Arrange
|
|
954
|
-
# Let d1 be the distance to the first neighbor shell.
|
|
955
|
-
d1 = 1.0
|
|
956
|
-
# Let d2 be the distance to the second shell, which is extremely close to d1.
|
|
957
|
-
epsilon = 1e-8 # A tiny separation
|
|
958
|
-
d2 = d1 + epsilon
|
|
959
|
-
|
|
960
|
-
# Create a 1D lattice with these specific distances.
|
|
961
|
-
# Site 0 is origin. Site 1 is at d1. Site 2 is at d2.
|
|
962
|
-
ids = [0, 1, 2]
|
|
963
|
-
coords = [[0.0], [d1], [d2]]
|
|
964
|
-
|
|
965
|
-
# We explicitly use a tolerance LARGER than the separation,
|
|
966
|
-
# which SHOULD cause the shells to merge.
|
|
967
|
-
lattice_merged = CustomizeLattice(
|
|
968
|
-
dimensionality=1, identifiers=ids, coordinates=coords
|
|
969
|
-
)
|
|
970
|
-
# Use a tolerance that cannot distinguish d1 and d2.
|
|
971
|
-
lattice_merged._build_neighbors(max_k=2, tol=1e-7)
|
|
972
|
-
|
|
973
|
-
# Now, use a tolerance SMALLER than the separation,
|
|
974
|
-
# which SHOULD correctly distinguish the shells.
|
|
975
|
-
lattice_distinct = CustomizeLattice(
|
|
976
|
-
dimensionality=1, identifiers=ids, coordinates=coords
|
|
977
|
-
)
|
|
978
|
-
lattice_distinct._build_neighbors(max_k=2, tol=1e-9)
|
|
979
|
-
|
|
980
|
-
# Assert for the merged case
|
|
981
|
-
# With a large tolerance, site 1 and 2 should both be in the k=1 shell.
|
|
982
|
-
merged_neighbors_k1 = lattice_merged.get_neighbors(0, k=1)
|
|
983
|
-
assert set(merged_neighbors_k1) == {
|
|
984
|
-
1,
|
|
985
|
-
2,
|
|
986
|
-
}, "Shells were not merged with a large tolerance."
|
|
987
|
-
# There should be no k=2 shell.
|
|
988
|
-
merged_neighbors_k2 = lattice_merged.get_neighbors(0, k=2)
|
|
989
|
-
assert (
|
|
990
|
-
merged_neighbors_k2 == []
|
|
991
|
-
), "A k=2 shell should not exist when shells are merged."
|
|
992
|
-
|
|
993
|
-
# Assert for the distinct case
|
|
994
|
-
# With a small tolerance, only site 1 should be in the k=1 shell.
|
|
995
|
-
distinct_neighbors_k1 = lattice_distinct.get_neighbors(0, k=1)
|
|
996
|
-
assert distinct_neighbors_k1 == [
|
|
997
|
-
1
|
|
998
|
-
], "k=1 shell is incorrect with a small tolerance."
|
|
999
|
-
# Site 2 should now be in its own k=2 shell.
|
|
1000
|
-
distinct_neighbors_k2 = lattice_distinct.get_neighbors(0, k=2)
|
|
1001
|
-
assert distinct_neighbors_k2 == [
|
|
1002
|
-
2
|
|
1003
|
-
], "k=2 shell is incorrect with a small tolerance."
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
class TestTILattice:
|
|
1007
|
-
"""
|
|
1008
|
-
A dedicated class for testing the Translationally Invariant Lattice (TILattice)
|
|
1009
|
-
and its subclasses like SquareLattice.
|
|
1010
|
-
"""
|
|
1011
|
-
|
|
1012
|
-
def test_init_with_mismatched_shapes_raises_error(self):
|
|
1013
|
-
"""
|
|
1014
|
-
Tests that TILattice raises AssertionError if the 'size' parameter's
|
|
1015
|
-
length does not match the dimensionality.
|
|
1016
|
-
"""
|
|
1017
|
-
# Act & Assert:
|
|
1018
|
-
# Pass a 'size' tuple with 3 elements to a 2D SquareLattice.
|
|
1019
|
-
# This should trigger the AssertionError from the parent TILattice class.
|
|
1020
|
-
with pytest.raises(AssertionError, match="Size tuple length mismatch"):
|
|
1021
|
-
SquareLattice(size=(2, 2, 2))
|
|
1022
|
-
|
|
1023
|
-
def test_init_with_tuple_pbc(self):
|
|
1024
|
-
"""
|
|
1025
|
-
Tests that TILattice correctly handles a tuple input for the 'pbc'
|
|
1026
|
-
(periodic boundary conditions) parameter. This covers the 'else' branch.
|
|
1027
|
-
"""
|
|
1028
|
-
# Arrange
|
|
1029
|
-
pbc_tuple = (True, False)
|
|
1030
|
-
|
|
1031
|
-
# Act
|
|
1032
|
-
# Initialize a lattice with a tuple for pbc.
|
|
1033
|
-
lattice = SquareLattice(size=(3, 3), pbc=pbc_tuple)
|
|
1034
|
-
|
|
1035
|
-
# Assert
|
|
1036
|
-
# The public 'pbc' attribute should be identical to the tuple we passed.
|
|
1037
|
-
assert lattice.pbc == pbc_tuple
|
|
1038
|
-
|
|
1039
|
-
@pytest.mark.parametrize(
|
|
1040
|
-
"LatticeClass, init_args, k_precompute",
|
|
1041
|
-
[
|
|
1042
|
-
(HoneycombLattice, {"size": (4, 5), "pbc": True}, 1),
|
|
1043
|
-
(SquareLattice, {"size": (5, 5), "pbc": True}, 2),
|
|
1044
|
-
(SquareLattice, {"size": (5, 5), "pbc": False}, 1),
|
|
1045
|
-
(KagomeLattice, {"size": (3, 3), "pbc": True}, 1),
|
|
1046
|
-
],
|
|
1047
|
-
)
|
|
1048
|
-
def test_tilattice_max_k_precomputation_and_ondemand(
|
|
1049
|
-
self, LatticeClass, init_args, k_precompute
|
|
1050
|
-
):
|
|
1051
|
-
"""
|
|
1052
|
-
A robust, parameterized test to verify that `precompute_neighbors` (max_k)
|
|
1053
|
-
works correctly across various TILattice types and conditions.
|
|
1054
|
-
This test is designed to FAIL on the buggy code.
|
|
1055
|
-
"""
|
|
1056
|
-
lattice = LatticeClass(**init_args, precompute_neighbors=k_precompute)
|
|
1057
|
-
|
|
1058
|
-
computed_shells = sorted(list(lattice._neighbor_maps.keys()))
|
|
1059
|
-
expected_shells = list(range(1, k_precompute + 1))
|
|
1060
|
-
|
|
1061
|
-
assert computed_shells == expected_shells, (
|
|
1062
|
-
f"TEST FAILED for {LatticeClass.__name__} with k={k_precompute}. "
|
|
1063
|
-
f"Expected shells {expected_shells}, but found {computed_shells}."
|
|
1064
|
-
)
|
|
1065
|
-
|
|
1066
|
-
k_ondemand = k_precompute + 1
|
|
1067
|
-
|
|
1068
|
-
_ = lattice.get_neighbors(0, k=k_ondemand)
|
|
1069
|
-
|
|
1070
|
-
computed_shells_after = sorted(list(lattice._neighbor_maps.keys()))
|
|
1071
|
-
expected_shells_after = list(range(1, k_ondemand + 1))
|
|
1072
|
-
|
|
1073
|
-
assert computed_shells_after == expected_shells_after, (
|
|
1074
|
-
f"ON-DEMAND TEST FAILED for {LatticeClass.__name__}. "
|
|
1075
|
-
f"Expected shells {expected_shells_after} after demanding k={k_ondemand}, "
|
|
1076
|
-
f"but found {computed_shells_after}."
|
|
1077
|
-
)
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
class TestLongRangeNeighborFinding:
|
|
1081
|
-
"""
|
|
1082
|
-
Tests neighbor finding on larger lattices and for longer-range interactions (large k),
|
|
1083
|
-
addressing suggestions from code review.
|
|
1084
|
-
"""
|
|
1085
|
-
|
|
1086
|
-
@pytest.fixture(scope="class")
|
|
1087
|
-
def large_pbc_square_lattice(self) -> SquareLattice:
|
|
1088
|
-
"""
|
|
1089
|
-
Provides a single 6x8 SquareLattice with PBC for all tests in this class.
|
|
1090
|
-
Using scope="class" makes it more efficient as it's created only once.
|
|
1091
|
-
"""
|
|
1092
|
-
# We choose a non-square size to catch potential bugs with non-uniform dimensions.
|
|
1093
|
-
return SquareLattice(size=(7, 9), pbc=True)
|
|
1094
|
-
|
|
1095
|
-
def test_neighbor_shell_structure_on_large_lattice(self, large_pbc_square_lattice):
|
|
1096
|
-
"""
|
|
1097
|
-
Tests the coordination number of various neighbor shells (k) on a large
|
|
1098
|
-
periodic lattice. In a PBC square lattice, every site is identical, so
|
|
1099
|
-
the number of neighbors for each shell k should be the same for all sites.
|
|
1100
|
-
|
|
1101
|
-
Shell distances squared and their coordination numbers for a 2D square lattice:
|
|
1102
|
-
- k=1: dist_sq=1 (e.g., (1,0)) -> 4 neighbors
|
|
1103
|
-
- k=2: dist_sq=2 (e.g., (1,1)) -> 4 neighbors
|
|
1104
|
-
- k=3: dist_sq=4 (e.g., (2,0)) -> 4 neighbors
|
|
1105
|
-
- k=4: dist_sq=5 (e.g., (2,1)) -> 8 neighbors
|
|
1106
|
-
- k=5: dist_sq=8 (e.g., (2,2)) -> 4 neighbors
|
|
1107
|
-
- k=6: dist_sq=9 (e.g., (3,0)) -> 4 neighbors
|
|
1108
|
-
- k=7: dist_sq=10 (e.g., (3,1)) -> 8 neighbors
|
|
1109
|
-
"""
|
|
1110
|
-
lattice = large_pbc_square_lattice
|
|
1111
|
-
# Pick an arbitrary site, e.g., index 0.
|
|
1112
|
-
site_idx = 0
|
|
1113
|
-
|
|
1114
|
-
# Expected coordination numbers for the first few shells.
|
|
1115
|
-
expected_coordinations = {1: 4, 2: 4, 3: 4, 4: 8, 5: 4, 6: 4, 7: 8}
|
|
1116
|
-
|
|
1117
|
-
for k, expected_count in expected_coordinations.items():
|
|
1118
|
-
neighbors = lattice.get_neighbors(site_idx, k=k)
|
|
1119
|
-
assert (
|
|
1120
|
-
len(neighbors) == expected_count
|
|
1121
|
-
), f"Failed for k={k}. Expected {expected_count}, got {len(neighbors)}"
|
|
1122
|
-
|
|
1123
|
-
def test_requesting_k_beyond_max_possible_shell(self, large_pbc_square_lattice):
|
|
1124
|
-
"""
|
|
1125
|
-
Tests that requesting a neighbor shell 'k' that is larger than any
|
|
1126
|
-
possible shell in the finite lattice returns an empty list, and does
|
|
1127
|
-
not raise an error.
|
|
1128
|
-
"""
|
|
1129
|
-
lattice = large_pbc_square_lattice
|
|
1130
|
-
site_idx = 0
|
|
1131
|
-
|
|
1132
|
-
# 1. First, find out the maximum number of shells that *do* exist.
|
|
1133
|
-
# We do this by calling _build_neighbors with a very large max_k.
|
|
1134
|
-
# This is a bit of "white-box" testing but necessary to find the true max k.
|
|
1135
|
-
lattice._build_neighbors(max_k=100)
|
|
1136
|
-
max_k_found = len(lattice._neighbor_maps)
|
|
1137
|
-
|
|
1138
|
-
# 2. Assert that the last valid shell is not empty.
|
|
1139
|
-
last_shell_neighbors = lattice.get_neighbors(site_idx, k=max_k_found)
|
|
1140
|
-
assert len(last_shell_neighbors) > 0
|
|
1141
|
-
|
|
1142
|
-
# 3. Assert that requesting a shell just beyond the last valid one returns empty.
|
|
1143
|
-
# This is the core of the test.
|
|
1144
|
-
non_existent_shell_neighbors = lattice.get_neighbors(
|
|
1145
|
-
site_idx, k=max_k_found + 1
|
|
1146
|
-
)
|
|
1147
|
-
assert non_existent_shell_neighbors == []
|
|
1148
|
-
|
|
1149
|
-
@patch("matplotlib.pyplot.subplots")
|
|
1150
|
-
def test_show_method_with_custom_bond_kwargs(
|
|
1151
|
-
self, mock_subplots, simple_square_lattice
|
|
1152
|
-
):
|
|
1153
|
-
"""
|
|
1154
|
-
Tests that .show() correctly uses the `bond_kwargs` parameter
|
|
1155
|
-
to customize the appearance of neighbor bonds.
|
|
1156
|
-
"""
|
|
1157
|
-
# Arrange:
|
|
1158
|
-
# 1. Set up mock Figure and Axes objects, similar to other show() tests.
|
|
1159
|
-
mock_fig = matplotlib.figure.Figure()
|
|
1160
|
-
mock_ax = matplotlib.axes.Axes(mock_fig, [0.0, 0.0, 1.0, 1.0])
|
|
1161
|
-
mock_subplots.return_value = (mock_fig, mock_ax)
|
|
1162
|
-
|
|
1163
|
-
# 2. Define our custom styles and the expected final styles.
|
|
1164
|
-
lattice = simple_square_lattice
|
|
1165
|
-
custom_bond_kwargs = {"color": "red", "linestyle": ":", "linewidth": 2}
|
|
1166
|
-
|
|
1167
|
-
# The final dictionary should contain the defaults updated by our custom arguments.
|
|
1168
|
-
expected_plot_kwargs = {
|
|
1169
|
-
"color": "red", # Overridden
|
|
1170
|
-
"linestyle": ":", # Overridden
|
|
1171
|
-
"linewidth": 2, # A new key
|
|
1172
|
-
"alpha": 0.6, # From default
|
|
1173
|
-
"zorder": 1, # From default
|
|
1174
|
-
}
|
|
1175
|
-
|
|
1176
|
-
# 3. We specifically mock the `plot` method on our mock `ax` object.
|
|
1177
|
-
with patch.object(mock_ax, "plot") as mock_plot_method:
|
|
1178
|
-
# Act:
|
|
1179
|
-
# Call the show method with our custom bond styles.
|
|
1180
|
-
lattice.show(show_bonds_k=1, bond_kwargs=custom_bond_kwargs)
|
|
1181
|
-
|
|
1182
|
-
# Assert:
|
|
1183
|
-
# Check that the plot method was called. For a 2x2 square, there are 4 NN bonds.
|
|
1184
|
-
assert mock_plot_method.call_count == 4
|
|
1185
|
-
|
|
1186
|
-
# Get the keyword arguments from the very first call to plot().
|
|
1187
|
-
# Note: call_args is a tuple (positional_args, keyword_args). We need the second element.
|
|
1188
|
-
actual_kwargs = mock_plot_method.call_args[1]
|
|
1189
|
-
|
|
1190
|
-
# Verify that the keyword arguments used for plotting match our expectations.
|
|
1191
|
-
assert actual_kwargs == expected_plot_kwargs
|
|
1192
|
-
|
|
1193
|
-
def test_mixed_boundary_conditions(self):
|
|
1194
|
-
"""
|
|
1195
|
-
Tests neighbor finding with mixed PBC (periodic in x, open in y).
|
|
1196
|
-
This verifies that the neighbor finding logic correctly handles
|
|
1197
|
-
anisotropy in periodic boundary conditions and returns sorted indices.
|
|
1198
|
-
"""
|
|
1199
|
-
# Arrange: Create a 3x3 square lattice, periodic in x, open in y.
|
|
1200
|
-
lattice = SquareLattice(size=(3, 3), pbc=(True, False))
|
|
1201
|
-
|
|
1202
|
-
# We will test a site on the corner of the open boundary: (0, 0)
|
|
1203
|
-
corner_site_idx = lattice.get_index((0, 0, 0))
|
|
1204
|
-
|
|
1205
|
-
# --- Test corner site (0, 0, 0), which is index 0 ---
|
|
1206
|
-
# Act
|
|
1207
|
-
corner_neighbors = lattice.get_neighbors(corner_site_idx, k=1)
|
|
1208
|
-
|
|
1209
|
-
# Assert: The expected neighbors are (1,0,0), (2,0,0) [periodic], and (0,1,0)
|
|
1210
|
-
# We get their indices and sort them to create the expected output.
|
|
1211
|
-
expected_indices = sorted(
|
|
1212
|
-
[
|
|
1213
|
-
lattice.get_index((1, 0, 0)), # Right neighbor
|
|
1214
|
-
lattice.get_index((2, 0, 0)), # "Left" neighbor (wraps around)
|
|
1215
|
-
lattice.get_index((0, 1, 0)), # "Up" neighbor
|
|
1216
|
-
]
|
|
1217
|
-
)
|
|
1218
|
-
|
|
1219
|
-
# The list returned by get_neighbors should be identical to our sorted list.
|
|
1220
|
-
assert (
|
|
1221
|
-
corner_neighbors == expected_indices
|
|
1222
|
-
), "Failed for corner site with mixed BC."
|
|
1223
|
-
|
|
1224
|
-
# --- Test middle site on the edge (1, 0, 0), which is index 1 ---
|
|
1225
|
-
edge_site_idx = lattice.get_index((1, 0, 0))
|
|
1226
|
-
|
|
1227
|
-
# Act
|
|
1228
|
-
edge_neighbors = lattice.get_neighbors(edge_site_idx, k=1)
|
|
1229
|
-
|
|
1230
|
-
# Assert
|
|
1231
|
-
expected_edge_indices = sorted(
|
|
1232
|
-
[
|
|
1233
|
-
lattice.get_index((0, 0, 0)), # Left neighbor
|
|
1234
|
-
lattice.get_index((2, 0, 0)), # Right neighbor
|
|
1235
|
-
lattice.get_index((1, 1, 0)), # "Up" neighbor
|
|
1236
|
-
]
|
|
1237
|
-
)
|
|
1238
|
-
assert (
|
|
1239
|
-
edge_neighbors == expected_edge_indices
|
|
1240
|
-
), "Failed for edge site with mixed BC."
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
class TestAllTILattices:
|
|
1244
|
-
"""
|
|
1245
|
-
A parameterized test class to verify the basic properties and coordination
|
|
1246
|
-
numbers for all implemented TILattice subclasses. This avoids code duplication.
|
|
1247
|
-
"""
|
|
1248
|
-
|
|
1249
|
-
# --- Test data in a structured and readable format ---
|
|
1250
|
-
# Format:
|
|
1251
|
-
# (
|
|
1252
|
-
# LatticeClass, # The lattice class to test
|
|
1253
|
-
# {"size": ..., ...}, # Arguments for the constructor
|
|
1254
|
-
# expected_num_sites, # Expected total number of sites
|
|
1255
|
-
# expected_num_basis, # Expected number of sites in the basis
|
|
1256
|
-
# {site_repr: count} # Dict of {representative_site: neighbor_count}
|
|
1257
|
-
# )
|
|
1258
|
-
# For `site_repr`:
|
|
1259
|
-
# - For simple lattices (basis=1), it's the integer index of the site.
|
|
1260
|
-
# - For composite lattices (basis>1), it's the *basis index* to test.
|
|
1261
|
-
lattice_test_cases = [
|
|
1262
|
-
# 1D Lattices
|
|
1263
|
-
(ChainLattice, {"size": (5,), "pbc": True}, 5, 1, {0: 2, 2: 2}),
|
|
1264
|
-
(ChainLattice, {"size": (5,), "pbc": False}, 5, 1, {0: 1, 2: 2}),
|
|
1265
|
-
(DimerizedChainLattice, {"size": (3,), "pbc": True}, 6, 2, {0: 2, 1: 2}),
|
|
1266
|
-
# 2D Lattices
|
|
1267
|
-
(
|
|
1268
|
-
RectangularLattice,
|
|
1269
|
-
{"size": (3, 4), "pbc": False},
|
|
1270
|
-
12,
|
|
1271
|
-
1,
|
|
1272
|
-
{5: 4, 0: 2, 4: 3},
|
|
1273
|
-
), # center, corner, edge
|
|
1274
|
-
(HoneycombLattice, {"size": (2, 2), "pbc": True}, 8, 2, {0: 3, 1: 3}),
|
|
1275
|
-
(TriangularLattice, {"size": (3, 3), "pbc": True}, 9, 1, {0: 6}),
|
|
1276
|
-
(CheckerboardLattice, {"size": (2, 2), "pbc": True}, 8, 2, {0: 4, 1: 4}),
|
|
1277
|
-
(KagomeLattice, {"size": (2, 2), "pbc": True}, 12, 3, {0: 4, 1: 4, 2: 4}),
|
|
1278
|
-
(LiebLattice, {"size": (2, 2), "pbc": True}, 12, 3, {0: 4, 1: 2, 2: 2}),
|
|
1279
|
-
# 3D Lattices
|
|
1280
|
-
(CubicLattice, {"size": (3, 3, 3), "pbc": True}, 27, 1, {0: 6, 13: 6}),
|
|
1281
|
-
]
|
|
1282
|
-
|
|
1283
|
-
@pytest.mark.parametrize(
|
|
1284
|
-
"LatticeClass, init_args, num_sites, num_basis, coordination_numbers",
|
|
1285
|
-
lattice_test_cases,
|
|
1286
|
-
)
|
|
1287
|
-
def test_lattice_properties_and_coordination(
|
|
1288
|
-
self,
|
|
1289
|
-
LatticeClass,
|
|
1290
|
-
init_args,
|
|
1291
|
-
num_sites,
|
|
1292
|
-
num_basis,
|
|
1293
|
-
coordination_numbers,
|
|
1294
|
-
):
|
|
1295
|
-
"""
|
|
1296
|
-
A single, parameterized test to validate all TILattice types.
|
|
1297
|
-
"""
|
|
1298
|
-
# --- Arrange ---
|
|
1299
|
-
# Create the lattice instance dynamically from the test data.
|
|
1300
|
-
lattice = LatticeClass(**init_args)
|
|
1301
|
-
|
|
1302
|
-
# --- Assert: Basic properties ---
|
|
1303
|
-
assert lattice.num_sites == num_sites
|
|
1304
|
-
assert lattice.num_basis == num_basis
|
|
1305
|
-
assert lattice.dimensionality == len(init_args["size"])
|
|
1306
|
-
|
|
1307
|
-
# --- Assert: Coordination numbers (nearest neighbors, k=1) ---
|
|
1308
|
-
for site_repr, expected_count in coordination_numbers.items():
|
|
1309
|
-
# This logic correctly gets the site index to test,
|
|
1310
|
-
# whether it's a simple or composite lattice.
|
|
1311
|
-
if lattice.num_basis > 1:
|
|
1312
|
-
# For composite lattices, site_repr is the basis_index.
|
|
1313
|
-
# We find the index of this basis site in the first unit cell.
|
|
1314
|
-
uc_coord = (0,) * lattice.dimensionality
|
|
1315
|
-
test_site_idx = lattice.get_index(uc_coord + (site_repr,))
|
|
1316
|
-
else:
|
|
1317
|
-
# For simple lattices, site_repr is the absolute site index.
|
|
1318
|
-
test_site_idx = site_repr
|
|
1319
|
-
|
|
1320
|
-
neighbors = lattice.get_neighbors(test_site_idx, k=1)
|
|
1321
|
-
assert len(neighbors) == expected_count
|
|
1322
|
-
if isinstance(LatticeClass, ChainLattice) and not init_args.get("pbc"):
|
|
1323
|
-
if test_site_idx == 0:
|
|
1324
|
-
assert 1 in neighbors
|
|
1325
|
-
|
|
1326
|
-
|
|
1327
|
-
class TestCustomizeLatticeDynamic:
|
|
1328
|
-
"""Tests the dynamic modification capabilities of CustomizeLattice."""
|
|
1329
|
-
|
|
1330
|
-
@pytest.fixture
|
|
1331
|
-
def initial_lattice(self) -> CustomizeLattice:
|
|
1332
|
-
"""Provides a basic 3-site lattice for modification tests."""
|
|
1333
|
-
return CustomizeLattice(
|
|
1334
|
-
dimensionality=2,
|
|
1335
|
-
identifiers=["A", "B", "C"],
|
|
1336
|
-
coordinates=[[0, 0], [1, 0], [0, 1]],
|
|
1337
|
-
)
|
|
1338
|
-
|
|
1339
|
-
def test_from_lattice_conversion(self):
|
|
1340
|
-
"""Tests creating a CustomizeLattice from a TILattice."""
|
|
1341
|
-
# Arrange
|
|
1342
|
-
sq_lattice = SquareLattice(size=(2, 2), pbc=False)
|
|
1343
|
-
|
|
1344
|
-
# Act
|
|
1345
|
-
custom_lattice = CustomizeLattice.from_lattice(sq_lattice)
|
|
1346
|
-
|
|
1347
|
-
# Assert
|
|
1348
|
-
assert isinstance(custom_lattice, CustomizeLattice)
|
|
1349
|
-
assert custom_lattice.num_sites == sq_lattice.num_sites
|
|
1350
|
-
assert custom_lattice.dimensionality == sq_lattice.dimensionality
|
|
1351
|
-
# Verify a site to be sure
|
|
1352
|
-
np.testing.assert_array_equal(
|
|
1353
|
-
custom_lattice.get_coordinates(3), sq_lattice.get_coordinates(3)
|
|
1354
|
-
)
|
|
1355
|
-
assert custom_lattice.get_identifier(3) == sq_lattice.get_identifier(3)
|
|
1356
|
-
|
|
1357
|
-
def test_add_sites_successfully(self, initial_lattice):
|
|
1358
|
-
"""Tests adding new, valid sites to the lattice."""
|
|
1359
|
-
# Arrange
|
|
1360
|
-
lat = initial_lattice
|
|
1361
|
-
assert lat.num_sites == 3
|
|
1362
|
-
|
|
1363
|
-
# Act
|
|
1364
|
-
lat.add_sites(identifiers=["D", "E"], coordinates=[[1, 1], [2, 2]])
|
|
1365
|
-
|
|
1366
|
-
# Assert
|
|
1367
|
-
assert lat.num_sites == 5
|
|
1368
|
-
assert lat.get_identifier(4) == "E"
|
|
1369
|
-
np.testing.assert_array_equal(lat.get_coordinates(3), np.array([1, 1]))
|
|
1370
|
-
assert "E" in lat._ident_to_idx
|
|
1371
|
-
|
|
1372
|
-
def test_remove_sites_successfully(self, initial_lattice):
|
|
1373
|
-
"""Tests removing existing sites from the lattice."""
|
|
1374
|
-
# Arrange
|
|
1375
|
-
lat = initial_lattice
|
|
1376
|
-
assert lat.num_sites == 3
|
|
1377
|
-
|
|
1378
|
-
# Act
|
|
1379
|
-
lat.remove_sites(identifiers=["A", "C"])
|
|
1380
|
-
|
|
1381
|
-
# Assert
|
|
1382
|
-
assert lat.num_sites == 1
|
|
1383
|
-
assert lat.get_identifier(0) == "B" # Site 'B' is now at index 0
|
|
1384
|
-
assert "A" not in lat._ident_to_idx
|
|
1385
|
-
np.testing.assert_array_equal(lat.get_coordinates(0), np.array([1, 0]))
|
|
1386
|
-
|
|
1387
|
-
def test_add_duplicate_identifier_raises_error(self, initial_lattice):
|
|
1388
|
-
"""Tests that adding a site with an existing identifier fails."""
|
|
1389
|
-
with pytest.raises(ValueError, match="Duplicate identifiers found"):
|
|
1390
|
-
initial_lattice.add_sites(identifiers=["A"], coordinates=[[9, 9]])
|
|
1391
|
-
|
|
1392
|
-
def test_remove_nonexistent_identifier_raises_error(self, initial_lattice):
|
|
1393
|
-
"""Tests that removing a non-existent site fails."""
|
|
1394
|
-
with pytest.raises(ValueError, match="Non-existent identifiers provided"):
|
|
1395
|
-
initial_lattice.remove_sites(identifiers=["Z"])
|
|
1396
|
-
|
|
1397
|
-
def test_modification_clears_neighbor_cache(self, initial_lattice):
|
|
1398
|
-
"""
|
|
1399
|
-
Tests that add_sites and remove_sites correctly invalidate the
|
|
1400
|
-
pre-computed neighbor map.
|
|
1401
|
-
"""
|
|
1402
|
-
# Arrange: Pre-compute neighbors on the initial lattice
|
|
1403
|
-
initial_lattice._build_neighbors(max_k=1)
|
|
1404
|
-
assert 0 in initial_lattice._neighbor_maps[1] # Check that neighbors exist
|
|
1405
|
-
|
|
1406
|
-
# Act 1: Add a site
|
|
1407
|
-
initial_lattice.add_sites(identifiers=["D"], coordinates=[[5, 5]])
|
|
1408
|
-
|
|
1409
|
-
# Assert 1: The neighbor map should now be empty
|
|
1410
|
-
assert not initial_lattice._neighbor_maps
|
|
1411
|
-
|
|
1412
|
-
# Arrange 2: Re-compute neighbors and then remove a site
|
|
1413
|
-
initial_lattice._build_neighbors(max_k=1)
|
|
1414
|
-
assert 0 in initial_lattice._neighbor_maps[1]
|
|
1415
|
-
|
|
1416
|
-
# Act 2: Remove a site
|
|
1417
|
-
initial_lattice.remove_sites(identifiers=["A"])
|
|
1418
|
-
|
|
1419
|
-
# Assert 2: The neighbor map should be empty again
|
|
1420
|
-
assert not initial_lattice._neighbor_maps
|
|
1421
|
-
|
|
1422
|
-
def test_modification_clears_distance_matrix_cache(self, initial_lattice):
|
|
1423
|
-
"""
|
|
1424
|
-
Tests that add_sites and remove_sites correctly invalidate the
|
|
1425
|
-
cached distance matrix and that the recomputed matrix is correct.
|
|
1426
|
-
"""
|
|
1427
|
-
# Arrange 1: Compute, cache, and perform a meaningful check on the original matrix.
|
|
1428
|
-
lat = initial_lattice
|
|
1429
|
-
original_matrix = lat.distance_matrix
|
|
1430
|
-
assert lat._distance_matrix is not None
|
|
1431
|
-
assert original_matrix.shape == (3, 3)
|
|
1432
|
-
# Meaningful check: distance from 'A'(idx 0) to 'B'(idx 1) should be 1.0
|
|
1433
|
-
np.testing.assert_allclose(original_matrix[0, 1], 1.0)
|
|
1434
|
-
|
|
1435
|
-
# Act 1: Add a site. This should invalidate the cache.
|
|
1436
|
-
lat.add_sites(identifiers=["D"], coordinates=[[1, 1]])
|
|
1437
|
-
|
|
1438
|
-
# Assert 1: Check cache is cleared and the new matrix is correct.
|
|
1439
|
-
assert lat._distance_matrix is None # Verify cache invalidation
|
|
1440
|
-
new_matrix_added = lat.distance_matrix
|
|
1441
|
-
assert new_matrix_added.shape == (4, 4)
|
|
1442
|
-
# Meaningful check: distance from 'B'(idx 1) to new site 'D'(idx 3) should be 1.0
|
|
1443
|
-
# Coords: B=[1,0], D=[1,1]
|
|
1444
|
-
np.testing.assert_allclose(new_matrix_added[1, 3], 1.0)
|
|
1445
|
-
|
|
1446
|
-
# Act 2: Remove a site. This should also invalidate the cache.
|
|
1447
|
-
lat.remove_sites(identifiers=["A"])
|
|
1448
|
-
|
|
1449
|
-
# Assert 2: Check cache is cleared again and the final matrix is correct.
|
|
1450
|
-
assert lat._distance_matrix is None # Verify cache invalidation
|
|
1451
|
-
final_matrix = lat.distance_matrix
|
|
1452
|
-
assert final_matrix.shape == (3, 3) # Now has 3 sites again
|
|
1453
|
-
# Meaningful check: After removing 'A', the sites are B, C, D.
|
|
1454
|
-
# 'B' is now at index 0 (coords [1,0])
|
|
1455
|
-
# 'C' is now at index 1 (coords [0,1])
|
|
1456
|
-
# 'D' is now at index 2 (coords [1,1])
|
|
1457
|
-
# Distance from new 'B' (idx 0) to new 'D' (idx 2) should be 1.0
|
|
1458
|
-
np.testing.assert_allclose(final_matrix[0, 2], 1.0)
|
|
1459
|
-
|
|
1460
|
-
def test_neighbor_finding_returns_sorted_list(self, simple_square_lattice):
|
|
1461
|
-
"""
|
|
1462
|
-
Ensures that the list of neighbors returned by get_neighbors is always sorted.
|
|
1463
|
-
This provides a stricter check than set-based comparisons.
|
|
1464
|
-
"""
|
|
1465
|
-
# Arrange
|
|
1466
|
-
lattice = simple_square_lattice
|
|
1467
|
-
|
|
1468
|
-
# Act
|
|
1469
|
-
# Get neighbors for the central site (index 1 in a 2x2 grid)
|
|
1470
|
-
# Expected neighbors are 0, 3.
|
|
1471
|
-
neighbors = lattice.get_neighbors(1, k=1)
|
|
1472
|
-
|
|
1473
|
-
# Assert
|
|
1474
|
-
# We compare directly against a pre-sorted list, not a set.
|
|
1475
|
-
# This will fail if the implementation returns [3, 0] instead of [0, 3].
|
|
1476
|
-
assert neighbors == [
|
|
1477
|
-
0,
|
|
1478
|
-
3,
|
|
1479
|
-
], "The neighbor list should be sorted in ascending order."
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
class TestDistanceMatrix:
|
|
1483
|
-
|
|
1484
|
-
# This is the upgraded, parameterized test.
|
|
1485
|
-
@pytest.mark.parametrize(
|
|
1486
|
-
# We define test scenarios as tuples:
|
|
1487
|
-
# (build_k, check_site_identifier, expected_dist_sq)
|
|
1488
|
-
# build_k: The number of neighbor shells to pre-build.
|
|
1489
|
-
# check_site_identifier: The identifier of a site whose distance from the origin we will check.
|
|
1490
|
-
# expected_dist_sq: The expected squared distance to that site.
|
|
1491
|
-
"build_k, check_site_identifier, expected_dist_sq",
|
|
1492
|
-
[
|
|
1493
|
-
# Scenario 1: The most common case. Build only NN (k=1), but check a NNN (k=2) distance.
|
|
1494
|
-
# A buggy cache would fail this.
|
|
1495
|
-
(1, (1, 1, 0), 2.0),
|
|
1496
|
-
# Scenario 2: Build up to k=2, but check a k=3 distance.
|
|
1497
|
-
(2, (2, 0, 0), 4.0),
|
|
1498
|
-
# Scenario 3: Build up to k=3, but check a k=4 distance.
|
|
1499
|
-
(3, (2, 1, 0), 5.0),
|
|
1500
|
-
# Scenario 4: A more complex, higher-order neighbor.
|
|
1501
|
-
(5, (3, 1, 0), 10.0),
|
|
1502
|
-
],
|
|
1503
|
-
)
|
|
1504
|
-
def test_tilattice_full_pbc_distance_matrix_is_correct_regardless_of_build_k(
|
|
1505
|
-
self, build_k, check_site_identifier, expected_dist_sq
|
|
1506
|
-
):
|
|
1507
|
-
"""
|
|
1508
|
-
Tests that the distance matrix for a fully periodic TILattice is
|
|
1509
|
-
always fully correct, no matter how many neighbor shells were pre-calculated.
|
|
1510
|
-
This is a high-strength test designed to catch subtle caching bugs where
|
|
1511
|
-
the cached matrix might only contain partial information.
|
|
1512
|
-
"""
|
|
1513
|
-
# Arrange
|
|
1514
|
-
# Using a larger, non-square lattice to avoid accidental symmetries
|
|
1515
|
-
lat = SquareLattice(size=(7, 9), pbc=True)
|
|
1516
|
-
|
|
1517
|
-
# Act
|
|
1518
|
-
# Step 1: Pre-build neighbors. This is where a faulty caching
|
|
1519
|
-
# mechanism in the source code might be triggered.
|
|
1520
|
-
lat._build_neighbors(max_k=build_k)
|
|
1521
|
-
|
|
1522
|
-
# Step 2: Access the distance_matrix property. A correct implementation
|
|
1523
|
-
# will return a fully valid matrix.
|
|
1524
|
-
dist_matrix = lat.distance_matrix
|
|
1525
|
-
|
|
1526
|
-
# Assert
|
|
1527
|
-
# Find the indices for the sites we want to check.
|
|
1528
|
-
origin_idx = lat.get_index((0, 0, 0))
|
|
1529
|
-
check_site_idx = lat.get_index(check_site_identifier)
|
|
1530
|
-
|
|
1531
|
-
# The core assertion: check the distance.
|
|
1532
|
-
actual_dist_sq = dist_matrix[origin_idx, check_site_idx] ** 2
|
|
1533
|
-
|
|
1534
|
-
error_message = (
|
|
1535
|
-
f"Distance matrix failed when building k={build_k}. "
|
|
1536
|
-
f"Checking distance to site {check_site_identifier} (expected sq={expected_dist_sq}) "
|
|
1537
|
-
f"but got sq={actual_dist_sq} instead."
|
|
1538
|
-
)
|
|
1539
|
-
|
|
1540
|
-
np.testing.assert_allclose(
|
|
1541
|
-
actual_dist_sq, expected_dist_sq, err_msg=error_message
|
|
1542
|
-
)
|
|
1543
|
-
|
|
1544
|
-
def test_tilattice_mixed_bc_distance_matrix_is_correct(self):
|
|
1545
|
-
"""
|
|
1546
|
-
Tests that the distance matrix is correctly calculated for a TILattice
|
|
1547
|
-
with mixed boundary conditions (e.g., periodic in x, open in y).
|
|
1548
|
-
"""
|
|
1549
|
-
# Arrange
|
|
1550
|
-
# pbc=(True, False) means periodic along x-axis, open along y-axis.
|
|
1551
|
-
lat = SquareLattice(size=(5, 5), pbc=(True, False))
|
|
1552
|
-
|
|
1553
|
-
# Pre-build neighbors to engage the caching logic.
|
|
1554
|
-
lat._build_neighbors(max_k=2)
|
|
1555
|
-
dist_matrix = lat.distance_matrix
|
|
1556
|
-
|
|
1557
|
-
# Assert
|
|
1558
|
-
origin_idx = lat.get_index((0, 0, 0))
|
|
1559
|
-
|
|
1560
|
-
# 1. Test a distance affected by the periodic boundary (x-direction)
|
|
1561
|
-
# The distance between (0,0) and (4,0) should be 1.0 due to PBC wrap-around.
|
|
1562
|
-
pbc_neighbor_idx = lat.get_index((4, 0, 0))
|
|
1563
|
-
np.testing.assert_allclose(dist_matrix[origin_idx, pbc_neighbor_idx], 1.0)
|
|
1564
|
-
|
|
1565
|
-
# 2. Test a distance affected by the open boundary (y-direction)
|
|
1566
|
-
# The distance between (0,0) and (0,4) should be 4.0 as there's no wrap-around.
|
|
1567
|
-
obc_neighbor_idx = lat.get_index((0, 4, 0))
|
|
1568
|
-
np.testing.assert_allclose(dist_matrix[origin_idx, obc_neighbor_idx], 4.0)
|
|
1569
|
-
|
|
1570
|
-
# 3. Test a general, off-axis point.
|
|
1571
|
-
# Distance from (0,0) to (3,3) with x-pbc. The x-distance is min(3, 5-3=2) = 2.
|
|
1572
|
-
# The y-distance is 3. So total distance is sqrt(2^2 + 3^2) = sqrt(13).
|
|
1573
|
-
general_neighbor_idx = lat.get_index((3, 3, 0))
|
|
1574
|
-
np.testing.assert_allclose(
|
|
1575
|
-
dist_matrix[origin_idx, general_neighbor_idx], np.sqrt(13)
|
|
1576
|
-
)
|
|
1577
|
-
|
|
1578
|
-
# --- This list and the following test are now at the correct indentation level ---
|
|
1579
|
-
lattice_instances_for_invariant_test = [
|
|
1580
|
-
SquareLattice(size=(4, 4), pbc=True),
|
|
1581
|
-
SquareLattice(size=(4, 3), pbc=(True, False)), # Mixed BC, non-square
|
|
1582
|
-
HoneycombLattice(size=(3, 3), pbc=True),
|
|
1583
|
-
TriangularLattice(size=(4, 4), pbc=False),
|
|
1584
|
-
CustomizeLattice(
|
|
1585
|
-
dimensionality=2,
|
|
1586
|
-
identifiers=list(range(4)),
|
|
1587
|
-
coordinates=[[0, 0], [1, 1], [0, 1], [1, 0]],
|
|
1588
|
-
),
|
|
1589
|
-
]
|
|
1590
|
-
|
|
1591
|
-
@pytest.mark.parametrize("lattice", lattice_instances_for_invariant_test)
|
|
1592
|
-
def test_distance_matrix_invariants_for_all_lattice_types(self, lattice):
|
|
1593
|
-
"""
|
|
1594
|
-
Tests that the distance matrix for any lattice type adheres to
|
|
1595
|
-
fundamental mathematical properties (invariants): symmetry, zero diagonal,
|
|
1596
|
-
and positive off-diagonal elements.
|
|
1597
|
-
"""
|
|
1598
|
-
# Arrange
|
|
1599
|
-
n = lattice.num_sites
|
|
1600
|
-
if n < 2:
|
|
1601
|
-
pytest.skip("Invariant test requires at least 2 sites.")
|
|
1602
|
-
|
|
1603
|
-
# Act
|
|
1604
|
-
# We call the property directly, without building neighbors first,
|
|
1605
|
-
# to test the on-demand computation path.
|
|
1606
|
-
matrix = lattice.distance_matrix
|
|
1607
|
-
|
|
1608
|
-
# Assert
|
|
1609
|
-
# 1. Symmetry: The matrix must be equal to its transpose.
|
|
1610
|
-
np.testing.assert_allclose(
|
|
1611
|
-
matrix,
|
|
1612
|
-
matrix.T,
|
|
1613
|
-
err_msg=f"Distance matrix for {type(lattice).__name__} is not symmetric.",
|
|
1614
|
-
)
|
|
1615
|
-
|
|
1616
|
-
# 2. Zero Diagonal: All diagonal elements must be zero.
|
|
1617
|
-
np.testing.assert_allclose(
|
|
1618
|
-
np.diag(matrix),
|
|
1619
|
-
np.zeros(n),
|
|
1620
|
-
err_msg=f"Diagonal of distance matrix for {type(lattice).__name__} is not zero.",
|
|
1621
|
-
)
|
|
1622
|
-
|
|
1623
|
-
# 3. Positive Off-diagonal: All non-diagonal elements must be > 0.
|
|
1624
|
-
# We create a boolean mask for the off-diagonal elements.
|
|
1625
|
-
off_diagonal_mask = ~np.eye(n, dtype=bool)
|
|
1626
|
-
assert np.all(
|
|
1627
|
-
matrix[off_diagonal_mask] > 1e-9
|
|
1628
|
-
), f"Found non-positive off-diagonal elements in distance matrix for {type(lattice).__name__}."
|
|
1629
|
-
|
|
1630
|
-
|
|
1631
|
-
# @pytest.mark.slow
|
|
1632
|
-
# class TestPerformance:
|
|
1633
|
-
# def test_pbc_implementation_is_not_significantly_slower_than_obc(self):
|
|
1634
|
-
# """
|
|
1635
|
-
# A performance regression test.
|
|
1636
|
-
# It ensures that the specialized implementation for fully periodic
|
|
1637
|
-
# lattices (pbc=True) is not substantially slower than the general
|
|
1638
|
-
# implementation used for open boundaries (pbc=False).
|
|
1639
|
-
# This test will FAIL with the current code, exposing the performance bug.
|
|
1640
|
-
# """
|
|
1641
|
-
# # Arrange: Use a large-enough lattice to make performance differences apparent
|
|
1642
|
-
# size = (30, 30)
|
|
1643
|
-
# k = 1
|
|
1644
|
-
|
|
1645
|
-
# # Act 1: Measure the execution time of the general (OBC) implementation
|
|
1646
|
-
# start_time_obc = time.time()
|
|
1647
|
-
# _ = SquareLattice(size=size, pbc=False, precompute_neighbors=k)
|
|
1648
|
-
# duration_obc = time.time() - start_time_obc
|
|
1649
|
-
|
|
1650
|
-
# # Act 2: Measure the execution time of the specialized (PBC) implementation
|
|
1651
|
-
# start_time_pbc = time.time()
|
|
1652
|
-
# _ = SquareLattice(size=size, pbc=True, precompute_neighbors=k)
|
|
1653
|
-
# duration_pbc = time.time() - start_time_pbc
|
|
1654
|
-
|
|
1655
|
-
# print(
|
|
1656
|
-
# f"\n[Performance] OBC ({size}): {duration_obc:.4f}s | PBC ({size}): {duration_pbc:.4f}s"
|
|
1657
|
-
# )
|
|
1658
|
-
|
|
1659
|
-
# # Assert: The PBC implementation should not be drastically slower.
|
|
1660
|
-
# # We allow it to be up to 3 times slower to account for minor overheads,
|
|
1661
|
-
# # but this will catch the current 10x+ regression.
|
|
1662
|
-
# # THIS ASSERTION WILL FAIL with the current buggy code.
|
|
1663
|
-
# assert duration_pbc < duration_obc * 5, (
|
|
1664
|
-
# "The specialized PBC implementation is significantly slower "
|
|
1665
|
-
# "than the general-purpose implementation."
|
|
1666
|
-
# )
|