tensorcircuit-nightly 1.0.2.dev20250108__py3-none-any.whl → 1.4.0.dev20251103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

Files changed (76) hide show
  1. tensorcircuit/__init__.py +18 -2
  2. tensorcircuit/about.py +46 -0
  3. tensorcircuit/abstractcircuit.py +4 -0
  4. tensorcircuit/analogcircuit.py +413 -0
  5. tensorcircuit/applications/layers.py +1 -1
  6. tensorcircuit/applications/van.py +1 -1
  7. tensorcircuit/backends/abstract_backend.py +320 -7
  8. tensorcircuit/backends/cupy_backend.py +3 -1
  9. tensorcircuit/backends/jax_backend.py +102 -4
  10. tensorcircuit/backends/jax_ops.py +110 -1
  11. tensorcircuit/backends/numpy_backend.py +49 -3
  12. tensorcircuit/backends/pytorch_backend.py +92 -3
  13. tensorcircuit/backends/tensorflow_backend.py +102 -3
  14. tensorcircuit/basecircuit.py +157 -98
  15. tensorcircuit/circuit.py +115 -57
  16. tensorcircuit/cloud/local.py +1 -1
  17. tensorcircuit/cloud/quafu_provider.py +1 -1
  18. tensorcircuit/cloud/tencent.py +1 -1
  19. tensorcircuit/compiler/simple_compiler.py +2 -2
  20. tensorcircuit/cons.py +142 -21
  21. tensorcircuit/densitymatrix.py +43 -14
  22. tensorcircuit/experimental.py +387 -129
  23. tensorcircuit/fgs.py +282 -81
  24. tensorcircuit/gates.py +66 -22
  25. tensorcircuit/interfaces/__init__.py +1 -3
  26. tensorcircuit/interfaces/jax.py +189 -0
  27. tensorcircuit/keras.py +3 -3
  28. tensorcircuit/mpscircuit.py +154 -65
  29. tensorcircuit/quantum.py +868 -152
  30. tensorcircuit/quditcircuit.py +733 -0
  31. tensorcircuit/quditgates.py +618 -0
  32. tensorcircuit/results/counts.py +147 -20
  33. tensorcircuit/results/readout_mitigation.py +4 -1
  34. tensorcircuit/shadows.py +1 -1
  35. tensorcircuit/simplify.py +3 -1
  36. tensorcircuit/stabilizercircuit.py +479 -0
  37. tensorcircuit/templates/__init__.py +2 -0
  38. tensorcircuit/templates/blocks.py +2 -2
  39. tensorcircuit/templates/hamiltonians.py +174 -0
  40. tensorcircuit/templates/lattice.py +1789 -0
  41. tensorcircuit/timeevol.py +896 -0
  42. tensorcircuit/translation.py +10 -3
  43. tensorcircuit/utils.py +7 -0
  44. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/METADATA +73 -23
  45. tensorcircuit_nightly-1.4.0.dev20251103.dist-info/RECORD +96 -0
  46. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/WHEEL +1 -1
  47. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/top_level.txt +0 -1
  48. tensorcircuit_nightly-1.0.2.dev20250108.dist-info/RECORD +0 -115
  49. tests/__init__.py +0 -0
  50. tests/conftest.py +0 -67
  51. tests/test_backends.py +0 -1031
  52. tests/test_calibrating.py +0 -149
  53. tests/test_channels.py +0 -365
  54. tests/test_circuit.py +0 -1699
  55. tests/test_cloud.py +0 -219
  56. tests/test_compiler.py +0 -147
  57. tests/test_dmcircuit.py +0 -555
  58. tests/test_ensemble.py +0 -72
  59. tests/test_fgs.py +0 -310
  60. tests/test_gates.py +0 -156
  61. tests/test_interfaces.py +0 -429
  62. tests/test_keras.py +0 -160
  63. tests/test_miscs.py +0 -277
  64. tests/test_mpscircuit.py +0 -341
  65. tests/test_noisemodel.py +0 -156
  66. tests/test_qaoa.py +0 -86
  67. tests/test_qem.py +0 -152
  68. tests/test_quantum.py +0 -526
  69. tests/test_quantum_attr.py +0 -42
  70. tests/test_results.py +0 -347
  71. tests/test_shadows.py +0 -160
  72. tests/test_simplify.py +0 -46
  73. tests/test_templates.py +0 -218
  74. tests/test_torchnn.py +0 -99
  75. tests/test_van.py +0 -102
  76. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info/licenses}/LICENSE +0 -0
