tensorcircuit-nightly 1.2.0.dev20250326__py3-none-any.whl → 1.4.0.dev20251128__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

Files changed (77) hide show
  1. tensorcircuit/__init__.py +5 -1
  2. tensorcircuit/abstractcircuit.py +4 -0
  3. tensorcircuit/analogcircuit.py +413 -0
  4. tensorcircuit/applications/layers.py +1 -1
  5. tensorcircuit/applications/van.py +1 -1
  6. tensorcircuit/backends/abstract_backend.py +312 -5
  7. tensorcircuit/backends/cupy_backend.py +3 -1
  8. tensorcircuit/backends/jax_backend.py +100 -4
  9. tensorcircuit/backends/jax_ops.py +108 -0
  10. tensorcircuit/backends/numpy_backend.py +49 -3
  11. tensorcircuit/backends/pytorch_backend.py +92 -3
  12. tensorcircuit/backends/tensorflow_backend.py +102 -3
  13. tensorcircuit/basecircuit.py +157 -98
  14. tensorcircuit/circuit.py +115 -57
  15. tensorcircuit/cloud/local.py +1 -1
  16. tensorcircuit/cloud/quafu_provider.py +1 -1
  17. tensorcircuit/cloud/tencent.py +1 -1
  18. tensorcircuit/compiler/simple_compiler.py +2 -2
  19. tensorcircuit/cons.py +105 -23
  20. tensorcircuit/densitymatrix.py +16 -11
  21. tensorcircuit/experimental.py +733 -153
  22. tensorcircuit/fgs.py +254 -73
  23. tensorcircuit/gates.py +66 -22
  24. tensorcircuit/interfaces/jax.py +5 -3
  25. tensorcircuit/interfaces/tensortrans.py +6 -2
  26. tensorcircuit/interfaces/torch.py +14 -4
  27. tensorcircuit/keras.py +3 -3
  28. tensorcircuit/mpscircuit.py +154 -65
  29. tensorcircuit/quantum.py +698 -134
  30. tensorcircuit/quditcircuit.py +733 -0
  31. tensorcircuit/quditgates.py +618 -0
  32. tensorcircuit/results/counts.py +131 -18
  33. tensorcircuit/results/readout_mitigation.py +4 -1
  34. tensorcircuit/shadows.py +1 -1
  35. tensorcircuit/simplify.py +3 -1
  36. tensorcircuit/stabilizercircuit.py +29 -17
  37. tensorcircuit/templates/__init__.py +2 -0
  38. tensorcircuit/templates/blocks.py +2 -2
  39. tensorcircuit/templates/hamiltonians.py +174 -0
  40. tensorcircuit/templates/lattice.py +1789 -0
  41. tensorcircuit/timeevol.py +896 -0
  42. tensorcircuit/translation.py +10 -3
  43. tensorcircuit/utils.py +7 -0
  44. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/METADATA +66 -29
  45. tensorcircuit_nightly-1.4.0.dev20251128.dist-info/RECORD +96 -0
  46. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/WHEEL +1 -1
  47. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/top_level.txt +0 -1
  48. tensorcircuit_nightly-1.2.0.dev20250326.dist-info/RECORD +0 -118
  49. tests/__init__.py +0 -0
  50. tests/conftest.py +0 -67
  51. tests/test_backends.py +0 -1035
  52. tests/test_calibrating.py +0 -149
  53. tests/test_channels.py +0 -409
  54. tests/test_circuit.py +0 -1699
  55. tests/test_cloud.py +0 -219
  56. tests/test_compiler.py +0 -147
  57. tests/test_dmcircuit.py +0 -555
  58. tests/test_ensemble.py +0 -72
  59. tests/test_fgs.py +0 -310
  60. tests/test_gates.py +0 -156
  61. tests/test_interfaces.py +0 -562
  62. tests/test_keras.py +0 -160
  63. tests/test_miscs.py +0 -282
  64. tests/test_mpscircuit.py +0 -341
  65. tests/test_noisemodel.py +0 -156
  66. tests/test_qaoa.py +0 -86
  67. tests/test_qem.py +0 -152
  68. tests/test_quantum.py +0 -549
  69. tests/test_quantum_attr.py +0 -42
  70. tests/test_results.py +0 -380
  71. tests/test_shadows.py +0 -160
  72. tests/test_simplify.py +0 -46
  73. tests/test_stabilizer.py +0 -217
  74. tests/test_templates.py +0 -218
  75. tests/test_torchnn.py +0 -99
  76. tests/test_van.py +0 -102
  77. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/licenses/LICENSE +0 -0
