tensorcircuit-nightly 1.3.0.dev20250728__py3-none-any.whl → 1.4.0.dev20251103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

Files changed (72) hide show
  1. tensorcircuit/__init__.py +5 -1
  2. tensorcircuit/abstractcircuit.py +4 -0
  3. tensorcircuit/analogcircuit.py +413 -0
  4. tensorcircuit/applications/layers.py +1 -1
  5. tensorcircuit/applications/van.py +1 -1
  6. tensorcircuit/backends/abstract_backend.py +312 -5
  7. tensorcircuit/backends/cupy_backend.py +3 -1
  8. tensorcircuit/backends/jax_backend.py +92 -3
  9. tensorcircuit/backends/jax_ops.py +108 -0
  10. tensorcircuit/backends/numpy_backend.py +49 -3
  11. tensorcircuit/backends/pytorch_backend.py +92 -3
  12. tensorcircuit/backends/tensorflow_backend.py +102 -3
  13. tensorcircuit/basecircuit.py +123 -82
  14. tensorcircuit/circuit.py +67 -57
  15. tensorcircuit/cloud/local.py +1 -1
  16. tensorcircuit/cloud/quafu_provider.py +1 -1
  17. tensorcircuit/cloud/tencent.py +1 -1
  18. tensorcircuit/compiler/simple_compiler.py +2 -2
  19. tensorcircuit/cons.py +1 -0
  20. tensorcircuit/densitymatrix.py +16 -11
  21. tensorcircuit/experimental.py +7 -152
  22. tensorcircuit/fgs.py +5 -6
  23. tensorcircuit/gates.py +66 -22
  24. tensorcircuit/keras.py +3 -3
  25. tensorcircuit/mpscircuit.py +109 -61
  26. tensorcircuit/quantum.py +697 -133
  27. tensorcircuit/quditcircuit.py +733 -0
  28. tensorcircuit/quditgates.py +618 -0
  29. tensorcircuit/results/counts.py +45 -31
  30. tensorcircuit/shadows.py +1 -1
  31. tensorcircuit/simplify.py +3 -1
  32. tensorcircuit/stabilizercircuit.py +4 -2
  33. tensorcircuit/templates/blocks.py +2 -2
  34. tensorcircuit/templates/hamiltonians.py +29 -8
  35. tensorcircuit/templates/lattice.py +676 -335
  36. tensorcircuit/timeevol.py +896 -0
  37. {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/METADATA +50 -25
  38. tensorcircuit_nightly-1.4.0.dev20251103.dist-info/RECORD +96 -0
  39. {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/top_level.txt +0 -1
  40. tensorcircuit_nightly-1.3.0.dev20250728.dist-info/RECORD +0 -122
  41. tests/__init__.py +0 -0
  42. tests/conftest.py +0 -67
  43. tests/test_backends.py +0 -1035
  44. tests/test_calibrating.py +0 -149
  45. tests/test_channels.py +0 -409
  46. tests/test_circuit.py +0 -1713
  47. tests/test_cloud.py +0 -219
  48. tests/test_compiler.py +0 -147
  49. tests/test_dmcircuit.py +0 -555
  50. tests/test_ensemble.py +0 -72
  51. tests/test_fgs.py +0 -318
  52. tests/test_gates.py +0 -156
  53. tests/test_hamiltonians.py +0 -159
  54. tests/test_interfaces.py +0 -557
  55. tests/test_keras.py +0 -160
  56. tests/test_lattice.py +0 -1666
  57. tests/test_miscs.py +0 -334
  58. tests/test_mpscircuit.py +0 -341
  59. tests/test_noisemodel.py +0 -156
  60. tests/test_qaoa.py +0 -86
  61. tests/test_qem.py +0 -152
  62. tests/test_quantum.py +0 -549
  63. tests/test_quantum_attr.py +0 -42
  64. tests/test_results.py +0 -379
  65. tests/test_shadows.py +0 -160
  66. tests/test_simplify.py +0 -46
  67. tests/test_stabilizer.py +0 -226
  68. tests/test_templates.py +0 -218
  69. tests/test_torchnn.py +0 -99
  70. tests/test_van.py +0 -102
  71. {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/WHEEL +0 -0
  72. {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/licenses/LICENSE +0 -0
tests/test_backends.py DELETED
@@ -1,1035 +0,0 @@
1
- # pylint: disable=invalid-name
2
-
3
- import sys
4
- import os
5
- from functools import partial
6
-
7
- os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
8
-
9
- import numpy as np
10
- import pytest
11
- from pytest_lazyfixture import lazy_fixture as lf
12
- import tensorflow as tf
13
-
14
- thisfile = os.path.abspath(__file__)
15
- modulepath = os.path.dirname(os.path.dirname(thisfile))
16
-
17
- sys.path.insert(0, modulepath)
18
- import tensorcircuit as tc
19
-
20
- dtype = np.complex64
21
-
22
-
23
- def universal_vmap():
24
- def sum_real(x, y):
25
- return tc.backend.real(x + y)
26
-
27
- vop = tc.backend.vmap(sum_real, vectorized_argnums=(0, 1))
28
- t = tc.gates.array_to_tensor(np.ones([20, 1]))
29
- return vop(t, 2.0 * t)
30
-
31
-
32
- def test_vmap_np():
33
- r = universal_vmap()
34
- assert r.shape == (20, 1)
35
-
36
-
37
- def test_vmap_jax(jaxb):
38
- r = universal_vmap()
39
- assert r.shape == (20, 1)
40
-
41
-
42
- def test_vmap_tf(tfb):
43
- r = universal_vmap()
44
- assert r.numpy()[0, 0] == 3.0
45
-
46
-
47
- def test_vmap_torch(torchb):
48
- r = universal_vmap()
49
- assert r.numpy()[0, 0] == 3.0
50
-
51
-
52
- def test_grad_torch(torchb):
53
- a = tc.backend.ones([2], dtype="float32")
54
-
55
- # @partial(tc.backend.jit, jit_compile=True)
56
- @tc.backend.grad
57
- def f(x):
58
- return tc.backend.sum(x)
59
-
60
- np.testing.assert_allclose(f(a), np.ones([2]), atol=1e-5)
61
-
62
-
63
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
64
- def test_backend_scatter(backend):
65
- np.testing.assert_allclose(
66
- tc.backend.scatter(
67
- tc.array_to_tensor(np.arange(8), dtype="int32"),
68
- tc.array_to_tensor(np.array([[1], [4]]), dtype="int32"),
69
- tc.array_to_tensor(np.array([0, 0]), dtype="int32"),
70
- ),
71
- np.array([0, 0, 2, 3, 0, 5, 6, 7]),
72
- atol=1e-4,
73
- )
74
- np.testing.assert_allclose(
75
- tc.backend.scatter(
76
- tc.array_to_tensor(np.arange(8).reshape([2, 4]), dtype="int32"),
77
- tc.array_to_tensor(np.array([[0, 2], [1, 2], [1, 3]]), dtype="int32"),
78
- tc.array_to_tensor(np.array([0, 99, 0]), dtype="int32"),
79
- ),
80
- np.array([[0, 1, 0, 3], [4, 5, 99, 0]]),
81
- atol=1e-4,
82
- )
83
- answer = np.arange(8).reshape([2, 2, 2])
84
- answer[0, 1, 0] = 99
85
- np.testing.assert_allclose(
86
- tc.backend.scatter(
87
- tc.array_to_tensor(np.arange(8).reshape([2, 2, 2]), dtype="int32"),
88
- tc.array_to_tensor(np.array([[0, 1, 0]]), dtype="int32"),
89
- tc.array_to_tensor(np.array([99]), dtype="int32"),
90
- ),
91
- answer,
92
- atol=1e-4,
93
- )
94
-
95
-
96
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
97
- def test_backend_methods(backend):
98
- # TODO(@refraction-ray): add more methods
99
- np.testing.assert_allclose(
100
- tc.backend.softmax(tc.array_to_tensor(np.ones([3, 2]), dtype="float32")),
101
- np.ones([3, 2]) / 6.0,
102
- atol=1e-4,
103
- )
104
-
105
- arr = np.random.normal(size=(6, 6))
106
-
107
- np.testing.assert_allclose(
108
- tc.backend.adjoint(tc.array_to_tensor(arr + 1.0j * arr)),
109
- arr.T - 1.0j * arr.T,
110
- atol=1e-4,
111
- )
112
-
113
- arr = tc.backend.zeros([5], dtype="float32")
114
- np.testing.assert_allclose(
115
- tc.backend.sigmoid(arr),
116
- tc.backend.ones([5]) * 0.5,
117
- atol=1e-4,
118
- )
119
- ans = np.array([[1, 0.5j], [-0.5j, 1]])
120
- ans2 = ans @ ans
121
- ansp = tc.backend.sqrtmh(tc.array_to_tensor(ans2))
122
- # print(ansp @ ansp, ans @ ans)
123
- np.testing.assert_allclose(ansp @ ansp, ans @ ans, atol=1e-4)
124
- singularm = np.array([[4.0, 0], [0, -1e-3]])
125
- np.testing.assert_allclose(
126
- tc.backend.sqrtmh(singularm, psd=True), np.array([[2.0, 0], [0, 0]]), atol=1e-5
127
- )
128
-
129
- np.testing.assert_allclose(
130
- tc.backend.sum(tc.array_to_tensor(np.arange(4))), 6, atol=1e-4
131
- )
132
-
133
- indices = np.array([[1, 2], [0, 1]])
134
- ans = np.array([[[0, 1, 0], [0, 0, 1]], [[1, 0, 0], [0, 1, 0]]])
135
- np.testing.assert_allclose(tc.backend.one_hot(indices, 3), ans, atol=1e-4)
136
-
137
- a = tc.array_to_tensor(np.array([1, 1, 3, 2, 2, 1]), dtype="int32")
138
- np.testing.assert_allclose(tc.backend.unique_with_counts(a)[0].shape[0], 3)
139
-
140
- np.testing.assert_allclose(
141
- tc.backend.cumsum(tc.array_to_tensor(np.array([[0.2, 0.2], [0.2, 0.4]]))),
142
- np.array([0.2, 0.4, 0.6, 1.0]),
143
- atol=1e-4,
144
- )
145
-
146
- np.testing.assert_allclose(
147
- tc.backend.max(tc.backend.ones([2, 2], "float32")), 1.0, atol=1e-4
148
- )
149
- np.testing.assert_allclose(
150
- tc.backend.min(
151
- tc.backend.cast(
152
- tc.backend.convert_to_tensor(np.array([[1.0, 2.0], [2.0, 3.0]])),
153
- "float64",
154
- ),
155
- axis=1,
156
- ),
157
- np.array([1.0, 2.0]),
158
- atol=1e-4,
159
- ) # by default no keepdim
160
-
161
- np.testing.assert_allclose(
162
- tc.backend.concat([tc.backend.ones([2, 2]), tc.backend.ones([1, 2])]),
163
- tc.backend.ones([3, 2]),
164
- atol=1e-5,
165
- )
166
-
167
- np.testing.assert_allclose(
168
- tc.backend.gather1d(
169
- tc.array_to_tensor(np.array([0, 1, 2])),
170
- tc.array_to_tensor(np.array([2, 1, 0]), dtype="int32"),
171
- ),
172
- np.array([2, 1, 0]),
173
- atol=1e-5,
174
- )
175
-
176
- def sum_(carry, x):
177
- return carry + x
178
-
179
- r = tc.backend.scan(sum_, tc.backend.ones([10, 2]), tc.backend.zeros([2]))
180
- np.testing.assert_allclose(r, 10 * np.ones([2]), atol=1e-5)
181
-
182
-
183
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
184
- def test_backend_methods_2(backend):
185
- np.testing.assert_allclose(tc.backend.mean(tc.backend.ones([10])), 1.0, atol=1e-5)
186
- # acos acosh asin asinh atan atan2 atanh cosh (cos) tan tanh sinh (sin)
187
- np.testing.assert_allclose(
188
- tc.backend.acos(tc.backend.ones([2], dtype="float32")),
189
- np.arccos(tc.backend.ones([2])),
190
- atol=1e-5,
191
- )
192
- np.testing.assert_allclose(
193
- tc.backend.acosh(tc.backend.ones([2], dtype="float32")),
194
- np.arccosh(tc.backend.ones([2])),
195
- atol=1e-5,
196
- )
197
- np.testing.assert_allclose(
198
- tc.backend.asin(tc.backend.ones([2], dtype="float32")),
199
- np.arcsin(tc.backend.ones([2])),
200
- atol=1e-5,
201
- )
202
- np.testing.assert_allclose(
203
- tc.backend.asinh(tc.backend.ones([2], dtype="float32")),
204
- np.arcsinh(tc.backend.ones([2])),
205
- atol=1e-5,
206
- )
207
- np.testing.assert_allclose(
208
- tc.backend.atan(0.5 * tc.backend.ones([2], dtype="float32")),
209
- np.arctan(0.5 * tc.backend.ones([2])),
210
- atol=1e-5,
211
- )
212
- np.testing.assert_allclose(
213
- tc.backend.atan2(
214
- tc.backend.ones([1], dtype="float32"), tc.backend.ones([1], dtype="float32")
215
- ),
216
- np.arctan2(
217
- tc.backend.ones([1], dtype="float32"), tc.backend.ones([1], dtype="float32")
218
- ),
219
- atol=1e-5,
220
- )
221
- np.testing.assert_allclose(
222
- tc.backend.atanh(0.5 * tc.backend.ones([2], dtype="float32")),
223
- np.arctanh(0.5 * tc.backend.ones([2])),
224
- atol=1e-5,
225
- )
226
- np.testing.assert_allclose(
227
- tc.backend.cosh(tc.backend.ones([2], dtype="float32")),
228
- np.cosh(tc.backend.ones([2])),
229
- atol=1e-5,
230
- )
231
- np.testing.assert_allclose(
232
- tc.backend.tan(tc.backend.ones([2], dtype="float32")),
233
- np.tan(tc.backend.ones([2])),
234
- atol=1e-5,
235
- )
236
- np.testing.assert_allclose(
237
- tc.backend.tanh(tc.backend.ones([2], dtype="float32")),
238
- np.tanh(tc.backend.ones([2])),
239
- atol=1e-5,
240
- )
241
- np.testing.assert_allclose(
242
- tc.backend.sinh(0.5 * tc.backend.ones([2], dtype="float32")),
243
- np.sinh(0.5 * tc.backend.ones([2])),
244
- atol=1e-5,
245
- )
246
- np.testing.assert_allclose(
247
- tc.backend.eigvalsh(tc.backend.ones([2, 2])), np.array([0, 2]), atol=1e-5
248
- )
249
- np.testing.assert_allclose(
250
- tc.backend.left_shift(
251
- tc.backend.convert_to_tensor(np.array([4, 3])),
252
- tc.backend.convert_to_tensor(np.array([1, 1])),
253
- ),
254
- np.array([8, 6]),
255
- )
256
- np.testing.assert_allclose(
257
- tc.backend.right_shift(
258
- tc.backend.convert_to_tensor(np.array([4, 3])),
259
- tc.backend.convert_to_tensor(np.array([1, 1])),
260
- ),
261
- np.array([2, 1]),
262
- )
263
- np.testing.assert_allclose(
264
- tc.backend.mod(
265
- tc.backend.convert_to_tensor(np.array([4, 3])),
266
- tc.backend.convert_to_tensor(np.array([2, 2])),
267
- ),
268
- np.array([0, 1]),
269
- )
270
- np.testing.assert_allclose(
271
- tc.backend.arange(3),
272
- np.array([0, 1, 2]),
273
- )
274
- np.testing.assert_allclose(
275
- tc.backend.arange(1, 5, 2),
276
- np.array([1, 3]),
277
- )
278
- assert tc.backend.dtype(tc.backend.ones([])) == "complex64"
279
- edges = [-1, 3.3, 9.1, 10.0]
280
- values = tc.backend.convert_to_tensor(np.array([0.0, 4.1, 12.0], dtype=np.float32))
281
- r = tc.backend.numpy(tc.backend.searchsorted(edges, values))
282
- np.testing.assert_allclose(r, np.array([1, 2, 4]))
283
- p = tc.backend.convert_to_tensor(
284
- np.array(
285
- [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.2, 0.4], dtype=np.float32
286
- )
287
- )
288
- r = tc.backend.probability_sample(10000, p, status=np.random.uniform(size=[10000]))
289
- _, r = np.unique(r, return_counts=True)
290
- np.testing.assert_allclose(
291
- r - tc.backend.numpy(p) * 10000.0, np.zeros([10]), atol=200, rtol=1
292
- )
293
- np.testing.assert_allclose(
294
- tc.backend.std(tc.backend.cast(tc.backend.arange(1, 4), "float32")),
295
- 0.81649658,
296
- atol=1e-5,
297
- )
298
- arr = np.random.normal(size=(6, 6))
299
- np.testing.assert_allclose(
300
- tc.backend.relu(tc.array_to_tensor(arr, dtype="float32")),
301
- np.maximum(arr, 0),
302
- atol=1e-4,
303
- )
304
- np.testing.assert_allclose(
305
- tc.backend.det(tc.backend.convert_to_tensor(np.eye(3) * 2)), 8, atol=1e-5
306
- )
307
-
308
-
309
- # @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
310
- # def test_backend_array(backend):
311
- # a = tc.backend.array([[0, 1], [1, 0]])
312
- # assert tc.interfaces.which_backend(a).name == tc.backend.name
313
- # a = tc.backend.array([[0, 1], [1, 0]], dtype=tc.rdtypestr)
314
- # assert tc.dtype(a) == "float32"
315
-
316
-
317
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
318
- def test_device_cpu_only(backend):
319
- a = tc.backend.ones([])
320
- dev_str = tc.backend.device(a)
321
- assert dev_str in ["cpu", "gpu:0"]
322
- tc.backend.device_move(a, dev_str)
323
-
324
-
325
- @pytest.mark.skipif(
326
- len(tf.config.list_physical_devices()) == 1, reason="no GPU detected"
327
- )
328
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
329
- def test_device_cpu_gpu(backend):
330
- a = tc.backend.ones([])
331
- a1 = tc.backend.device_move(a, "gpu:0")
332
- dev_str = tc.backend.device(a1)
333
- assert dev_str == "gpu:0"
334
- a2 = tc.backend.device_move(a1, "cpu")
335
- assert tc.backend.device(a2) == "cpu"
336
-
337
-
338
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
339
- def test_dlpack(backend):
340
- a = tc.backend.ones([2, 2], dtype="float64")
341
- cap = tc.backend.to_dlpack(a)
342
- a1 = tc.backend.from_dlpack(cap)
343
- np.testing.assert_allclose(a, a1, atol=1e-5)
344
-
345
-
346
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
347
- def test_arg_cmp(backend):
348
- np.testing.assert_allclose(tc.backend.argmax(tc.backend.ones([3], "float64")), 0)
349
- np.testing.assert_allclose(
350
- tc.backend.argmax(
351
- tc.array_to_tensor(np.array([[1, 2], [3, 4]]), dtype="float64")
352
- ),
353
- np.array([1, 1]),
354
- )
355
- np.testing.assert_allclose(
356
- tc.backend.argmin(
357
- tc.array_to_tensor(np.array([[1, 2], [3, 4]]), dtype="float64"), axis=-1
358
- ),
359
- np.array([0, 0]),
360
- )
361
-
362
-
363
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
364
- def test_tree_map(backend):
365
- def f(a, b):
366
- return a + b
367
-
368
- r = tc.backend.tree_map(
369
- f, {"a": tc.backend.ones([2])}, {"a": 2 * tc.backend.ones([2])}
370
- )
371
- np.testing.assert_allclose(r["a"], 3 * np.ones([2]), atol=1e-4)
372
-
373
- def _add(a, b):
374
- return a + b
375
-
376
- ans = tc.backend.tree_map(
377
- _add,
378
- {"a": tc.backend.ones([2]), "b": tc.backend.ones([3])},
379
- {"a": tc.backend.ones([2]), "b": tc.backend.ones([3])},
380
- )
381
- np.testing.assert_allclose(ans["a"], 2 * np.ones([2]))
382
-
383
-
384
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
385
- def test_backend_randoms(backend):
386
- @partial(tc.backend.jit, static_argnums=0)
387
- def random_matrixn(key):
388
- tc.backend.set_random_state(key)
389
- r1 = tc.backend.implicit_randn(shape=[2, 2], mean=0.5)
390
- r2 = tc.backend.implicit_randn(shape=[2, 2], mean=0.5)
391
- return r1, r2
392
-
393
- key = 42
394
- if tc.backend.name == "tensorflow":
395
- key = tf.random.Generator.from_seed(42)
396
- r11, r12 = random_matrixn(key)
397
- if tc.backend.name == "tensorflow":
398
- key = tf.random.Generator.from_seed(42)
399
- r21, r22 = random_matrixn(key)
400
- np.testing.assert_allclose(r11, r21, atol=1e-4)
401
- np.testing.assert_allclose(r12, r22, atol=1e-4)
402
- assert not np.allclose(r11, r12, atol=1e-4)
403
-
404
- def random_matrixu(key):
405
- tc.backend.set_random_state(key)
406
- r1 = tc.backend.implicit_randu(shape=[2, 2], high=2)
407
- r2 = tc.backend.implicit_randu(shape=[2, 2], high=1)
408
- return r1, r2
409
-
410
- key = 42
411
- r31, r32 = random_matrixu(key)
412
- np.testing.assert_allclose(r31.shape, [2, 2])
413
- assert np.any(r32 > 0)
414
- assert not np.allclose(r31, r32, atol=1e-4)
415
-
416
- def random_matrixc(key):
417
- tc.backend.set_random_state(key)
418
- r1 = tc.backend.implicit_randc(a=[1, 2, 3], shape=(2, 2))
419
- r2 = tc.backend.implicit_randc(a=[1, 2, 3], shape=(2, 2), p=[0.1, 0.4, 0.5])
420
- return r1, r2
421
-
422
- r41, r42 = random_matrixc(key)
423
- np.testing.assert_allclose(r41.shape, [2, 2])
424
- assert np.any((r42 > 0) & (r42 < 4))
425
-
426
-
427
- def vqe_energy(inputs, param, n, nlayers):
428
- c = tc.Circuit(n, inputs=inputs)
429
- paramc = tc.backend.cast(param, "complex64")
430
-
431
- for i in range(n):
432
- c.H(i)
433
- for j in range(nlayers):
434
- for i in range(n - 1):
435
- c.ryy(i, i + 1, theta=paramc[2 * j, i])
436
- # c.any(
437
- # i,
438
- # i + 1,
439
- # unitary=tc.backend.cos(paramc[2 * j, i]) * iir
440
- # + tc.backend.sin(paramc[2 * j, i]) * 1.0j * yzr,
441
- # )
442
- for i in range(n):
443
- c.rx(i, theta=paramc[2 * j + 1, i])
444
- e = 0.0
445
- for i in range(n):
446
- e += c.expectation((tc.gates.x(), [i]))
447
- for i in range(n - 1): # OBC
448
- e += c.expectation((tc.gates.z(), [i]), (tc.gates.z(), [(i + 1) % n]))
449
- e = tc.backend.real(e)
450
- return e
451
-
452
-
453
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
454
- def test_vvag(backend):
455
- n = 4
456
- nlayers = 3
457
- inp = tc.backend.ones([2**n]) / 2 ** (n / 2)
458
- param = tc.backend.ones([2 * nlayers, n])
459
- # inp = tc.backend.cast(inp, "complex64")
460
- # param = tc.backend.cast(param, "complex64")
461
-
462
- vqe_energy_p = partial(vqe_energy, n=n, nlayers=nlayers)
463
-
464
- vg = tc.backend.value_and_grad(vqe_energy_p, argnums=(0, 1))
465
- v0, (g00, g01) = vg(inp, param)
466
-
467
- batch = 8
468
- inps = tc.backend.ones([batch, 2**n]) / 2 ** (n / 2)
469
- inps = tc.backend.cast(inps, "complex64")
470
-
471
- pvag = tc.backend.vvag(vqe_energy_p, argnums=(0, 1))
472
- v1, (g10, g11) = pvag(inps, param)
473
- print(v1.shape, g10.shape, g11.shape)
474
- np.testing.assert_allclose(v1[0], v0, atol=1e-4)
475
- np.testing.assert_allclose(g10[0], g00, atol=1e-4)
476
- np.testing.assert_allclose(g11 / batch, g01, atol=1e-4)
477
-
478
-
479
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
480
- def test_vvag_dict(backend):
481
- def dict_plus(x, y):
482
- a = x["a"]
483
- return tc.backend.real((a + y)[0])
484
-
485
- dp_vvag = tc.backend.vvag(dict_plus, vectorized_argnums=1, argnums=0)
486
- x = {"a": tc.backend.ones([1])}
487
- y = tc.backend.ones([20, 1])
488
- v, g = dp_vvag(x, y)
489
- np.testing.assert_allclose(v.shape, [20])
490
- np.testing.assert_allclose(g["a"], 20.0, atol=1e-4)
491
-
492
-
493
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
494
- def test_vjp(backend):
495
- def f(x):
496
- return x**2
497
-
498
- inputs = tc.backend.ones([2, 2])
499
- v, g = tc.backend.vjp(f, inputs, inputs)
500
- np.testing.assert_allclose(v, inputs, atol=1e-5)
501
- np.testing.assert_allclose(g, 2 * inputs, atol=1e-5)
502
-
503
- def f2(x, y):
504
- return x + y, x - y
505
-
506
- inputs = [tc.backend.ones([2]), tc.backend.ones([2])]
507
- v = [2.0 * t for t in inputs]
508
- v, g = tc.backend.vjp(f2, inputs, v)
509
- np.testing.assert_allclose(v[1], np.zeros([2]), atol=1e-5)
510
- np.testing.assert_allclose(g[0], 4 * np.ones([2]), atol=1e-5)
511
-
512
-
513
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
514
- def test_vjp_complex(backend):
515
- def f(x):
516
- return tc.backend.conj(x)
517
-
518
- inputs = tc.backend.ones([1]) + 1.0j * tc.backend.ones([1])
519
- v = tc.backend.ones([1], dtype="complex64")
520
- v, g = tc.backend.vjp(f, inputs, v)
521
- np.testing.assert_allclose(tc.backend.numpy(g), np.ones([1]), atol=1e-5)
522
-
523
- def f2(x):
524
- return x**2
525
-
526
- inputs = tc.backend.ones([1]) + 1.0j * tc.backend.ones([1])
527
- v = tc.backend.ones([1], dtype="complex64") # + 1.0j * tc.backend.ones([1])
528
- v, g = tc.backend.vjp(f2, inputs, v)
529
- # note how vjp definition on complex function is different in jax backend
530
- if tc.backend.name == "jax":
531
- np.testing.assert_allclose(tc.backend.numpy(g), 2 + 2j, atol=1e-5)
532
- else:
533
- np.testing.assert_allclose(tc.backend.numpy(g), 2 - 2j, atol=1e-5)
534
-
535
-
536
- # TODO(@refraction-ray): consistent and unified pytree utils for pytorch backend?
537
-
538
-
539
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
540
- def test_vjp_pytree(backend):
541
- def f3(d):
542
- return d["a"] + d["b"], d["a"]
543
-
544
- inputs = {"a": tc.backend.ones([2]), "b": tc.backend.ones([1])}
545
- v = (tc.backend.ones([2]), tc.backend.zeros([2]))
546
- v, g = tc.backend.vjp(f3, inputs, v)
547
- np.testing.assert_allclose(v[0], 2 * np.ones([2]), atol=1e-5)
548
- np.testing.assert_allclose(g["a"], np.ones([2]), atol=1e-5)
549
-
550
-
551
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
552
- def test_jvp(backend):
553
- def f(x):
554
- return x**2
555
-
556
- inputs = tc.backend.ones([2, 2])
557
- v, g = tc.backend.jvp(f, inputs, inputs)
558
- np.testing.assert_allclose(v, inputs, atol=1e-5)
559
- np.testing.assert_allclose(g, 2 * inputs, atol=1e-5)
560
-
561
- def f2(x, y):
562
- return x + y, x - y
563
-
564
- inputs = [tc.backend.ones([2]), tc.backend.ones([2])]
565
- v = [2.0 * t for t in inputs]
566
- v, g = tc.backend.jvp(f2, inputs, v)
567
- np.testing.assert_allclose(v[1], np.zeros([2]), atol=1e-5)
568
- np.testing.assert_allclose(g[0], 4 * np.ones([2]), atol=1e-5)
569
-
570
-
571
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
572
- def test_jvp_complex(backend):
573
- def f(x):
574
- return tc.backend.conj(x)
575
-
576
- inputs = tc.backend.ones([1]) + 1.0j * tc.backend.ones([1])
577
- v = tc.backend.ones([1], dtype="complex64")
578
- v, g = tc.backend.jvp(f, inputs, v)
579
- # numpy auto numpy doesn't work for torch conjugate tensor
580
- np.testing.assert_allclose(tc.backend.numpy(g), np.ones([1]), atol=1e-5)
581
-
582
- def f2(x):
583
- return x**2
584
-
585
- inputs = tc.backend.ones([1]) + 1.0j * tc.backend.ones([1])
586
- v = tc.backend.ones([1]) + 1.0j * tc.backend.ones([1])
587
- v, g = tc.backend.jvp(f2, inputs, v)
588
- np.testing.assert_allclose(tc.backend.numpy(g), 4.0j, atol=1e-5)
589
-
590
-
591
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
592
- def test_jvp_pytree(backend):
593
- def f3(d):
594
- return d["a"] + d["b"], d["a"]
595
-
596
- inputs = {"a": tc.backend.ones([2]), "b": tc.backend.ones([1])}
597
- v = (tc.backend.ones([2]), tc.backend.zeros([2]))
598
- v, g = tc.backend.vjp(f3, inputs, v)
599
- np.testing.assert_allclose(v[0], 2 * np.ones([2]), atol=1e-5)
600
- np.testing.assert_allclose(g["a"], np.ones([2]), atol=1e-5)
601
-
602
-
603
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
604
- @pytest.mark.parametrize("mode", ["jacfwd", "jacrev"])
605
- def test_jac(backend, mode):
606
- # make no sense for torch backend when you have no real vmap interface
607
- backend_jac = getattr(tc.backend, mode)
608
-
609
- def f(x):
610
- return x**2
611
-
612
- x = tc.backend.ones([3])
613
- jacf = backend_jac(f)
614
- np.testing.assert_allclose(jacf(x), 2 * np.eye(3), atol=1e-5)
615
-
616
- def f2(x):
617
- return x**2, x
618
-
619
- jacf2 = backend_jac(f2)
620
- np.testing.assert_allclose(jacf2(x)[1], np.eye(3), atol=1e-5)
621
- np.testing.assert_allclose(jacf2(x)[0], 2 * np.eye(3), atol=1e-5)
622
-
623
- def f3(x, y):
624
- return x + y**2
625
-
626
- jacf3 = backend_jac(f3, argnums=(0, 1))
627
- jacf3jit = tc.backend.jit(backend_jac(f3, argnums=(0, 1)))
628
- np.testing.assert_allclose(jacf3jit(x, x)[1], 2 * np.eye(3), atol=1e-5)
629
- np.testing.assert_allclose(jacf3(x, x)[1], 2 * np.eye(3), atol=1e-5)
630
-
631
- def f4(x, y):
632
- return x**2, y
633
-
634
- # note the subtle difference of two tuples order in jacrev and jacfwd for current API
635
- # the value happen to be the same here, though
636
- jacf4 = backend_jac(f4, argnums=(0, 1))
637
- jacf4jit = tc.backend.jit(backend_jac(f4, argnums=(0, 1)))
638
- np.testing.assert_allclose(jacf4jit(x, x)[1][1], np.eye(3), atol=1e-5)
639
- np.testing.assert_allclose(jacf4jit(x, x)[0][1], np.zeros([3, 3]), atol=1e-5)
640
- np.testing.assert_allclose(jacf4(x, x)[1][1], np.eye(3), atol=1e-5)
641
- np.testing.assert_allclose(jacf4(x, x)[0][1], np.zeros([3, 3]), atol=1e-5)
642
-
643
-
644
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
645
- @pytest.mark.parametrize("mode", ["jacfwd", "jacrev"])
646
- def test_jac_md_input(backend, mode):
647
- backend_jac = getattr(tc.backend, mode)
648
-
649
- def f(x):
650
- return x**2
651
-
652
- x = tc.backend.ones([2, 3])
653
- jacf = backend_jac(f)
654
- np.testing.assert_allclose(jacf(x).shape, [2, 3, 2, 3], atol=1e-5)
655
-
656
- def f2(x):
657
- return tc.backend.sum(x, axis=0)
658
-
659
- x = tc.backend.ones([2, 3])
660
- jacf2 = backend_jac(f2)
661
- np.testing.assert_allclose(jacf2(x).shape, [3, 2, 3], atol=1e-5)
662
-
663
-
664
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
665
- @pytest.mark.parametrize("mode", ["jacfwd", "jacrev"])
666
- def test_jac_tall(backend, mode):
667
- backend_jac = getattr(tc.backend, mode)
668
-
669
- h = tc.backend.ones([5, 3])
670
-
671
- def f(x):
672
- x = tc.backend.reshape(x, [-1, 1])
673
- return tc.backend.reshape(h @ x, [-1])
674
-
675
- x = tc.backend.ones([3])
676
- jacf = backend_jac(f)
677
- np.testing.assert_allclose(jacf(x), np.ones([5, 3]), atol=1e-5)
678
-
679
-
680
- @pytest.mark.parametrize("backend", [lf("jaxb"), lf("tfb")])
681
- def test_vvag_has_aux(backend):
682
- def f(x):
683
- y = tc.backend.sum(x)
684
- return tc.backend.real(y**2), y
685
-
686
- fvvag = tc.backend.vvag(f, has_aux=True)
687
- (_, v1), _ = fvvag(tc.backend.ones([10, 2]))
688
- np.testing.assert_allclose(v1, 2 * tc.backend.ones([10]))
689
-
690
-
691
- def test_jax_svd(jaxb, highp):
692
- def l(A):
693
- u, _, v, _ = tc.backend.svd(A)
694
- return tc.backend.real(u[0, 0] * v[0, 0])
695
-
696
- def numericald(A):
697
- eps = 1e-6
698
- DA = np.zeros_like(A)
699
- for i in range(A.shape[0]):
700
- for j in range(A.shape[1]):
701
- dA = np.zeros_like(A)
702
- dA[i, j] = 1
703
- DA[i, j] = (l(A + eps * dA) - l(A)) / eps - 1.0j * (
704
- l(A + eps * 1.0j * dA) - l(A)
705
- ) / eps
706
- return DA
707
-
708
- def analyticald(A):
709
- A = tc.backend.convert_to_tensor(A)
710
- g = tc.backend.grad(l)
711
- return g(A)
712
-
713
- for shape in [(2, 2), (3, 3), (2, 3), (4, 2)]:
714
- m = np.random.normal(size=shape).astype(
715
- np.complex128
716
- ) + 1.0j * np.random.normal(size=shape).astype(np.complex128)
717
- print(m)
718
- np.testing.assert_allclose(numericald(m), analyticald(m), atol=1e-3)
719
-
720
-
721
- @pytest.mark.parametrize("backend", [lf("jaxb"), lf("tfb"), lf("torchb")])
722
- def test_qr(backend, highp):
723
- def get_random_complex(shape):
724
- result = np.random.random(shape) + np.random.random(shape) * 1j
725
- return tc.backend.convert_to_tensor(result.astype(dtype))
726
-
727
- np.random.seed(0)
728
- A1 = get_random_complex((2, 2))
729
- A2 = tc.backend.convert_to_tensor(np.array([[1.0, 0.0], [0.0, 0.0]]).astype(dtype))
730
- X = get_random_complex((2, 2))
731
-
732
- def func(A, x):
733
- x = tc.backend.cast(x, "complex64")
734
- Q, R = tc.backend.qr(A + X * x)
735
- return tc.backend.real(tc.backend.sum(tc.backend.matmul(Q, R)))
736
-
737
- def grad(A, x):
738
- return tc.backend.grad(func, argnums=1)(A, x)
739
-
740
- for A in [A1, A2]:
741
- epsilon = tc.backend.convert_to_tensor(1e-3)
742
- n_grad = (func(A, epsilon) - func(A, -epsilon)) / (2 * epsilon)
743
- a_grad = grad(A, tc.backend.convert_to_tensor(0.0))
744
- np.testing.assert_allclose(n_grad, a_grad, atol=1e-3)
745
-
746
-
747
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
748
- def test_sparse_methods(backend):
749
- values = tc.backend.convert_to_tensor(np.array([1.0, 2.0]))
750
- values = tc.backend.cast(values, "complex64")
751
- indices = tc.backend.convert_to_tensor(np.array([[0, 0], [1, 1]]))
752
- indices = tc.backend.cast(indices, "int64")
753
- spa = tc.backend.coo_sparse_matrix(indices, values, shape=[4, 4])
754
- vec = tc.backend.ones([4, 1])
755
- da = np.array(
756
- [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], dtype=np.complex64
757
- )
758
- assert tc.backend.is_sparse(spa) is True
759
- assert tc.backend.is_sparse(vec) is False
760
- np.testing.assert_allclose(
761
- tc.backend.to_dense(spa),
762
- da,
763
- atol=1e-5,
764
- )
765
- np.testing.assert_allclose(
766
- tc.backend.sparse_dense_matmul(spa, vec),
767
- np.array([[1], [2], [0], [0]], dtype=np.complex64),
768
- atol=1e-5,
769
- )
770
- spa_np = tc.backend.numpy(spa)
771
- np.testing.assert_allclose(spa_np.todense(), da, atol=1e-6)
772
- np.testing.assert_allclose(
773
- tc.backend.to_dense(tc.backend.coo_sparse_matrix_from_numpy(spa_np)),
774
- da,
775
- atol=1e-5,
776
- )
777
-
778
-
779
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
780
- def test_backend_randoms_v2(backend):
781
- g = tc.backend.get_random_state(42)
782
- for t in tc.backend.stateful_randc(g, 3, [3]):
783
- assert t >= 0
784
- assert t < 3
785
- key = tc.backend.get_random_state(42)
786
- r = []
787
- for _ in range(2):
788
- key, subkey = tc.backend.random_split(key)
789
- r.append(tc.backend.stateful_randc(subkey, 3, [5]))
790
- assert tuple(r[0]) != tuple(r[1])
791
-
792
-
793
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
794
- def test_backend_randoms_v3(backend):
795
- tc.backend.set_random_state(42)
796
- for _ in range(2):
797
- r1 = tc.backend.implicit_randu()
798
- key = tc.backend.get_random_state(42)
799
- for _ in range(2):
800
- key, subkey = tc.backend.random_split(key)
801
- r2 = tc.backend.stateful_randu(subkey)
802
- np.testing.assert_allclose(r1, r2, atol=1e-5)
803
-
804
- @tc.backend.jit
805
- def f(key):
806
- tc.backend.set_random_state(key)
807
- r = []
808
- for _ in range(3):
809
- r.append(tc.backend.implicit_randu()[0])
810
- return r
811
-
812
- @tc.backend.jit
813
- def f2(key):
814
- r = []
815
- for _ in range(3):
816
- key, subkey = tc.backend.random_split(key)
817
- r.append(tc.backend.stateful_randu(subkey)[0])
818
- return r
819
-
820
- key = tc.backend.get_random_state(43)
821
- r = f(key)
822
- key = tc.backend.get_random_state(43)
823
- r1 = f2(key)
824
- np.testing.assert_allclose(r[-1], r1[-1], atol=1e-5)
825
-
826
-
827
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
828
- def test_function_level_set(backend):
829
- def f(x):
830
- return tc.backend.ones([x])
831
-
832
- f_jax_128 = tc.set_function_backend("jax")(tc.set_function_dtype("complex128")(f))
833
- # note the order to enable complex 128 in jax backend
834
-
835
- assert f_jax_128(3).dtype.__str__() == "complex128"
836
-
837
-
838
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
839
- def test_function_level_set_contractor(backend):
840
- @tc.set_function_contractor("branch")
841
- def f():
842
- return tc.contractor
843
-
844
- print(f())
845
- print(tc.contractor)
846
-
847
-
848
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
849
- def test_with_level_set(backend):
850
- with tc.runtime_backend("jax"):
851
- with tc.runtime_dtype("complex128"):
852
- with tc.runtime_contractor("branch"):
853
- assert tc.backend.ones([2]).dtype.__str__() == "complex128"
854
- print(tc.contractor)
855
- print(tc.contractor)
856
- print(tc.backend.name)
857
-
858
-
859
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
860
- def test_with_level_set_return(backend):
861
- with tc.runtime_backend("jax") as K:
862
- assert K.name == "jax"
863
-
864
-
865
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
866
- def test_grad_has_aux(backend):
867
- def f(x):
868
- return tc.backend.real(x**2), x**3
869
-
870
- vg = tc.backend.value_and_grad(f, has_aux=True)
871
-
872
- np.testing.assert_allclose(
873
- vg(tc.backend.ones([]))[1], 2 * tc.backend.ones([]), atol=1e-5
874
- )
875
-
876
- def f2(x):
877
- return tc.backend.real(x**2), (x**3, tc.backend.ones([3]))
878
-
879
- gs = tc.backend.grad(f2, has_aux=True)
880
- np.testing.assert_allclose(gs(tc.backend.ones([]))[0], 2.0, atol=1e-5)
881
-
882
-
883
- @pytest.mark.parametrize("backend", [lf("npb"), lf("jaxb"), lf("tfb")])
884
- def test_solve(backend):
885
- A = np.array([[2, 1, 0], [1, 2, 0], [0, 0, 1]], dtype=np.float32)
886
- A = tc.backend.convert_to_tensor(A)
887
- x = np.ones([3, 1], dtype=np.float32)
888
- x = tc.backend.convert_to_tensor(x)
889
- b = (A @ x)[:, 0]
890
- print(A.shape, b.shape)
891
- xp = tc.backend.solve(A, b, assume_a="her")
892
- np.testing.assert_allclose(xp, x[:, 0], atol=1e-5)
893
-
894
-
895
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
896
- def test_treeutils(backend):
897
- d0 = {"a": np.ones([2]), "b": [tc.backend.zeros([]), tc.backend.ones([1, 1])]}
898
- leaves, treedef = tc.backend.tree_flatten(d0)
899
- d1 = tc.backend.tree_unflatten(treedef, leaves)
900
- d2 = tc.backend.tree_map(lambda x: 2 * x, d1)
901
- np.testing.assert_allclose(2 * np.ones([1, 1]), d2["b"][1])
902
-
903
-
904
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
905
- def test_optimizers(backend):
906
- if tc.backend.name == "jax":
907
- try:
908
- import optax
909
- except ImportError:
910
- pytest.skip("optax is not installed")
911
-
912
- if tc.backend.name == "pytorch":
913
- try:
914
- import torch
915
- except ImportError:
916
- pytest.skip("torch is not installed")
917
-
918
- def f(params, n):
919
- c = tc.Circuit(n)
920
- c = tc.templates.blocks.example_block(c, params["a"])
921
- c = tc.templates.blocks.example_block(c, params["b"])
922
- return tc.backend.real(c.expectation([tc.gates.x(), [n // 2]]))
923
-
924
- vgs = tc.backend.jit(tc.backend.value_and_grad(f, argnums=0), static_argnums=1)
925
-
926
- def get_opt():
927
- if tc.backend.name == "tensorflow":
928
- optimizer1 = tf.keras.optimizers.Adam(5e-2)
929
- opt = tc.backend.optimizer(optimizer1)
930
- elif tc.backend.name == "jax":
931
- optimizer2 = optax.adam(5e-2)
932
- opt = tc.backend.optimizer(optimizer2)
933
- elif tc.backend.name == "pytorch":
934
- optimizer3 = partial(torch.optim.Adam, lr=5e-2)
935
- opt = tc.backend.optimizer(optimizer3)
936
- else:
937
- raise ValueError("%s doesn't support optimizer interface" % tc.backend.name)
938
- return opt
939
-
940
- n = 3
941
- opt = get_opt()
942
-
943
- params = {
944
- "a": tc.backend.ones([4, n], dtype="float32"),
945
- "b": tc.backend.ones([4, n], dtype="float32"),
946
- }
947
-
948
- for _ in range(20):
949
- loss, grads = vgs(params, n)
950
- params = opt.update(grads, params)
951
- print(loss)
952
-
953
- assert loss < -0.7
954
-
955
- def f2(params, n):
956
- c = tc.Circuit(n)
957
- c = tc.templates.blocks.example_block(c, params)
958
- return tc.backend.real(c.expectation([tc.gates.x(), [n // 2]]))
959
-
960
- vgs2 = tc.backend.jit(tc.backend.value_and_grad(f2, argnums=0), static_argnums=1)
961
-
962
- params = tc.backend.ones([4, n], dtype="float32")
963
- opt = get_opt()
964
-
965
- for _ in range(20):
966
- loss, grads = vgs2(params, n)
967
- params = opt.update(grads, params)
968
- print(loss)
969
-
970
- assert loss < -0.7
971
-
972
-
973
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
974
- def test_hessian(backend):
975
- # hessian support is now very fragile and especially has potential issues on tf backend
976
- def f(param):
977
- return tc.backend.sum(param**2)
978
-
979
- hf = tc.backend.hessian(f)
980
- param = tc.backend.ones([2])
981
- np.testing.assert_allclose(hf(param), 2 * tc.backend.eye(2), atol=1e-5)
982
-
983
- param = tc.backend.ones([2, 2])
984
- assert list(hf(param).shape) == [2, 2, 2, 2] # possible tf retracing?
985
-
986
- g = tc.templates.graphs.Line1D(5)
987
-
988
- def circuit_f(param):
989
- c = tc.Circuit(5)
990
- c = tc.templates.blocks.example_block(c, param, nlayers=1)
991
- return tc.templates.measurements.heisenberg_measurements(c, g)
992
-
993
- param = tc.backend.ones([10])
994
- hf = tc.backend.hessian(circuit_f)
995
- print(hf(param)) # still upto a conjugate for jax and tf backend.
996
-
997
-
998
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
999
- def test_nested_vmap(backend):
1000
- def f(x, w):
1001
- c = tc.Circuit(4)
1002
- for i in range(4):
1003
- c.rx(i, theta=x[i])
1004
- c.ry(i, theta=w[i])
1005
- return tc.backend.stack([c.expectation_ps(z=[i]) for i in range(4)])
1006
-
1007
- def fa1(*args):
1008
- r = tc.backend.vmap(f, vectorized_argnums=1)(*args)
1009
- return r
1010
-
1011
- def fa2(*args):
1012
- r = tc.backend.vmap(fa1, vectorized_argnums=0)(*args)
1013
- return r
1014
-
1015
- fa2jit = tc.backend.jit(fa2)
1016
-
1017
- ya = fa2(tc.backend.ones([3, 4]), tc.backend.ones([7, 4]))
1018
- yajit = fa2jit(tc.backend.ones([3, 4]), tc.backend.ones([7, 4]))
1019
-
1020
- def fb1(*args):
1021
- r = tc.backend.vmap(f, vectorized_argnums=0)(*args)
1022
- return r
1023
-
1024
- def fb2(*args):
1025
- r = tc.backend.vmap(fb1, vectorized_argnums=1)(*args)
1026
- return r
1027
-
1028
- fb2jit = tc.backend.jit(fb2)
1029
-
1030
- yb = fb2(tc.backend.ones([3, 4]), tc.backend.ones([7, 4]))
1031
- ybjit = fb2jit(tc.backend.ones([3, 4]), tc.backend.ones([7, 4]))
1032
-
1033
- np.testing.assert_allclose(ya, tc.backend.transpose(yb, [1, 0, 2]), atol=1e-5)
1034
- np.testing.assert_allclose(ya, yajit, atol=1e-5)
1035
- np.testing.assert_allclose(yajit, tc.backend.transpose(ybjit, [1, 0, 2]), atol=1e-5)