tensorcircuit-nightly 1.0.2.dev20250108__py3-none-any.whl → 1.4.0.dev20251103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

Files changed (76) hide show
  1. tensorcircuit/__init__.py +18 -2
  2. tensorcircuit/about.py +46 -0
  3. tensorcircuit/abstractcircuit.py +4 -0
  4. tensorcircuit/analogcircuit.py +413 -0
  5. tensorcircuit/applications/layers.py +1 -1
  6. tensorcircuit/applications/van.py +1 -1
  7. tensorcircuit/backends/abstract_backend.py +320 -7
  8. tensorcircuit/backends/cupy_backend.py +3 -1
  9. tensorcircuit/backends/jax_backend.py +102 -4
  10. tensorcircuit/backends/jax_ops.py +110 -1
  11. tensorcircuit/backends/numpy_backend.py +49 -3
  12. tensorcircuit/backends/pytorch_backend.py +92 -3
  13. tensorcircuit/backends/tensorflow_backend.py +102 -3
  14. tensorcircuit/basecircuit.py +157 -98
  15. tensorcircuit/circuit.py +115 -57
  16. tensorcircuit/cloud/local.py +1 -1
  17. tensorcircuit/cloud/quafu_provider.py +1 -1
  18. tensorcircuit/cloud/tencent.py +1 -1
  19. tensorcircuit/compiler/simple_compiler.py +2 -2
  20. tensorcircuit/cons.py +142 -21
  21. tensorcircuit/densitymatrix.py +43 -14
  22. tensorcircuit/experimental.py +387 -129
  23. tensorcircuit/fgs.py +282 -81
  24. tensorcircuit/gates.py +66 -22
  25. tensorcircuit/interfaces/__init__.py +1 -3
  26. tensorcircuit/interfaces/jax.py +189 -0
  27. tensorcircuit/keras.py +3 -3
  28. tensorcircuit/mpscircuit.py +154 -65
  29. tensorcircuit/quantum.py +868 -152
  30. tensorcircuit/quditcircuit.py +733 -0
  31. tensorcircuit/quditgates.py +618 -0
  32. tensorcircuit/results/counts.py +147 -20
  33. tensorcircuit/results/readout_mitigation.py +4 -1
  34. tensorcircuit/shadows.py +1 -1
  35. tensorcircuit/simplify.py +3 -1
  36. tensorcircuit/stabilizercircuit.py +479 -0
  37. tensorcircuit/templates/__init__.py +2 -0
  38. tensorcircuit/templates/blocks.py +2 -2
  39. tensorcircuit/templates/hamiltonians.py +174 -0
  40. tensorcircuit/templates/lattice.py +1789 -0
  41. tensorcircuit/timeevol.py +896 -0
  42. tensorcircuit/translation.py +10 -3
  43. tensorcircuit/utils.py +7 -0
  44. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/METADATA +73 -23
  45. tensorcircuit_nightly-1.4.0.dev20251103.dist-info/RECORD +96 -0
  46. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/WHEEL +1 -1
  47. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/top_level.txt +0 -1
  48. tensorcircuit_nightly-1.0.2.dev20250108.dist-info/RECORD +0 -115
  49. tests/__init__.py +0 -0
  50. tests/conftest.py +0 -67
  51. tests/test_backends.py +0 -1031
  52. tests/test_calibrating.py +0 -149
  53. tests/test_channels.py +0 -365
  54. tests/test_circuit.py +0 -1699
  55. tests/test_cloud.py +0 -219
  56. tests/test_compiler.py +0 -147
  57. tests/test_dmcircuit.py +0 -555
  58. tests/test_ensemble.py +0 -72
  59. tests/test_fgs.py +0 -310
  60. tests/test_gates.py +0 -156
  61. tests/test_interfaces.py +0 -429
  62. tests/test_keras.py +0 -160
  63. tests/test_miscs.py +0 -277
  64. tests/test_mpscircuit.py +0 -341
  65. tests/test_noisemodel.py +0 -156
  66. tests/test_qaoa.py +0 -86
  67. tests/test_qem.py +0 -152
  68. tests/test_quantum.py +0 -526
  69. tests/test_quantum_attr.py +0 -42
  70. tests/test_results.py +0 -347
  71. tests/test_shadows.py +0 -160
  72. tests/test_simplify.py +0 -46
  73. tests/test_templates.py +0 -218
  74. tests/test_torchnn.py +0 -99
  75. tests/test_van.py +0 -102
  76. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info/licenses}/LICENSE +0 -0
