tensorcircuit-nightly 1.3.0.dev20250809__py3-none-any.whl → 1.3.0.dev20250810__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

Files changed (37) hide show
  1. tensorcircuit/__init__.py +1 -1
  2. {tensorcircuit_nightly-1.3.0.dev20250809.dist-info → tensorcircuit_nightly-1.3.0.dev20250810.dist-info}/METADATA +1 -1
  3. {tensorcircuit_nightly-1.3.0.dev20250809.dist-info → tensorcircuit_nightly-1.3.0.dev20250810.dist-info}/RECORD +6 -37
  4. {tensorcircuit_nightly-1.3.0.dev20250809.dist-info → tensorcircuit_nightly-1.3.0.dev20250810.dist-info}/top_level.txt +0 -1
  5. tests/__init__.py +0 -0
  6. tests/conftest.py +0 -67
  7. tests/test_backends.py +0 -1156
  8. tests/test_calibrating.py +0 -149
  9. tests/test_channels.py +0 -409
  10. tests/test_circuit.py +0 -1713
  11. tests/test_cloud.py +0 -219
  12. tests/test_compiler.py +0 -147
  13. tests/test_dmcircuit.py +0 -555
  14. tests/test_ensemble.py +0 -72
  15. tests/test_fgs.py +0 -318
  16. tests/test_gates.py +0 -156
  17. tests/test_hamiltonians.py +0 -159
  18. tests/test_interfaces.py +0 -557
  19. tests/test_keras.py +0 -160
  20. tests/test_lattice.py +0 -1750
  21. tests/test_miscs.py +0 -304
  22. tests/test_mpscircuit.py +0 -341
  23. tests/test_noisemodel.py +0 -156
  24. tests/test_qaoa.py +0 -86
  25. tests/test_qem.py +0 -152
  26. tests/test_quantum.py +0 -549
  27. tests/test_quantum_attr.py +0 -42
  28. tests/test_results.py +0 -379
  29. tests/test_shadows.py +0 -160
  30. tests/test_simplify.py +0 -46
  31. tests/test_stabilizer.py +0 -226
  32. tests/test_templates.py +0 -218
  33. tests/test_timeevol.py +0 -641
  34. tests/test_torchnn.py +0 -99
  35. tests/test_van.py +0 -102
  36. {tensorcircuit_nightly-1.3.0.dev20250809.dist-info → tensorcircuit_nightly-1.3.0.dev20250810.dist-info}/WHEEL +0 -0
  37. {tensorcircuit_nightly-1.3.0.dev20250809.dist-info → tensorcircuit_nightly-1.3.0.dev20250810.dist-info}/licenses/LICENSE +0 -0
