tensorcircuit-nightly 1.0.2.dev20250108__py3-none-any.whl → 1.4.0.dev20251103__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tensorcircuit-nightly might be problematic. Click here for more details.
- tensorcircuit/__init__.py +18 -2
- tensorcircuit/about.py +46 -0
- tensorcircuit/abstractcircuit.py +4 -0
- tensorcircuit/analogcircuit.py +413 -0
- tensorcircuit/applications/layers.py +1 -1
- tensorcircuit/applications/van.py +1 -1
- tensorcircuit/backends/abstract_backend.py +320 -7
- tensorcircuit/backends/cupy_backend.py +3 -1
- tensorcircuit/backends/jax_backend.py +102 -4
- tensorcircuit/backends/jax_ops.py +110 -1
- tensorcircuit/backends/numpy_backend.py +49 -3
- tensorcircuit/backends/pytorch_backend.py +92 -3
- tensorcircuit/backends/tensorflow_backend.py +102 -3
- tensorcircuit/basecircuit.py +157 -98
- tensorcircuit/circuit.py +115 -57
- tensorcircuit/cloud/local.py +1 -1
- tensorcircuit/cloud/quafu_provider.py +1 -1
- tensorcircuit/cloud/tencent.py +1 -1
- tensorcircuit/compiler/simple_compiler.py +2 -2
- tensorcircuit/cons.py +142 -21
- tensorcircuit/densitymatrix.py +43 -14
- tensorcircuit/experimental.py +387 -129
- tensorcircuit/fgs.py +282 -81
- tensorcircuit/gates.py +66 -22
- tensorcircuit/interfaces/__init__.py +1 -3
- tensorcircuit/interfaces/jax.py +189 -0
- tensorcircuit/keras.py +3 -3
- tensorcircuit/mpscircuit.py +154 -65
- tensorcircuit/quantum.py +868 -152
- tensorcircuit/quditcircuit.py +733 -0
- tensorcircuit/quditgates.py +618 -0
- tensorcircuit/results/counts.py +147 -20
- tensorcircuit/results/readout_mitigation.py +4 -1
- tensorcircuit/shadows.py +1 -1
- tensorcircuit/simplify.py +3 -1
- tensorcircuit/stabilizercircuit.py +479 -0
- tensorcircuit/templates/__init__.py +2 -0
- tensorcircuit/templates/blocks.py +2 -2
- tensorcircuit/templates/hamiltonians.py +174 -0
- tensorcircuit/templates/lattice.py +1789 -0
- tensorcircuit/timeevol.py +896 -0
- tensorcircuit/translation.py +10 -3
- tensorcircuit/utils.py +7 -0
- {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/METADATA +73 -23
- tensorcircuit_nightly-1.4.0.dev20251103.dist-info/RECORD +96 -0
- {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/WHEEL +1 -1
- {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/top_level.txt +0 -1
- tensorcircuit_nightly-1.0.2.dev20250108.dist-info/RECORD +0 -115
- tests/__init__.py +0 -0
- tests/conftest.py +0 -67
- tests/test_backends.py +0 -1031
- tests/test_calibrating.py +0 -149
- tests/test_channels.py +0 -365
- tests/test_circuit.py +0 -1699
- tests/test_cloud.py +0 -219
- tests/test_compiler.py +0 -147
- tests/test_dmcircuit.py +0 -555
- tests/test_ensemble.py +0 -72
- tests/test_fgs.py +0 -310
- tests/test_gates.py +0 -156
- tests/test_interfaces.py +0 -429
- tests/test_keras.py +0 -160
- tests/test_miscs.py +0 -277
- tests/test_mpscircuit.py +0 -341
- tests/test_noisemodel.py +0 -156
- tests/test_qaoa.py +0 -86
- tests/test_qem.py +0 -152
- tests/test_quantum.py +0 -526
- tests/test_quantum_attr.py +0 -42
- tests/test_results.py +0 -347
- tests/test_shadows.py +0 -160
- tests/test_simplify.py +0 -46
- tests/test_templates.py +0 -218
- tests/test_torchnn.py +0 -99
- tests/test_van.py +0 -102
- {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info/licenses}/LICENSE +0 -0
tensorcircuit/experimental.py
CHANGED
|
@@ -2,17 +2,26 @@
|
|
|
2
2
|
Experimental features
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
# pylint: disable=unused-import
|
|
6
|
+
|
|
5
7
|
from functools import partial
|
|
6
|
-
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Any, Callable, Dict, Optional, Tuple, List, Sequence, Union
|
|
7
10
|
|
|
8
11
|
import numpy as np
|
|
9
12
|
|
|
10
|
-
from .cons import backend, dtypestr,
|
|
13
|
+
from .cons import backend, dtypestr, rdtypestr, get_tn_info
|
|
11
14
|
from .gates import Gate
|
|
15
|
+
from .timeevol import hamiltonian_evol, evol_global, evol_local
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# for backward compatibility
|
|
12
19
|
|
|
13
20
|
Tensor = Any
|
|
14
21
|
Circuit = Any
|
|
15
22
|
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
16
25
|
|
|
17
26
|
def adaptive_vmap(
|
|
18
27
|
f: Callable[..., Any],
|
|
@@ -20,6 +29,16 @@ def adaptive_vmap(
|
|
|
20
29
|
static_argnums: Optional[Union[int, Sequence[int]]] = None,
|
|
21
30
|
chunk_size: Optional[int] = None,
|
|
22
31
|
) -> Callable[..., Any]:
|
|
32
|
+
"""
|
|
33
|
+
Vectorized map with adaptive chunking for memory efficiency.
|
|
34
|
+
|
|
35
|
+
:param f: Function to be vectorized
|
|
36
|
+
:param vectorized_argnums: Arguments to be vectorized over
|
|
37
|
+
:param static_argnums: Arguments that remain static during vectorization
|
|
38
|
+
:param chunk_size: Size of chunks for batch processing, None means no chunking
|
|
39
|
+
(naive vmap)
|
|
40
|
+
:return: Vectorized function
|
|
41
|
+
"""
|
|
23
42
|
if chunk_size is None:
|
|
24
43
|
return backend.vmap(f, vectorized_argnums) # type: ignore
|
|
25
44
|
|
|
@@ -94,6 +113,29 @@ def qng(
|
|
|
94
113
|
postprocess: Optional[str] = "qng",
|
|
95
114
|
mode: str = "fwd",
|
|
96
115
|
) -> Callable[..., Tensor]:
|
|
116
|
+
"""
|
|
117
|
+
Compute quantum natural gradient for quantum circuit optimization.
|
|
118
|
+
|
|
119
|
+
:param f: Function that takes parameters and returns quantum state
|
|
120
|
+
:param kernel: Type of kernel to use ("qng" or "dynamics"), the former has the second term
|
|
121
|
+
:param postprocess: Post-processing method ("qng" or None)
|
|
122
|
+
:param mode: Mode of differentiation ("fwd" or "rev")
|
|
123
|
+
:return: Function computing QNG matrix
|
|
124
|
+
|
|
125
|
+
:Example:
|
|
126
|
+
|
|
127
|
+
>>> import tensorcircuit as tc
|
|
128
|
+
>>> def ansatz(params):
|
|
129
|
+
... c = tc.Circuit(2)
|
|
130
|
+
... c.rx(0, theta=params[0])
|
|
131
|
+
... c.ry(1, theta=params[1])
|
|
132
|
+
... return c.state()
|
|
133
|
+
>>> qng_fn = tc.experimental.qng(ansatz)
|
|
134
|
+
>>> params = tc.array_to_tensor([0.5, 0.8])
|
|
135
|
+
>>> qng_matrix = qng_fn(params)
|
|
136
|
+
>>> print(qng_matrix.shape) # (2, 2)
|
|
137
|
+
"""
|
|
138
|
+
|
|
97
139
|
# for both qng and qng2 calculation, we highly recommended complex-dtype but real valued inputs
|
|
98
140
|
def wrapper(params: Tensor, **kws: Any) -> Tensor:
|
|
99
141
|
params = backend.cast(params, dtype=dtypestr) # R->C protection
|
|
@@ -399,130 +441,6 @@ def finite_difference_differentiator(
|
|
|
399
441
|
return tf_function # type: ignore
|
|
400
442
|
|
|
401
443
|
|
|
402
|
-
def hamiltonian_evol(
|
|
403
|
-
tlist: Tensor,
|
|
404
|
-
h: Tensor,
|
|
405
|
-
psi0: Tensor,
|
|
406
|
-
callback: Optional[Callable[..., Any]] = None,
|
|
407
|
-
) -> Tensor:
|
|
408
|
-
"""
|
|
409
|
-
Fast implementation of static full Hamiltonian evolution
|
|
410
|
-
(default as imaginary time)
|
|
411
|
-
|
|
412
|
-
:param tlist: _description_
|
|
413
|
-
:type tlist: Tensor
|
|
414
|
-
:param h: _description_
|
|
415
|
-
:type h: Tensor
|
|
416
|
-
:param psi0: _description_
|
|
417
|
-
:type psi0: Tensor
|
|
418
|
-
:param callback: _description_, defaults to None
|
|
419
|
-
:type callback: Optional[Callable[..., Any]], optional
|
|
420
|
-
:return: Tensor
|
|
421
|
-
:rtype: result dynamics on ``tlist``
|
|
422
|
-
"""
|
|
423
|
-
es, u = backend.eigh(h)
|
|
424
|
-
utpsi0 = backend.reshape(
|
|
425
|
-
backend.transpose(u) @ backend.reshape(psi0, [-1, 1]), [-1]
|
|
426
|
-
)
|
|
427
|
-
|
|
428
|
-
@backend.jit
|
|
429
|
-
def _evol(t: Tensor) -> Tensor:
|
|
430
|
-
ebetah_utpsi0 = backend.exp(-t * es) * utpsi0
|
|
431
|
-
psi_exact = backend.conj(u) @ backend.reshape(ebetah_utpsi0, [-1, 1])
|
|
432
|
-
psi_exact = backend.reshape(psi_exact, [-1])
|
|
433
|
-
psi_exact = psi_exact / backend.norm(psi_exact)
|
|
434
|
-
if callback is None:
|
|
435
|
-
return psi_exact
|
|
436
|
-
return callback(psi_exact)
|
|
437
|
-
|
|
438
|
-
return backend.stack([_evol(t) for t in tlist])
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
def evol_local(
|
|
442
|
-
c: Circuit,
|
|
443
|
-
index: Sequence[int],
|
|
444
|
-
h_fun: Callable[..., Tensor],
|
|
445
|
-
t: float,
|
|
446
|
-
*args: Any,
|
|
447
|
-
**solver_kws: Any
|
|
448
|
-
) -> Circuit:
|
|
449
|
-
"""
|
|
450
|
-
ode evolution of time dependent Hamiltonian on circuit of given indices
|
|
451
|
-
[only jax backend support for now]
|
|
452
|
-
|
|
453
|
-
:param c: _description_
|
|
454
|
-
:type c: Circuit
|
|
455
|
-
:param index: qubit sites to evolve
|
|
456
|
-
:type index: Sequence[int]
|
|
457
|
-
:param h_fun: h_fun should return a dense Hamiltonian matrix
|
|
458
|
-
with input arguments time and *args
|
|
459
|
-
:type h_fun: Callable[..., Tensor]
|
|
460
|
-
:param t: evolution time
|
|
461
|
-
:type t: float
|
|
462
|
-
:return: _description_
|
|
463
|
-
:rtype: Circuit
|
|
464
|
-
"""
|
|
465
|
-
from jax.experimental.ode import odeint
|
|
466
|
-
|
|
467
|
-
s = c.state()
|
|
468
|
-
n = c._nqubits
|
|
469
|
-
l = len(index)
|
|
470
|
-
|
|
471
|
-
def f(y: Tensor, t: Tensor, *args: Any) -> Tensor:
|
|
472
|
-
y = backend.reshape2(y)
|
|
473
|
-
y = Gate(y)
|
|
474
|
-
h = -1.0j * h_fun(t, *args)
|
|
475
|
-
h = backend.reshape2(h)
|
|
476
|
-
h = Gate(h)
|
|
477
|
-
edges = []
|
|
478
|
-
for i in range(n):
|
|
479
|
-
if i not in index:
|
|
480
|
-
edges.append(y[i])
|
|
481
|
-
else:
|
|
482
|
-
j = index.index(i)
|
|
483
|
-
edges.append(h[j])
|
|
484
|
-
h[j + l] ^ y[i]
|
|
485
|
-
y = contractor([y, h], output_edge_order=edges)
|
|
486
|
-
return backend.reshape(y.tensor, [-1])
|
|
487
|
-
|
|
488
|
-
ts = backend.stack([0.0, t])
|
|
489
|
-
ts = backend.cast(ts, dtype=rdtypestr)
|
|
490
|
-
s1 = odeint(f, s, ts, *args, **solver_kws)
|
|
491
|
-
return type(c)(n, inputs=s1[-1])
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
def evol_global(
|
|
495
|
-
c: Circuit, h_fun: Callable[..., Tensor], t: float, *args: Any, **solver_kws: Any
|
|
496
|
-
) -> Circuit:
|
|
497
|
-
"""
|
|
498
|
-
ode evolution of time dependent Hamiltonian on circuit of all qubits
|
|
499
|
-
[only jax backend support for now]
|
|
500
|
-
|
|
501
|
-
:param c: _description_
|
|
502
|
-
:type c: Circuit
|
|
503
|
-
:param h_fun: h_fun should return a **SPARSE** Hamiltonian matrix
|
|
504
|
-
with input arguments time and *args
|
|
505
|
-
:type h_fun: Callable[..., Tensor]
|
|
506
|
-
:param t: _description_
|
|
507
|
-
:type t: float
|
|
508
|
-
:return: _description_
|
|
509
|
-
:rtype: Circuit
|
|
510
|
-
"""
|
|
511
|
-
from jax.experimental.ode import odeint
|
|
512
|
-
|
|
513
|
-
s = c.state()
|
|
514
|
-
n = c._nqubits
|
|
515
|
-
|
|
516
|
-
def f(y: Tensor, t: Tensor, *args: Any) -> Tensor:
|
|
517
|
-
h = -1.0j * h_fun(t, *args)
|
|
518
|
-
return backend.sparse_dense_matmul(h, y)
|
|
519
|
-
|
|
520
|
-
ts = backend.stack([0.0, t])
|
|
521
|
-
ts = backend.cast(ts, dtype=rdtypestr)
|
|
522
|
-
s1 = odeint(f, s, ts, *args, **solver_kws)
|
|
523
|
-
return type(c)(n, inputs=s1[-1])
|
|
524
|
-
|
|
525
|
-
|
|
526
444
|
def jax_jitted_function_save(filename: str, f: Callable[..., Any], *args: Any) -> None:
|
|
527
445
|
"""
|
|
528
446
|
save a jitted jax function as a file
|
|
@@ -534,7 +452,7 @@ def jax_jitted_function_save(filename: str, f: Callable[..., Any], *args: Any) -
|
|
|
534
452
|
:param args: example function arguments for ``f``
|
|
535
453
|
"""
|
|
536
454
|
|
|
537
|
-
from jax import export
|
|
455
|
+
from jax import export # type: ignore
|
|
538
456
|
|
|
539
457
|
f_export = export.export(f)(*args) # type: ignore
|
|
540
458
|
barray = f_export.serialize()
|
|
@@ -555,14 +473,354 @@ def jax_jitted_function_load(filename: str) -> Callable[..., Any]:
|
|
|
555
473
|
:return: the loaded function
|
|
556
474
|
:rtype: _type_
|
|
557
475
|
"""
|
|
558
|
-
from jax import export
|
|
476
|
+
from jax import export # type: ignore
|
|
559
477
|
|
|
560
478
|
with open(filename, "rb") as f:
|
|
561
479
|
barray = f.read()
|
|
562
480
|
|
|
563
481
|
f_load = export.deserialize(barray) # type: ignore
|
|
564
482
|
|
|
565
|
-
return f_load.call
|
|
483
|
+
return f_load.call # type: ignore
|
|
566
484
|
|
|
567
485
|
|
|
568
486
|
jax_func_load = jax_jitted_function_load
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
PADDING_VALUE = -1
|
|
490
|
+
jaxlib: Any
|
|
491
|
+
ctg: Any
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
class DistributedContractor:
|
|
495
|
+
"""
|
|
496
|
+
A distributed tensor network contractor that parallelizes computations across multiple devices.
|
|
497
|
+
|
|
498
|
+
This class uses cotengra to find optimal contraction paths and distributes the computational
|
|
499
|
+
load across multiple devices (e.g., GPUs) for efficient tensor network calculations.
|
|
500
|
+
Particularly useful for large-scale quantum circuit simulations and variational quantum algorithms.
|
|
501
|
+
|
|
502
|
+
Example:
|
|
503
|
+
>>> def nodes_fn(params):
|
|
504
|
+
... c = tc.Circuit(4)
|
|
505
|
+
... c.rx(0, theta=params[0])
|
|
506
|
+
... return c.expectation_before([tc.gates.z(), [0]], reuse=False)
|
|
507
|
+
>>> dc = DistributedContractor(nodes_fn, params)
|
|
508
|
+
>>> value, grad = dc.value_and_grad(params)
|
|
509
|
+
|
|
510
|
+
:param nodes_fn: Function that takes parameters and returns a list of tensor network nodes
|
|
511
|
+
:type nodes_fn: Callable[[Tensor], List[Gate]]
|
|
512
|
+
:param params: Initial parameters used to determine the tensor network structure
|
|
513
|
+
:type params: Tensor
|
|
514
|
+
:param cotengra_options: Configuration options passed to the cotengra optimizer. Defaults to None
|
|
515
|
+
:type cotengra_options: Optional[Dict[str, Any]], optional
|
|
516
|
+
:param devices: List of devices to use. If None, uses all available local devices
|
|
517
|
+
:type devices: Optional[List[Any]], optional
|
|
518
|
+
"""
|
|
519
|
+
|
|
520
|
+
def __init__(
|
|
521
|
+
self,
|
|
522
|
+
nodes_fn: Callable[[Tensor], List[Gate]],
|
|
523
|
+
params: Tensor,
|
|
524
|
+
cotengra_options: Optional[Dict[str, Any]] = None,
|
|
525
|
+
devices: Optional[List[Any]] = None,
|
|
526
|
+
) -> None:
|
|
527
|
+
global jaxlib
|
|
528
|
+
global ctg
|
|
529
|
+
|
|
530
|
+
logger.info("Initializing DistributedContractor...")
|
|
531
|
+
import cotengra as ctg
|
|
532
|
+
import jax as jaxlib
|
|
533
|
+
|
|
534
|
+
self.nodes_fn = nodes_fn
|
|
535
|
+
if devices is None:
|
|
536
|
+
self.num_devices = jaxlib.local_device_count()
|
|
537
|
+
self.devices = jaxlib.local_devices()
|
|
538
|
+
# TODO(@refraction-ray): multi host support
|
|
539
|
+
else:
|
|
540
|
+
self.devices = devices
|
|
541
|
+
self.num_devices = len(devices)
|
|
542
|
+
|
|
543
|
+
if self.num_devices <= 1:
|
|
544
|
+
logger.info("DistributedContractor is running on a single device.")
|
|
545
|
+
|
|
546
|
+
self._params_template = params
|
|
547
|
+
self._backend = "jax"
|
|
548
|
+
self._compiled_v_fns: Dict[
|
|
549
|
+
Tuple[Callable[[Tensor], Tensor], str],
|
|
550
|
+
Callable[[Any, Tensor, Tensor], Tensor],
|
|
551
|
+
] = {}
|
|
552
|
+
self._compiled_vg_fns: Dict[
|
|
553
|
+
Tuple[Callable[[Tensor], Tensor], str],
|
|
554
|
+
Callable[[Any, Tensor, Tensor], Tensor],
|
|
555
|
+
] = {}
|
|
556
|
+
|
|
557
|
+
logger.info("Running cotengra pathfinder... (This may take a while)")
|
|
558
|
+
nodes = self.nodes_fn(self._params_template)
|
|
559
|
+
tn_info, _ = get_tn_info(nodes)
|
|
560
|
+
default_cotengra_options = {
|
|
561
|
+
"slicing_reconf_opts": {"target_size": 2**28},
|
|
562
|
+
"max_repeats": 128,
|
|
563
|
+
"progbar": True,
|
|
564
|
+
"minimize": "write",
|
|
565
|
+
"parallel": "auto",
|
|
566
|
+
}
|
|
567
|
+
if cotengra_options:
|
|
568
|
+
default_cotengra_options = cotengra_options
|
|
569
|
+
|
|
570
|
+
opt = ctg.ReusableHyperOptimizer(**default_cotengra_options)
|
|
571
|
+
self.tree = opt.search(*tn_info)
|
|
572
|
+
actual_num_slices = self.tree.nslices
|
|
573
|
+
|
|
574
|
+
print("\n--- Contraction Path Info ---")
|
|
575
|
+
stats = self.tree.contract_stats()
|
|
576
|
+
print(f"Path found with {actual_num_slices} slices.")
|
|
577
|
+
print(
|
|
578
|
+
f"Arithmetic Intensity (higher is better): {self.tree.arithmetic_intensity():.2f}"
|
|
579
|
+
)
|
|
580
|
+
print("flops (TFlops):", stats["flops"] / 2**40 / self.num_devices)
|
|
581
|
+
print("write (GB):", stats["write"] / 2**27 / actual_num_slices)
|
|
582
|
+
print("size (GB):", stats["size"] / 2**27)
|
|
583
|
+
print("-----------------------------\n")
|
|
584
|
+
|
|
585
|
+
slices_per_device = int(np.ceil(actual_num_slices / self.num_devices))
|
|
586
|
+
padded_size = slices_per_device * self.num_devices
|
|
587
|
+
slice_indices = np.arange(actual_num_slices)
|
|
588
|
+
padded_slice_indices = np.full(padded_size, PADDING_VALUE, dtype=np.int32)
|
|
589
|
+
padded_slice_indices[:actual_num_slices] = slice_indices
|
|
590
|
+
self.batched_slice_indices = backend.convert_to_tensor(
|
|
591
|
+
padded_slice_indices.reshape(self.num_devices, slices_per_device)
|
|
592
|
+
)
|
|
593
|
+
print(
|
|
594
|
+
f"Distributing across {self.num_devices} devices. Each device will sequentially process "
|
|
595
|
+
f"up to {slices_per_device} slices."
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
self._compiled_vg_fn = None
|
|
599
|
+
self._compiled_v_fn = None
|
|
600
|
+
|
|
601
|
+
logger.info("Initialization complete.")
|
|
602
|
+
|
|
603
|
+
def _get_single_slice_contraction_fn(
|
|
604
|
+
self, op: Optional[Callable[[Tensor], Tensor]] = None
|
|
605
|
+
) -> Callable[[Any, Tensor, int], Tensor]:
|
|
606
|
+
if op is None:
|
|
607
|
+
op = backend.sum
|
|
608
|
+
|
|
609
|
+
def single_slice_contraction(
|
|
610
|
+
tree: ctg.ContractionTree, params: Tensor, slice_idx: int
|
|
611
|
+
) -> Tensor:
|
|
612
|
+
nodes = self.nodes_fn(params)
|
|
613
|
+
_, standardized_nodes = get_tn_info(nodes)
|
|
614
|
+
input_arrays = [node.tensor for node in standardized_nodes]
|
|
615
|
+
sliced_arrays = tree.slice_arrays(input_arrays, slice_idx)
|
|
616
|
+
result = tree.contract_core(sliced_arrays, backend=self._backend)
|
|
617
|
+
return op(result)
|
|
618
|
+
|
|
619
|
+
return single_slice_contraction
|
|
620
|
+
|
|
621
|
+
def _get_device_sum_vg_fn(
|
|
622
|
+
self,
|
|
623
|
+
op: Optional[Callable[[Tensor], Tensor]] = None,
|
|
624
|
+
output_dtype: Optional[str] = None,
|
|
625
|
+
) -> Callable[[Any, Tensor, Tensor], Tuple[Tensor, Tensor]]:
|
|
626
|
+
post_processing = lambda x: backend.real(backend.sum(x))
|
|
627
|
+
if op is None:
|
|
628
|
+
op = post_processing
|
|
629
|
+
base_fn = self._get_single_slice_contraction_fn(op=op)
|
|
630
|
+
# to ensure the output is real so that can be differentiated
|
|
631
|
+
single_slice_vg_fn = jaxlib.value_and_grad(base_fn, argnums=1)
|
|
632
|
+
|
|
633
|
+
if output_dtype is None:
|
|
634
|
+
output_dtype = rdtypestr
|
|
635
|
+
|
|
636
|
+
def device_sum_fn(
|
|
637
|
+
tree: ctg.ContractionTree, params: Tensor, slice_indices_for_device: Tensor
|
|
638
|
+
) -> Tuple[Tensor, Tensor]:
|
|
639
|
+
def scan_body(
|
|
640
|
+
carry: Tuple[Tensor, Tensor], slice_idx: Tensor
|
|
641
|
+
) -> Tuple[Tuple[Tensor, Tensor], None]:
|
|
642
|
+
acc_value, acc_grads = carry
|
|
643
|
+
|
|
644
|
+
def compute_and_add() -> Tuple[Tensor, Tensor]:
|
|
645
|
+
value_slice, grads_slice = single_slice_vg_fn(
|
|
646
|
+
tree, params, slice_idx
|
|
647
|
+
)
|
|
648
|
+
new_value = acc_value + value_slice
|
|
649
|
+
new_grads = jaxlib.tree_util.tree_map(
|
|
650
|
+
jaxlib.numpy.add, acc_grads, grads_slice
|
|
651
|
+
)
|
|
652
|
+
return new_value, new_grads
|
|
653
|
+
|
|
654
|
+
def do_nothing() -> Tuple[Tensor, Tensor]:
|
|
655
|
+
return acc_value, acc_grads
|
|
656
|
+
|
|
657
|
+
return (
|
|
658
|
+
jaxlib.lax.cond(
|
|
659
|
+
slice_idx == PADDING_VALUE, do_nothing, compute_and_add
|
|
660
|
+
),
|
|
661
|
+
None,
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
initial_carry = (
|
|
665
|
+
backend.cast(backend.convert_to_tensor(0.0), dtype=output_dtype),
|
|
666
|
+
jaxlib.tree_util.tree_map(lambda x: jaxlib.numpy.zeros_like(x), params),
|
|
667
|
+
)
|
|
668
|
+
(final_value, final_grads), _ = jaxlib.lax.scan(
|
|
669
|
+
scan_body, initial_carry, slice_indices_for_device
|
|
670
|
+
)
|
|
671
|
+
return final_value, final_grads
|
|
672
|
+
|
|
673
|
+
return device_sum_fn
|
|
674
|
+
|
|
675
|
+
def _get_device_sum_v_fn(
|
|
676
|
+
self,
|
|
677
|
+
op: Optional[Callable[[Tensor], Tensor]] = None,
|
|
678
|
+
output_dtype: Optional[str] = None,
|
|
679
|
+
) -> Callable[[Any, Tensor, Tensor], Tensor]:
|
|
680
|
+
base_fn = self._get_single_slice_contraction_fn(op=op)
|
|
681
|
+
if output_dtype is None:
|
|
682
|
+
output_dtype = dtypestr
|
|
683
|
+
|
|
684
|
+
def device_sum_fn(
|
|
685
|
+
tree: ctg.ContractionTree, params: Tensor, slice_indices_for_device: Tensor
|
|
686
|
+
) -> Tensor:
|
|
687
|
+
def scan_body(
|
|
688
|
+
carry_value: Tensor, slice_idx: Tensor
|
|
689
|
+
) -> Tuple[Tensor, None]:
|
|
690
|
+
def compute_and_add() -> Tensor:
|
|
691
|
+
return carry_value + base_fn(tree, params, slice_idx)
|
|
692
|
+
|
|
693
|
+
return (
|
|
694
|
+
jaxlib.lax.cond(
|
|
695
|
+
slice_idx == PADDING_VALUE, lambda: carry_value, compute_and_add
|
|
696
|
+
),
|
|
697
|
+
None,
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
initial_carry = backend.cast(
|
|
701
|
+
backend.convert_to_tensor(0.0), dtype=output_dtype
|
|
702
|
+
)
|
|
703
|
+
final_value, _ = jaxlib.lax.scan(
|
|
704
|
+
scan_body, initial_carry, slice_indices_for_device
|
|
705
|
+
)
|
|
706
|
+
return final_value
|
|
707
|
+
|
|
708
|
+
return device_sum_fn
|
|
709
|
+
|
|
710
|
+
def _get_or_compile_fn(
|
|
711
|
+
self,
|
|
712
|
+
cache: Dict[
|
|
713
|
+
Tuple[Callable[[Tensor], Tensor], str],
|
|
714
|
+
Callable[[Any, Tensor, Tensor], Tensor],
|
|
715
|
+
],
|
|
716
|
+
fn_getter: Callable[..., Any],
|
|
717
|
+
op: Optional[Callable[[Tensor], Tensor]],
|
|
718
|
+
output_dtype: Optional[str],
|
|
719
|
+
) -> Callable[[Any, Tensor, Tensor], Tensor]:
|
|
720
|
+
"""
|
|
721
|
+
Gets a compiled pmap-ed function from cache or compiles and caches it.
|
|
722
|
+
|
|
723
|
+
The cache key is a tuple of (op, output_dtype). Caution on lambda function!
|
|
724
|
+
|
|
725
|
+
Returns:
|
|
726
|
+
The compiled, pmap-ed JAX function.
|
|
727
|
+
"""
|
|
728
|
+
cache_key = (op, output_dtype)
|
|
729
|
+
if cache_key not in cache:
|
|
730
|
+
device_fn = fn_getter(op=op, output_dtype=output_dtype)
|
|
731
|
+
compiled_fn = jaxlib.pmap(
|
|
732
|
+
device_fn,
|
|
733
|
+
in_axes=(
|
|
734
|
+
None,
|
|
735
|
+
None,
|
|
736
|
+
0,
|
|
737
|
+
), # tree: broadcast, params: broadcast, indices: map
|
|
738
|
+
static_broadcasted_argnums=(0,), # arg 0 (tree) is a static argument
|
|
739
|
+
devices=self.devices,
|
|
740
|
+
)
|
|
741
|
+
cache[cache_key] = compiled_fn # type: ignore
|
|
742
|
+
return cache[cache_key] # type: ignore
|
|
743
|
+
|
|
744
|
+
def value_and_grad(
|
|
745
|
+
self,
|
|
746
|
+
params: Tensor,
|
|
747
|
+
aggregate: bool = True,
|
|
748
|
+
op: Optional[Callable[[Tensor], Tensor]] = None,
|
|
749
|
+
output_dtype: Optional[str] = None,
|
|
750
|
+
) -> Tuple[Tensor, Tensor]:
|
|
751
|
+
"""
|
|
752
|
+
Calculates the value and gradient, compiling the pmap function if needed for the first call.
|
|
753
|
+
|
|
754
|
+
:param params: Parameters for the `nodes_fn` input
|
|
755
|
+
:type params: Tensor
|
|
756
|
+
:param aggregate: Whether to aggregate (sum) the results across devices, defaults to True
|
|
757
|
+
:type aggregate: bool, optional
|
|
758
|
+
:param op: Optional post-processing function for the output, defaults to None (corresponding to `backend.real`)
|
|
759
|
+
op is a cache key, so dont directly pass lambda function for op
|
|
760
|
+
:type op: Optional[Callable[[Tensor], Tensor]], optional
|
|
761
|
+
:param output_dtype: dtype str for the output of `nodes_fn`, defaults to None (corresponding to `rdtypestr`)
|
|
762
|
+
:type output_dtype: Optional[str], optional
|
|
763
|
+
"""
|
|
764
|
+
compiled_vg_fn = self._get_or_compile_fn(
|
|
765
|
+
cache=self._compiled_vg_fns,
|
|
766
|
+
fn_getter=self._get_device_sum_vg_fn,
|
|
767
|
+
op=op,
|
|
768
|
+
output_dtype=output_dtype,
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
device_values, device_grads = compiled_vg_fn(
|
|
772
|
+
self.tree, params, self.batched_slice_indices
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
if aggregate:
|
|
776
|
+
total_value = backend.sum(device_values)
|
|
777
|
+
total_grad = jaxlib.tree_util.tree_map(
|
|
778
|
+
lambda x: backend.sum(x, axis=0), device_grads
|
|
779
|
+
)
|
|
780
|
+
return total_value, total_grad
|
|
781
|
+
return device_values, device_grads
|
|
782
|
+
|
|
783
|
+
def value(
|
|
784
|
+
self,
|
|
785
|
+
params: Tensor,
|
|
786
|
+
aggregate: bool = True,
|
|
787
|
+
op: Optional[Callable[[Tensor], Tensor]] = None,
|
|
788
|
+
output_dtype: Optional[str] = None,
|
|
789
|
+
) -> Tensor:
|
|
790
|
+
"""
|
|
791
|
+
Calculates the value, compiling the pmap function for the first call.
|
|
792
|
+
|
|
793
|
+
:param params: Parameters for the `nodes_fn` input
|
|
794
|
+
:type params: Tensor
|
|
795
|
+
:param aggregate: Whether to aggregate (sum) the results across devices, defaults to True
|
|
796
|
+
:type aggregate: bool, optional
|
|
797
|
+
:param op: Optional post-processing function for the output, defaults to None (corresponding to identity)
|
|
798
|
+
op is a cache key, so dont directly pass lambda function for op
|
|
799
|
+
:type op: Optional[Callable[[Tensor], Tensor]], optional
|
|
800
|
+
:param output_dtype: dtype str for the output of `nodes_fn`, defaults to None (corresponding to `dtypestr`)
|
|
801
|
+
:type output_dtype: Optional[str], optional
|
|
802
|
+
"""
|
|
803
|
+
compiled_v_fn = self._get_or_compile_fn(
|
|
804
|
+
cache=self._compiled_v_fns,
|
|
805
|
+
fn_getter=self._get_device_sum_v_fn,
|
|
806
|
+
op=op,
|
|
807
|
+
output_dtype=output_dtype,
|
|
808
|
+
)
|
|
809
|
+
|
|
810
|
+
device_values = compiled_v_fn(self.tree, params, self.batched_slice_indices)
|
|
811
|
+
|
|
812
|
+
if aggregate:
|
|
813
|
+
return backend.sum(device_values)
|
|
814
|
+
return device_values
|
|
815
|
+
|
|
816
|
+
def grad(
|
|
817
|
+
self,
|
|
818
|
+
params: Tensor,
|
|
819
|
+
aggregate: bool = True,
|
|
820
|
+
op: Optional[Callable[[Tensor], Tensor]] = None,
|
|
821
|
+
output_dtype: Optional[str] = None,
|
|
822
|
+
) -> Tensor:
|
|
823
|
+
_, grad = self.value_and_grad(
|
|
824
|
+
params, aggregate=aggregate, op=op, output_dtype=output_dtype
|
|
825
|
+
)
|
|
826
|
+
return grad
|