tests/test_backends.py DELETED
@@ -1,1031 +0,0 @@
1
- # pylint: disable=invalid-name
2
-
3
- import sys
4
- import os
5
- from functools import partial
6
-
7
- os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
8
-
9
- import numpy as np
10
- import pytest
11
- from pytest_lazyfixture import lazy_fixture as lf
12
- import tensorflow as tf
13
-
14
- thisfile = os.path.abspath(__file__)
15
- modulepath = os.path.dirname(os.path.dirname(thisfile))
16
-
17
- sys.path.insert(0, modulepath)
18
- import tensorcircuit as tc
19
-
20
- dtype = np.complex64
21
-
22
-
23
- def universal_vmap():
24
- def sum_real(x, y):
25
- return tc.backend.real(x + y)
26
-
27
- vop = tc.backend.vmap(sum_real, vectorized_argnums=(0, 1))
28
- t = tc.gates.array_to_tensor(np.ones([20, 1]))
29
- return vop(t, 2.0 * t)
30
-
31
-
32
- def test_vmap_np():
33
- r = universal_vmap()
34
- assert r.shape == (20, 1)
35
-
36
-
37
- def test_vmap_jax(jaxb):
38
- r = universal_vmap()
39
- assert r.shape == (20, 1)
40
-
41
-
42
- def test_vmap_tf(tfb):
43
- r = universal_vmap()
44
- assert r.numpy()[0, 0] == 3.0
45
-
46
-
47
- def test_vmap_torch(torchb):
48
- r = universal_vmap()
49
- assert r.numpy()[0, 0] == 3.0
50
-
51
-
52
- def test_grad_torch(torchb):
53
- a = tc.backend.ones([2], dtype="float32")
54
-
55
- # @partial(tc.backend.jit, jit_compile=True)
56
- @tc.backend.grad
57
- def f(x):
58
- return tc.backend.sum(x)
59
-
60
- np.testing.assert_allclose(f(a), np.ones([2]), atol=1e-5)
61
-
62
-
63
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
64
- def test_backend_scatter(backend):
65
- np.testing.assert_allclose(
66
- tc.backend.scatter(
67
- tc.array_to_tensor(np.arange(8), dtype="int32"),
68
- tc.array_to_tensor(np.array([[1], [4]]), dtype="int32"),
69
- tc.array_to_tensor(np.array([0, 0]), dtype="int32"),
70
- ),
71
- np.array([0, 0, 2, 3, 0, 5, 6, 7]),
72
- atol=1e-4,
73
- )
74
- np.testing.assert_allclose(
75
- tc.backend.scatter(
76
- tc.array_to_tensor(np.arange(8).reshape([2, 4]), dtype="int32"),
77
- tc.array_to_tensor(np.array([[0, 2], [1, 2], [1, 3]]), dtype="int32"),
78
- tc.array_to_tensor(np.array([0, 99, 0]), dtype="int32"),
79
- ),
80
- np.array([[0, 1, 0, 3], [4, 5, 99, 0]]),
81
- atol=1e-4,
82
- )
83
- answer = np.arange(8).reshape([2, 2, 2])
84
- answer[0, 1, 0] = 99
85
- np.testing.assert_allclose(
86
- tc.backend.scatter(
87
- tc.array_to_tensor(np.arange(8).reshape([2, 2, 2]), dtype="int32"),
88
- tc.array_to_tensor(np.array([[0, 1, 0]]), dtype="int32"),
89
- tc.array_to_tensor(np.array([99]), dtype="int32"),
90
- ),
91
- answer,
92
- atol=1e-4,
93
- )
94
-
95
-
96
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
97
- def test_backend_methods(backend):
98
- # TODO(@refraction-ray): add more methods
99
- np.testing.assert_allclose(
100
- tc.backend.softmax(tc.array_to_tensor(np.ones([3, 2]), dtype="float32")),
101
- np.ones([3, 2]) / 6.0,
102
- atol=1e-4,
103
- )
104
-
105
- arr = np.random.normal(size=(6, 6))
106
-
107
- np.testing.assert_allclose(
108
- tc.backend.adjoint(tc.array_to_tensor(arr + 1.0j * arr)),
109
- arr.T - 1.0j * arr.T,
110
- atol=1e-4,
111
- )
112
-
113
- arr = tc.backend.zeros([5], dtype="float32")
114
- np.testing.assert_allclose(
115
- tc.backend.sigmoid(arr),
116
- tc.backend.ones([5]) * 0.5,
117
- atol=1e-4,
118
- )
119
- ans = np.array([[1, 0.5j], [-0.5j, 1]])
120
- ans2 = ans @ ans
121
- ansp = tc.backend.sqrtmh(tc.array_to_tensor(ans2))
122
- print(ansp @ ansp, ans @ ans)
123
- np.testing.assert_allclose(ansp @ ansp, ans @ ans, atol=1e-4)
124
-
125
- np.testing.assert_allclose(
126
- tc.backend.sum(tc.array_to_tensor(np.arange(4))), 6, atol=1e-4
127
- )
128
-
129
- indices = np.array([[1, 2], [0, 1]])
130
- ans = np.array([[[0, 1, 0], [0, 0, 1]], [[1, 0, 0], [0, 1, 0]]])
131
- np.testing.assert_allclose(tc.backend.one_hot(indices, 3), ans, atol=1e-4)
132
-
133
- a = tc.array_to_tensor(np.array([1, 1, 3, 2, 2, 1]), dtype="int32")
134
- np.testing.assert_allclose(tc.backend.unique_with_counts(a)[0].shape[0], 3)
135
-
136
- np.testing.assert_allclose(
137
- tc.backend.cumsum(tc.array_to_tensor(np.array([[0.2, 0.2], [0.2, 0.4]]))),
138
- np.array([0.2, 0.4, 0.6, 1.0]),
139
- atol=1e-4,
140
- )
141
-
142
- np.testing.assert_allclose(
143
- tc.backend.max(tc.backend.ones([2, 2], "float32")), 1.0, atol=1e-4
144
- )
145
- np.testing.assert_allclose(
146
- tc.backend.min(
147
- tc.backend.cast(
148
- tc.backend.convert_to_tensor(np.array([[1.0, 2.0], [2.0, 3.0]])),
149
- "float64",
150
- ),
151
- axis=1,
152
- ),
153
- np.array([1.0, 2.0]),
154
- atol=1e-4,
155
- ) # by default no keepdim
156
-
157
- np.testing.assert_allclose(
158
- tc.backend.concat([tc.backend.ones([2, 2]), tc.backend.ones([1, 2])]),
159
- tc.backend.ones([3, 2]),
160
- atol=1e-5,
161
- )
162
-
163
- np.testing.assert_allclose(
164
- tc.backend.gather1d(
165
- tc.array_to_tensor(np.array([0, 1, 2])),
166
- tc.array_to_tensor(np.array([2, 1, 0]), dtype="int32"),
167
- ),
168
- np.array([2, 1, 0]),
169
- atol=1e-5,
170
- )
171
-
172
- def sum_(carry, x):
173
- return carry + x
174
-
175
- r = tc.backend.scan(sum_, tc.backend.ones([10, 2]), tc.backend.zeros([2]))
176
- np.testing.assert_allclose(r, 10 * np.ones([2]), atol=1e-5)
177
-
178
-
179
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
180
- def test_backend_methods_2(backend):
181
- np.testing.assert_allclose(tc.backend.mean(tc.backend.ones([10])), 1.0, atol=1e-5)
182
- # acos acosh asin asinh atan atan2 atanh cosh (cos) tan tanh sinh (sin)
183
- np.testing.assert_allclose(
184
- tc.backend.acos(tc.backend.ones([2], dtype="float32")),
185
- np.arccos(tc.backend.ones([2])),
186
- atol=1e-5,
187
- )
188
- np.testing.assert_allclose(
189
- tc.backend.acosh(tc.backend.ones([2], dtype="float32")),
190
- np.arccosh(tc.backend.ones([2])),
191
- atol=1e-5,
192
- )
193
- np.testing.assert_allclose(
194
- tc.backend.asin(tc.backend.ones([2], dtype="float32")),
195
- np.arcsin(tc.backend.ones([2])),
196
- atol=1e-5,
197
- )
198
- np.testing.assert_allclose(
199
- tc.backend.asinh(tc.backend.ones([2], dtype="float32")),
200
- np.arcsinh(tc.backend.ones([2])),
201
- atol=1e-5,
202
- )
203
- np.testing.assert_allclose(
204
- tc.backend.atan(0.5 * tc.backend.ones([2], dtype="float32")),
205
- np.arctan(0.5 * tc.backend.ones([2])),
206
- atol=1e-5,
207
- )
208
- np.testing.assert_allclose(
209
- tc.backend.atan2(
210
- tc.backend.ones([1], dtype="float32"), tc.backend.ones([1], dtype="float32")
211
- ),
212
- np.arctan2(
213
- tc.backend.ones([1], dtype="float32"), tc.backend.ones([1], dtype="float32")
214
- ),
215
- atol=1e-5,
216
- )
217
- np.testing.assert_allclose(
218
- tc.backend.atanh(0.5 * tc.backend.ones([2], dtype="float32")),
219
- np.arctanh(0.5 * tc.backend.ones([2])),
220
- atol=1e-5,
221
- )
222
- np.testing.assert_allclose(
223
- tc.backend.cosh(tc.backend.ones([2], dtype="float32")),
224
- np.cosh(tc.backend.ones([2])),
225
- atol=1e-5,
226
- )
227
- np.testing.assert_allclose(
228
- tc.backend.tan(tc.backend.ones([2], dtype="float32")),
229
- np.tan(tc.backend.ones([2])),
230
- atol=1e-5,
231
- )
232
- np.testing.assert_allclose(
233
- tc.backend.tanh(tc.backend.ones([2], dtype="float32")),
234
- np.tanh(tc.backend.ones([2])),
235
- atol=1e-5,
236
- )
237
- np.testing.assert_allclose(
238
- tc.backend.sinh(0.5 * tc.backend.ones([2], dtype="float32")),
239
- np.sinh(0.5 * tc.backend.ones([2])),
240
- atol=1e-5,
241
- )
242
- np.testing.assert_allclose(
243
- tc.backend.eigvalsh(tc.backend.ones([2, 2])), np.array([0, 2]), atol=1e-5
244
- )
245
- np.testing.assert_allclose(
246
- tc.backend.left_shift(
247
- tc.backend.convert_to_tensor(np.array([4, 3])),
248
- tc.backend.convert_to_tensor(np.array([1, 1])),
249
- ),
250
- np.array([8, 6]),
251
- )
252
- np.testing.assert_allclose(
253
- tc.backend.right_shift(
254
- tc.backend.convert_to_tensor(np.array([4, 3])),
255
- tc.backend.convert_to_tensor(np.array([1, 1])),
256
- ),
257
- np.array([2, 1]),
258
- )
259
- np.testing.assert_allclose(
260
- tc.backend.mod(
261
- tc.backend.convert_to_tensor(np.array([4, 3])),
262
- tc.backend.convert_to_tensor(np.array([2, 2])),
263
- ),
264
- np.array([0, 1]),
265
- )
266
- np.testing.assert_allclose(
267
- tc.backend.arange(3),
268
- np.array([0, 1, 2]),
269
- )
270
- np.testing.assert_allclose(
271
- tc.backend.arange(1, 5, 2),
272
- np.array([1, 3]),
273
- )
274
- assert tc.backend.dtype(tc.backend.ones([])) == "complex64"
275
- edges = [-1, 3.3, 9.1, 10.0]
276
- values = tc.backend.convert_to_tensor(np.array([0.0, 4.1, 12.0], dtype=np.float32))
277
- r = tc.backend.numpy(tc.backend.searchsorted(edges, values))
278
- np.testing.assert_allclose(r, np.array([1, 2, 4]))
279
- p = tc.backend.convert_to_tensor(
280
- np.array(
281
- [0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.2, 0.4], dtype=np.float32
282
- )
283
- )
284
- r = tc.backend.probability_sample(10000, p, status=np.random.uniform(size=[10000]))
285
- _, r = np.unique(r, return_counts=True)
286
- np.testing.assert_allclose(
287
- r - tc.backend.numpy(p) * 10000.0, np.zeros([10]), atol=200, rtol=1
288
- )
289
- np.testing.assert_allclose(
290
- tc.backend.std(tc.backend.cast(tc.backend.arange(1, 4), "float32")),
291
- 0.81649658,
292
- atol=1e-5,
293
- )
294
- arr = np.random.normal(size=(6, 6))
295
- np.testing.assert_allclose(
296
- tc.backend.relu(tc.array_to_tensor(arr, dtype="float32")),
297
- np.maximum(arr, 0),
298
- atol=1e-4,
299
- )
300
- np.testing.assert_allclose(
301
- tc.backend.det(tc.backend.convert_to_tensor(np.eye(3) * 2)), 8, atol=1e-5
302
- )
303
-
304
-
305
- # @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
306
- # def test_backend_array(backend):
307
- # a = tc.backend.array([[0, 1], [1, 0]])
308
- # assert tc.interfaces.which_backend(a).name == tc.backend.name
309
- # a = tc.backend.array([[0, 1], [1, 0]], dtype=tc.rdtypestr)
310
- # assert tc.dtype(a) == "float32"
311
-
312
-
313
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
314
- def test_device_cpu_only(backend):
315
- a = tc.backend.ones([])
316
- dev_str = tc.backend.device(a)
317
- assert dev_str in ["cpu", "gpu:0"]
318
- tc.backend.device_move(a, dev_str)
319
-
320
-
321
- @pytest.mark.skipif(
322
- len(tf.config.list_physical_devices()) == 1, reason="no GPU detected"
323
- )
324
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
325
- def test_device_cpu_gpu(backend):
326
- a = tc.backend.ones([])
327
- a1 = tc.backend.device_move(a, "gpu:0")
328
- dev_str = tc.backend.device(a1)
329
- assert dev_str == "gpu:0"
330
- a2 = tc.backend.device_move(a1, "cpu")
331
- assert tc.backend.device(a2) == "cpu"
332
-
333
-
334
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
335
- def test_dlpack(backend):
336
- a = tc.backend.ones([2, 2], dtype="float64")
337
- cap = tc.backend.to_dlpack(a)
338
- a1 = tc.backend.from_dlpack(cap)
339
- np.testing.assert_allclose(a, a1, atol=1e-5)
340
-
341
-
342
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
343
- def test_arg_cmp(backend):
344
- np.testing.assert_allclose(tc.backend.argmax(tc.backend.ones([3], "float64")), 0)
345
- np.testing.assert_allclose(
346
- tc.backend.argmax(
347
- tc.array_to_tensor(np.array([[1, 2], [3, 4]]), dtype="float64")
348
- ),
349
- np.array([1, 1]),
350
- )
351
- np.testing.assert_allclose(
352
- tc.backend.argmin(
353
- tc.array_to_tensor(np.array([[1, 2], [3, 4]]), dtype="float64"), axis=-1
354
- ),
355
- np.array([0, 0]),
356
- )
357
-
358
-
359
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
360
- def test_tree_map(backend):
361
- def f(a, b):
362
- return a + b
363
-
364
- r = tc.backend.tree_map(
365
- f, {"a": tc.backend.ones([2])}, {"a": 2 * tc.backend.ones([2])}
366
- )
367
- np.testing.assert_allclose(r["a"], 3 * np.ones([2]), atol=1e-4)
368
-
369
- def _add(a, b):
370
- return a + b
371
-
372
- ans = tc.backend.tree_map(
373
- _add,
374
- {"a": tc.backend.ones([2]), "b": tc.backend.ones([3])},
375
- {"a": tc.backend.ones([2]), "b": tc.backend.ones([3])},
376
- )
377
- np.testing.assert_allclose(ans["a"], 2 * np.ones([2]))
378
-
379
-
380
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
381
- def test_backend_randoms(backend):
382
- @partial(tc.backend.jit, static_argnums=0)
383
- def random_matrixn(key):
384
- tc.backend.set_random_state(key)
385
- r1 = tc.backend.implicit_randn(shape=[2, 2], mean=0.5)
386
- r2 = tc.backend.implicit_randn(shape=[2, 2], mean=0.5)
387
- return r1, r2
388
-
389
- key = 42
390
- if tc.backend.name == "tensorflow":
391
- key = tf.random.Generator.from_seed(42)
392
- r11, r12 = random_matrixn(key)
393
- if tc.backend.name == "tensorflow":
394
- key = tf.random.Generator.from_seed(42)
395
- r21, r22 = random_matrixn(key)
396
- np.testing.assert_allclose(r11, r21, atol=1e-4)
397
- np.testing.assert_allclose(r12, r22, atol=1e-4)
398
- assert not np.allclose(r11, r12, atol=1e-4)
399
-
400
- def random_matrixu(key):
401
- tc.backend.set_random_state(key)
402
- r1 = tc.backend.implicit_randu(shape=[2, 2], high=2)
403
- r2 = tc.backend.implicit_randu(shape=[2, 2], high=1)
404
- return r1, r2
405
-
406
- key = 42
407
- r31, r32 = random_matrixu(key)
408
- np.testing.assert_allclose(r31.shape, [2, 2])
409
- assert np.any(r32 > 0)
410
- assert not np.allclose(r31, r32, atol=1e-4)
411
-
412
- def random_matrixc(key):
413
- tc.backend.set_random_state(key)
414
- r1 = tc.backend.implicit_randc(a=[1, 2, 3], shape=(2, 2))
415
- r2 = tc.backend.implicit_randc(a=[1, 2, 3], shape=(2, 2), p=[0.1, 0.4, 0.5])
416
- return r1, r2
417
-
418
- r41, r42 = random_matrixc(key)
419
- np.testing.assert_allclose(r41.shape, [2, 2])
420
- assert np.any((r42 > 0) & (r42 < 4))
421
-
422
-
423
- def vqe_energy(inputs, param, n, nlayers):
424
- c = tc.Circuit(n, inputs=inputs)
425
- paramc = tc.backend.cast(param, "complex64")
426
-
427
- for i in range(n):
428
- c.H(i)
429
- for j in range(nlayers):
430
- for i in range(n - 1):
431
- c.ryy(i, i + 1, theta=paramc[2 * j, i])
432
- # c.any(
433
- # i,
434
- # i + 1,
435
- # unitary=tc.backend.cos(paramc[2 * j, i]) * iir
436
- # + tc.backend.sin(paramc[2 * j, i]) * 1.0j * yzr,
437
- # )
438
- for i in range(n):
439
- c.rx(i, theta=paramc[2 * j + 1, i])
440
- e = 0.0
441
- for i in range(n):
442
- e += c.expectation((tc.gates.x(), [i]))
443
- for i in range(n - 1): # OBC
444
- e += c.expectation((tc.gates.z(), [i]), (tc.gates.z(), [(i + 1) % n]))
445
- e = tc.backend.real(e)
446
- return e
447
-
448
-
449
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
450
- def test_vvag(backend):
451
- n = 4
452
- nlayers = 3
453
- inp = tc.backend.ones([2**n]) / 2 ** (n / 2)
454
- param = tc.backend.ones([2 * nlayers, n])
455
- # inp = tc.backend.cast(inp, "complex64")
456
- # param = tc.backend.cast(param, "complex64")
457
-
458
- vqe_energy_p = partial(vqe_energy, n=n, nlayers=nlayers)
459
-
460
- vg = tc.backend.value_and_grad(vqe_energy_p, argnums=(0, 1))
461
- v0, (g00, g01) = vg(inp, param)
462
-
463
- batch = 8
464
- inps = tc.backend.ones([batch, 2**n]) / 2 ** (n / 2)
465
- inps = tc.backend.cast(inps, "complex64")
466
-
467
- pvag = tc.backend.vvag(vqe_energy_p, argnums=(0, 1))
468
- v1, (g10, g11) = pvag(inps, param)
469
- print(v1.shape, g10.shape, g11.shape)
470
- np.testing.assert_allclose(v1[0], v0, atol=1e-4)
471
- np.testing.assert_allclose(g10[0], g00, atol=1e-4)
472
- np.testing.assert_allclose(g11 / batch, g01, atol=1e-4)
473
-
474
-
475
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
476
- def test_vvag_dict(backend):
477
- def dict_plus(x, y):
478
- a = x["a"]
479
- return tc.backend.real((a + y)[0])
480
-
481
- dp_vvag = tc.backend.vvag(dict_plus, vectorized_argnums=1, argnums=0)
482
- x = {"a": tc.backend.ones([1])}
483
- y = tc.backend.ones([20, 1])
484
- v, g = dp_vvag(x, y)
485
- np.testing.assert_allclose(v.shape, [20])
486
- np.testing.assert_allclose(g["a"], 20.0, atol=1e-4)
487
-
488
-
489
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
490
- def test_vjp(backend):
491
- def f(x):
492
- return x**2
493
-
494
- inputs = tc.backend.ones([2, 2])
495
- v, g = tc.backend.vjp(f, inputs, inputs)
496
- np.testing.assert_allclose(v, inputs, atol=1e-5)
497
- np.testing.assert_allclose(g, 2 * inputs, atol=1e-5)
498
-
499
- def f2(x, y):
500
- return x + y, x - y
501
-
502
- inputs = [tc.backend.ones([2]), tc.backend.ones([2])]
503
- v = [2.0 * t for t in inputs]
504
- v, g = tc.backend.vjp(f2, inputs, v)
505
- np.testing.assert_allclose(v[1], np.zeros([2]), atol=1e-5)
506
- np.testing.assert_allclose(g[0], 4 * np.ones([2]), atol=1e-5)
507
-
508
-
509
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
510
- def test_vjp_complex(backend):
511
- def f(x):
512
- return tc.backend.conj(x)
513
-
514
- inputs = tc.backend.ones([1]) + 1.0j * tc.backend.ones([1])
515
- v = tc.backend.ones([1], dtype="complex64")
516
- v, g = tc.backend.vjp(f, inputs, v)
517
- np.testing.assert_allclose(tc.backend.numpy(g), np.ones([1]), atol=1e-5)
518
-
519
- def f2(x):
520
- return x**2
521
-
522
- inputs = tc.backend.ones([1]) + 1.0j * tc.backend.ones([1])
523
- v = tc.backend.ones([1], dtype="complex64") # + 1.0j * tc.backend.ones([1])
524
- v, g = tc.backend.vjp(f2, inputs, v)
525
- # note how vjp definition on complex function is different in jax backend
526
- if tc.backend.name == "jax":
527
- np.testing.assert_allclose(tc.backend.numpy(g), 2 + 2j, atol=1e-5)
528
- else:
529
- np.testing.assert_allclose(tc.backend.numpy(g), 2 - 2j, atol=1e-5)
530
-
531
-
532
- # TODO(@refraction-ray): consistent and unified pytree utils for pytorch backend?
533
-
534
-
535
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
536
- def test_vjp_pytree(backend):
537
- def f3(d):
538
- return d["a"] + d["b"], d["a"]
539
-
540
- inputs = {"a": tc.backend.ones([2]), "b": tc.backend.ones([1])}
541
- v = (tc.backend.ones([2]), tc.backend.zeros([2]))
542
- v, g = tc.backend.vjp(f3, inputs, v)
543
- np.testing.assert_allclose(v[0], 2 * np.ones([2]), atol=1e-5)
544
- np.testing.assert_allclose(g["a"], np.ones([2]), atol=1e-5)
545
-
546
-
547
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
548
- def test_jvp(backend):
549
- def f(x):
550
- return x**2
551
-
552
- inputs = tc.backend.ones([2, 2])
553
- v, g = tc.backend.jvp(f, inputs, inputs)
554
- np.testing.assert_allclose(v, inputs, atol=1e-5)
555
- np.testing.assert_allclose(g, 2 * inputs, atol=1e-5)
556
-
557
- def f2(x, y):
558
- return x + y, x - y
559
-
560
- inputs = [tc.backend.ones([2]), tc.backend.ones([2])]
561
- v = [2.0 * t for t in inputs]
562
- v, g = tc.backend.jvp(f2, inputs, v)
563
- np.testing.assert_allclose(v[1], np.zeros([2]), atol=1e-5)
564
- np.testing.assert_allclose(g[0], 4 * np.ones([2]), atol=1e-5)
565
-
566
-
567
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
568
- def test_jvp_complex(backend):
569
- def f(x):
570
- return tc.backend.conj(x)
571
-
572
- inputs = tc.backend.ones([1]) + 1.0j * tc.backend.ones([1])
573
- v = tc.backend.ones([1], dtype="complex64")
574
- v, g = tc.backend.jvp(f, inputs, v)
575
- # numpy auto numpy doesn't work for torch conjugate tensor
576
- np.testing.assert_allclose(tc.backend.numpy(g), np.ones([1]), atol=1e-5)
577
-
578
- def f2(x):
579
- return x**2
580
-
581
- inputs = tc.backend.ones([1]) + 1.0j * tc.backend.ones([1])
582
- v = tc.backend.ones([1]) + 1.0j * tc.backend.ones([1])
583
- v, g = tc.backend.jvp(f2, inputs, v)
584
- np.testing.assert_allclose(tc.backend.numpy(g), 4.0j, atol=1e-5)
585
-
586
-
587
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
588
- def test_jvp_pytree(backend):
589
- def f3(d):
590
- return d["a"] + d["b"], d["a"]
591
-
592
- inputs = {"a": tc.backend.ones([2]), "b": tc.backend.ones([1])}
593
- v = (tc.backend.ones([2]), tc.backend.zeros([2]))
594
- v, g = tc.backend.vjp(f3, inputs, v)
595
- np.testing.assert_allclose(v[0], 2 * np.ones([2]), atol=1e-5)
596
- np.testing.assert_allclose(g["a"], np.ones([2]), atol=1e-5)
597
-
598
-
599
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
600
- @pytest.mark.parametrize("mode", ["jacfwd", "jacrev"])
601
- def test_jac(backend, mode):
602
- # make no sense for torch backend when you have no real vmap interface
603
- backend_jac = getattr(tc.backend, mode)
604
-
605
- def f(x):
606
- return x**2
607
-
608
- x = tc.backend.ones([3])
609
- jacf = backend_jac(f)
610
- np.testing.assert_allclose(jacf(x), 2 * np.eye(3), atol=1e-5)
611
-
612
- def f2(x):
613
- return x**2, x
614
-
615
- jacf2 = backend_jac(f2)
616
- np.testing.assert_allclose(jacf2(x)[1], np.eye(3), atol=1e-5)
617
- np.testing.assert_allclose(jacf2(x)[0], 2 * np.eye(3), atol=1e-5)
618
-
619
- def f3(x, y):
620
- return x + y**2
621
-
622
- jacf3 = backend_jac(f3, argnums=(0, 1))
623
- jacf3jit = tc.backend.jit(backend_jac(f3, argnums=(0, 1)))
624
- np.testing.assert_allclose(jacf3jit(x, x)[1], 2 * np.eye(3), atol=1e-5)
625
- np.testing.assert_allclose(jacf3(x, x)[1], 2 * np.eye(3), atol=1e-5)
626
-
627
- def f4(x, y):
628
- return x**2, y
629
-
630
- # note the subtle difference of two tuples order in jacrev and jacfwd for current API
631
- # the value happen to be the same here, though
632
- jacf4 = backend_jac(f4, argnums=(0, 1))
633
- jacf4jit = tc.backend.jit(backend_jac(f4, argnums=(0, 1)))
634
- np.testing.assert_allclose(jacf4jit(x, x)[1][1], np.eye(3), atol=1e-5)
635
- np.testing.assert_allclose(jacf4jit(x, x)[0][1], np.zeros([3, 3]), atol=1e-5)
636
- np.testing.assert_allclose(jacf4(x, x)[1][1], np.eye(3), atol=1e-5)
637
- np.testing.assert_allclose(jacf4(x, x)[0][1], np.zeros([3, 3]), atol=1e-5)
638
-
639
-
640
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
641
- @pytest.mark.parametrize("mode", ["jacfwd", "jacrev"])
642
- def test_jac_md_input(backend, mode):
643
- backend_jac = getattr(tc.backend, mode)
644
-
645
- def f(x):
646
- return x**2
647
-
648
- x = tc.backend.ones([2, 3])
649
- jacf = backend_jac(f)
650
- np.testing.assert_allclose(jacf(x).shape, [2, 3, 2, 3], atol=1e-5)
651
-
652
- def f2(x):
653
- return tc.backend.sum(x, axis=0)
654
-
655
- x = tc.backend.ones([2, 3])
656
- jacf2 = backend_jac(f2)
657
- np.testing.assert_allclose(jacf2(x).shape, [3, 2, 3], atol=1e-5)
658
-
659
-
660
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
661
- @pytest.mark.parametrize("mode", ["jacfwd", "jacrev"])
662
- def test_jac_tall(backend, mode):
663
- backend_jac = getattr(tc.backend, mode)
664
-
665
- h = tc.backend.ones([5, 3])
666
-
667
- def f(x):
668
- x = tc.backend.reshape(x, [-1, 1])
669
- return tc.backend.reshape(h @ x, [-1])
670
-
671
- x = tc.backend.ones([3])
672
- jacf = backend_jac(f)
673
- np.testing.assert_allclose(jacf(x), np.ones([5, 3]), atol=1e-5)
674
-
675
-
676
- @pytest.mark.parametrize("backend", [lf("jaxb"), lf("tfb")])
677
- def test_vvag_has_aux(backend):
678
- def f(x):
679
- y = tc.backend.sum(x)
680
- return tc.backend.real(y**2), y
681
-
682
- fvvag = tc.backend.vvag(f, has_aux=True)
683
- (_, v1), _ = fvvag(tc.backend.ones([10, 2]))
684
- np.testing.assert_allclose(v1, 2 * tc.backend.ones([10]))
685
-
686
-
687
- def test_jax_svd(jaxb, highp):
688
- def l(A):
689
- u, _, v, _ = tc.backend.svd(A)
690
- return tc.backend.real(u[0, 0] * v[0, 0])
691
-
692
- def numericald(A):
693
- eps = 1e-6
694
- DA = np.zeros_like(A)
695
- for i in range(A.shape[0]):
696
- for j in range(A.shape[1]):
697
- dA = np.zeros_like(A)
698
- dA[i, j] = 1
699
- DA[i, j] = (l(A + eps * dA) - l(A)) / eps - 1.0j * (
700
- l(A + eps * 1.0j * dA) - l(A)
701
- ) / eps
702
- return DA
703
-
704
- def analyticald(A):
705
- A = tc.backend.convert_to_tensor(A)
706
- g = tc.backend.grad(l)
707
- return g(A)
708
-
709
- for shape in [(2, 2), (3, 3), (2, 3), (4, 2)]:
710
- m = np.random.normal(size=shape).astype(
711
- np.complex128
712
- ) + 1.0j * np.random.normal(size=shape).astype(np.complex128)
713
- print(m)
714
- np.testing.assert_allclose(numericald(m), analyticald(m), atol=1e-3)
715
-
716
-
717
- @pytest.mark.parametrize("backend", [lf("jaxb"), lf("tfb"), lf("torchb")])
718
- def test_qr(backend, highp):
719
- def get_random_complex(shape):
720
- result = np.random.random(shape) + np.random.random(shape) * 1j
721
- return tc.backend.convert_to_tensor(result.astype(dtype))
722
-
723
- np.random.seed(0)
724
- A1 = get_random_complex((2, 2))
725
- A2 = tc.backend.convert_to_tensor(np.array([[1.0, 0.0], [0.0, 0.0]]).astype(dtype))
726
- X = get_random_complex((2, 2))
727
-
728
- def func(A, x):
729
- x = tc.backend.cast(x, "complex64")
730
- Q, R = tc.backend.qr(A + X * x)
731
- return tc.backend.real(tc.backend.sum(tc.backend.matmul(Q, R)))
732
-
733
- def grad(A, x):
734
- return tc.backend.grad(func, argnums=1)(A, x)
735
-
736
- for A in [A1, A2]:
737
- epsilon = tc.backend.convert_to_tensor(1e-3)
738
- n_grad = (func(A, epsilon) - func(A, -epsilon)) / (2 * epsilon)
739
- a_grad = grad(A, tc.backend.convert_to_tensor(0.0))
740
- np.testing.assert_allclose(n_grad, a_grad, atol=1e-3)
741
-
742
-
743
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
744
- def test_sparse_methods(backend):
745
- values = tc.backend.convert_to_tensor(np.array([1.0, 2.0]))
746
- values = tc.backend.cast(values, "complex64")
747
- indices = tc.backend.convert_to_tensor(np.array([[0, 0], [1, 1]]))
748
- indices = tc.backend.cast(indices, "int64")
749
- spa = tc.backend.coo_sparse_matrix(indices, values, shape=[4, 4])
750
- vec = tc.backend.ones([4, 1])
751
- da = np.array(
752
- [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], dtype=np.complex64
753
- )
754
- assert tc.backend.is_sparse(spa) is True
755
- assert tc.backend.is_sparse(vec) is False
756
- np.testing.assert_allclose(
757
- tc.backend.to_dense(spa),
758
- da,
759
- atol=1e-5,
760
- )
761
- np.testing.assert_allclose(
762
- tc.backend.sparse_dense_matmul(spa, vec),
763
- np.array([[1], [2], [0], [0]], dtype=np.complex64),
764
- atol=1e-5,
765
- )
766
- spa_np = tc.backend.numpy(spa)
767
- np.testing.assert_allclose(spa_np.todense(), da, atol=1e-6)
768
- np.testing.assert_allclose(
769
- tc.backend.to_dense(tc.backend.coo_sparse_matrix_from_numpy(spa_np)),
770
- da,
771
- atol=1e-5,
772
- )
773
-
774
-
775
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
776
- def test_backend_randoms_v2(backend):
777
- g = tc.backend.get_random_state(42)
778
- for t in tc.backend.stateful_randc(g, 3, [3]):
779
- assert t >= 0
780
- assert t < 3
781
- key = tc.backend.get_random_state(42)
782
- r = []
783
- for _ in range(2):
784
- key, subkey = tc.backend.random_split(key)
785
- r.append(tc.backend.stateful_randc(subkey, 3, [5]))
786
- assert tuple(r[0]) != tuple(r[1])
787
-
788
-
789
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
790
- def test_backend_randoms_v3(backend):
791
- tc.backend.set_random_state(42)
792
- for _ in range(2):
793
- r1 = tc.backend.implicit_randu()
794
- key = tc.backend.get_random_state(42)
795
- for _ in range(2):
796
- key, subkey = tc.backend.random_split(key)
797
- r2 = tc.backend.stateful_randu(subkey)
798
- np.testing.assert_allclose(r1, r2, atol=1e-5)
799
-
800
- @tc.backend.jit
801
- def f(key):
802
- tc.backend.set_random_state(key)
803
- r = []
804
- for _ in range(3):
805
- r.append(tc.backend.implicit_randu()[0])
806
- return r
807
-
808
- @tc.backend.jit
809
- def f2(key):
810
- r = []
811
- for _ in range(3):
812
- key, subkey = tc.backend.random_split(key)
813
- r.append(tc.backend.stateful_randu(subkey)[0])
814
- return r
815
-
816
- key = tc.backend.get_random_state(43)
817
- r = f(key)
818
- key = tc.backend.get_random_state(43)
819
- r1 = f2(key)
820
- np.testing.assert_allclose(r[-1], r1[-1], atol=1e-5)
821
-
822
-
823
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
824
- def test_function_level_set(backend):
825
- def f(x):
826
- return tc.backend.ones([x])
827
-
828
- f_jax_128 = tc.set_function_backend("jax")(tc.set_function_dtype("complex128")(f))
829
- # note the order to enable complex 128 in jax backend
830
-
831
- assert f_jax_128(3).dtype.__str__() == "complex128"
832
-
833
-
834
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
835
- def test_function_level_set_contractor(backend):
836
- @tc.set_function_contractor("branch")
837
- def f():
838
- return tc.contractor
839
-
840
- print(f())
841
- print(tc.contractor)
842
-
843
-
844
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
845
- def test_with_level_set(backend):
846
- with tc.runtime_backend("jax"):
847
- with tc.runtime_dtype("complex128"):
848
- with tc.runtime_contractor("branch"):
849
- assert tc.backend.ones([2]).dtype.__str__() == "complex128"
850
- print(tc.contractor)
851
- print(tc.contractor)
852
- print(tc.backend.name)
853
-
854
-
855
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
856
- def test_with_level_set_return(backend):
857
- with tc.runtime_backend("jax") as K:
858
- assert K.name == "jax"
859
-
860
-
861
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
862
- def test_grad_has_aux(backend):
863
- def f(x):
864
- return tc.backend.real(x**2), x**3
865
-
866
- vg = tc.backend.value_and_grad(f, has_aux=True)
867
-
868
- np.testing.assert_allclose(
869
- vg(tc.backend.ones([]))[1], 2 * tc.backend.ones([]), atol=1e-5
870
- )
871
-
872
- def f2(x):
873
- return tc.backend.real(x**2), (x**3, tc.backend.ones([3]))
874
-
875
- gs = tc.backend.grad(f2, has_aux=True)
876
- np.testing.assert_allclose(gs(tc.backend.ones([]))[0], 2.0, atol=1e-5)
877
-
878
-
879
- @pytest.mark.parametrize("backend", [lf("npb"), lf("jaxb"), lf("tfb")])
880
- def test_solve(backend):
881
- A = np.array([[2, 1, 0], [1, 2, 0], [0, 0, 1]], dtype=np.float32)
882
- A = tc.backend.convert_to_tensor(A)
883
- x = np.ones([3, 1], dtype=np.float32)
884
- x = tc.backend.convert_to_tensor(x)
885
- b = (A @ x)[:, 0]
886
- print(A.shape, b.shape)
887
- xp = tc.backend.solve(A, b, assume_a="her")
888
- np.testing.assert_allclose(xp, x[:, 0], atol=1e-5)
889
-
890
-
891
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
892
- def test_treeutils(backend):
893
- d0 = {"a": np.ones([2]), "b": [tc.backend.zeros([]), tc.backend.ones([1, 1])]}
894
- leaves, treedef = tc.backend.tree_flatten(d0)
895
- d1 = tc.backend.tree_unflatten(treedef, leaves)
896
- d2 = tc.backend.tree_map(lambda x: 2 * x, d1)
897
- np.testing.assert_allclose(2 * np.ones([1, 1]), d2["b"][1])
898
-
899
-
900
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
901
- def test_optimizers(backend):
902
- if tc.backend.name == "jax":
903
- try:
904
- import optax
905
- except ImportError:
906
- pytest.skip("optax is not installed")
907
-
908
- if tc.backend.name == "pytorch":
909
- try:
910
- import torch
911
- except ImportError:
912
- pytest.skip("torch is not installed")
913
-
914
- def f(params, n):
915
- c = tc.Circuit(n)
916
- c = tc.templates.blocks.example_block(c, params["a"])
917
- c = tc.templates.blocks.example_block(c, params["b"])
918
- return tc.backend.real(c.expectation([tc.gates.x(), [n // 2]]))
919
-
920
- vgs = tc.backend.jit(tc.backend.value_and_grad(f, argnums=0), static_argnums=1)
921
-
922
- def get_opt():
923
- if tc.backend.name == "tensorflow":
924
- optimizer1 = tf.keras.optimizers.Adam(5e-2)
925
- opt = tc.backend.optimizer(optimizer1)
926
- elif tc.backend.name == "jax":
927
- optimizer2 = optax.adam(5e-2)
928
- opt = tc.backend.optimizer(optimizer2)
929
- elif tc.backend.name == "pytorch":
930
- optimizer3 = partial(torch.optim.Adam, lr=5e-2)
931
- opt = tc.backend.optimizer(optimizer3)
932
- else:
933
- raise ValueError("%s doesn't support optimizer interface" % tc.backend.name)
934
- return opt
935
-
936
- n = 3
937
- opt = get_opt()
938
-
939
- params = {
940
- "a": tc.backend.ones([4, n], dtype="float32"),
941
- "b": tc.backend.ones([4, n], dtype="float32"),
942
- }
943
-
944
- for _ in range(20):
945
- loss, grads = vgs(params, n)
946
- params = opt.update(grads, params)
947
- print(loss)
948
-
949
- assert loss < -0.7
950
-
951
- def f2(params, n):
952
- c = tc.Circuit(n)
953
- c = tc.templates.blocks.example_block(c, params)
954
- return tc.backend.real(c.expectation([tc.gates.x(), [n // 2]]))
955
-
956
- vgs2 = tc.backend.jit(tc.backend.value_and_grad(f2, argnums=0), static_argnums=1)
957
-
958
- params = tc.backend.ones([4, n], dtype="float32")
959
- opt = get_opt()
960
-
961
- for _ in range(20):
962
- loss, grads = vgs2(params, n)
963
- params = opt.update(grads, params)
964
- print(loss)
965
-
966
- assert loss < -0.7
967
-
968
-
969
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
970
- def test_hessian(backend):
971
- # hessian support is now very fragile and especially has potential issues on tf backend
972
- def f(param):
973
- return tc.backend.sum(param**2)
974
-
975
- hf = tc.backend.hessian(f)
976
- param = tc.backend.ones([2])
977
- np.testing.assert_allclose(hf(param), 2 * tc.backend.eye(2), atol=1e-5)
978
-
979
- param = tc.backend.ones([2, 2])
980
- assert list(hf(param).shape) == [2, 2, 2, 2] # possible tf retracing?
981
-
982
- g = tc.templates.graphs.Line1D(5)
983
-
984
- def circuit_f(param):
985
- c = tc.Circuit(5)
986
- c = tc.templates.blocks.example_block(c, param, nlayers=1)
987
- return tc.templates.measurements.heisenberg_measurements(c, g)
988
-
989
- param = tc.backend.ones([10])
990
- hf = tc.backend.hessian(circuit_f)
991
- print(hf(param)) # still upto a conjugate for jax and tf backend.
992
-
993
-
994
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
995
- def test_nested_vmap(backend):
996
- def f(x, w):
997
- c = tc.Circuit(4)
998
- for i in range(4):
999
- c.rx(i, theta=x[i])
1000
- c.ry(i, theta=w[i])
1001
- return tc.backend.stack([c.expectation_ps(z=[i]) for i in range(4)])
1002
-
1003
- def fa1(*args):
1004
- r = tc.backend.vmap(f, vectorized_argnums=1)(*args)
1005
- return r
1006
-
1007
- def fa2(*args):
1008
- r = tc.backend.vmap(fa1, vectorized_argnums=0)(*args)
1009
- return r
1010
-
1011
- fa2jit = tc.backend.jit(fa2)
1012
-
1013
- ya = fa2(tc.backend.ones([3, 4]), tc.backend.ones([7, 4]))
1014
- yajit = fa2jit(tc.backend.ones([3, 4]), tc.backend.ones([7, 4]))
1015
-
1016
- def fb1(*args):
1017
- r = tc.backend.vmap(f, vectorized_argnums=0)(*args)
1018
- return r
1019
-
1020
- def fb2(*args):
1021
- r = tc.backend.vmap(fb1, vectorized_argnums=1)(*args)
1022
- return r
1023
-
1024
- fb2jit = tc.backend.jit(fb2)
1025
-
1026
- yb = fb2(tc.backend.ones([3, 4]), tc.backend.ones([7, 4]))
1027
- ybjit = fb2jit(tc.backend.ones([3, 4]), tc.backend.ones([7, 4]))
1028
-
1029
- np.testing.assert_allclose(ya, tc.backend.transpose(yb, [1, 0, 2]), atol=1e-5)
1030
- np.testing.assert_allclose(ya, yajit, atol=1e-5)
1031
- np.testing.assert_allclose(yajit, tc.backend.transpose(ybjit, [1, 0, 2]), atol=1e-5)