tests/test_backends.py DELETED
@@ -1,1156 +0,0 @@
1
- # pylint: disable=invalid-name
2
-
3
- import sys
4
- import os
5
- from functools import partial
6
-
7
- os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
8
-
9
- import numpy as np
10
- import pytest
11
- from pytest_lazyfixture import lazy_fixture as lf
12
- import scipy
13
- import tensorflow as tf
14
-
15
- thisfile = os.path.abspath(__file__)
16
- modulepath = os.path.dirname(os.path.dirname(thisfile))
17
-
18
- sys.path.insert(0, modulepath)
19
- import tensorcircuit as tc
20
-
21
- dtype = np.complex64
22
-
23
-
24
- def universal_vmap():
25
- def sum_real(x, y):
26
- return tc.backend.real(x + y)
27
-
28
- vop = tc.backend.vmap(sum_real, vectorized_argnums=(0, 1))
29
- t = tc.gates.array_to_tensor(np.ones([20, 1]))
30
- return vop(t, 2.0 * t)
31
-
32
-
33
- def test_vmap_np():
34
- r = universal_vmap()
35
- assert r.shape == (20, 1)
36
-
37
-
38
- def test_vmap_jax(jaxb):
39
- r = universal_vmap()
40
- assert r.shape == (20, 1)
41
-
42
-
43
- def test_vmap_tf(tfb):
44
- r = universal_vmap()
45
- assert r.numpy()[0, 0] == 3.0
46
-
47
-
48
- def test_vmap_torch(torchb):
49
- r = universal_vmap()
50
- assert r.numpy()[0, 0] == 3.0
51
-
52
-
53
- def test_grad_torch(torchb):
54
- a = tc.backend.ones([2], dtype="float32")
55
-
56
- # @partial(tc.backend.jit, jit_compile=True)
57
- @tc.backend.grad
58
- def f(x):
59
- return tc.backend.sum(x)
60
-
61
- np.testing.assert_allclose(f(a), np.ones([2]), atol=1e-5)
62
-
63
-
64
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
65
- def test_sparse_csr_from_coo(backend):
66
- # Create a sparse matrix in COO format
67
- values = tc.backend.convert_to_tensor(np.array([1.0, 2.0, 3.0]))
68
- values = tc.backend.cast(values, "complex64")
69
- indices = tc.backend.convert_to_tensor(np.array([[0, 0], [1, 1], [2, 3]]))
70
- indices = tc.backend.cast(indices, "int64")
71
- coo_matrix = tc.backend.coo_sparse_matrix(indices, values, shape=[4, 4])
72
-
73
- # Convert COO to CSR
74
- csr_matrix = tc.backend.sparse_csr_from_coo(coo_matrix)
75
-
76
- # Check that the result is still recognized as sparse
77
- assert tc.backend.is_sparse(csr_matrix) is True
78
-
79
- # Check that the conversion preserves values by comparing dense representations
80
- coo_dense = tc.backend.to_dense(coo_matrix)
81
- csr_dense = tc.backend.to_dense(csr_matrix)
82
- np.testing.assert_allclose(coo_dense, csr_dense, atol=1e-5)
83
-
84
-
85
- def test_sparse_tensor_matmul_monkey_patch(tfb):
86
- """
87
- Test the monkey-patched __matmul__ method for tf.SparseTensor.
88
- This test specifically targets the line:
89
- tf.SparseTensor.__matmul__ = sparse_tensor_matmul
90
- """
91
- # Create a sparse matrix in COO format
92
- indices = tf.constant([[0, 0], [1, 1], [2, 3]], dtype=tf.int64)
93
- values = tf.constant([1.0, 2.0, 3.0], dtype=tf.complex64)
94
- shape = [4, 4]
95
- sparse_matrix = tf.SparseTensor(indices=indices, values=values, dense_shape=shape)
96
-
97
- # Test 1: Matrix-vector multiplication with 1D vector
98
- vector_1d = tf.constant([1.0, 2.0, 3.0, 4.0], dtype=tf.complex64)
99
- result_1d = sparse_matrix @ vector_1d # Using the monkey-patched @ operator
100
-
101
- expected_1d = tf.constant([1.0, 4.0, 12.0, 0.0], dtype=tf.complex64)
102
-
103
- np.testing.assert_allclose(result_1d, expected_1d, atol=1e-6)
104
- vector_1d = tc.backend.reshape(vector_1d, [4, 1])
105
- result_1dn = sparse_matrix @ vector_1d # Using the monkey-patched @ operator
106
- expected_1d = tc.backend.reshape(expected_1d, [4, 1])
107
-
108
- np.testing.assert_allclose(result_1dn, expected_1d, atol=1e-6)
109
-
110
- # Test 2: Matrix-matrix multiplication with 2D matrix
111
- matrix_2d = tf.constant(
112
- [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]], dtype=tf.complex64
113
- )
114
- result_2d = sparse_matrix @ matrix_2d # Using the monkey-patched @ operator
115
-
116
- expected_2d = tf.sparse.sparse_dense_matmul(sparse_matrix, matrix_2d)
117
-
118
- np.testing.assert_allclose(result_2d.numpy(), expected_2d.numpy(), atol=1e-6)
119
-
120
- # Test 3: Verify that the operation is consistent with sparse_dense_matmul
121
-
122
- reference_result = tc.backend.sparse_dense_matmul(sparse_matrix, vector_1d)
123
- reference_result_squeezed = tc.backend.reshape(reference_result, [-1])
124
-
125
- np.testing.assert_allclose(result_1d, reference_result_squeezed, atol=1e-6)
126
-
127
-
128
- @pytest.mark.parametrize("backend", [lf("npb"), lf("jaxb")])
129
- def test_backend_jv(backend, highp):
130
- def calculate_M(k, x_val):
131
- safety_factor = 15
132
- M = max(k, int(abs(x_val))) + int(safety_factor * np.sqrt(abs(x_val)))
133
- M = max(M, k + 30)
134
- return M
135
-
136
- k_values = [5, 20, 50, 200, 500, 3000]
137
- x_values = [0.0, 0.1, 1.0, 10.0, 100, 1000, 6000]
138
- for k in k_values:
139
- for x_val in x_values:
140
- M = calculate_M(k, x_val)
141
- f_vals = tc.backend.special_jv(k, x_val, M)
142
- np.testing.assert_allclose(
143
- f_vals, scipy.special.jv(np.arange(k), x_val), atol=1e-6
144
- )
145
-
146
-
147
- @pytest.mark.parametrize("backend", [lf("npb"), lf("jaxb")])
148
- def test_backend_jaxy_scan(backend):
149
- def body_fun(carry, x):
150
- counter, decrementor = carry
151
-
152
- # 更新状态
153
- new_counter = counter + 1
154
- new_decrementor = decrementor - 1
155
- new_carry = (new_counter, new_decrementor)
156
-
157
- y = counter + decrementor
158
-
159
- return new_carry, y
160
-
161
- init_val = (0, 100)
162
-
163
- final_carry, stacked_ys = tc.backend.jaxy_scan(
164
- f=body_fun,
165
- init=init_val,
166
- xs=tc.backend.arange(5),
167
- )
168
-
169
- expected_final_carry = (5, 95)
170
- expected_stacked_ys = np.array([100, 100, 100, 100, 100])
171
-
172
- assert final_carry == expected_final_carry
173
-
174
- np.testing.assert_array_equal(np.asarray(stacked_ys), expected_stacked_ys)
175
-
176
-
177
- def test_backend_jv_grad(jaxb, highp):
178
- def f(x):
179
- return tc.backend.sum(tc.backend.special_jv(5, x, 100))
180
-
181
- print(tc.backend.jit(tc.backend.value_and_grad(f))(0.2))
182
-
183
-
184
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
185
- def test_backend_scatter(backend):
186
- np.testing.assert_allclose(
187
- tc.backend.scatter(
188
- tc.array_to_tensor(np.arange(8), dtype="int32"),
189
- tc.array_to_tensor(np.array([[1], [4]]), dtype="int32"),
190
- tc.array_to_tensor(np.array([0, 0]), dtype="int32"),
191
- ),
192
- np.array([0, 0, 2, 3, 0, 5, 6, 7]),
193
- atol=1e-4,
194
- )
195
- np.testing.assert_allclose(
196
- tc.backend.scatter(
197
- tc.array_to_tensor(np.arange(8).reshape([2, 4]), dtype="int32"),
198
- tc.array_to_tensor(np.array([[0, 2], [1, 2], [1, 3]]), dtype="int32"),
199
- tc.array_to_tensor(np.array([0, 99, 0]), dtype="int32"),
200
- ),
201
- np.array([[0, 1, 0, 3], [4, 5, 99, 0]]),
202
- atol=1e-4,
203
- )
204
- answer = np.arange(8).reshape([2, 2, 2])
205
- answer[0, 1, 0] = 99
206
- np.testing.assert_allclose(
207
- tc.backend.scatter(
208
- tc.array_to_tensor(np.arange(8).reshape([2, 2, 2]), dtype="int32"),
209
- tc.array_to_tensor(np.array([[0, 1, 0]]), dtype="int32"),
210
- tc.array_to_tensor(np.array([99]), dtype="int32"),
211
- ),
212
- answer,
213
- atol=1e-4,
214
- )
215
-
216
-
217
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
218
- def test_backend_methods(backend):
219
- # TODO(@refraction-ray): add more methods
220
- np.testing.assert_allclose(
221
- tc.backend.softmax(tc.array_to_tensor(np.ones([3, 2]), dtype="float32")),
222
- np.ones([3, 2]) / 6.0,
223
- atol=1e-4,
224
- )
225
-
226
- arr = np.random.normal(size=(6, 6))
227
-
228
- np.testing.assert_allclose(
229
- tc.backend.adjoint(tc.array_to_tensor(arr + 1.0j * arr)),
230
- arr.T - 1.0j * arr.T,
231
- atol=1e-4,
232
- )
233
-
234
- arr = tc.backend.zeros([5], dtype="float32")
235
- np.testing.assert_allclose(
236
- tc.backend.sigmoid(arr),
237
- tc.backend.ones([5]) * 0.5,
238
- atol=1e-4,
239
- )
240
- ans = np.array([[1, 0.5j], [-0.5j, 1]])
241
- ans2 = ans @ ans
242
- ansp = tc.backend.sqrtmh(tc.array_to_tensor(ans2))
243
- # print(ansp @ ansp, ans @ ans)
244
- np.testing.assert_allclose(ansp @ ansp, ans @ ans, atol=1e-4)
245
- singularm = np.array([[4.0, 0], [0, -1e-3]])
246
- np.testing.assert_allclose(
247
- tc.backend.sqrtmh(singularm, psd=True), np.array([[2.0, 0], [0, 0]]), atol=1e-5
248
- )
249
-
250
- np.testing.assert_allclose(
251
- tc.backend.sum(tc.array_to_tensor(np.arange(4))), 6, atol=1e-4
252
- )
253
-
254
- indices = np.array([[1, 2], [0, 1]])
255
- ans = np.array([[[0, 1, 0], [0, 0, 1]], [[1, 0, 0], [0, 1, 0]]])
256
- np.testing.assert_allclose(tc.backend.one_hot(indices, 3), ans, atol=1e-4)
257
-
258
- a = tc.array_to_tensor(np.array([1, 1, 3, 2, 2, 1]), dtype="int32")
259
- np.testing.assert_allclose(tc.backend.unique_with_counts(a)[0].shape[0], 3)
260
-
261
- np.testing.assert_allclose(
262
- tc.backend.cumsum(tc.array_to_tensor(np.array([[0.2, 0.2], [0.2, 0.4]]))),
263
- np.array([0.2, 0.4, 0.6, 1.0]),
264
- atol=1e-4,
265
- )
266
-
267
- np.testing.assert_allclose(
268
- tc.backend.max(tc.backend.ones([2, 2], "float32")), 1.0, atol=1e-4
269
- )
270
- np.testing.assert_allclose(
271
- tc.backend.min(
272
- tc.backend.cast(
273
- tc.backend.convert_to_tensor(np.array([[1.0, 2.0], [2.0, 3.0]])),
274
- "float64",
275
- ),
276
- axis=1,
277
- ),
278
- np.array([1.0, 2.0]),
279
- atol=1e-4,
280
- ) # by default no keepdim
281
-
282
- np.testing.assert_allclose(
283
- tc.backend.concat([tc.backend.ones([2, 2]), tc.backend.ones([1, 2])]),
284
- tc.backend.ones([3, 2]),
285
- atol=1e-5,
286
- )
287
-
288
- np.testing.assert_allclose(
289
- tc.backend.gather1d(
290
- tc.array_to_tensor(np.array([0, 1, 2])),
291
- tc.array_to_tensor(np.array([2, 1, 0]), dtype="int32"),
292
- ),
293
- np.array([2, 1, 0]),
294
- atol=1e-5,
295
- )
296
-
297
- def sum_(carry, x):
298
- return carry + x
299
-
300
- r = tc.backend.scan(sum_, tc.backend.ones([10, 2]), tc.backend.zeros([2]))
301
- np.testing.assert_allclose(r, 10 * np.ones([2]), atol=1e-5)
302
-
303
-
304
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
305
- def test_backend_methods_2(backend):
306
- np.testing.assert_allclose(tc.backend.mean(tc.backend.ones([10])), 1.0, atol=1e-5)
307
- # acos acosh asin asinh atan atan2 atanh cosh (cos) tan tanh sinh (sin)
308
- np.testing.assert_allclose(
309
- tc.backend.acos(tc.backend.ones([2], dtype="float32")),
310
- np.arccos(tc.backend.ones([2])),
311
- atol=1e-5,
312
- )
313
- np.testing.assert_allclose(
314
- tc.backend.acosh(tc.backend.ones([2], dtype="float32")),
315
- np.arccosh(tc.backend.ones([2])),
316
- atol=1e-5,
317
- )
318
- np.testing.assert_allclose(
319
- tc.backend.asin(tc.backend.ones([2], dtype="float32")),
320
- np.arcsin(tc.backend.ones([2])),
321
- atol=1e-5,
322
- )
323
- np.testing.assert_allclose(
324
- tc.backend.asinh(tc.backend.ones([2], dtype="float32")),
325
- np.arcsinh(tc.backend.ones([2])),
326
- atol=1e-5,
327
- )
328
- np.testing.assert_allclose(
329
- tc.backend.atan(0.5 * tc.backend.ones([2], dtype="float32")),
330
- np.arctan(0.5 * tc.backend.ones([2])),
331
- atol=1e-5,
332
- )
333
- np.testing.assert_allclose(
334
- tc.backend.atan2(
335
- tc.backend.ones([1], dtype="float32"), tc.backend.ones([1], dtype="float32")
336
- ),
337
- np.arctan2(
338
- tc.backend.ones([1], dtype="float32"), tc.backend.ones([1], dtype="float32")
339
- ),
340
- atol=1e-5,
341
- )
342
- np.testing.assert_allclose(
343
- tc.backend.atanh(0.5 * tc.backend.ones([2], dtype="float32")),
344
- np.arctanh(0.5 * tc.backend.ones([2])),
345
- atol=1e-5,
346
- )
347
- np.testing.assert_allclose(
348
- tc.backend.cosh(tc.backend.ones([2], dtype="float32")),
349
- np.cosh(tc.backend.ones([2])),
350
- atol=1e-5,
351
- )
352
- np.testing.assert_allclose(
353
- tc.backend.tan(tc.backend.ones([2], dtype="float32")),
354
- np.tan(tc.backend.ones([2])),
355
- atol=1e-5,
356
- )
357
- np.testing.assert_allclose(
358
- tc.backend.tanh(tc.backend.ones([2], dtype="float32")),
359
- np.tanh(tc.backend.ones([2])),
360
- atol=1e-5,
361
- )
362
- np.testing.assert_allclose(
363
- tc.backend.sinh(0.5 * tc.backend.ones([2], dtype="float32")),
364
- np.sinh(0.5 * tc.backend.ones([2])),
365
- atol=1e-5,
366
- )
367
- np.testing.assert_allclose(
368
- tc.backend.eigvalsh(tc.backend.ones([2, 2])), np.array([0, 2]), atol=1e-5
369
- )
370
- np.testing.assert_allclose(
371
- tc.backend.left_shift(
372
- tc.backend.convert_to_tensor(np.array([4, 3])),
373
- tc.backend.convert_to_tensor(np.array([1, 1])),
374
- ),
375
- np.array([8, 6]),
376
- )
377
- np.testing.assert_allclose(
378
- tc.backend.right_shift(
379
- tc.backend.convert_to_tensor(np.array([4, 3])),
380
- tc.backend.convert_to_tensor(np.array([1, 1])),
381
- ),
382
- np.array([2, 1]),
383
- )
384
- np.testing.assert_allclose(
385
- tc.backend.mod(
386
- tc.backend.convert_to_tensor(np.array([4, 3])),
387
- tc.backend.convert_to_tensor(np.array([2, 2])),
388
- ),
389
- np.array([0, 1]),
390
- )
391
- np.testing.assert_allclose(
392
- tc.backend.arange(3),
393
- np.array([0, 1, 2]),
394
- )
395
- np.testing.assert_allclose(
396
- tc.backend.arange(1, 5, 2),
397
- np.array([1, 3]),
398
- )
399
- assert tc.backend.dtype(tc.backend.ones([])) == "complex64"
400
- edges = [-1, 3.3, 9.1, 10.0]
401
- values = tc.backend.convert_to_tensor(np.array([0.0, 4.1, 12.0], dtype=np.float32))
402
- r = tc.backend.numpy(tc.backend.searchsorted(edges, values))
403
- np.testing.assert_allclose(r, np.array([1, 2, 4]))
404
- p = tc.backend.convert_to_tensor(
405
- np.array(
406
- [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.2, 0.4], dtype=np.float32
407
- )
408
- )
409
- r = tc.backend.probability_sample(10000, p, status=np.random.uniform(size=[10000]))
410
- _, r = np.unique(r, return_counts=True)
411
- np.testing.assert_allclose(
412
- r - tc.backend.numpy(p) * 10000.0, np.zeros([10]), atol=200, rtol=1
413
- )
414
- np.testing.assert_allclose(
415
- tc.backend.std(tc.backend.cast(tc.backend.arange(1, 4), "float32")),
416
- 0.81649658,
417
- atol=1e-5,
418
- )
419
- arr = np.random.normal(size=(6, 6))
420
- np.testing.assert_allclose(
421
- tc.backend.relu(tc.array_to_tensor(arr, dtype="float32")),
422
- np.maximum(arr, 0),
423
- atol=1e-4,
424
- )
425
- np.testing.assert_allclose(
426
- tc.backend.det(tc.backend.convert_to_tensor(np.eye(3) * 2)), 8, atol=1e-5
427
- )
428
-
429
-
430
- # @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
431
- # def test_backend_array(backend):
432
- # a = tc.backend.array([[0, 1], [1, 0]])
433
- # assert tc.interfaces.which_backend(a).name == tc.backend.name
434
- # a = tc.backend.array([[0, 1], [1, 0]], dtype=tc.rdtypestr)
435
- # assert tc.dtype(a) == "float32"
436
-
437
-
438
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
439
- def test_device_cpu_only(backend):
440
- a = tc.backend.ones([])
441
- dev_str = tc.backend.device(a)
442
- assert dev_str in ["cpu", "gpu:0"]
443
- tc.backend.device_move(a, dev_str)
444
-
445
-
446
- @pytest.mark.skipif(
447
- len(tf.config.list_physical_devices()) == 1, reason="no GPU detected"
448
- )
449
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
450
- def test_device_cpu_gpu(backend):
451
- a = tc.backend.ones([])
452
- a1 = tc.backend.device_move(a, "gpu:0")
453
- dev_str = tc.backend.device(a1)
454
- assert dev_str == "gpu:0"
455
- a2 = tc.backend.device_move(a1, "cpu")
456
- assert tc.backend.device(a2) == "cpu"
457
-
458
-
459
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
460
- def test_dlpack(backend):
461
- a = tc.backend.ones([2, 2], dtype="float64")
462
- cap = tc.backend.to_dlpack(a)
463
- a1 = tc.backend.from_dlpack(cap)
464
- np.testing.assert_allclose(a, a1, atol=1e-5)
465
-
466
-
467
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
468
- def test_arg_cmp(backend):
469
- np.testing.assert_allclose(tc.backend.argmax(tc.backend.ones([3], "float64")), 0)
470
- np.testing.assert_allclose(
471
- tc.backend.argmax(
472
- tc.array_to_tensor(np.array([[1, 2], [3, 4]]), dtype="float64")
473
- ),
474
- np.array([1, 1]),
475
- )
476
- np.testing.assert_allclose(
477
- tc.backend.argmin(
478
- tc.array_to_tensor(np.array([[1, 2], [3, 4]]), dtype="float64"), axis=-1
479
- ),
480
- np.array([0, 0]),
481
- )
482
-
483
-
484
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
485
- def test_tree_map(backend):
486
- def f(a, b):
487
- return a + b
488
-
489
- r = tc.backend.tree_map(
490
- f, {"a": tc.backend.ones([2])}, {"a": 2 * tc.backend.ones([2])}
491
- )
492
- np.testing.assert_allclose(r["a"], 3 * np.ones([2]), atol=1e-4)
493
-
494
- def _add(a, b):
495
- return a + b
496
-
497
- ans = tc.backend.tree_map(
498
- _add,
499
- {"a": tc.backend.ones([2]), "b": tc.backend.ones([3])},
500
- {"a": tc.backend.ones([2]), "b": tc.backend.ones([3])},
501
- )
502
- np.testing.assert_allclose(ans["a"], 2 * np.ones([2]))
503
-
504
-
505
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
506
- def test_backend_randoms(backend):
507
- @partial(tc.backend.jit, static_argnums=0)
508
- def random_matrixn(key):
509
- tc.backend.set_random_state(key)
510
- r1 = tc.backend.implicit_randn(shape=[2, 2], mean=0.5)
511
- r2 = tc.backend.implicit_randn(shape=[2, 2], mean=0.5)
512
- return r1, r2
513
-
514
- key = 42
515
- if tc.backend.name == "tensorflow":
516
- key = tf.random.Generator.from_seed(42)
517
- r11, r12 = random_matrixn(key)
518
- if tc.backend.name == "tensorflow":
519
- key = tf.random.Generator.from_seed(42)
520
- r21, r22 = random_matrixn(key)
521
- np.testing.assert_allclose(r11, r21, atol=1e-4)
522
- np.testing.assert_allclose(r12, r22, atol=1e-4)
523
- assert not np.allclose(r11, r12, atol=1e-4)
524
-
525
- def random_matrixu(key):
526
- tc.backend.set_random_state(key)
527
- r1 = tc.backend.implicit_randu(shape=[2, 2], high=2)
528
- r2 = tc.backend.implicit_randu(shape=[2, 2], high=1)
529
- return r1, r2
530
-
531
- key = 42
532
- r31, r32 = random_matrixu(key)
533
- np.testing.assert_allclose(r31.shape, [2, 2])
534
- assert np.any(r32 > 0)
535
- assert not np.allclose(r31, r32, atol=1e-4)
536
-
537
- def random_matrixc(key):
538
- tc.backend.set_random_state(key)
539
- r1 = tc.backend.implicit_randc(a=[1, 2, 3], shape=(2, 2))
540
- r2 = tc.backend.implicit_randc(a=[1, 2, 3], shape=(2, 2), p=[0.1, 0.4, 0.5])
541
- return r1, r2
542
-
543
- r41, r42 = random_matrixc(key)
544
- np.testing.assert_allclose(r41.shape, [2, 2])
545
- assert np.any((r42 > 0) & (r42 < 4))
546
-
547
-
548
- def vqe_energy(inputs, param, n, nlayers):
549
- c = tc.Circuit(n, inputs=inputs)
550
- paramc = tc.backend.cast(param, "complex64")
551
-
552
- for i in range(n):
553
- c.H(i)
554
- for j in range(nlayers):
555
- for i in range(n - 1):
556
- c.ryy(i, i + 1, theta=paramc[2 * j, i])
557
- # c.any(
558
- # i,
559
- # i + 1,
560
- # unitary=tc.backend.cos(paramc[2 * j, i]) * iir
561
- # + tc.backend.sin(paramc[2 * j, i]) * 1.0j * yzr,
562
- # )
563
- for i in range(n):
564
- c.rx(i, theta=paramc[2 * j + 1, i])
565
- e = 0.0
566
- for i in range(n):
567
- e += c.expectation((tc.gates.x(), [i]))
568
- for i in range(n - 1): # OBC
569
- e += c.expectation((tc.gates.z(), [i]), (tc.gates.z(), [(i + 1) % n]))
570
- e = tc.backend.real(e)
571
- return e
572
-
573
-
574
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
575
- def test_vvag(backend):
576
- n = 4
577
- nlayers = 3
578
- inp = tc.backend.ones([2**n]) / 2 ** (n / 2)
579
- param = tc.backend.ones([2 * nlayers, n])
580
- # inp = tc.backend.cast(inp, "complex64")
581
- # param = tc.backend.cast(param, "complex64")
582
-
583
- vqe_energy_p = partial(vqe_energy, n=n, nlayers=nlayers)
584
-
585
- vg = tc.backend.value_and_grad(vqe_energy_p, argnums=(0, 1))
586
- v0, (g00, g01) = vg(inp, param)
587
-
588
- batch = 8
589
- inps = tc.backend.ones([batch, 2**n]) / 2 ** (n / 2)
590
- inps = tc.backend.cast(inps, "complex64")
591
-
592
- pvag = tc.backend.vvag(vqe_energy_p, argnums=(0, 1))
593
- v1, (g10, g11) = pvag(inps, param)
594
- print(v1.shape, g10.shape, g11.shape)
595
- np.testing.assert_allclose(v1[0], v0, atol=1e-4)
596
- np.testing.assert_allclose(g10[0], g00, atol=1e-4)
597
- np.testing.assert_allclose(g11 / batch, g01, atol=1e-4)
598
-
599
-
600
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
601
- def test_vvag_dict(backend):
602
- def dict_plus(x, y):
603
- a = x["a"]
604
- return tc.backend.real((a + y)[0])
605
-
606
- dp_vvag = tc.backend.vvag(dict_plus, vectorized_argnums=1, argnums=0)
607
- x = {"a": tc.backend.ones([1])}
608
- y = tc.backend.ones([20, 1])
609
- v, g = dp_vvag(x, y)
610
- np.testing.assert_allclose(v.shape, [20])
611
- np.testing.assert_allclose(g["a"], 20.0, atol=1e-4)
612
-
613
-
614
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
615
- def test_vjp(backend):
616
- def f(x):
617
- return x**2
618
-
619
- inputs = tc.backend.ones([2, 2])
620
- v, g = tc.backend.vjp(f, inputs, inputs)
621
- np.testing.assert_allclose(v, inputs, atol=1e-5)
622
- np.testing.assert_allclose(g, 2 * inputs, atol=1e-5)
623
-
624
- def f2(x, y):
625
- return x + y, x - y
626
-
627
- inputs = [tc.backend.ones([2]), tc.backend.ones([2])]
628
- v = [2.0 * t for t in inputs]
629
- v, g = tc.backend.vjp(f2, inputs, v)
630
- np.testing.assert_allclose(v[1], np.zeros([2]), atol=1e-5)
631
- np.testing.assert_allclose(g[0], 4 * np.ones([2]), atol=1e-5)
632
-
633
-
634
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
635
- def test_vjp_complex(backend):
636
- def f(x):
637
- return tc.backend.conj(x)
638
-
639
- inputs = tc.backend.ones([1]) + 1.0j * tc.backend.ones([1])
640
- v = tc.backend.ones([1], dtype="complex64")
641
- v, g = tc.backend.vjp(f, inputs, v)
642
- np.testing.assert_allclose(tc.backend.numpy(g), np.ones([1]), atol=1e-5)
643
-
644
- def f2(x):
645
- return x**2
646
-
647
- inputs = tc.backend.ones([1]) + 1.0j * tc.backend.ones([1])
648
- v = tc.backend.ones([1], dtype="complex64") # + 1.0j * tc.backend.ones([1])
649
- v, g = tc.backend.vjp(f2, inputs, v)
650
- # note how vjp definition on complex function is different in jax backend
651
- if tc.backend.name == "jax":
652
- np.testing.assert_allclose(tc.backend.numpy(g), 2 + 2j, atol=1e-5)
653
- else:
654
- np.testing.assert_allclose(tc.backend.numpy(g), 2 - 2j, atol=1e-5)
655
-
656
-
657
- # TODO(@refraction-ray): consistent and unified pytree utils for pytorch backend?
658
-
659
-
660
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
661
- def test_vjp_pytree(backend):
662
- def f3(d):
663
- return d["a"] + d["b"], d["a"]
664
-
665
- inputs = {"a": tc.backend.ones([2]), "b": tc.backend.ones([1])}
666
- v = (tc.backend.ones([2]), tc.backend.zeros([2]))
667
- v, g = tc.backend.vjp(f3, inputs, v)
668
- np.testing.assert_allclose(v[0], 2 * np.ones([2]), atol=1e-5)
669
- np.testing.assert_allclose(g["a"], np.ones([2]), atol=1e-5)
670
-
671
-
672
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
673
- def test_jvp(backend):
674
- def f(x):
675
- return x**2
676
-
677
- inputs = tc.backend.ones([2, 2])
678
- v, g = tc.backend.jvp(f, inputs, inputs)
679
- np.testing.assert_allclose(v, inputs, atol=1e-5)
680
- np.testing.assert_allclose(g, 2 * inputs, atol=1e-5)
681
-
682
- def f2(x, y):
683
- return x + y, x - y
684
-
685
- inputs = [tc.backend.ones([2]), tc.backend.ones([2])]
686
- v = [2.0 * t for t in inputs]
687
- v, g = tc.backend.jvp(f2, inputs, v)
688
- np.testing.assert_allclose(v[1], np.zeros([2]), atol=1e-5)
689
- np.testing.assert_allclose(g[0], 4 * np.ones([2]), atol=1e-5)
690
-
691
-
692
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
693
- def test_jvp_complex(backend):
694
- def f(x):
695
- return tc.backend.conj(x)
696
-
697
- inputs = tc.backend.ones([1]) + 1.0j * tc.backend.ones([1])
698
- v = tc.backend.ones([1], dtype="complex64")
699
- v, g = tc.backend.jvp(f, inputs, v)
700
- # numpy auto numpy doesn't work for torch conjugate tensor
701
- np.testing.assert_allclose(tc.backend.numpy(g), np.ones([1]), atol=1e-5)
702
-
703
- def f2(x):
704
- return x**2
705
-
706
- inputs = tc.backend.ones([1]) + 1.0j * tc.backend.ones([1])
707
- v = tc.backend.ones([1]) + 1.0j * tc.backend.ones([1])
708
- v, g = tc.backend.jvp(f2, inputs, v)
709
- np.testing.assert_allclose(tc.backend.numpy(g), 4.0j, atol=1e-5)
710
-
711
-
712
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
713
- def test_jvp_pytree(backend):
714
- def f3(d):
715
- return d["a"] + d["b"], d["a"]
716
-
717
- inputs = {"a": tc.backend.ones([2]), "b": tc.backend.ones([1])}
718
- v = (tc.backend.ones([2]), tc.backend.zeros([2]))
719
- v, g = tc.backend.vjp(f3, inputs, v)
720
- np.testing.assert_allclose(v[0], 2 * np.ones([2]), atol=1e-5)
721
- np.testing.assert_allclose(g["a"], np.ones([2]), atol=1e-5)
722
-
723
-
724
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
725
- @pytest.mark.parametrize("mode", ["jacfwd", "jacrev"])
726
- def test_jac(backend, mode):
727
- # make no sense for torch backend when you have no real vmap interface
728
- backend_jac = getattr(tc.backend, mode)
729
-
730
- def f(x):
731
- return x**2
732
-
733
- x = tc.backend.ones([3])
734
- jacf = backend_jac(f)
735
- np.testing.assert_allclose(jacf(x), 2 * np.eye(3), atol=1e-5)
736
-
737
- def f2(x):
738
- return x**2, x
739
-
740
- jacf2 = backend_jac(f2)
741
- np.testing.assert_allclose(jacf2(x)[1], np.eye(3), atol=1e-5)
742
- np.testing.assert_allclose(jacf2(x)[0], 2 * np.eye(3), atol=1e-5)
743
-
744
- def f3(x, y):
745
- return x + y**2
746
-
747
- jacf3 = backend_jac(f3, argnums=(0, 1))
748
- jacf3jit = tc.backend.jit(backend_jac(f3, argnums=(0, 1)))
749
- np.testing.assert_allclose(jacf3jit(x, x)[1], 2 * np.eye(3), atol=1e-5)
750
- np.testing.assert_allclose(jacf3(x, x)[1], 2 * np.eye(3), atol=1e-5)
751
-
752
- def f4(x, y):
753
- return x**2, y
754
-
755
- # note the subtle difference of two tuples order in jacrev and jacfwd for current API
756
- # the value happen to be the same here, though
757
- jacf4 = backend_jac(f4, argnums=(0, 1))
758
- jacf4jit = tc.backend.jit(backend_jac(f4, argnums=(0, 1)))
759
- np.testing.assert_allclose(jacf4jit(x, x)[1][1], np.eye(3), atol=1e-5)
760
- np.testing.assert_allclose(jacf4jit(x, x)[0][1], np.zeros([3, 3]), atol=1e-5)
761
- np.testing.assert_allclose(jacf4(x, x)[1][1], np.eye(3), atol=1e-5)
762
- np.testing.assert_allclose(jacf4(x, x)[0][1], np.zeros([3, 3]), atol=1e-5)
763
-
764
-
765
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
766
- @pytest.mark.parametrize("mode", ["jacfwd", "jacrev"])
767
- def test_jac_md_input(backend, mode):
768
- backend_jac = getattr(tc.backend, mode)
769
-
770
- def f(x):
771
- return x**2
772
-
773
- x = tc.backend.ones([2, 3])
774
- jacf = backend_jac(f)
775
- np.testing.assert_allclose(jacf(x).shape, [2, 3, 2, 3], atol=1e-5)
776
-
777
- def f2(x):
778
- return tc.backend.sum(x, axis=0)
779
-
780
- x = tc.backend.ones([2, 3])
781
- jacf2 = backend_jac(f2)
782
- np.testing.assert_allclose(jacf2(x).shape, [3, 2, 3], atol=1e-5)
783
-
784
-
785
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
786
- @pytest.mark.parametrize("mode", ["jacfwd", "jacrev"])
787
- def test_jac_tall(backend, mode):
788
- backend_jac = getattr(tc.backend, mode)
789
-
790
- h = tc.backend.ones([5, 3])
791
-
792
- def f(x):
793
- x = tc.backend.reshape(x, [-1, 1])
794
- return tc.backend.reshape(h @ x, [-1])
795
-
796
- x = tc.backend.ones([3])
797
- jacf = backend_jac(f)
798
- np.testing.assert_allclose(jacf(x), np.ones([5, 3]), atol=1e-5)
799
-
800
-
801
- @pytest.mark.parametrize("backend", [lf("jaxb"), lf("tfb")])
802
- def test_vvag_has_aux(backend):
803
- def f(x):
804
- y = tc.backend.sum(x)
805
- return tc.backend.real(y**2), y
806
-
807
- fvvag = tc.backend.vvag(f, has_aux=True)
808
- (_, v1), _ = fvvag(tc.backend.ones([10, 2]))
809
- np.testing.assert_allclose(v1, 2 * tc.backend.ones([10]))
810
-
811
-
812
- def test_jax_svd(jaxb, highp):
813
- def l(A):
814
- u, _, v, _ = tc.backend.svd(A)
815
- return tc.backend.real(u[0, 0] * v[0, 0])
816
-
817
- def numericald(A):
818
- eps = 1e-6
819
- DA = np.zeros_like(A)
820
- for i in range(A.shape[0]):
821
- for j in range(A.shape[1]):
822
- dA = np.zeros_like(A)
823
- dA[i, j] = 1
824
- DA[i, j] = (l(A + eps * dA) - l(A)) / eps - 1.0j * (
825
- l(A + eps * 1.0j * dA) - l(A)
826
- ) / eps
827
- return DA
828
-
829
- def analyticald(A):
830
- A = tc.backend.convert_to_tensor(A)
831
- g = tc.backend.grad(l)
832
- return g(A)
833
-
834
- for shape in [(2, 2), (3, 3), (2, 3), (4, 2)]:
835
- m = np.random.normal(size=shape).astype(
836
- np.complex128
837
- ) + 1.0j * np.random.normal(size=shape).astype(np.complex128)
838
- print(m)
839
- np.testing.assert_allclose(numericald(m), analyticald(m), atol=1e-3)
840
-
841
-
842
- @pytest.mark.parametrize("backend", [lf("jaxb"), lf("tfb"), lf("torchb")])
843
- def test_qr(backend, highp):
844
- def get_random_complex(shape):
845
- result = np.random.random(shape) + np.random.random(shape) * 1j
846
- return tc.backend.convert_to_tensor(result.astype(dtype))
847
-
848
- np.random.seed(0)
849
- A1 = get_random_complex((2, 2))
850
- A2 = tc.backend.convert_to_tensor(np.array([[1.0, 0.0], [0.0, 0.0]]).astype(dtype))
851
- X = get_random_complex((2, 2))
852
-
853
- def func(A, x):
854
- x = tc.backend.cast(x, "complex64")
855
- Q, R = tc.backend.qr(A + X * x)
856
- return tc.backend.real(tc.backend.sum(tc.backend.matmul(Q, R)))
857
-
858
- def grad(A, x):
859
- return tc.backend.grad(func, argnums=1)(A, x)
860
-
861
- for A in [A1, A2]:
862
- epsilon = tc.backend.convert_to_tensor(1e-3)
863
- n_grad = (func(A, epsilon) - func(A, -epsilon)) / (2 * epsilon)
864
- a_grad = grad(A, tc.backend.convert_to_tensor(0.0))
865
- np.testing.assert_allclose(n_grad, a_grad, atol=1e-3)
866
-
867
-
868
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
869
- def test_sparse_methods(backend):
870
- values = tc.backend.convert_to_tensor(np.array([1.0, 2.0]))
871
- values = tc.backend.cast(values, "complex64")
872
- indices = tc.backend.convert_to_tensor(np.array([[0, 0], [1, 1]]))
873
- indices = tc.backend.cast(indices, "int64")
874
- spa = tc.backend.coo_sparse_matrix(indices, values, shape=[4, 4])
875
- vec = tc.backend.ones([4, 1])
876
- da = np.array(
877
- [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], dtype=np.complex64
878
- )
879
- assert tc.backend.is_sparse(spa) is True
880
- assert tc.backend.is_sparse(vec) is False
881
- np.testing.assert_allclose(
882
- tc.backend.to_dense(spa),
883
- da,
884
- atol=1e-5,
885
- )
886
- np.testing.assert_allclose(
887
- tc.backend.sparse_dense_matmul(spa, vec),
888
- np.array([[1], [2], [0], [0]], dtype=np.complex64),
889
- atol=1e-5,
890
- )
891
- spa_np = tc.backend.numpy(spa)
892
- np.testing.assert_allclose(spa_np.todense(), da, atol=1e-6)
893
- np.testing.assert_allclose(
894
- tc.backend.to_dense(tc.backend.coo_sparse_matrix_from_numpy(spa_np)),
895
- da,
896
- atol=1e-5,
897
- )
898
-
899
-
900
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
901
- def test_backend_randoms_v2(backend):
902
- g = tc.backend.get_random_state(42)
903
- for t in tc.backend.stateful_randc(g, 3, [3]):
904
- assert t >= 0
905
- assert t < 3
906
- key = tc.backend.get_random_state(42)
907
- r = []
908
- for _ in range(2):
909
- key, subkey = tc.backend.random_split(key)
910
- r.append(tc.backend.stateful_randc(subkey, 3, [5]))
911
- assert tuple(r[0]) != tuple(r[1])
912
-
913
-
914
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
915
- def test_backend_randoms_v3(backend):
916
- tc.backend.set_random_state(42)
917
- for _ in range(2):
918
- r1 = tc.backend.implicit_randu()
919
- key = tc.backend.get_random_state(42)
920
- for _ in range(2):
921
- key, subkey = tc.backend.random_split(key)
922
- r2 = tc.backend.stateful_randu(subkey)
923
- np.testing.assert_allclose(r1, r2, atol=1e-5)
924
-
925
- @tc.backend.jit
926
- def f(key):
927
- tc.backend.set_random_state(key)
928
- r = []
929
- for _ in range(3):
930
- r.append(tc.backend.implicit_randu()[0])
931
- return r
932
-
933
- @tc.backend.jit
934
- def f2(key):
935
- r = []
936
- for _ in range(3):
937
- key, subkey = tc.backend.random_split(key)
938
- r.append(tc.backend.stateful_randu(subkey)[0])
939
- return r
940
-
941
- key = tc.backend.get_random_state(43)
942
- r = f(key)
943
- key = tc.backend.get_random_state(43)
944
- r1 = f2(key)
945
- np.testing.assert_allclose(r[-1], r1[-1], atol=1e-5)
946
-
947
-
948
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
949
- def test_function_level_set(backend):
950
- def f(x):
951
- return tc.backend.ones([x])
952
-
953
- f_jax_128 = tc.set_function_backend("jax")(tc.set_function_dtype("complex128")(f))
954
- # note the order to enable complex 128 in jax backend
955
-
956
- assert f_jax_128(3).dtype.__str__() == "complex128"
957
-
958
-
959
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
960
- def test_function_level_set_contractor(backend):
961
- @tc.set_function_contractor("branch")
962
- def f():
963
- return tc.contractor
964
-
965
- print(f())
966
- print(tc.contractor)
967
-
968
-
969
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
970
- def test_with_level_set(backend):
971
- with tc.runtime_backend("jax"):
972
- with tc.runtime_dtype("complex128"):
973
- with tc.runtime_contractor("branch"):
974
- assert tc.backend.ones([2]).dtype.__str__() == "complex128"
975
- print(tc.contractor)
976
- print(tc.contractor)
977
- print(tc.backend.name)
978
-
979
-
980
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
981
- def test_with_level_set_return(backend):
982
- with tc.runtime_backend("jax") as K:
983
- assert K.name == "jax"
984
-
985
-
986
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
987
- def test_grad_has_aux(backend):
988
- def f(x):
989
- return tc.backend.real(x**2), x**3
990
-
991
- vg = tc.backend.value_and_grad(f, has_aux=True)
992
-
993
- np.testing.assert_allclose(
994
- vg(tc.backend.ones([]))[1], 2 * tc.backend.ones([]), atol=1e-5
995
- )
996
-
997
- def f2(x):
998
- return tc.backend.real(x**2), (x**3, tc.backend.ones([3]))
999
-
1000
- gs = tc.backend.grad(f2, has_aux=True)
1001
- np.testing.assert_allclose(gs(tc.backend.ones([]))[0], 2.0, atol=1e-5)
1002
-
1003
-
1004
- @pytest.mark.parametrize("backend", [lf("npb"), lf("jaxb"), lf("tfb")])
1005
- def test_solve(backend):
1006
- A = np.array([[2, 1, 0], [1, 2, 0], [0, 0, 1]], dtype=np.float32)
1007
- A = tc.backend.convert_to_tensor(A)
1008
- x = np.ones([3, 1], dtype=np.float32)
1009
- x = tc.backend.convert_to_tensor(x)
1010
- b = (A @ x)[:, 0]
1011
- print(A.shape, b.shape)
1012
- xp = tc.backend.solve(A, b, assume_a="her")
1013
- np.testing.assert_allclose(xp, x[:, 0], atol=1e-5)
1014
-
1015
-
1016
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
1017
- def test_treeutils(backend):
1018
- d0 = {"a": np.ones([2]), "b": [tc.backend.zeros([]), tc.backend.ones([1, 1])]}
1019
- leaves, treedef = tc.backend.tree_flatten(d0)
1020
- d1 = tc.backend.tree_unflatten(treedef, leaves)
1021
- d2 = tc.backend.tree_map(lambda x: 2 * x, d1)
1022
- np.testing.assert_allclose(2 * np.ones([1, 1]), d2["b"][1])
1023
-
1024
-
1025
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
1026
- def test_optimizers(backend):
1027
- if tc.backend.name == "jax":
1028
- try:
1029
- import optax
1030
- except ImportError:
1031
- pytest.skip("optax is not installed")
1032
-
1033
- if tc.backend.name == "pytorch":
1034
- try:
1035
- import torch
1036
- except ImportError:
1037
- pytest.skip("torch is not installed")
1038
-
1039
- def f(params, n):
1040
- c = tc.Circuit(n)
1041
- c = tc.templates.blocks.example_block(c, params["a"])
1042
- c = tc.templates.blocks.example_block(c, params["b"])
1043
- return tc.backend.real(c.expectation([tc.gates.x(), [n // 2]]))
1044
-
1045
- vgs = tc.backend.jit(tc.backend.value_and_grad(f, argnums=0), static_argnums=1)
1046
-
1047
- def get_opt():
1048
- if tc.backend.name == "tensorflow":
1049
- optimizer1 = tf.keras.optimizers.Adam(5e-2)
1050
- opt = tc.backend.optimizer(optimizer1)
1051
- elif tc.backend.name == "jax":
1052
- optimizer2 = optax.adam(5e-2)
1053
- opt = tc.backend.optimizer(optimizer2)
1054
- elif tc.backend.name == "pytorch":
1055
- optimizer3 = partial(torch.optim.Adam, lr=5e-2)
1056
- opt = tc.backend.optimizer(optimizer3)
1057
- else:
1058
- raise ValueError("%s doesn't support optimizer interface" % tc.backend.name)
1059
- return opt
1060
-
1061
- n = 3
1062
- opt = get_opt()
1063
-
1064
- params = {
1065
- "a": tc.backend.ones([4, n], dtype="float32"),
1066
- "b": tc.backend.ones([4, n], dtype="float32"),
1067
- }
1068
-
1069
- for _ in range(20):
1070
- loss, grads = vgs(params, n)
1071
- params = opt.update(grads, params)
1072
- print(loss)
1073
-
1074
- assert loss < -0.7
1075
-
1076
- def f2(params, n):
1077
- c = tc.Circuit(n)
1078
- c = tc.templates.blocks.example_block(c, params)
1079
- return tc.backend.real(c.expectation([tc.gates.x(), [n // 2]]))
1080
-
1081
- vgs2 = tc.backend.jit(tc.backend.value_and_grad(f2, argnums=0), static_argnums=1)
1082
-
1083
- params = tc.backend.ones([4, n], dtype="float32")
1084
- opt = get_opt()
1085
-
1086
- for _ in range(20):
1087
- loss, grads = vgs2(params, n)
1088
- params = opt.update(grads, params)
1089
- print(loss)
1090
-
1091
- assert loss < -0.7
1092
-
1093
-
1094
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
1095
- def test_hessian(backend):
1096
- # hessian support is now very fragile and especially has potential issues on tf backend
1097
- def f(param):
1098
- return tc.backend.sum(param**2)
1099
-
1100
- hf = tc.backend.hessian(f)
1101
- param = tc.backend.ones([2])
1102
- np.testing.assert_allclose(hf(param), 2 * tc.backend.eye(2), atol=1e-5)
1103
-
1104
- param = tc.backend.ones([2, 2])
1105
- assert list(hf(param).shape) == [2, 2, 2, 2] # possible tf retracing?
1106
-
1107
- g = tc.templates.graphs.Line1D(5)
1108
-
1109
- def circuit_f(param):
1110
- c = tc.Circuit(5)
1111
- c = tc.templates.blocks.example_block(c, param, nlayers=1)
1112
- return tc.templates.measurements.heisenberg_measurements(c, g)
1113
-
1114
- param = tc.backend.ones([10])
1115
- hf = tc.backend.hessian(circuit_f)
1116
- print(hf(param)) # still upto a conjugate for jax and tf backend.
1117
-
1118
-
1119
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
1120
- def test_nested_vmap(backend):
1121
- def f(x, w):
1122
- c = tc.Circuit(4)
1123
- for i in range(4):
1124
- c.rx(i, theta=x[i])
1125
- c.ry(i, theta=w[i])
1126
- return tc.backend.stack([c.expectation_ps(z=[i]) for i in range(4)])
1127
-
1128
- def fa1(*args):
1129
- r = tc.backend.vmap(f, vectorized_argnums=1)(*args)
1130
- return r
1131
-
1132
- def fa2(*args):
1133
- r = tc.backend.vmap(fa1, vectorized_argnums=0)(*args)
1134
- return r
1135
-
1136
- fa2jit = tc.backend.jit(fa2)
1137
-
1138
- ya = fa2(tc.backend.ones([3, 4]), tc.backend.ones([7, 4]))
1139
- yajit = fa2jit(tc.backend.ones([3, 4]), tc.backend.ones([7, 4]))
1140
-
1141
- def fb1(*args):
1142
- r = tc.backend.vmap(f, vectorized_argnums=0)(*args)
1143
- return r
1144
-
1145
- def fb2(*args):
1146
- r = tc.backend.vmap(fb1, vectorized_argnums=1)(*args)
1147
- return r
1148
-
1149
- fb2jit = tc.backend.jit(fb2)
1150
-
1151
- yb = fb2(tc.backend.ones([3, 4]), tc.backend.ones([7, 4]))
1152
- ybjit = fb2jit(tc.backend.ones([3, 4]), tc.backend.ones([7, 4]))
1153
-
1154
- np.testing.assert_allclose(ya, tc.backend.transpose(yb, [1, 0, 2]), atol=1e-5)
1155
- np.testing.assert_allclose(ya, yajit, atol=1e-5)
1156
- np.testing.assert_allclose(yajit, tc.backend.transpose(ybjit, [1, 0, 2]), atol=1e-5)