tensorcircuit-nightly 1.3.0.dev20250728__py3-none-any.whl → 1.4.0.dev20251103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

Files changed (72) hide show
  1. tensorcircuit/__init__.py +5 -1
  2. tensorcircuit/abstractcircuit.py +4 -0
  3. tensorcircuit/analogcircuit.py +413 -0
  4. tensorcircuit/applications/layers.py +1 -1
  5. tensorcircuit/applications/van.py +1 -1
  6. tensorcircuit/backends/abstract_backend.py +312 -5
  7. tensorcircuit/backends/cupy_backend.py +3 -1
  8. tensorcircuit/backends/jax_backend.py +92 -3
  9. tensorcircuit/backends/jax_ops.py +108 -0
  10. tensorcircuit/backends/numpy_backend.py +49 -3
  11. tensorcircuit/backends/pytorch_backend.py +92 -3
  12. tensorcircuit/backends/tensorflow_backend.py +102 -3
  13. tensorcircuit/basecircuit.py +123 -82
  14. tensorcircuit/circuit.py +67 -57
  15. tensorcircuit/cloud/local.py +1 -1
  16. tensorcircuit/cloud/quafu_provider.py +1 -1
  17. tensorcircuit/cloud/tencent.py +1 -1
  18. tensorcircuit/compiler/simple_compiler.py +2 -2
  19. tensorcircuit/cons.py +1 -0
  20. tensorcircuit/densitymatrix.py +16 -11
  21. tensorcircuit/experimental.py +7 -152
  22. tensorcircuit/fgs.py +5 -6
  23. tensorcircuit/gates.py +66 -22
  24. tensorcircuit/keras.py +3 -3
  25. tensorcircuit/mpscircuit.py +109 -61
  26. tensorcircuit/quantum.py +697 -133
  27. tensorcircuit/quditcircuit.py +733 -0
  28. tensorcircuit/quditgates.py +618 -0
  29. tensorcircuit/results/counts.py +45 -31
  30. tensorcircuit/shadows.py +1 -1
  31. tensorcircuit/simplify.py +3 -1
  32. tensorcircuit/stabilizercircuit.py +4 -2
  33. tensorcircuit/templates/blocks.py +2 -2
  34. tensorcircuit/templates/hamiltonians.py +29 -8
  35. tensorcircuit/templates/lattice.py +676 -335
  36. tensorcircuit/timeevol.py +896 -0
  37. {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/METADATA +50 -25
  38. tensorcircuit_nightly-1.4.0.dev20251103.dist-info/RECORD +96 -0
  39. {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/top_level.txt +0 -1
  40. tensorcircuit_nightly-1.3.0.dev20250728.dist-info/RECORD +0 -122
  41. tests/__init__.py +0 -0
  42. tests/conftest.py +0 -67
  43. tests/test_backends.py +0 -1035
  44. tests/test_calibrating.py +0 -149
  45. tests/test_channels.py +0 -409
  46. tests/test_circuit.py +0 -1713
  47. tests/test_cloud.py +0 -219
  48. tests/test_compiler.py +0 -147
  49. tests/test_dmcircuit.py +0 -555
  50. tests/test_ensemble.py +0 -72
  51. tests/test_fgs.py +0 -318
  52. tests/test_gates.py +0 -156
  53. tests/test_hamiltonians.py +0 -159
  54. tests/test_interfaces.py +0 -557
  55. tests/test_keras.py +0 -160
  56. tests/test_lattice.py +0 -1666
  57. tests/test_miscs.py +0 -334
  58. tests/test_mpscircuit.py +0 -341
  59. tests/test_noisemodel.py +0 -156
  60. tests/test_qaoa.py +0 -86
  61. tests/test_qem.py +0 -152
  62. tests/test_quantum.py +0 -549
  63. tests/test_quantum_attr.py +0 -42
  64. tests/test_results.py +0 -379
  65. tests/test_shadows.py +0 -160
  66. tests/test_simplify.py +0 -46
  67. tests/test_stabilizer.py +0 -226
  68. tests/test_templates.py +0 -218
  69. tests/test_torchnn.py +0 -99
  70. tests/test_van.py +0 -102
  71. {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/WHEEL +0 -0
  72. {tensorcircuit_nightly-1.3.0.dev20250728.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/licenses/LICENSE +0 -0
tests/test_interfaces.py DELETED
@@ -1,557 +0,0 @@
1
- # pylint: disable=invalid-name
2
-
3
- import os
4
- import sys
5
- from functools import partial
6
- import pytest
7
- from pytest_lazyfixture import lazy_fixture as lf
8
- from scipy import optimize
9
- import tensorflow as tf
10
- import jax
11
- from jax import numpy as jnp
12
-
13
- thisfile = os.path.abspath(__file__)
14
- modulepath = os.path.dirname(os.path.dirname(thisfile))
15
-
16
- sys.path.insert(0, modulepath)
17
-
18
- try:
19
- import torch
20
-
21
- is_torch = True
22
- except ImportError:
23
- is_torch = False
24
-
25
- import numpy as np
26
- import tensorcircuit as tc
27
-
28
-
29
- @pytest.mark.skipif(is_torch is False, reason="torch not installed")
30
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
31
- def test_torch_interface(backend):
32
- n = 4
33
-
34
- def f(param):
35
- c = tc.Circuit(n)
36
- c = tc.templates.blocks.example_block(c, param)
37
- loss = c.expectation(
38
- [
39
- tc.gates.x(),
40
- [
41
- 1,
42
- ],
43
- ]
44
- )
45
- return tc.backend.real(loss)
46
-
47
- f_jit = tc.backend.jit(f)
48
-
49
- f_jit_torch = tc.interfaces.torch_interface(f_jit, enable_dlpack=True)
50
-
51
- param = torch.ones([4, n], requires_grad=True)
52
- l = f_jit_torch(param)
53
- l = l**2
54
- l.backward()
55
-
56
- pg = param.grad
57
- np.testing.assert_allclose(pg.shape, [4, n])
58
- np.testing.assert_allclose(pg[0, 1], -2.146e-3, atol=1e-5)
59
-
60
- def f2(paramzz, paramx):
61
- c = tc.Circuit(n)
62
- for i in range(n):
63
- c.H(i)
64
- for j in range(2):
65
- for i in range(n - 1):
66
- c.exp1(i, i + 1, unitary=tc.gates._zz_matrix, theta=paramzz[j, i])
67
- for i in range(n):
68
- c.rx(i, theta=paramx[j, i])
69
- loss1 = c.expectation(
70
- [
71
- tc.gates.x(),
72
- [
73
- 1,
74
- ],
75
- ]
76
- )
77
- loss2 = c.expectation(
78
- [
79
- tc.gates.x(),
80
- [
81
- 2,
82
- ],
83
- ]
84
- )
85
- return tc.backend.real(loss1), tc.backend.real(loss2)
86
-
87
- f2_torch = tc.interfaces.torch_interface(f2, jit=True, enable_dlpack=True)
88
-
89
- paramzz = torch.ones([2, n], requires_grad=True)
90
- paramx = torch.ones([2, n], requires_grad=True)
91
-
92
- l1, l2 = f2_torch(paramzz, paramx)
93
- l = l1 - l2
94
- l.backward()
95
-
96
- pg = paramzz.grad
97
- np.testing.assert_allclose(pg.shape, [2, n])
98
- np.testing.assert_allclose(pg[0, 0], -0.41609, atol=1e-5)
99
-
100
- def f3(x):
101
- return tc.backend.real(x**2)
102
-
103
- f3_torch = tc.interfaces.torch_interface(f3)
104
- param3 = torch.ones([2], dtype=torch.complex64, requires_grad=True)
105
- l3 = f3_torch(param3)
106
- l3 = torch.sum(l3)
107
- l3.backward()
108
- pg = param3.grad
109
- np.testing.assert_allclose(pg, 2 * np.ones([2]).astype(np.complex64), atol=1e-5)
110
-
111
-
112
- @pytest.mark.skipif(is_torch is False, reason="torch not installed")
113
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
114
- def test_torch_interface_kws(backend):
115
- def f(param, n):
116
- c = tc.Circuit(n)
117
- c = tc.templates.blocks.example_block(c, param)
118
- loss = c.expectation(
119
- [
120
- tc.gates.x(),
121
- [
122
- 1,
123
- ],
124
- ]
125
- )
126
- return tc.backend.real(loss)
127
-
128
- f_jit_torch = tc.interfaces.torch_interface_kws(f, jit=True, enable_dlpack=True)
129
-
130
- param = torch.ones([4, 4], requires_grad=True)
131
- l = f_jit_torch(param, n=4)
132
- l = l**2
133
- l.backward()
134
-
135
- pg = param.grad
136
- np.testing.assert_allclose(pg.shape, [4, 4])
137
- np.testing.assert_allclose(pg[0, 1], -2.146e-3, atol=1e-5)
138
-
139
- def f2(paramzz, paramx, n, nlayer):
140
- c = tc.Circuit(n)
141
- for i in range(n):
142
- c.H(i)
143
- for j in range(nlayer): # 2
144
- for i in range(n - 1):
145
- c.exp1(i, i + 1, unitary=tc.gates._zz_matrix, theta=paramzz[j, i])
146
- for i in range(n):
147
- c.rx(i, theta=paramx[j, i])
148
- loss1 = c.expectation(
149
- [
150
- tc.gates.x(),
151
- [
152
- 1,
153
- ],
154
- ]
155
- )
156
- loss2 = c.expectation(
157
- [
158
- tc.gates.x(),
159
- [
160
- 2,
161
- ],
162
- ]
163
- )
164
- return tc.backend.real(loss1), tc.backend.real(loss2)
165
-
166
- f2_torch = tc.interfaces.torch_interface_kws(f2, jit=True, enable_dlpack=True)
167
-
168
- paramzz = torch.ones([2, 4], requires_grad=True)
169
- paramx = torch.ones([2, 4], requires_grad=True)
170
-
171
- l1, l2 = f2_torch(paramzz, paramx, n=4, nlayer=2)
172
- l = l1 - l2
173
- l.backward()
174
-
175
- pg = paramzz.grad
176
- np.testing.assert_allclose(pg.shape, [2, 4])
177
- np.testing.assert_allclose(pg[0, 0], -0.41609, atol=1e-5)
178
-
179
-
180
- @pytest.mark.skipif(is_torch is False, reason="torch not installed")
181
- @pytest.mark.xfail(
182
- (int(tf.__version__.split(".")[1]) < 9)
183
- or (int("".join(jax.__version__.split(".")[1:])) < 314),
184
- reason="version too low for tf or jax",
185
- )
186
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
187
- def test_torch_interface_dlpack_complex(backend):
188
- def f3(x):
189
- return tc.backend.real(x**2)
190
-
191
- f3_torch = tc.interfaces.torch_interface(f3, enable_dlpack=True)
192
- param3 = torch.ones([2], dtype=torch.complex64, requires_grad=True)
193
- l3 = f3_torch(param3)
194
- l3 = torch.sum(l3)
195
- l3.backward()
196
- pg = param3.grad
197
- np.testing.assert_allclose(pg, 2 * np.ones([2]).astype(np.complex64), atol=1e-5)
198
-
199
-
200
- @pytest.mark.skipif(is_torch is False, reason="torch not installed")
201
- @pytest.mark.xfail(reason="see comment link below")
202
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
203
- def test_torch_interface_pytree(backend):
204
- # pytree cannot support in pytorch autograd function...
205
- # https://github.com/pytorch/pytorch/issues/55509
206
- def f4(x):
207
- return tc.backend.sum(x["a"] ** 2), tc.backend.sum(x["b"] ** 3)
208
-
209
- f4_torch = tc.interfaces.torch_interface(f4, jit=False)
210
- param4 = {
211
- "a": torch.ones([2], requires_grad=True),
212
- "b": torch.ones([2], requires_grad=True),
213
- }
214
-
215
- def f4_post(x):
216
- r1, r2 = f4_torch(param4)
217
- l4 = r1 + r2
218
- return l4
219
-
220
- pg = tc.get_backend("pytorch").grad(f4_post)(param4)
221
- np.testing.assert_allclose(
222
- pg["a"], 2 * np.ones([2]).astype(np.complex64), atol=1e-5
223
- )
224
-
225
-
226
- @pytest.mark.parametrize("backend", [lf("jaxb")])
227
- def test_tf_interface(backend):
228
- def f0(params):
229
- c = tc.Circuit(1)
230
- c.rx(0, theta=params[0])
231
- c.ry(0, theta=params[1])
232
- return tc.backend.real(c.expectation([tc.gates.z(), [0]]))
233
-
234
- f = tc.interfaces.tf_interface(f0, ydtype=tf.float32, jit=True, enable_dlpack=True)
235
-
236
- tfb = tc.get_backend("tensorflow")
237
- grads = tfb.jit(tfb.grad(f))(tfb.ones([2], dtype="float32"))
238
- np.testing.assert_allclose(
239
- tfb.real(grads), np.array([-0.45464867, -0.45464873]), atol=1e-5
240
- )
241
-
242
- f = tc.interfaces.tf_interface(f0, ydtype="float32", jit=False)
243
-
244
- grads = tfb.grad(f)(tf.ones([2]))
245
- np.testing.assert_allclose(grads, np.array([-0.45464867, -0.45464873]), atol=1e-5)
246
-
247
-
248
- @pytest.mark.parametrize("backend", [lf("jaxb")])
249
- def test_tf_interface_2(backend):
250
- def f1(a, b):
251
- sa, sb = tc.backend.sum(a), tc.backend.sum(b)
252
- return sa + sb, sa - sb
253
-
254
- f = tc.interfaces.tf_interface(f1, ydtype=["float32", "float32"], jit=True)
255
-
256
- def f_post(a, b):
257
- p, m = f(a, b)
258
- return p + m
259
-
260
- tfb = tc.get_backend("tensorflow")
261
-
262
- grads = tfb.jit(tfb.grad(f_post))(
263
- tf.ones([2], dtype=tf.float32), tf.ones([2], dtype=tf.float32)
264
- )
265
-
266
- np.testing.assert_allclose(grads, 2 * np.ones([2]), atol=1e-5)
267
-
268
-
269
- @pytest.mark.parametrize("backend", [lf("jaxb")])
270
- def test_tf_interface_3(backend, highp):
271
- def f1(a, b):
272
- sa, sb = tc.backend.sum(a), tc.backend.sum(b)
273
- return sa + sb
274
-
275
- f = tc.interfaces.tf_interface(f1, ydtype="float64", jit=True)
276
-
277
- tfb = tc.get_backend("tensorflow")
278
-
279
- grads = tfb.jit(tfb.grad(f))(
280
- tf.ones([2], dtype=tf.float64), tf.ones([2], dtype=tf.float64)
281
- )
282
- np.testing.assert_allclose(grads, np.ones([2]), atol=1e-5)
283
-
284
-
285
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
286
- def test_scipy_interface(backend):
287
- n = 3
288
-
289
- def f(param):
290
- c = tc.Circuit(n)
291
- for i in range(n):
292
- c.rx(i, theta=param[0, i])
293
- c.rz(i, theta=param[1, i])
294
- loss = c.expectation(
295
- [
296
- tc.gates.y(),
297
- [
298
- 0,
299
- ],
300
- ]
301
- )
302
- return tc.backend.real(loss)
303
-
304
- if tc.backend.name != "numpy":
305
- f_scipy = tc.interfaces.scipy_optimize_interface(f, shape=[2, n])
306
- r = optimize.minimize(f_scipy, np.zeros([2 * n]), method="L-BFGS-B", jac=True)
307
- # L-BFGS-B may has issue with float32
308
- # see: https://github.com/scipy/scipy/issues/5832
309
- np.testing.assert_allclose(r["fun"], -1.0, atol=1e-5)
310
-
311
- f_scipy = tc.interfaces.scipy_optimize_interface(f, shape=[2, n], gradient=False)
312
- r = optimize.minimize(f_scipy, np.zeros([2 * n]), method="COBYLA")
313
- np.testing.assert_allclose(r["fun"], -1.0, atol=1e-5)
314
-
315
-
316
- @pytest.mark.parametrize("backend", [lf("torchb"), lf("tfb"), lf("jaxb")])
317
- def test_numpy_interface(backend):
318
- def f(params, n):
319
- c = tc.Circuit(n)
320
- for i in range(n):
321
- c.rx(i, theta=params[i])
322
- for i in range(n - 1):
323
- c.cnot(i, i + 1)
324
- r = tc.backend.real(c.expectation_ps(z=[n - 1]))
325
- return r
326
-
327
- n = 3
328
- f_np = tc.interfaces.numpy_interface(f, jit=False)
329
- r = f_np(np.ones([n]), n)
330
- np.testing.assert_allclose(r, 0.1577285, atol=1e-5)
331
-
332
-
333
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
334
- def test_args_transformation(backend):
335
- ans = tc.interfaces.general_args_to_numpy(
336
- (
337
- tc.backend.ones([2]),
338
- {
339
- "a": tc.get_backend("tensorflow").ones([]),
340
- "b": [tc.get_backend("numpy").zeros([2, 1])],
341
- },
342
- )
343
- )
344
- print(ans)
345
- np.testing.assert_allclose(ans[1]["b"][0], np.zeros([2, 1], dtype=np.complex64))
346
- ans1 = tc.interfaces.numpy_args_to_backend(
347
- ans, target_backend="jax", dtype="float32"
348
- )
349
- print(ans1[1]["a"].dtype)
350
- ans1 = tc.interfaces.numpy_args_to_backend(
351
- ans,
352
- target_backend="jax",
353
- dtype=("complex64", {"a": "float32", "b": ["complex64"]}),
354
- )
355
- print(ans1[1]["a"].dtype)
356
-
357
-
358
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
359
- def test_dlpack_transformation(backend):
360
- blist = ["tensorflow", "jax"]
361
- if is_torch is True:
362
- blist.append("pytorch")
363
- for b in blist:
364
- ans = tc.interfaces.general_args_to_backend(
365
- args=tc.backend.ones([2], dtype="float32"),
366
- target_backend=b,
367
- enable_dlpack=True,
368
- )
369
- ans = tc.interfaces.which_backend(ans).device_move(ans, "cpu")
370
- np.testing.assert_allclose(ans, np.ones([2]))
371
-
372
-
373
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
374
- def test_args_to_tensor(backend):
375
- @partial(
376
- tc.interfaces.args_to_tensor,
377
- argnums=[0, 1, 2],
378
- gate_to_tensor=True,
379
- qop_to_tensor=True,
380
- )
381
- def f(a, b, c, d):
382
- return a, b, c, d
383
-
384
- r = f(np.ones([2]), tc.backend.ones([1, 2]), {"a": [tf.zeros([3])]}, np.ones([2]))
385
- a = r[0]
386
- b = r[1]
387
- c = r[2]["a"][0]
388
- d = r[3]
389
- assert tc.interfaces.which_backend(a, return_backend=False) == tc.backend.name
390
- assert tc.interfaces.which_backend(b, return_backend=False) == tc.backend.name
391
- assert tc.interfaces.which_backend(c, return_backend=False) == tc.backend.name
392
- assert tc.interfaces.which_backend(d, return_backend=False) == "numpy"
393
- # print(f([np.ones([2]), np.ones([1])], {"a": np.ones([3])}))
394
- # print(f([tc.Gate(np.ones([2, 2])), tc.Gate(np.ones([2, 2, 2, 2]))], np.ones([2])))
395
-
396
- a, b, c, d = f(
397
- [tc.Gate(np.ones([2, 2])), tc.Gate(np.ones([2, 2, 2, 2]))],
398
- tc.QuOperator.from_tensor(np.ones([2, 2, 2, 2, 2, 2])),
399
- np.ones([2, 2, 2, 2]),
400
- tf.zeros([1, 2]),
401
- )
402
- assert tc.interfaces.which_backend(a[0], return_backend=False) == tc.backend.name
403
- assert tc.backend.shape_tuple(a[1]) == (4, 4)
404
- assert tc.interfaces.which_backend(b, return_backend=False) == tc.backend.name
405
- assert tc.interfaces.which_backend(d, return_backend=False) == "tensorflow"
406
- assert tc.backend.shape_tuple(b) == (8, 8)
407
- assert tc.backend.shape_tuple(c) == (2, 2, 2, 2)
408
-
409
- @partial(
410
- tc.interfaces.args_to_tensor,
411
- argnums=[0, 1, 2],
412
- tensor_as_matrix=False,
413
- gate_to_tensor=True,
414
- gate_as_matrix=False,
415
- qop_to_tensor=True,
416
- qop_as_matrix=False,
417
- )
418
- def g(a, b, c):
419
- return a, b, c
420
-
421
- a, b, c = g(
422
- [tc.Gate(np.ones([2, 2])), tc.Gate(np.ones([2, 2, 2, 2]))],
423
- tc.QuOperator.from_tensor(np.ones([2, 2, 2, 2, 2, 2])),
424
- np.ones([2, 2, 2, 2]),
425
- )
426
-
427
- assert tc.interfaces.which_backend(a[0], return_backend=False) == tc.backend.name
428
- assert tc.backend.shape_tuple(a[1]) == (2, 2, 2, 2)
429
- assert tc.backend.shape_tuple(b.eval()) == (2, 2, 2, 2, 2, 2)
430
- assert tc.backend.shape_tuple(c) == (2, 2, 2, 2)
431
-
432
-
433
- def test_jax_interface_basic(tfb):
434
- def f(params):
435
- c = tc.Circuit(1)
436
- c.rx(0, theta=params[0])
437
- c.ry(0, theta=params[1])
438
- return tc.backend.real(c.expectation_ps(z=[0]))
439
-
440
- f_jax = tc.interfaces.jax_interface(f, jit=True)
441
- params = jnp.ones(2)
442
-
443
- # Test forward pass
444
- val = f_jax(params)
445
- assert isinstance(val, jnp.ndarray)
446
- np.testing.assert_allclose(val, 0.291927, atol=1e-5)
447
-
448
- # Test gradient computation
449
- val, grad = jax.value_and_grad(f_jax)(params)
450
- assert isinstance(grad, jnp.ndarray)
451
- assert grad.shape == params.shape
452
-
453
-
454
- def test_jax_interface_multiple_inputs(tfb):
455
- def f(params1, params2):
456
- c = tc.Circuit(2)
457
- c.rx(0, theta=params1[0])
458
- c.ry(1, theta=params2[0])
459
- return tc.backend.real(c.expectation([tc.gates.z(), [0]]))
460
-
461
- f_jax = tc.interfaces.jax_interface(f, jit=False)
462
- p1 = jnp.array([1.0])
463
- p2 = jnp.array([2.0])
464
-
465
- # Test forward pass
466
- val = f_jax(p1, p2)
467
- assert isinstance(val, jnp.ndarray)
468
-
469
- # Test gradient computation
470
-
471
- val, (grad1, grad2) = jax.value_and_grad(f_jax, argnums=(0, 1))(p1, p2)
472
- assert isinstance(grad1, jnp.ndarray)
473
- assert isinstance(grad2, jnp.ndarray)
474
- assert grad1.shape == p1.shape
475
- assert grad2.shape == p2.shape
476
-
477
-
478
- @pytest.mark.skip(
479
- reason="might fail when testing with other function",
480
- )
481
- def test_jax_interface_jit_dlpack(tfb):
482
- def f(params):
483
- c = tc.Circuit(2)
484
- c.rx(range(2), theta=params)
485
- return tc.backend.real(c.expectation([tc.gates.z(), [0]]))
486
-
487
- # Test with JIT
488
- f_jax = tc.interfaces.jax_interface(f, jit=True, enable_dlpack=True)
489
- params = jnp.array([np.pi, np.pi], dtype=jnp.float32)
490
-
491
- # First call compiles
492
- val1 = f_jax(params)
493
- # Second call should be faster
494
- val2, gs = jax.value_and_grad(f_jax)(params)
495
-
496
- assert isinstance(val1, jnp.ndarray)
497
- assert isinstance(gs, jnp.ndarray)
498
- np.testing.assert_allclose(val1, val2, atol=1e-5)
499
-
500
-
501
- def test_jax_interface_pure_callback(tfb):
502
- def f(params):
503
- # Use TF operation to test pure_callback
504
- return tf.square(params)
505
-
506
- def f_jax1(params):
507
- return jnp.sum(tc.interfaces.jax_interface(f)(params))
508
-
509
- def f_jax2(params):
510
- return jnp.sum(
511
- tc.interfaces.jax_interface(
512
- f, jit=True, output_shape=[2], output_dtype=jnp.float32
513
- )(params)
514
- )
515
-
516
- params = jnp.array([1.0, 2.0])
517
-
518
- for f_jax in [f_jax1, f_jax2]:
519
- val = f_jax(params)
520
- assert isinstance(val, jnp.ndarray)
521
- np.testing.assert_allclose(val, 5.0, atol=1e-5)
522
-
523
- # Test gradient
524
- grad = jax.grad(f_jax)(params)
525
- assert isinstance(grad, jnp.ndarray)
526
- np.testing.assert_allclose(grad, [2.0, 4.0], atol=1e-5)
527
-
528
-
529
- def test_jax_interface_multiple_outputs(tfb):
530
- def f(params):
531
- # Use TF operation to test pure_callback
532
- return tf.square(params), params
533
-
534
- def f_jax1(params):
535
- r = tc.interfaces.jax_interface(f)(params)
536
- return jnp.sum(r[0] + r[1] ** 2) / 2
537
-
538
- def f_jax2(params):
539
- r = tc.interfaces.jax_interface(
540
- f,
541
- jit=True,
542
- output_shape=([2], [2]),
543
- output_dtype=(jnp.float32, jnp.float32),
544
- )(params)
545
- return jnp.sum(r[0] + r[1] ** 2) / 2
546
-
547
- params = jnp.array([1.0, 2.0])
548
-
549
- for f_jax in [f_jax1, f_jax2]:
550
- val = f_jax(params)
551
- assert isinstance(val, jnp.ndarray)
552
- np.testing.assert_allclose(val, 5.0, atol=1e-5)
553
-
554
- # Test gradient
555
- grad = jax.grad(f_jax)(params)
556
- assert isinstance(grad, jnp.ndarray)
557
- np.testing.assert_allclose(grad, [2.0, 4.0], atol=1e-5)
tests/test_keras.py DELETED
@@ -1,160 +0,0 @@
1
- import os
2
- import sys
3
- from functools import partial
4
-
5
- thisfile = os.path.abspath(__file__)
6
- modulepath = os.path.dirname(os.path.dirname(thisfile))
7
-
8
- sys.path.insert(0, modulepath)
9
-
10
- import numpy as np
11
- import tensorflow as tf
12
- import tensorcircuit as tc
13
-
14
-
15
- dtype = np.complex128
16
- tfdtype = tf.complex128
17
-
18
- ii = np.eye(4, dtype=dtype)
19
- iir = tf.constant(ii.reshape([2, 2, 2, 2]))
20
- zz = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], dtype=dtype)
21
- zzr = tf.constant(zz.reshape([2, 2, 2, 2]))
22
-
23
-
24
- def tfi_energy(c, j=1.0, h=-1.0):
25
- e = 0.0
26
- n = c._nqubits
27
- for i in range(n):
28
- e += h * c.expectation((tc.gates.x(), [i]))
29
- for i in range(n - 1): # OBC
30
- e += j * c.expectation((tc.gates.z(), [i]), (tc.gates.z(), [(i + 1) % n]))
31
- return e
32
-
33
-
34
- def vqe_f2(inputs, xweights, zzweights, nlayers, n):
35
- c = tc.Circuit(n)
36
- paramx = tf.cast(xweights, tfdtype)
37
- paramzz = tf.cast(zzweights, tfdtype)
38
- for i in range(n):
39
- c.H(i)
40
- for j in range(nlayers):
41
- for i in range(n - 1):
42
- c.any(
43
- i,
44
- i + 1,
45
- unitary=tf.math.cos(paramzz[j, i]) * iir
46
- + tf.math.sin(paramzz[j, i]) * 1.0j * zzr,
47
- )
48
- for i in range(n):
49
- c.rx(i, theta=paramx[j, i])
50
- e = tfi_energy(c)
51
- e = tf.math.real(e)
52
- return e
53
-
54
-
55
- def test_vqe_layer2(tfb, highp):
56
- vqe_fp = partial(vqe_f2, nlayers=3, n=6)
57
- vqe_layer = tc.KerasLayer(vqe_fp, [(3, 6), (3, 6)])
58
- inputs = np.zeros([1])
59
- with tf.GradientTape() as tape:
60
- e = vqe_layer(inputs)
61
- print(e, tape.gradient(e, vqe_layer.variables))
62
- model = tf.keras.Sequential([vqe_layer])
63
- model.compile(
64
- loss=tc.keras.output_asis_loss, optimizer=tf.keras.optimizers.Adam(0.01)
65
- )
66
- model.fit(np.zeros([1, 1]), np.zeros([1]), batch_size=1, epochs=300)
67
-
68
-
69
- def vqe_f(inputs, weights, nlayers, n):
70
- c = tc.Circuit(n)
71
- paramc = tf.cast(weights, tfdtype)
72
- for i in range(n):
73
- c.H(i)
74
- for j in range(nlayers):
75
- for i in range(n - 1):
76
- c.any(
77
- i,
78
- i + 1,
79
- unitary=tf.math.cos(paramc[2 * j, i]) * iir
80
- + tf.math.sin(paramc[2 * j, i]) * 1.0j * zzr,
81
- )
82
- for i in range(n):
83
- c.rx(i, theta=paramc[2 * j + 1, i])
84
- e = tfi_energy(c)
85
- e = tf.math.real(e)
86
- return e
87
-
88
-
89
- def test_vqe_layer(tfb, highp):
90
- vqe_fp = partial(vqe_f, nlayers=6, n=6)
91
- vqe_layer = tc.keras.QuantumLayer(vqe_fp, (6 * 2, 6))
92
- inputs = np.zeros([1])
93
- inputs = tf.constant(inputs)
94
- model = tf.keras.Sequential([vqe_layer])
95
-
96
- model.compile(
97
- loss=tc.keras.output_asis_loss, optimizer=tf.keras.optimizers.Adam(0.01)
98
- )
99
-
100
- model.fit(np.zeros([2, 1]), np.zeros([2, 1]), batch_size=2, epochs=500)
101
-
102
- np.testing.assert_allclose(model.predict(np.zeros([1])), -7.27, atol=5e-2)
103
-
104
-
105
- def test_function_io(tfb, tmp_path, highp):
106
- vqe_f_p = partial(vqe_f, inputs=tf.ones([1]))
107
-
108
- vqe_f_p = tf.function(vqe_f_p)
109
- vqe_f_p(weights=tf.ones([6, 6], dtype=tf.float64), nlayers=3, n=6)
110
- tc.keras.save_func(vqe_f_p, str(tmp_path))
111
- loaded = tc.keras.load_func(str(tmp_path), fallback=vqe_f_p)
112
- print(loaded(weights=tf.ones([6, 6], dtype=tf.float64), nlayers=3, n=6))
113
- print(loaded(weights=tf.ones([6, 6], dtype=tf.float64), nlayers=3, n=6))
114
-
115
-
116
- def test_keras_hardware(tfb):
117
- n = 2
118
-
119
- def qf(inputs, param):
120
- c = tc.Circuit(n)
121
- c.rx(0, theta=inputs[0])
122
- c.rx(1, theta=inputs[1])
123
- c.h(1)
124
- c.rzz(0, 1, theta=param[0])
125
- return tc.backend.stack([c.expectation_ps(z=[i]) for i in range(n)])
126
-
127
- ql = tc.keras.HardwareLayer(qf, [1], regularizer=tf.keras.regularizers.l2(1e-3))
128
- print(ql(tf.ones([1, 2])))
129
-
130
-
131
- def test_keras_layer_inputs_dict(tfb):
132
- # https://github.com/tensorflow/tensorflow/issues/65306
133
- # keras3 for tf2.16+ fails to accept complex valued input for keras layers
134
- # which is vital for quantum applications
135
- n = 3
136
- p = 0.1
137
- K = tc.backend
138
-
139
- def f(inputs, weights):
140
- state = inputs["state"]
141
- noise = inputs["noise"]
142
- c = tc.Circuit(n, inputs=state)
143
- for i in range(n):
144
- c.rz(i, theta=weights[i])
145
- for i in range(n):
146
- c.depolarizing(i, px=p, py=p, pz=p, status=noise[i])
147
- return K.real(c.expectation_ps(x=[0]))
148
-
149
- layer = tc.KerasLayer(f, [n])
150
- v = {"state": K.ones([1, 2**n]) / 2 ** (n / 2), "noise": 0.2 * K.ones([1, n])}
151
- with tf.GradientTape() as tape:
152
- l = layer(v)
153
- g1 = tape.gradient(l, layer.trainable_variables)
154
-
155
- v = {"state": K.ones([2**n]) / 2 ** (n / 2), "noise": 0.2 * K.ones([n])}
156
- with tf.GradientTape() as tape:
157
- l = layer(v)
158
- g2 = tape.gradient(l, layer.trainable_variables)
159
-
160
- np.testing.assert_allclose(g1[0], g2[0], atol=1e-5)