tests/test_interfaces.py DELETED
@@ -1,562 +0,0 @@
1
- # pylint: disable=invalid-name
2
-
3
- import os
4
- import sys
5
- from functools import partial
6
- import pytest
7
- from pytest_lazyfixture import lazy_fixture as lf
8
- from scipy import optimize
9
- import tensorflow as tf
10
- import jax
11
- from jax import numpy as jnp
12
-
13
- thisfile = os.path.abspath(__file__)
14
- modulepath = os.path.dirname(os.path.dirname(thisfile))
15
-
16
- sys.path.insert(0, modulepath)
17
-
18
- try:
19
- import torch
20
-
21
- is_torch = True
22
- except ImportError:
23
- is_torch = False
24
-
25
- import numpy as np
26
- import tensorcircuit as tc
27
-
28
-
29
- @pytest.mark.skipif(is_torch is False, reason="torch not installed")
30
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
31
- def test_torch_interface(backend):
32
- n = 4
33
-
34
- def f(param):
35
- c = tc.Circuit(n)
36
- c = tc.templates.blocks.example_block(c, param)
37
- loss = c.expectation(
38
- [
39
- tc.gates.x(),
40
- [
41
- 1,
42
- ],
43
- ]
44
- )
45
- return tc.backend.real(loss)
46
-
47
- f_jit = tc.backend.jit(f)
48
-
49
- f_jit_torch = tc.interfaces.torch_interface(f_jit, enable_dlpack=True)
50
-
51
- param = torch.ones([4, n], requires_grad=True)
52
- l = f_jit_torch(param)
53
- l = l**2
54
- l.backward()
55
-
56
- pg = param.grad
57
- np.testing.assert_allclose(pg.shape, [4, n])
58
- np.testing.assert_allclose(pg[0, 1], -2.146e-3, atol=1e-5)
59
-
60
- def f2(paramzz, paramx):
61
- c = tc.Circuit(n)
62
- for i in range(n):
63
- c.H(i)
64
- for j in range(2):
65
- for i in range(n - 1):
66
- c.exp1(i, i + 1, unitary=tc.gates._zz_matrix, theta=paramzz[j, i])
67
- for i in range(n):
68
- c.rx(i, theta=paramx[j, i])
69
- loss1 = c.expectation(
70
- [
71
- tc.gates.x(),
72
- [
73
- 1,
74
- ],
75
- ]
76
- )
77
- loss2 = c.expectation(
78
- [
79
- tc.gates.x(),
80
- [
81
- 2,
82
- ],
83
- ]
84
- )
85
- return tc.backend.real(loss1), tc.backend.real(loss2)
86
-
87
- f2_torch = tc.interfaces.torch_interface(f2, jit=True, enable_dlpack=True)
88
-
89
- paramzz = torch.ones([2, n], requires_grad=True)
90
- paramx = torch.ones([2, n], requires_grad=True)
91
-
92
- l1, l2 = f2_torch(paramzz, paramx)
93
- l = l1 - l2
94
- l.backward()
95
-
96
- pg = paramzz.grad
97
- np.testing.assert_allclose(pg.shape, [2, n])
98
- np.testing.assert_allclose(pg[0, 0], -0.41609, atol=1e-5)
99
-
100
- def f3(x):
101
- return tc.backend.real(x**2)
102
-
103
- f3_torch = tc.interfaces.torch_interface(f3)
104
- param3 = torch.ones([2], dtype=torch.complex64, requires_grad=True)
105
- l3 = f3_torch(param3)
106
- l3 = torch.sum(l3)
107
- l3.backward()
108
- pg = param3.grad
109
- np.testing.assert_allclose(pg, 2 * np.ones([2]).astype(np.complex64), atol=1e-5)
110
-
111
-
112
- @pytest.mark.skipif(is_torch is False, reason="torch not installed")
113
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
114
- def test_torch_interface_kws(backend):
115
- def f(param, n):
116
- c = tc.Circuit(n)
117
- c = tc.templates.blocks.example_block(c, param)
118
- loss = c.expectation(
119
- [
120
- tc.gates.x(),
121
- [
122
- 1,
123
- ],
124
- ]
125
- )
126
- return tc.backend.real(loss)
127
-
128
- f_jit_torch = tc.interfaces.torch_interface_kws(f, jit=True, enable_dlpack=True)
129
-
130
- param = torch.ones([4, 4], requires_grad=True)
131
- l = f_jit_torch(param, n=4)
132
- l = l**2
133
- l.backward()
134
-
135
- pg = param.grad
136
- np.testing.assert_allclose(pg.shape, [4, 4])
137
- np.testing.assert_allclose(pg[0, 1], -2.146e-3, atol=1e-5)
138
-
139
- def f2(paramzz, paramx, n, nlayer):
140
- c = tc.Circuit(n)
141
- for i in range(n):
142
- c.H(i)
143
- for j in range(nlayer): # 2
144
- for i in range(n - 1):
145
- c.exp1(i, i + 1, unitary=tc.gates._zz_matrix, theta=paramzz[j, i])
146
- for i in range(n):
147
- c.rx(i, theta=paramx[j, i])
148
- loss1 = c.expectation(
149
- [
150
- tc.gates.x(),
151
- [
152
- 1,
153
- ],
154
- ]
155
- )
156
- loss2 = c.expectation(
157
- [
158
- tc.gates.x(),
159
- [
160
- 2,
161
- ],
162
- ]
163
- )
164
- return tc.backend.real(loss1), tc.backend.real(loss2)
165
-
166
- f2_torch = tc.interfaces.torch_interface_kws(f2, jit=True, enable_dlpack=True)
167
-
168
- paramzz = torch.ones([2, 4], requires_grad=True)
169
- paramx = torch.ones([2, 4], requires_grad=True)
170
-
171
- l1, l2 = f2_torch(paramzz, paramx, n=4, nlayer=2)
172
- l = l1 - l2
173
- l.backward()
174
-
175
- pg = paramzz.grad
176
- np.testing.assert_allclose(pg.shape, [2, 4])
177
- np.testing.assert_allclose(pg[0, 0], -0.41609, atol=1e-5)
178
-
179
-
180
- @pytest.mark.skipif(is_torch is False, reason="torch not installed")
181
- @pytest.mark.xfail(
182
- (int(tf.__version__.split(".")[1]) < 9)
183
- or (int("".join(jax.__version__.split(".")[1:])) < 314),
184
- reason="version too low for tf or jax",
185
- )
186
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
187
- def test_torch_interface_dlpack_complex(backend):
188
- def f3(x):
189
- return tc.backend.real(x**2)
190
-
191
- f3_torch = tc.interfaces.torch_interface(f3, enable_dlpack=True)
192
- param3 = torch.ones([2], dtype=torch.complex64, requires_grad=True)
193
- l3 = f3_torch(param3)
194
- l3 = torch.sum(l3)
195
- l3.backward()
196
- pg = param3.grad
197
- np.testing.assert_allclose(pg, 2 * np.ones([2]).astype(np.complex64), atol=1e-5)
198
-
199
-
200
- @pytest.mark.skipif(is_torch is False, reason="torch not installed")
201
- @pytest.mark.xfail(reason="see comment link below")
202
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
203
- def test_torch_interface_pytree(backend):
204
- # pytree cannot support in pytorch autograd function...
205
- # https://github.com/pytorch/pytorch/issues/55509
206
- def f4(x):
207
- return tc.backend.sum(x["a"] ** 2), tc.backend.sum(x["b"] ** 3)
208
-
209
- f4_torch = tc.interfaces.torch_interface(f4, jit=False)
210
- param4 = {
211
- "a": torch.ones([2], requires_grad=True),
212
- "b": torch.ones([2], requires_grad=True),
213
- }
214
-
215
- def f4_post(x):
216
- r1, r2 = f4_torch(param4)
217
- l4 = r1 + r2
218
- return l4
219
-
220
- pg = tc.get_backend("pytorch").grad(f4_post)(param4)
221
- np.testing.assert_allclose(
222
- pg["a"], 2 * np.ones([2]).astype(np.complex64), atol=1e-5
223
- )
224
-
225
-
226
- @pytest.mark.parametrize("backend", [lf("jaxb")])
227
- def test_tf_interface(backend):
228
- def f0(params):
229
- c = tc.Circuit(1)
230
- c.rx(0, theta=params[0])
231
- c.ry(0, theta=params[1])
232
- return tc.backend.real(c.expectation([tc.gates.z(), [0]]))
233
-
234
- f = tc.interfaces.tf_interface(f0, ydtype=tf.float32, jit=True, enable_dlpack=True)
235
-
236
- tfb = tc.get_backend("tensorflow")
237
- grads = tfb.jit(tfb.grad(f))(tfb.ones([2], dtype="float32"))
238
- np.testing.assert_allclose(
239
- tfb.real(grads), np.array([-0.45464867, -0.45464873]), atol=1e-5
240
- )
241
-
242
- f = tc.interfaces.tf_interface(f0, ydtype="float32", jit=False)
243
-
244
- grads = tfb.grad(f)(tf.ones([2]))
245
- np.testing.assert_allclose(grads, np.array([-0.45464867, -0.45464873]), atol=1e-5)
246
-
247
-
248
- @pytest.mark.parametrize("backend", [lf("jaxb")])
249
- def test_tf_interface_2(backend):
250
- def f1(a, b):
251
- sa, sb = tc.backend.sum(a), tc.backend.sum(b)
252
- return sa + sb, sa - sb
253
-
254
- f = tc.interfaces.tf_interface(f1, ydtype=["float32", "float32"], jit=True)
255
-
256
- def f_post(a, b):
257
- p, m = f(a, b)
258
- return p + m
259
-
260
- tfb = tc.get_backend("tensorflow")
261
-
262
- grads = tfb.jit(tfb.grad(f_post))(
263
- tf.ones([2], dtype=tf.float32), tf.ones([2], dtype=tf.float32)
264
- )
265
-
266
- np.testing.assert_allclose(grads, 2 * np.ones([2]), atol=1e-5)
267
-
268
-
269
- @pytest.mark.parametrize("backend", [lf("jaxb")])
270
- def test_tf_interface_3(backend, highp):
271
- def f1(a, b):
272
- sa, sb = tc.backend.sum(a), tc.backend.sum(b)
273
- return sa + sb
274
-
275
- f = tc.interfaces.tf_interface(f1, ydtype="float64", jit=True)
276
-
277
- tfb = tc.get_backend("tensorflow")
278
-
279
- grads = tfb.jit(tfb.grad(f))(
280
- tf.ones([2], dtype=tf.float64), tf.ones([2], dtype=tf.float64)
281
- )
282
- np.testing.assert_allclose(grads, np.ones([2]), atol=1e-5)
283
-
284
-
285
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
286
- def test_scipy_interface(backend):
287
- n = 3
288
-
289
- def f(param):
290
- c = tc.Circuit(n)
291
- for i in range(n):
292
- c.rx(i, theta=param[0, i])
293
- c.rz(i, theta=param[1, i])
294
- loss = c.expectation(
295
- [
296
- tc.gates.y(),
297
- [
298
- 0,
299
- ],
300
- ]
301
- )
302
- return tc.backend.real(loss)
303
-
304
- if tc.backend.name != "numpy":
305
- f_scipy = tc.interfaces.scipy_optimize_interface(f, shape=[2, n])
306
- r = optimize.minimize(f_scipy, np.zeros([2 * n]), method="L-BFGS-B", jac=True)
307
- # L-BFGS-B may has issue with float32
308
- # see: https://github.com/scipy/scipy/issues/5832
309
- np.testing.assert_allclose(r["fun"], -1.0, atol=1e-5)
310
-
311
- f_scipy = tc.interfaces.scipy_optimize_interface(f, shape=[2, n], gradient=False)
312
- r = optimize.minimize(f_scipy, np.zeros([2 * n]), method="COBYLA")
313
- np.testing.assert_allclose(r["fun"], -1.0, atol=1e-5)
314
-
315
-
316
- @pytest.mark.parametrize("backend", [lf("torchb"), lf("tfb"), lf("jaxb")])
317
- def test_numpy_interface(backend):
318
- def f(params, n):
319
- c = tc.Circuit(n)
320
- for i in range(n):
321
- c.rx(i, theta=params[i])
322
- for i in range(n - 1):
323
- c.cnot(i, i + 1)
324
- r = tc.backend.real(c.expectation_ps(z=[n - 1]))
325
- return r
326
-
327
- n = 3
328
- f_np = tc.interfaces.numpy_interface(f, jit=False)
329
- r = f_np(np.ones([n]), n)
330
- np.testing.assert_allclose(r, 0.1577285, atol=1e-5)
331
-
332
-
333
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
334
- def test_args_transformation(backend):
335
- ans = tc.interfaces.general_args_to_numpy(
336
- (
337
- tc.backend.ones([2]),
338
- {
339
- "a": tc.get_backend("tensorflow").ones([]),
340
- "b": [tc.get_backend("numpy").zeros([2, 1])],
341
- },
342
- )
343
- )
344
- print(ans)
345
- np.testing.assert_allclose(ans[1]["b"][0], np.zeros([2, 1], dtype=np.complex64))
346
- ans1 = tc.interfaces.numpy_args_to_backend(
347
- ans, target_backend="jax", dtype="float32"
348
- )
349
- print(ans1[1]["a"].dtype)
350
- ans1 = tc.interfaces.numpy_args_to_backend(
351
- ans,
352
- target_backend="jax",
353
- dtype=("complex64", {"a": "float32", "b": ["complex64"]}),
354
- )
355
- print(ans1[1]["a"].dtype)
356
-
357
-
358
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
359
- def test_dlpack_transformation(backend):
360
- blist = ["tensorflow", "jax"]
361
- if is_torch is True:
362
- blist.append("pytorch")
363
- for b in blist:
364
- ans = tc.interfaces.general_args_to_backend(
365
- args=tc.backend.ones([2], dtype="float32"),
366
- target_backend=b,
367
- enable_dlpack=True,
368
- )
369
- ans = tc.interfaces.which_backend(ans).device_move(ans, "cpu")
370
- np.testing.assert_allclose(ans, np.ones([2]))
371
-
372
-
373
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
374
- def test_args_to_tensor(backend):
375
- @partial(
376
- tc.interfaces.args_to_tensor,
377
- argnums=[0, 1, 2],
378
- gate_to_tensor=True,
379
- qop_to_tensor=True,
380
- )
381
- def f(a, b, c, d):
382
- return a, b, c, d
383
-
384
- r = f(np.ones([2]), tc.backend.ones([1, 2]), {"a": [tf.zeros([3])]}, np.ones([2]))
385
- a = r[0]
386
- b = r[1]
387
- c = r[2]["a"][0]
388
- d = r[3]
389
- assert tc.interfaces.which_backend(a, return_backend=False) == tc.backend.name
390
- assert tc.interfaces.which_backend(b, return_backend=False) == tc.backend.name
391
- assert tc.interfaces.which_backend(c, return_backend=False) == tc.backend.name
392
- assert tc.interfaces.which_backend(d, return_backend=False) == "numpy"
393
- # print(f([np.ones([2]), np.ones([1])], {"a": np.ones([3])}))
394
- # print(f([tc.Gate(np.ones([2, 2])), tc.Gate(np.ones([2, 2, 2, 2]))], np.ones([2])))
395
-
396
- a, b, c, d = f(
397
- [tc.Gate(np.ones([2, 2])), tc.Gate(np.ones([2, 2, 2, 2]))],
398
- tc.QuOperator.from_tensor(np.ones([2, 2, 2, 2, 2, 2])),
399
- np.ones([2, 2, 2, 2]),
400
- tf.zeros([1, 2]),
401
- )
402
- assert tc.interfaces.which_backend(a[0], return_backend=False) == tc.backend.name
403
- assert tc.backend.shape_tuple(a[1]) == (4, 4)
404
- assert tc.interfaces.which_backend(b, return_backend=False) == tc.backend.name
405
- assert tc.interfaces.which_backend(d, return_backend=False) == "tensorflow"
406
- assert tc.backend.shape_tuple(b) == (8, 8)
407
- assert tc.backend.shape_tuple(c) == (2, 2, 2, 2)
408
-
409
- @partial(
410
- tc.interfaces.args_to_tensor,
411
- argnums=[0, 1, 2],
412
- tensor_as_matrix=False,
413
- gate_to_tensor=True,
414
- gate_as_matrix=False,
415
- qop_to_tensor=True,
416
- qop_as_matrix=False,
417
- )
418
- def g(a, b, c):
419
- return a, b, c
420
-
421
- a, b, c = g(
422
- [tc.Gate(np.ones([2, 2])), tc.Gate(np.ones([2, 2, 2, 2]))],
423
- tc.QuOperator.from_tensor(np.ones([2, 2, 2, 2, 2, 2])),
424
- np.ones([2, 2, 2, 2]),
425
- )
426
-
427
- assert tc.interfaces.which_backend(a[0], return_backend=False) == tc.backend.name
428
- assert tc.backend.shape_tuple(a[1]) == (2, 2, 2, 2)
429
- assert tc.backend.shape_tuple(b.eval()) == (2, 2, 2, 2, 2, 2)
430
- assert tc.backend.shape_tuple(c) == (2, 2, 2, 2)
431
-
432
-
433
- def test_jax_interface_basic(tfb):
434
-
435
- def f(params):
436
- c = tc.Circuit(1)
437
- c.rx(0, theta=params[0])
438
- c.ry(0, theta=params[1])
439
- return tc.backend.real(c.expectation_ps(z=[0]))
440
-
441
- f_jax = tc.interfaces.jax_interface(f, jit=True)
442
- params = jnp.ones(2)
443
-
444
- # Test forward pass
445
- val = f_jax(params)
446
- assert isinstance(val, jnp.ndarray)
447
- np.testing.assert_allclose(val, 0.291927, atol=1e-5)
448
-
449
- # Test gradient computation
450
- val, grad = jax.value_and_grad(f_jax)(params)
451
- assert isinstance(grad, jnp.ndarray)
452
- assert grad.shape == params.shape
453
-
454
-
455
- def test_jax_interface_multiple_inputs(tfb):
456
-
457
- def f(params1, params2):
458
- c = tc.Circuit(2)
459
- c.rx(0, theta=params1[0])
460
- c.ry(1, theta=params2[0])
461
- return tc.backend.real(c.expectation([tc.gates.z(), [0]]))
462
-
463
- f_jax = tc.interfaces.jax_interface(f, jit=False)
464
- p1 = jnp.array([1.0])
465
- p2 = jnp.array([2.0])
466
-
467
- # Test forward pass
468
- val = f_jax(p1, p2)
469
- assert isinstance(val, jnp.ndarray)
470
-
471
- # Test gradient computation
472
-
473
- val, (grad1, grad2) = jax.value_and_grad(f_jax, argnums=(0, 1))(p1, p2)
474
- assert isinstance(grad1, jnp.ndarray)
475
- assert isinstance(grad2, jnp.ndarray)
476
- assert grad1.shape == p1.shape
477
- assert grad2.shape == p2.shape
478
-
479
-
480
- @pytest.mark.skip(
481
- reason="might fail when testing with other function",
482
- )
483
- def test_jax_interface_jit_dlpack(tfb):
484
-
485
- def f(params):
486
- c = tc.Circuit(2)
487
- c.rx(range(2), theta=params)
488
- return tc.backend.real(c.expectation([tc.gates.z(), [0]]))
489
-
490
- # Test with JIT
491
- f_jax = tc.interfaces.jax_interface(f, jit=True, enable_dlpack=True)
492
- params = jnp.array([np.pi, np.pi], dtype=jnp.float32)
493
-
494
- # First call compiles
495
- val1 = f_jax(params)
496
- # Second call should be faster
497
- val2, gs = jax.value_and_grad(f_jax)(params)
498
-
499
- assert isinstance(val1, jnp.ndarray)
500
- assert isinstance(gs, jnp.ndarray)
501
- np.testing.assert_allclose(val1, val2, atol=1e-5)
502
-
503
-
504
- def test_jax_interface_pure_callback(tfb):
505
-
506
- def f(params):
507
- # Use TF operation to test pure_callback
508
- return tf.square(params)
509
-
510
- def f_jax1(params):
511
- return jnp.sum(tc.interfaces.jax_interface(f)(params))
512
-
513
- def f_jax2(params):
514
- return jnp.sum(
515
- tc.interfaces.jax_interface(
516
- f, jit=True, output_shape=[2], output_dtype=jnp.float32
517
- )(params)
518
- )
519
-
520
- params = jnp.array([1.0, 2.0])
521
-
522
- for f_jax in [f_jax1, f_jax2]:
523
- val = f_jax(params)
524
- assert isinstance(val, jnp.ndarray)
525
- np.testing.assert_allclose(val, 5.0, atol=1e-5)
526
-
527
- # Test gradient
528
- grad = jax.grad(f_jax)(params)
529
- assert isinstance(grad, jnp.ndarray)
530
- np.testing.assert_allclose(grad, [2.0, 4.0], atol=1e-5)
531
-
532
-
533
- def test_jax_interface_multiple_outputs(tfb):
534
-
535
- def f(params):
536
- # Use TF operation to test pure_callback
537
- return tf.square(params), params
538
-
539
- def f_jax1(params):
540
- r = tc.interfaces.jax_interface(f)(params)
541
- return jnp.sum(r[0] + r[1] ** 2) / 2
542
-
543
- def f_jax2(params):
544
- r = tc.interfaces.jax_interface(
545
- f,
546
- jit=True,
547
- output_shape=([2], [2]),
548
- output_dtype=(jnp.float32, jnp.float32),
549
- )(params)
550
- return jnp.sum(r[0] + r[1] ** 2) / 2
551
-
552
- params = jnp.array([1.0, 2.0])
553
-
554
- for f_jax in [f_jax1, f_jax2]:
555
- val = f_jax(params)
556
- assert isinstance(val, jnp.ndarray)
557
- np.testing.assert_allclose(val, 5.0, atol=1e-5)
558
-
559
- # Test gradient
560
- grad = jax.grad(f_jax)(params)
561
- assert isinstance(grad, jnp.ndarray)
562
- np.testing.assert_allclose(grad, [2.0, 4.0], atol=1e-5)
tests/test_keras.py DELETED
@@ -1,160 +0,0 @@
1
- import os
2
- import sys
3
- from functools import partial
4
-
5
- thisfile = os.path.abspath(__file__)
6
- modulepath = os.path.dirname(os.path.dirname(thisfile))
7
-
8
- sys.path.insert(0, modulepath)
9
-
10
- import numpy as np
11
- import tensorflow as tf
12
- import tensorcircuit as tc
13
-
14
-
15
- dtype = np.complex128
16
- tfdtype = tf.complex128
17
-
18
- ii = np.eye(4, dtype=dtype)
19
- iir = tf.constant(ii.reshape([2, 2, 2, 2]))
20
- zz = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], dtype=dtype)
21
- zzr = tf.constant(zz.reshape([2, 2, 2, 2]))
22
-
23
-
24
- def tfi_energy(c, j=1.0, h=-1.0):
25
- e = 0.0
26
- n = c._nqubits
27
- for i in range(n):
28
- e += h * c.expectation((tc.gates.x(), [i]))
29
- for i in range(n - 1): # OBC
30
- e += j * c.expectation((tc.gates.z(), [i]), (tc.gates.z(), [(i + 1) % n]))
31
- return e
32
-
33
-
34
- def vqe_f2(inputs, xweights, zzweights, nlayers, n):
35
- c = tc.Circuit(n)
36
- paramx = tf.cast(xweights, tfdtype)
37
- paramzz = tf.cast(zzweights, tfdtype)
38
- for i in range(n):
39
- c.H(i)
40
- for j in range(nlayers):
41
- for i in range(n - 1):
42
- c.any(
43
- i,
44
- i + 1,
45
- unitary=tf.math.cos(paramzz[j, i]) * iir
46
- + tf.math.sin(paramzz[j, i]) * 1.0j * zzr,
47
- )
48
- for i in range(n):
49
- c.rx(i, theta=paramx[j, i])
50
- e = tfi_energy(c)
51
- e = tf.math.real(e)
52
- return e
53
-
54
-
55
- def test_vqe_layer2(tfb, highp):
56
- vqe_fp = partial(vqe_f2, nlayers=3, n=6)
57
- vqe_layer = tc.KerasLayer(vqe_fp, [(3, 6), (3, 6)])
58
- inputs = np.zeros([1])
59
- with tf.GradientTape() as tape:
60
- e = vqe_layer(inputs)
61
- print(e, tape.gradient(e, vqe_layer.variables))
62
- model = tf.keras.Sequential([vqe_layer])
63
- model.compile(
64
- loss=tc.keras.output_asis_loss, optimizer=tf.keras.optimizers.Adam(0.01)
65
- )
66
- model.fit(np.zeros([1, 1]), np.zeros([1]), batch_size=1, epochs=300)
67
-
68
-
69
- def vqe_f(inputs, weights, nlayers, n):
70
- c = tc.Circuit(n)
71
- paramc = tf.cast(weights, tfdtype)
72
- for i in range(n):
73
- c.H(i)
74
- for j in range(nlayers):
75
- for i in range(n - 1):
76
- c.any(
77
- i,
78
- i + 1,
79
- unitary=tf.math.cos(paramc[2 * j, i]) * iir
80
- + tf.math.sin(paramc[2 * j, i]) * 1.0j * zzr,
81
- )
82
- for i in range(n):
83
- c.rx(i, theta=paramc[2 * j + 1, i])
84
- e = tfi_energy(c)
85
- e = tf.math.real(e)
86
- return e
87
-
88
-
89
- def test_vqe_layer(tfb, highp):
90
- vqe_fp = partial(vqe_f, nlayers=6, n=6)
91
- vqe_layer = tc.keras.QuantumLayer(vqe_fp, (6 * 2, 6))
92
- inputs = np.zeros([1])
93
- inputs = tf.constant(inputs)
94
- model = tf.keras.Sequential([vqe_layer])
95
-
96
- model.compile(
97
- loss=tc.keras.output_asis_loss, optimizer=tf.keras.optimizers.Adam(0.01)
98
- )
99
-
100
- model.fit(np.zeros([2, 1]), np.zeros([2, 1]), batch_size=2, epochs=500)
101
-
102
- np.testing.assert_allclose(model.predict(np.zeros([1])), -7.27, atol=5e-2)
103
-
104
-
105
- def test_function_io(tfb, tmp_path, highp):
106
- vqe_f_p = partial(vqe_f, inputs=tf.ones([1]))
107
-
108
- vqe_f_p = tf.function(vqe_f_p)
109
- vqe_f_p(weights=tf.ones([6, 6], dtype=tf.float64), nlayers=3, n=6)
110
- tc.keras.save_func(vqe_f_p, str(tmp_path))
111
- loaded = tc.keras.load_func(str(tmp_path), fallback=vqe_f_p)
112
- print(loaded(weights=tf.ones([6, 6], dtype=tf.float64), nlayers=3, n=6))
113
- print(loaded(weights=tf.ones([6, 6], dtype=tf.float64), nlayers=3, n=6))
114
-
115
-
116
- def test_keras_hardware(tfb):
117
- n = 2
118
-
119
- def qf(inputs, param):
120
- c = tc.Circuit(n)
121
- c.rx(0, theta=inputs[0])
122
- c.rx(1, theta=inputs[1])
123
- c.h(1)
124
- c.rzz(0, 1, theta=param[0])
125
- return tc.backend.stack([c.expectation_ps(z=[i]) for i in range(n)])
126
-
127
- ql = tc.keras.HardwareLayer(qf, [1], regularizer=tf.keras.regularizers.l2(1e-3))
128
- print(ql(tf.ones([1, 2])))
129
-
130
-
131
- def test_keras_layer_inputs_dict(tfb):
132
- # https://github.com/tensorflow/tensorflow/issues/65306
133
- # keras3 for tf2.16+ fails to accept complex valued input for keras layers
134
- # which is vital for quantum applications
135
- n = 3
136
- p = 0.1
137
- K = tc.backend
138
-
139
- def f(inputs, weights):
140
- state = inputs["state"]
141
- noise = inputs["noise"]
142
- c = tc.Circuit(n, inputs=state)
143
- for i in range(n):
144
- c.rz(i, theta=weights[i])
145
- for i in range(n):
146
- c.depolarizing(i, px=p, py=p, pz=p, status=noise[i])
147
- return K.real(c.expectation_ps(x=[0]))
148
-
149
- layer = tc.KerasLayer(f, [n])
150
- v = {"state": K.ones([1, 2**n]) / 2 ** (n / 2), "noise": 0.2 * K.ones([1, n])}
151
- with tf.GradientTape() as tape:
152
- l = layer(v)
153
- g1 = tape.gradient(l, layer.trainable_variables)
154
-
155
- v = {"state": K.ones([2**n]) / 2 ** (n / 2), "noise": 0.2 * K.ones([n])}
156
- with tf.GradientTape() as tape:
157
- l = layer(v)
158
- g2 = tape.gradient(l, layer.trainable_variables)
159
-
160
- np.testing.assert_allclose(g1[0], g2[0], atol=1e-5)