tests/test_interfaces.py DELETED
@@ -1,429 +0,0 @@
1
- # pylint: disable=invalid-name
2
-
3
- import os
4
- import sys
5
- from functools import partial
6
- import pytest
7
- from pytest_lazyfixture import lazy_fixture as lf
8
- from scipy import optimize
9
- import tensorflow as tf
10
- import jax
11
-
12
- thisfile = os.path.abspath(__file__)
13
- modulepath = os.path.dirname(os.path.dirname(thisfile))
14
-
15
- sys.path.insert(0, modulepath)
16
-
17
- try:
18
- import torch
19
-
20
- is_torch = True
21
- except ImportError:
22
- is_torch = False
23
-
24
- import numpy as np
25
- import tensorcircuit as tc
26
-
27
-
28
- @pytest.mark.skipif(is_torch is False, reason="torch not installed")
29
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
30
- def test_torch_interface(backend):
31
- n = 4
32
-
33
- def f(param):
34
- c = tc.Circuit(n)
35
- c = tc.templates.blocks.example_block(c, param)
36
- loss = c.expectation(
37
- [
38
- tc.gates.x(),
39
- [
40
- 1,
41
- ],
42
- ]
43
- )
44
- return tc.backend.real(loss)
45
-
46
- f_jit = tc.backend.jit(f)
47
-
48
- f_jit_torch = tc.interfaces.torch_interface(f_jit, enable_dlpack=True)
49
-
50
- param = torch.ones([4, n], requires_grad=True)
51
- l = f_jit_torch(param)
52
- l = l**2
53
- l.backward()
54
-
55
- pg = param.grad
56
- np.testing.assert_allclose(pg.shape, [4, n])
57
- np.testing.assert_allclose(pg[0, 1], -2.146e-3, atol=1e-5)
58
-
59
- def f2(paramzz, paramx):
60
- c = tc.Circuit(n)
61
- for i in range(n):
62
- c.H(i)
63
- for j in range(2):
64
- for i in range(n - 1):
65
- c.exp1(i, i + 1, unitary=tc.gates._zz_matrix, theta=paramzz[j, i])
66
- for i in range(n):
67
- c.rx(i, theta=paramx[j, i])
68
- loss1 = c.expectation(
69
- [
70
- tc.gates.x(),
71
- [
72
- 1,
73
- ],
74
- ]
75
- )
76
- loss2 = c.expectation(
77
- [
78
- tc.gates.x(),
79
- [
80
- 2,
81
- ],
82
- ]
83
- )
84
- return tc.backend.real(loss1), tc.backend.real(loss2)
85
-
86
- f2_torch = tc.interfaces.torch_interface(f2, jit=True, enable_dlpack=True)
87
-
88
- paramzz = torch.ones([2, n], requires_grad=True)
89
- paramx = torch.ones([2, n], requires_grad=True)
90
-
91
- l1, l2 = f2_torch(paramzz, paramx)
92
- l = l1 - l2
93
- l.backward()
94
-
95
- pg = paramzz.grad
96
- np.testing.assert_allclose(pg.shape, [2, n])
97
- np.testing.assert_allclose(pg[0, 0], -0.41609, atol=1e-5)
98
-
99
- def f3(x):
100
- return tc.backend.real(x**2)
101
-
102
- f3_torch = tc.interfaces.torch_interface(f3)
103
- param3 = torch.ones([2], dtype=torch.complex64, requires_grad=True)
104
- l3 = f3_torch(param3)
105
- l3 = torch.sum(l3)
106
- l3.backward()
107
- pg = param3.grad
108
- np.testing.assert_allclose(pg, 2 * np.ones([2]).astype(np.complex64), atol=1e-5)
109
-
110
-
111
- @pytest.mark.skipif(is_torch is False, reason="torch not installed")
112
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
113
- def test_torch_interface_kws(backend):
114
- def f(param, n):
115
- c = tc.Circuit(n)
116
- c = tc.templates.blocks.example_block(c, param)
117
- loss = c.expectation(
118
- [
119
- tc.gates.x(),
120
- [
121
- 1,
122
- ],
123
- ]
124
- )
125
- return tc.backend.real(loss)
126
-
127
- f_jit_torch = tc.interfaces.torch_interface_kws(f, jit=True, enable_dlpack=True)
128
-
129
- param = torch.ones([4, 4], requires_grad=True)
130
- l = f_jit_torch(param, n=4)
131
- l = l**2
132
- l.backward()
133
-
134
- pg = param.grad
135
- np.testing.assert_allclose(pg.shape, [4, 4])
136
- np.testing.assert_allclose(pg[0, 1], -2.146e-3, atol=1e-5)
137
-
138
- def f2(paramzz, paramx, n, nlayer):
139
- c = tc.Circuit(n)
140
- for i in range(n):
141
- c.H(i)
142
- for j in range(nlayer): # 2
143
- for i in range(n - 1):
144
- c.exp1(i, i + 1, unitary=tc.gates._zz_matrix, theta=paramzz[j, i])
145
- for i in range(n):
146
- c.rx(i, theta=paramx[j, i])
147
- loss1 = c.expectation(
148
- [
149
- tc.gates.x(),
150
- [
151
- 1,
152
- ],
153
- ]
154
- )
155
- loss2 = c.expectation(
156
- [
157
- tc.gates.x(),
158
- [
159
- 2,
160
- ],
161
- ]
162
- )
163
- return tc.backend.real(loss1), tc.backend.real(loss2)
164
-
165
- f2_torch = tc.interfaces.torch_interface_kws(f2, jit=True, enable_dlpack=True)
166
-
167
- paramzz = torch.ones([2, 4], requires_grad=True)
168
- paramx = torch.ones([2, 4], requires_grad=True)
169
-
170
- l1, l2 = f2_torch(paramzz, paramx, n=4, nlayer=2)
171
- l = l1 - l2
172
- l.backward()
173
-
174
- pg = paramzz.grad
175
- np.testing.assert_allclose(pg.shape, [2, 4])
176
- np.testing.assert_allclose(pg[0, 0], -0.41609, atol=1e-5)
177
-
178
-
179
- @pytest.mark.skipif(is_torch is False, reason="torch not installed")
180
- @pytest.mark.xfail(
181
- (int(tf.__version__.split(".")[1]) < 9)
182
- or (int("".join(jax.__version__.split(".")[1:])) < 314),
183
- reason="version too low for tf or jax",
184
- )
185
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
186
- def test_torch_interface_dlpack_complex(backend):
187
- def f3(x):
188
- return tc.backend.real(x**2)
189
-
190
- f3_torch = tc.interfaces.torch_interface(f3, enable_dlpack=True)
191
- param3 = torch.ones([2], dtype=torch.complex64, requires_grad=True)
192
- l3 = f3_torch(param3)
193
- l3 = torch.sum(l3)
194
- l3.backward()
195
- pg = param3.grad
196
- np.testing.assert_allclose(pg, 2 * np.ones([2]).astype(np.complex64), atol=1e-5)
197
-
198
-
199
- @pytest.mark.skipif(is_torch is False, reason="torch not installed")
200
- @pytest.mark.xfail(reason="see comment link below")
201
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
202
- def test_torch_interface_pytree(backend):
203
- # pytree cannot support in pytorch autograd function...
204
- # https://github.com/pytorch/pytorch/issues/55509
205
- def f4(x):
206
- return tc.backend.sum(x["a"] ** 2), tc.backend.sum(x["b"] ** 3)
207
-
208
- f4_torch = tc.interfaces.torch_interface(f4, jit=False)
209
- param4 = {
210
- "a": torch.ones([2], requires_grad=True),
211
- "b": torch.ones([2], requires_grad=True),
212
- }
213
-
214
- def f4_post(x):
215
- r1, r2 = f4_torch(param4)
216
- l4 = r1 + r2
217
- return l4
218
-
219
- pg = tc.get_backend("pytorch").grad(f4_post)(param4)
220
- np.testing.assert_allclose(
221
- pg["a"], 2 * np.ones([2]).astype(np.complex64), atol=1e-5
222
- )
223
-
224
-
225
- @pytest.mark.parametrize("backend", [lf("jaxb")])
226
- def test_tf_interface(backend):
227
- def f0(params):
228
- c = tc.Circuit(1)
229
- c.rx(0, theta=params[0])
230
- c.ry(0, theta=params[1])
231
- return tc.backend.real(c.expectation([tc.gates.z(), [0]]))
232
-
233
- f = tc.interfaces.tf_interface(f0, ydtype=tf.float32, jit=True, enable_dlpack=True)
234
-
235
- tfb = tc.get_backend("tensorflow")
236
- grads = tfb.jit(tfb.grad(f))(tfb.ones([2], dtype="float32"))
237
- np.testing.assert_allclose(
238
- tfb.real(grads), np.array([-0.45464867, -0.45464873]), atol=1e-5
239
- )
240
-
241
- f = tc.interfaces.tf_interface(f0, ydtype="float32", jit=False)
242
-
243
- grads = tfb.grad(f)(tf.ones([2]))
244
- np.testing.assert_allclose(grads, np.array([-0.45464867, -0.45464873]), atol=1e-5)
245
-
246
-
247
- @pytest.mark.parametrize("backend", [lf("jaxb")])
248
- def test_tf_interface_2(backend):
249
- def f1(a, b):
250
- sa, sb = tc.backend.sum(a), tc.backend.sum(b)
251
- return sa + sb, sa - sb
252
-
253
- f = tc.interfaces.tf_interface(f1, ydtype=["float32", "float32"], jit=True)
254
-
255
- def f_post(a, b):
256
- p, m = f(a, b)
257
- return p + m
258
-
259
- tfb = tc.get_backend("tensorflow")
260
-
261
- grads = tfb.jit(tfb.grad(f_post))(
262
- tf.ones([2], dtype=tf.float32), tf.ones([2], dtype=tf.float32)
263
- )
264
-
265
- np.testing.assert_allclose(grads, 2 * np.ones([2]), atol=1e-5)
266
-
267
-
268
- @pytest.mark.parametrize("backend", [lf("jaxb")])
269
- def test_tf_interface_3(backend, highp):
270
- def f1(a, b):
271
- sa, sb = tc.backend.sum(a), tc.backend.sum(b)
272
- return sa + sb
273
-
274
- f = tc.interfaces.tf_interface(f1, ydtype="float64", jit=True)
275
-
276
- tfb = tc.get_backend("tensorflow")
277
-
278
- grads = tfb.jit(tfb.grad(f))(
279
- tf.ones([2], dtype=tf.float64), tf.ones([2], dtype=tf.float64)
280
- )
281
- np.testing.assert_allclose(grads, np.ones([2]), atol=1e-5)
282
-
283
-
284
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
285
- def test_scipy_interface(backend):
286
- n = 3
287
-
288
- def f(param):
289
- c = tc.Circuit(n)
290
- for i in range(n):
291
- c.rx(i, theta=param[0, i])
292
- c.rz(i, theta=param[1, i])
293
- loss = c.expectation(
294
- [
295
- tc.gates.y(),
296
- [
297
- 0,
298
- ],
299
- ]
300
- )
301
- return tc.backend.real(loss)
302
-
303
- if tc.backend.name != "numpy":
304
- f_scipy = tc.interfaces.scipy_optimize_interface(f, shape=[2, n])
305
- r = optimize.minimize(f_scipy, np.zeros([2 * n]), method="L-BFGS-B", jac=True)
306
- # L-BFGS-B may has issue with float32
307
- # see: https://github.com/scipy/scipy/issues/5832
308
- np.testing.assert_allclose(r["fun"], -1.0, atol=1e-5)
309
-
310
- f_scipy = tc.interfaces.scipy_optimize_interface(f, shape=[2, n], gradient=False)
311
- r = optimize.minimize(f_scipy, np.zeros([2 * n]), method="COBYLA")
312
- np.testing.assert_allclose(r["fun"], -1.0, atol=1e-5)
313
-
314
-
315
- @pytest.mark.parametrize("backend", [lf("torchb"), lf("tfb"), lf("jaxb")])
316
- def test_numpy_interface(backend):
317
- def f(params, n):
318
- c = tc.Circuit(n)
319
- for i in range(n):
320
- c.rx(i, theta=params[i])
321
- for i in range(n - 1):
322
- c.cnot(i, i + 1)
323
- r = tc.backend.real(c.expectation_ps(z=[n - 1]))
324
- return r
325
-
326
- n = 3
327
- f_np = tc.interfaces.numpy_interface(f, jit=False)
328
- r = f_np(np.ones([n]), n)
329
- np.testing.assert_allclose(r, 0.1577285, atol=1e-5)
330
-
331
-
332
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
333
- def test_args_transformation(backend):
334
- ans = tc.interfaces.general_args_to_numpy(
335
- (
336
- tc.backend.ones([2]),
337
- {
338
- "a": tc.get_backend("tensorflow").ones([]),
339
- "b": [tc.get_backend("numpy").zeros([2, 1])],
340
- },
341
- )
342
- )
343
- print(ans)
344
- np.testing.assert_allclose(ans[1]["b"][0], np.zeros([2, 1], dtype=np.complex64))
345
- ans1 = tc.interfaces.numpy_args_to_backend(
346
- ans, target_backend="jax", dtype="float32"
347
- )
348
- print(ans1[1]["a"].dtype)
349
- ans1 = tc.interfaces.numpy_args_to_backend(
350
- ans,
351
- target_backend="jax",
352
- dtype=("complex64", {"a": "float32", "b": ["complex64"]}),
353
- )
354
- print(ans1[1]["a"].dtype)
355
-
356
-
357
- @pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb"), lf("torchb")])
358
- def test_dlpack_transformation(backend):
359
- blist = ["tensorflow", "jax"]
360
- if is_torch is True:
361
- blist.append("pytorch")
362
- for b in blist:
363
- ans = tc.interfaces.general_args_to_backend(
364
- args=tc.backend.ones([2], dtype="float32"),
365
- target_backend=b,
366
- enable_dlpack=True,
367
- )
368
- ans = tc.interfaces.which_backend(ans).device_move(ans, "cpu")
369
- np.testing.assert_allclose(ans, np.ones([2]))
370
-
371
-
372
- @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
373
- def test_args_to_tensor(backend):
374
- @partial(
375
- tc.interfaces.args_to_tensor,
376
- argnums=[0, 1, 2],
377
- gate_to_tensor=True,
378
- qop_to_tensor=True,
379
- )
380
- def f(a, b, c, d):
381
- return a, b, c, d
382
-
383
- r = f(np.ones([2]), tc.backend.ones([1, 2]), {"a": [tf.zeros([3])]}, np.ones([2]))
384
- a = r[0]
385
- b = r[1]
386
- c = r[2]["a"][0]
387
- d = r[3]
388
- assert tc.interfaces.which_backend(a, return_backend=False) == tc.backend.name
389
- assert tc.interfaces.which_backend(b, return_backend=False) == tc.backend.name
390
- assert tc.interfaces.which_backend(c, return_backend=False) == tc.backend.name
391
- assert tc.interfaces.which_backend(d, return_backend=False) == "numpy"
392
- # print(f([np.ones([2]), np.ones([1])], {"a": np.ones([3])}))
393
- # print(f([tc.Gate(np.ones([2, 2])), tc.Gate(np.ones([2, 2, 2, 2]))], np.ones([2])))
394
-
395
- a, b, c, d = f(
396
- [tc.Gate(np.ones([2, 2])), tc.Gate(np.ones([2, 2, 2, 2]))],
397
- tc.QuOperator.from_tensor(np.ones([2, 2, 2, 2, 2, 2])),
398
- np.ones([2, 2, 2, 2]),
399
- tf.zeros([1, 2]),
400
- )
401
- assert tc.interfaces.which_backend(a[0], return_backend=False) == tc.backend.name
402
- assert tc.backend.shape_tuple(a[1]) == (4, 4)
403
- assert tc.interfaces.which_backend(b, return_backend=False) == tc.backend.name
404
- assert tc.interfaces.which_backend(d, return_backend=False) == "tensorflow"
405
- assert tc.backend.shape_tuple(b) == (8, 8)
406
- assert tc.backend.shape_tuple(c) == (2, 2, 2, 2)
407
-
408
- @partial(
409
- tc.interfaces.args_to_tensor,
410
- argnums=[0, 1, 2],
411
- tensor_as_matrix=False,
412
- gate_to_tensor=True,
413
- gate_as_matrix=False,
414
- qop_to_tensor=True,
415
- qop_as_matrix=False,
416
- )
417
- def g(a, b, c):
418
- return a, b, c
419
-
420
- a, b, c = g(
421
- [tc.Gate(np.ones([2, 2])), tc.Gate(np.ones([2, 2, 2, 2]))],
422
- tc.QuOperator.from_tensor(np.ones([2, 2, 2, 2, 2, 2])),
423
- np.ones([2, 2, 2, 2]),
424
- )
425
-
426
- assert tc.interfaces.which_backend(a[0], return_backend=False) == tc.backend.name
427
- assert tc.backend.shape_tuple(a[1]) == (2, 2, 2, 2)
428
- assert tc.backend.shape_tuple(b.eval()) == (2, 2, 2, 2, 2, 2)
429
- assert tc.backend.shape_tuple(c) == (2, 2, 2, 2)
tests/test_keras.py DELETED
@@ -1,160 +0,0 @@
1
- import os
2
- import sys
3
- from functools import partial
4
-
5
- thisfile = os.path.abspath(__file__)
6
- modulepath = os.path.dirname(os.path.dirname(thisfile))
7
-
8
- sys.path.insert(0, modulepath)
9
-
10
- import numpy as np
11
- import tensorflow as tf
12
- import tensorcircuit as tc
13
-
14
-
15
- dtype = np.complex128
16
- tfdtype = tf.complex128
17
-
18
- ii = np.eye(4, dtype=dtype)
19
- iir = tf.constant(ii.reshape([2, 2, 2, 2]))
20
- zz = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], dtype=dtype)
21
- zzr = tf.constant(zz.reshape([2, 2, 2, 2]))
22
-
23
-
24
- def tfi_energy(c, j=1.0, h=-1.0):
25
- e = 0.0
26
- n = c._nqubits
27
- for i in range(n):
28
- e += h * c.expectation((tc.gates.x(), [i]))
29
- for i in range(n - 1): # OBC
30
- e += j * c.expectation((tc.gates.z(), [i]), (tc.gates.z(), [(i + 1) % n]))
31
- return e
32
-
33
-
34
- def vqe_f2(inputs, xweights, zzweights, nlayers, n):
35
- c = tc.Circuit(n)
36
- paramx = tf.cast(xweights, tfdtype)
37
- paramzz = tf.cast(zzweights, tfdtype)
38
- for i in range(n):
39
- c.H(i)
40
- for j in range(nlayers):
41
- for i in range(n - 1):
42
- c.any(
43
- i,
44
- i + 1,
45
- unitary=tf.math.cos(paramzz[j, i]) * iir
46
- + tf.math.sin(paramzz[j, i]) * 1.0j * zzr,
47
- )
48
- for i in range(n):
49
- c.rx(i, theta=paramx[j, i])
50
- e = tfi_energy(c)
51
- e = tf.math.real(e)
52
- return e
53
-
54
-
55
- def test_vqe_layer2(tfb, highp):
56
- vqe_fp = partial(vqe_f2, nlayers=3, n=6)
57
- vqe_layer = tc.KerasLayer(vqe_fp, [(3, 6), (3, 6)])
58
- inputs = np.zeros([1])
59
- with tf.GradientTape() as tape:
60
- e = vqe_layer(inputs)
61
- print(e, tape.gradient(e, vqe_layer.variables))
62
- model = tf.keras.Sequential([vqe_layer])
63
- model.compile(
64
- loss=tc.keras.output_asis_loss, optimizer=tf.keras.optimizers.Adam(0.01)
65
- )
66
- model.fit(np.zeros([1, 1]), np.zeros([1]), batch_size=1, epochs=300)
67
-
68
-
69
- def vqe_f(inputs, weights, nlayers, n):
70
- c = tc.Circuit(n)
71
- paramc = tf.cast(weights, tfdtype)
72
- for i in range(n):
73
- c.H(i)
74
- for j in range(nlayers):
75
- for i in range(n - 1):
76
- c.any(
77
- i,
78
- i + 1,
79
- unitary=tf.math.cos(paramc[2 * j, i]) * iir
80
- + tf.math.sin(paramc[2 * j, i]) * 1.0j * zzr,
81
- )
82
- for i in range(n):
83
- c.rx(i, theta=paramc[2 * j + 1, i])
84
- e = tfi_energy(c)
85
- e = tf.math.real(e)
86
- return e
87
-
88
-
89
- def test_vqe_layer(tfb, highp):
90
- vqe_fp = partial(vqe_f, nlayers=6, n=6)
91
- vqe_layer = tc.keras.QuantumLayer(vqe_fp, (6 * 2, 6))
92
- inputs = np.zeros([1])
93
- inputs = tf.constant(inputs)
94
- model = tf.keras.Sequential([vqe_layer])
95
-
96
- model.compile(
97
- loss=tc.keras.output_asis_loss, optimizer=tf.keras.optimizers.Adam(0.01)
98
- )
99
-
100
- model.fit(np.zeros([2, 1]), np.zeros([2, 1]), batch_size=2, epochs=500)
101
-
102
- np.testing.assert_allclose(model.predict(np.zeros([1])), -7.27, atol=5e-2)
103
-
104
-
105
- def test_function_io(tfb, tmp_path, highp):
106
- vqe_f_p = partial(vqe_f, inputs=tf.ones([1]))
107
-
108
- vqe_f_p = tf.function(vqe_f_p)
109
- vqe_f_p(weights=tf.ones([6, 6], dtype=tf.float64), nlayers=3, n=6)
110
- tc.keras.save_func(vqe_f_p, str(tmp_path))
111
- loaded = tc.keras.load_func(str(tmp_path), fallback=vqe_f_p)
112
- print(loaded(weights=tf.ones([6, 6], dtype=tf.float64), nlayers=3, n=6))
113
- print(loaded(weights=tf.ones([6, 6], dtype=tf.float64), nlayers=3, n=6))
114
-
115
-
116
- def test_keras_hardware(tfb):
117
- n = 2
118
-
119
- def qf(inputs, param):
120
- c = tc.Circuit(n)
121
- c.rx(0, theta=inputs[0])
122
- c.rx(1, theta=inputs[1])
123
- c.h(1)
124
- c.rzz(0, 1, theta=param[0])
125
- return tc.backend.stack([c.expectation_ps(z=[i]) for i in range(n)])
126
-
127
- ql = tc.keras.HardwareLayer(qf, [1], regularizer=tf.keras.regularizers.l2(1e-3))
128
- print(ql(tf.ones([1, 2])))
129
-
130
-
131
- def test_keras_layer_inputs_dict(tfb):
132
- # https://github.com/tensorflow/tensorflow/issues/65306
133
- # keras3 for tf2.16+ fails to accept complex valued input for keras layers
134
- # which is vital for quantum applications
135
- n = 3
136
- p = 0.1
137
- K = tc.backend
138
-
139
- def f(inputs, weights):
140
- state = inputs["state"]
141
- noise = inputs["noise"]
142
- c = tc.Circuit(n, inputs=state)
143
- for i in range(n):
144
- c.rz(i, theta=weights[i])
145
- for i in range(n):
146
- c.depolarizing(i, px=p, py=p, pz=p, status=noise[i])
147
- return K.real(c.expectation_ps(x=[0]))
148
-
149
- layer = tc.KerasLayer(f, [n])
150
- v = {"state": K.ones([1, 2**n]) / 2 ** (n / 2), "noise": 0.2 * K.ones([1, n])}
151
- with tf.GradientTape() as tape:
152
- l = layer(v)
153
- g1 = tape.gradient(l, layer.trainable_variables)
154
-
155
- v = {"state": K.ones([2**n]) / 2 ** (n / 2), "noise": 0.2 * K.ones([n])}
156
- with tf.GradientTape() as tape:
157
- l = layer(v)
158
- g2 = tape.gradient(l, layer.trainable_variables)
159
-
160
- np.testing.assert_allclose(g1[0], g2[0], atol=1e-5)