tensorcircuit-nightly 1.2.0.dev20250326__py3-none-any.whl → 1.4.0.dev20251128__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tensorcircuit-nightly might be problematic. Click here for more details.
- tensorcircuit/__init__.py +5 -1
- tensorcircuit/abstractcircuit.py +4 -0
- tensorcircuit/analogcircuit.py +413 -0
- tensorcircuit/applications/layers.py +1 -1
- tensorcircuit/applications/van.py +1 -1
- tensorcircuit/backends/abstract_backend.py +312 -5
- tensorcircuit/backends/cupy_backend.py +3 -1
- tensorcircuit/backends/jax_backend.py +100 -4
- tensorcircuit/backends/jax_ops.py +108 -0
- tensorcircuit/backends/numpy_backend.py +49 -3
- tensorcircuit/backends/pytorch_backend.py +92 -3
- tensorcircuit/backends/tensorflow_backend.py +102 -3
- tensorcircuit/basecircuit.py +157 -98
- tensorcircuit/circuit.py +115 -57
- tensorcircuit/cloud/local.py +1 -1
- tensorcircuit/cloud/quafu_provider.py +1 -1
- tensorcircuit/cloud/tencent.py +1 -1
- tensorcircuit/compiler/simple_compiler.py +2 -2
- tensorcircuit/cons.py +105 -23
- tensorcircuit/densitymatrix.py +16 -11
- tensorcircuit/experimental.py +733 -153
- tensorcircuit/fgs.py +254 -73
- tensorcircuit/gates.py +66 -22
- tensorcircuit/interfaces/jax.py +5 -3
- tensorcircuit/interfaces/tensortrans.py +6 -2
- tensorcircuit/interfaces/torch.py +14 -4
- tensorcircuit/keras.py +3 -3
- tensorcircuit/mpscircuit.py +154 -65
- tensorcircuit/quantum.py +698 -134
- tensorcircuit/quditcircuit.py +733 -0
- tensorcircuit/quditgates.py +618 -0
- tensorcircuit/results/counts.py +131 -18
- tensorcircuit/results/readout_mitigation.py +4 -1
- tensorcircuit/shadows.py +1 -1
- tensorcircuit/simplify.py +3 -1
- tensorcircuit/stabilizercircuit.py +29 -17
- tensorcircuit/templates/__init__.py +2 -0
- tensorcircuit/templates/blocks.py +2 -2
- tensorcircuit/templates/hamiltonians.py +174 -0
- tensorcircuit/templates/lattice.py +1789 -0
- tensorcircuit/timeevol.py +896 -0
- tensorcircuit/translation.py +10 -3
- tensorcircuit/utils.py +7 -0
- {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/METADATA +66 -29
- tensorcircuit_nightly-1.4.0.dev20251128.dist-info/RECORD +96 -0
- {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/WHEEL +1 -1
- {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/top_level.txt +0 -1
- tensorcircuit_nightly-1.2.0.dev20250326.dist-info/RECORD +0 -118
- tests/__init__.py +0 -0
- tests/conftest.py +0 -67
- tests/test_backends.py +0 -1035
- tests/test_calibrating.py +0 -149
- tests/test_channels.py +0 -409
- tests/test_circuit.py +0 -1699
- tests/test_cloud.py +0 -219
- tests/test_compiler.py +0 -147
- tests/test_dmcircuit.py +0 -555
- tests/test_ensemble.py +0 -72
- tests/test_fgs.py +0 -310
- tests/test_gates.py +0 -156
- tests/test_interfaces.py +0 -562
- tests/test_keras.py +0 -160
- tests/test_miscs.py +0 -282
- tests/test_mpscircuit.py +0 -341
- tests/test_noisemodel.py +0 -156
- tests/test_qaoa.py +0 -86
- tests/test_qem.py +0 -152
- tests/test_quantum.py +0 -549
- tests/test_quantum_attr.py +0 -42
- tests/test_results.py +0 -380
- tests/test_shadows.py +0 -160
- tests/test_simplify.py +0 -46
- tests/test_stabilizer.py +0 -217
- tests/test_templates.py +0 -218
- tests/test_torchnn.py +0 -99
- tests/test_van.py +0 -102
- {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/licenses/LICENSE +0 -0
tensorcircuit/experimental.py
CHANGED
|
@@ -2,17 +2,30 @@
|
|
|
2
2
|
Experimental features
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
# pylint: disable=unused-import
|
|
6
|
+
|
|
5
7
|
from functools import partial
|
|
6
|
-
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Any, Callable, Dict, Optional, Tuple, List, Sequence, Union
|
|
10
|
+
import pickle
|
|
11
|
+
import uuid
|
|
12
|
+
import time
|
|
13
|
+
import os
|
|
7
14
|
|
|
8
15
|
import numpy as np
|
|
9
16
|
|
|
10
|
-
from .cons import backend, dtypestr,
|
|
17
|
+
from .cons import backend, dtypestr, rdtypestr, get_tn_info
|
|
11
18
|
from .gates import Gate
|
|
19
|
+
from .timeevol import hamiltonian_evol, evol_global, evol_local
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# for backward compatibility
|
|
12
23
|
|
|
13
24
|
Tensor = Any
|
|
14
25
|
Circuit = Any
|
|
15
26
|
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
16
29
|
|
|
17
30
|
def adaptive_vmap(
|
|
18
31
|
f: Callable[..., Any],
|
|
@@ -432,157 +445,6 @@ def finite_difference_differentiator(
|
|
|
432
445
|
return tf_function # type: ignore
|
|
433
446
|
|
|
434
447
|
|
|
435
|
-
def hamiltonian_evol(
|
|
436
|
-
tlist: Tensor,
|
|
437
|
-
h: Tensor,
|
|
438
|
-
psi0: Tensor,
|
|
439
|
-
callback: Optional[Callable[..., Any]] = None,
|
|
440
|
-
) -> Tensor:
|
|
441
|
-
"""
|
|
442
|
-
Fast implementation of time independent Hamiltonian evolution using eigendecomposition.
|
|
443
|
-
By default, performs imaginary time evolution.
|
|
444
|
-
|
|
445
|
-
:param tlist: Time points for evolution
|
|
446
|
-
:type tlist: Tensor
|
|
447
|
-
:param h: Time-independent Hamiltonian matrix
|
|
448
|
-
:type h: Tensor
|
|
449
|
-
:param psi0: Initial state vector
|
|
450
|
-
:type psi0: Tensor
|
|
451
|
-
:param callback: Optional function to process state at each time point
|
|
452
|
-
:type callback: Optional[Callable[..., Any]], optional
|
|
453
|
-
:return: Evolution results at each time point. If callback is None, returns state vectors;
|
|
454
|
-
otherwise returns callback results
|
|
455
|
-
:rtype: Tensor
|
|
456
|
-
|
|
457
|
-
:Example:
|
|
458
|
-
|
|
459
|
-
>>> import tensorcircuit as tc
|
|
460
|
-
>>> import numpy as np
|
|
461
|
-
>>> # Define a simple 2-qubit Hamiltonian
|
|
462
|
-
>>> h = tc.array_to_tensor([
|
|
463
|
-
... [1.0, 0.0, 0.0, 0.0],
|
|
464
|
-
... [0.0, -1.0, 2.0, 0.0],
|
|
465
|
-
... [0.0, 2.0, -1.0, 0.0],
|
|
466
|
-
... [0.0, 0.0, 0.0, 1.0]
|
|
467
|
-
... ])
|
|
468
|
-
>>> # Initial state |00⟩
|
|
469
|
-
>>> psi0 = tc.array_to_tensor([1.0, 0.0, 0.0, 0.0])
|
|
470
|
-
>>> # Evolution times
|
|
471
|
-
>>> times = tc.array_to_tensor([0.0, 0.5, 1.0])
|
|
472
|
-
>>> # Evolve and get states
|
|
473
|
-
>>> states = tc.experimental.hamiltonian_evol(times, h, psi0)
|
|
474
|
-
>>> print(states.shape) # (3, 4)
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
Note:
|
|
478
|
-
1. The Hamiltonian must be time-independent
|
|
479
|
-
2. For time-dependent Hamiltonians, use ``evol_local`` or ``evol_global`` instead
|
|
480
|
-
3. The evolution is performed in imaginary time by default (factor -t in exponential)
|
|
481
|
-
4. The state is automatically normalized at each time point
|
|
482
|
-
"""
|
|
483
|
-
es, u = backend.eigh(h)
|
|
484
|
-
utpsi0 = backend.reshape(
|
|
485
|
-
backend.transpose(u) @ backend.reshape(psi0, [-1, 1]), [-1]
|
|
486
|
-
)
|
|
487
|
-
|
|
488
|
-
@backend.jit
|
|
489
|
-
def _evol(t: Tensor) -> Tensor:
|
|
490
|
-
ebetah_utpsi0 = backend.exp(-t * es) * utpsi0
|
|
491
|
-
psi_exact = backend.conj(u) @ backend.reshape(ebetah_utpsi0, [-1, 1])
|
|
492
|
-
psi_exact = backend.reshape(psi_exact, [-1])
|
|
493
|
-
psi_exact = psi_exact / backend.norm(psi_exact)
|
|
494
|
-
if callback is None:
|
|
495
|
-
return psi_exact
|
|
496
|
-
return callback(psi_exact)
|
|
497
|
-
|
|
498
|
-
return backend.stack([_evol(t) for t in tlist])
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
def evol_local(
|
|
502
|
-
c: Circuit,
|
|
503
|
-
index: Sequence[int],
|
|
504
|
-
h_fun: Callable[..., Tensor],
|
|
505
|
-
t: float,
|
|
506
|
-
*args: Any,
|
|
507
|
-
**solver_kws: Any
|
|
508
|
-
) -> Circuit:
|
|
509
|
-
"""
|
|
510
|
-
ode evolution of time dependent Hamiltonian on circuit of given indices
|
|
511
|
-
[only jax backend support for now]
|
|
512
|
-
|
|
513
|
-
:param c: _description_
|
|
514
|
-
:type c: Circuit
|
|
515
|
-
:param index: qubit sites to evolve
|
|
516
|
-
:type index: Sequence[int]
|
|
517
|
-
:param h_fun: h_fun should return a dense Hamiltonian matrix
|
|
518
|
-
with input arguments time and *args
|
|
519
|
-
:type h_fun: Callable[..., Tensor]
|
|
520
|
-
:param t: evolution time
|
|
521
|
-
:type t: float
|
|
522
|
-
:return: _description_
|
|
523
|
-
:rtype: Circuit
|
|
524
|
-
"""
|
|
525
|
-
from jax.experimental.ode import odeint
|
|
526
|
-
|
|
527
|
-
s = c.state()
|
|
528
|
-
n = c._nqubits
|
|
529
|
-
l = len(index)
|
|
530
|
-
|
|
531
|
-
def f(y: Tensor, t: Tensor, *args: Any) -> Tensor:
|
|
532
|
-
y = backend.reshape2(y)
|
|
533
|
-
y = Gate(y)
|
|
534
|
-
h = -1.0j * h_fun(t, *args)
|
|
535
|
-
h = backend.reshape2(h)
|
|
536
|
-
h = Gate(h)
|
|
537
|
-
edges = []
|
|
538
|
-
for i in range(n):
|
|
539
|
-
if i not in index:
|
|
540
|
-
edges.append(y[i])
|
|
541
|
-
else:
|
|
542
|
-
j = index.index(i)
|
|
543
|
-
edges.append(h[j])
|
|
544
|
-
h[j + l] ^ y[i]
|
|
545
|
-
y = contractor([y, h], output_edge_order=edges)
|
|
546
|
-
return backend.reshape(y.tensor, [-1])
|
|
547
|
-
|
|
548
|
-
ts = backend.stack([0.0, t])
|
|
549
|
-
ts = backend.cast(ts, dtype=rdtypestr)
|
|
550
|
-
s1 = odeint(f, s, ts, *args, **solver_kws)
|
|
551
|
-
return type(c)(n, inputs=s1[-1])
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
def evol_global(
|
|
555
|
-
c: Circuit, h_fun: Callable[..., Tensor], t: float, *args: Any, **solver_kws: Any
|
|
556
|
-
) -> Circuit:
|
|
557
|
-
"""
|
|
558
|
-
ode evolution of time dependent Hamiltonian on circuit of all qubits
|
|
559
|
-
[only jax backend support for now]
|
|
560
|
-
|
|
561
|
-
:param c: _description_
|
|
562
|
-
:type c: Circuit
|
|
563
|
-
:param h_fun: h_fun should return a **SPARSE** Hamiltonian matrix
|
|
564
|
-
with input arguments time and *args
|
|
565
|
-
:type h_fun: Callable[..., Tensor]
|
|
566
|
-
:param t: _description_
|
|
567
|
-
:type t: float
|
|
568
|
-
:return: _description_
|
|
569
|
-
:rtype: Circuit
|
|
570
|
-
"""
|
|
571
|
-
from jax.experimental.ode import odeint
|
|
572
|
-
|
|
573
|
-
s = c.state()
|
|
574
|
-
n = c._nqubits
|
|
575
|
-
|
|
576
|
-
def f(y: Tensor, t: Tensor, *args: Any) -> Tensor:
|
|
577
|
-
h = -1.0j * h_fun(t, *args)
|
|
578
|
-
return backend.sparse_dense_matmul(h, y)
|
|
579
|
-
|
|
580
|
-
ts = backend.stack([0.0, t])
|
|
581
|
-
ts = backend.cast(ts, dtype=rdtypestr)
|
|
582
|
-
s1 = odeint(f, s, ts, *args, **solver_kws)
|
|
583
|
-
return type(c)(n, inputs=s1[-1])
|
|
584
|
-
|
|
585
|
-
|
|
586
448
|
def jax_jitted_function_save(filename: str, f: Callable[..., Any], *args: Any) -> None:
|
|
587
449
|
"""
|
|
588
450
|
save a jitted jax function as a file
|
|
@@ -626,3 +488,721 @@ def jax_jitted_function_load(filename: str) -> Callable[..., Any]:
|
|
|
626
488
|
|
|
627
489
|
|
|
628
490
|
jax_func_load = jax_jitted_function_load
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
PADDING_VALUE = -1
|
|
494
|
+
jaxlib: Any
|
|
495
|
+
ctg: Any
|
|
496
|
+
Mesh: Any
|
|
497
|
+
NamedSharding: Any
|
|
498
|
+
P: Any
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def broadcast_py_object(obj: Any, shared_dir: Optional[str] = None) -> Any:
|
|
502
|
+
"""
|
|
503
|
+
Broadcast a picklable Python object from process 0 to all other processes,
|
|
504
|
+
with fallback mechanism from gRPC to file system based approach.
|
|
505
|
+
|
|
506
|
+
This function first attempts to use gRPC-based broadcast. If that fails due to
|
|
507
|
+
pickling issues, it falls back to a file system based approach that is more robust.
|
|
508
|
+
|
|
509
|
+
:param obj: The Python object to broadcast. It must be picklable.
|
|
510
|
+
This object should exist on process 0 and can be None on others.
|
|
511
|
+
:type obj: Any
|
|
512
|
+
:param shared_dir: Directory path for shared file system broadcast fallback.
|
|
513
|
+
If None, uses current directory. Only used in fallback mode.
|
|
514
|
+
:type shared_dir: Optional[str], optional
|
|
515
|
+
:return: The broadcasted object, now present on all processes.
|
|
516
|
+
:rtype: Any
|
|
517
|
+
"""
|
|
518
|
+
import jax
|
|
519
|
+
from jax.experimental import multihost_utils
|
|
520
|
+
|
|
521
|
+
try:
|
|
522
|
+
result = broadcast_py_object_jax(obj)
|
|
523
|
+
return result
|
|
524
|
+
|
|
525
|
+
except pickle.UnpicklingError as e:
|
|
526
|
+
# This block is executed if any process fails during the gRPC attempt.
|
|
527
|
+
|
|
528
|
+
multihost_utils.sync_global_devices("grpc_broadcast_failed_fallback_sync")
|
|
529
|
+
|
|
530
|
+
if jax.process_index() == 0:
|
|
531
|
+
border = "=" * 80
|
|
532
|
+
logger.warning(
|
|
533
|
+
"\n%s\nJAX gRPC broadcast failed with error: %s\n"
|
|
534
|
+
"--> Falling back to robust Shared File System broadcast method.\n%s",
|
|
535
|
+
border,
|
|
536
|
+
e,
|
|
537
|
+
border,
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
return broadcast_py_object_fs(obj, shared_dir)
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
def broadcast_py_object_jax(obj: Any) -> Any:
|
|
544
|
+
"""
|
|
545
|
+
Broadcast a picklable Python object from process 0 to all other processes
|
|
546
|
+
within jax ditribution system.
|
|
547
|
+
|
|
548
|
+
This function uses a two-step broadcast: first the size, then the data.
|
|
549
|
+
This is necessary because `broadcast_one_to_all` requires the same
|
|
550
|
+
shaped array on all hosts.
|
|
551
|
+
|
|
552
|
+
:param obj: The Python object to broadcast. It must be picklable.
|
|
553
|
+
This object should exist on process 0 and can be None on others.
|
|
554
|
+
|
|
555
|
+
:return: The broadcasted object, now present on all processes.
|
|
556
|
+
"""
|
|
557
|
+
import jax as jaxlib
|
|
558
|
+
import pickle
|
|
559
|
+
from jax.experimental import multihost_utils
|
|
560
|
+
|
|
561
|
+
# Serialize to bytes on process 0, empty bytes on others
|
|
562
|
+
if jaxlib.process_index() == 0:
|
|
563
|
+
if obj is None:
|
|
564
|
+
raise ValueError("Object to broadcast from process 0 cannot be None.")
|
|
565
|
+
data = pickle.dumps(obj)
|
|
566
|
+
logger.info(
|
|
567
|
+
f"--- Size of object to be broadcast: {len(data) / 1024**2:.3f} MB ---"
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
else:
|
|
571
|
+
data = b""
|
|
572
|
+
|
|
573
|
+
# Step 1: Broadcast the length of the serialized data.
|
|
574
|
+
# We send a single-element int32 array.
|
|
575
|
+
length = np.array([len(data)], dtype=np.int32)
|
|
576
|
+
length = multihost_utils.broadcast_one_to_all(length)
|
|
577
|
+
|
|
578
|
+
length = int(length[0]) # type: ignore
|
|
579
|
+
|
|
580
|
+
# Step 2: Broadcast the actual data.
|
|
581
|
+
# Convert byte string to a uint8 array for broadcasting.
|
|
582
|
+
send_arr_uint8 = np.frombuffer(data, dtype=np.uint8)
|
|
583
|
+
padded_length = (length + 3) // 4 * 4
|
|
584
|
+
if send_arr_uint8.size < padded_length:
|
|
585
|
+
send_arr_uint8 = np.pad( # type: ignore
|
|
586
|
+
send_arr_uint8, (0, padded_length - send_arr_uint8.size), mode="constant"
|
|
587
|
+
)
|
|
588
|
+
send_arr_int32 = send_arr_uint8.astype(np.int32)
|
|
589
|
+
# send_arr_int32 = jaxlib.numpy.array(send_arr_int32, dtype=np.int32)
|
|
590
|
+
send_arr_int32 = jaxlib.device_put(send_arr_int32)
|
|
591
|
+
|
|
592
|
+
jaxlib.experimental.multihost_utils.sync_global_devices("bulk_before")
|
|
593
|
+
|
|
594
|
+
received_arr = multihost_utils.broadcast_one_to_all(send_arr_int32)
|
|
595
|
+
|
|
596
|
+
received_arr = np.array(received_arr)
|
|
597
|
+
received_arr_uint8 = received_arr.astype(np.uint8)
|
|
598
|
+
|
|
599
|
+
# Step 3: Reconstruct the object from the received bytes.
|
|
600
|
+
# Convert the NumPy array back to bytes, truncate any padding, and unpickle.
|
|
601
|
+
received_data = received_arr_uint8[:length].tobytes()
|
|
602
|
+
# if jaxlib.process_index() == 0:
|
|
603
|
+
# logger.info(f"Broadcasted object {obj}")
|
|
604
|
+
return pickle.loads(received_data)
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
def broadcast_py_object_fs(
|
|
608
|
+
obj: Any, shared_dir: Optional[str] = None, timeout_seconds: int = 300
|
|
609
|
+
) -> Any:
|
|
610
|
+
"""
|
|
611
|
+
Broadcast a picklable Python object from process 0 to all other processes
|
|
612
|
+
using a shared file system approach.
|
|
613
|
+
|
|
614
|
+
This is a fallback method when gRPC-based broadcast fails. It uses UUID-based
|
|
615
|
+
file communication to share objects between processes through a shared file system.
|
|
616
|
+
|
|
617
|
+
:param obj: The Python object to broadcast. Must be picklable.
|
|
618
|
+
Should exist on process 0, can be None on others.
|
|
619
|
+
:type obj: Any
|
|
620
|
+
:param shared_dir: Directory path for shared file system communication.
|
|
621
|
+
If None, uses current directory.
|
|
622
|
+
:type shared_dir: Optional[str], optional
|
|
623
|
+
:param timeout_seconds: Maximum time to wait for file operations before timing out.
|
|
624
|
+
Defaults to 300 seconds.
|
|
625
|
+
:type timeout_seconds: int, optional
|
|
626
|
+
:return: The broadcasted object, now present on all processes.
|
|
627
|
+
:rtype: Any
|
|
628
|
+
"""
|
|
629
|
+
# to_avoid very subtle bugs for broadcast tree_data on A800 clusters
|
|
630
|
+
import jax
|
|
631
|
+
from jax.experimental import multihost_utils
|
|
632
|
+
|
|
633
|
+
if shared_dir is None:
|
|
634
|
+
shared_dir = "."
|
|
635
|
+
if jax.process_index() == 0:
|
|
636
|
+
os.makedirs(shared_dir, exist_ok=True)
|
|
637
|
+
|
|
638
|
+
id_comm_path = os.path.join(shared_dir, f".broadcast_temp_12318")
|
|
639
|
+
transfer_id = ""
|
|
640
|
+
|
|
641
|
+
if jax.process_index() == 0:
|
|
642
|
+
transfer_id = str(uuid.uuid4())
|
|
643
|
+
# print(f"[Process 0] Generated unique transfer ID: {transfer_id}", flush=True)
|
|
644
|
+
with open(id_comm_path, "w") as f:
|
|
645
|
+
f.write(transfer_id)
|
|
646
|
+
|
|
647
|
+
multihost_utils.sync_global_devices("fs_broadcast_id_written")
|
|
648
|
+
|
|
649
|
+
if jax.process_index() != 0:
|
|
650
|
+
start_time = time.time()
|
|
651
|
+
while not os.path.exists(id_comm_path):
|
|
652
|
+
time.sleep(0.1)
|
|
653
|
+
if time.time() - start_time > timeout_seconds:
|
|
654
|
+
raise TimeoutError(
|
|
655
|
+
f"Process {jax.process_index()} timed out waiting for ID file: {id_comm_path}"
|
|
656
|
+
)
|
|
657
|
+
with open(id_comm_path, "r") as f:
|
|
658
|
+
transfer_id = f.read()
|
|
659
|
+
|
|
660
|
+
multihost_utils.sync_global_devices("fs_broadcast_id_read")
|
|
661
|
+
if jax.process_index() == 0:
|
|
662
|
+
try:
|
|
663
|
+
os.remove(id_comm_path)
|
|
664
|
+
except OSError:
|
|
665
|
+
pass # 如果文件已被其他进程快速清理,忽略错误
|
|
666
|
+
|
|
667
|
+
# 定义本次传输使用的数据文件和标志文件路径
|
|
668
|
+
data_path = os.path.join(shared_dir, f"{transfer_id}.data")
|
|
669
|
+
done_path = os.path.join(shared_dir, f"{transfer_id}.done")
|
|
670
|
+
|
|
671
|
+
result_obj = None
|
|
672
|
+
|
|
673
|
+
if jax.process_index() == 0:
|
|
674
|
+
if obj is None:
|
|
675
|
+
raise ValueError("None cannot be broadcasted.")
|
|
676
|
+
|
|
677
|
+
# print(f"[Process 0] Pickling object...", flush=True)
|
|
678
|
+
pickled_data = pickle.dumps(obj)
|
|
679
|
+
logger.info(
|
|
680
|
+
f"[Process 0] Writing {len(pickled_data) / 1024**2:.3f} MB to {data_path}"
|
|
681
|
+
)
|
|
682
|
+
with open(data_path, "wb") as f:
|
|
683
|
+
f.write(pickled_data)
|
|
684
|
+
|
|
685
|
+
with open(done_path, "w") as f:
|
|
686
|
+
pass
|
|
687
|
+
logger.info(f"[Process 0] Write complete.")
|
|
688
|
+
result_obj = obj
|
|
689
|
+
else:
|
|
690
|
+
# print(f"[Process {jax.process_index()}] Waiting for done file: {done_path}", flush=True)
|
|
691
|
+
start_time = time.time()
|
|
692
|
+
while not os.path.exists(done_path):
|
|
693
|
+
time.sleep(0.1)
|
|
694
|
+
if time.time() - start_time > timeout_seconds:
|
|
695
|
+
raise TimeoutError(
|
|
696
|
+
f"Process {jax.process_index()} timed out waiting for done file: {done_path}"
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
# print(f"[Process {jax.process_index()}] Done file found. Reading data from {data_path}", flush=True)
|
|
700
|
+
with open(data_path, "rb") as f:
|
|
701
|
+
pickled_data = f.read()
|
|
702
|
+
|
|
703
|
+
result_obj = pickle.loads(pickled_data)
|
|
704
|
+
logger.info(f"[Process {jax.process_index()}] Object successfully loaded.")
|
|
705
|
+
|
|
706
|
+
multihost_utils.sync_global_devices("fs_broadcast_read_complete")
|
|
707
|
+
|
|
708
|
+
if jax.process_index() == 0:
|
|
709
|
+
try:
|
|
710
|
+
os.remove(data_path)
|
|
711
|
+
os.remove(done_path)
|
|
712
|
+
# print(f"[Process 0] Cleaned up temporary files for transfer {transfer_id}.", flush=True)
|
|
713
|
+
except OSError as e:
|
|
714
|
+
logger.info(
|
|
715
|
+
f"[Process 0]: Failed to clean up temporary files: {e}",
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
return result_obj
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
class DistributedContractor:
|
|
722
|
+
"""
|
|
723
|
+
A distributed tensor network contractor that parallelizes computations across multiple devices.
|
|
724
|
+
|
|
725
|
+
This class uses cotengra to find optimal contraction paths and distributes the computational
|
|
726
|
+
load across multiple devices (e.g., GPUs) for efficient tensor network calculations.
|
|
727
|
+
Particularly useful for large-scale quantum circuit simulations and variational quantum algorithms.
|
|
728
|
+
|
|
729
|
+
Example:
|
|
730
|
+
>>> def nodes_fn(params):
|
|
731
|
+
... c = tc.Circuit(4)
|
|
732
|
+
... c.rx(0, theta=params[0])
|
|
733
|
+
... return c.expectation_before([tc.gates.z(), [0]], reuse=False)
|
|
734
|
+
>>> dc = DistributedContractor(nodes_fn, params)
|
|
735
|
+
>>> value, grad = dc.value_and_grad(params)
|
|
736
|
+
|
|
737
|
+
:param nodes_fn: Function that takes parameters and returns a list of tensor network nodes
|
|
738
|
+
:type nodes_fn: Callable[[Tensor], List[Gate]]
|
|
739
|
+
:param params: Initial parameters used to determine the tensor network structure
|
|
740
|
+
:type params: Tensor
|
|
741
|
+
:param cotengra_options: Configuration options passed to the cotengra optimizer. Defaults to None
|
|
742
|
+
:type cotengra_options: Optional[Dict[str, Any]], optional
|
|
743
|
+
:param devices: List of devices to use. If None, uses all available devices
|
|
744
|
+
:type devices: Optional[List[Any]], optional
|
|
745
|
+
:param mesh: Mesh object to use for distributed computation. If None, uses all available devices
|
|
746
|
+
:type mesh: Optional[Any], optional
|
|
747
|
+
"""
|
|
748
|
+
|
|
749
|
+
def __init__(
|
|
750
|
+
self,
|
|
751
|
+
nodes_fn: Callable[[Tensor], List[Gate]],
|
|
752
|
+
params: Tensor,
|
|
753
|
+
cotengra_options: Optional[Dict[str, Any]] = None,
|
|
754
|
+
devices: Optional[List[Any]] = None, # backward compatibility
|
|
755
|
+
mesh: Optional[Any] = None,
|
|
756
|
+
tree_data: Optional[Dict[str, Any]] = None,
|
|
757
|
+
) -> None:
|
|
758
|
+
global jaxlib
|
|
759
|
+
global ctg
|
|
760
|
+
global Mesh
|
|
761
|
+
global NamedSharding
|
|
762
|
+
global P
|
|
763
|
+
|
|
764
|
+
logger.info("Initializing DistributedContractor...")
|
|
765
|
+
import cotengra as ctg
|
|
766
|
+
from cotengra import ContractionTree
|
|
767
|
+
import jax as jaxlib
|
|
768
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
|
|
769
|
+
|
|
770
|
+
self.nodes_fn = nodes_fn
|
|
771
|
+
if mesh is not None:
|
|
772
|
+
self.mesh = mesh
|
|
773
|
+
elif devices is not None:
|
|
774
|
+
self.mesh = Mesh(devices, axis_names=("devices",))
|
|
775
|
+
else:
|
|
776
|
+
self.mesh = Mesh(jaxlib.devices(), axis_names=("devices",))
|
|
777
|
+
self.num_devices = len(self.mesh.devices)
|
|
778
|
+
|
|
779
|
+
if self.num_devices <= 1:
|
|
780
|
+
logger.info("DistributedContractor is running on a single device.")
|
|
781
|
+
|
|
782
|
+
self._params_template = params
|
|
783
|
+
self.params_sharding = jaxlib.tree_util.tree_map(
|
|
784
|
+
lambda x: NamedSharding(self.mesh, P(*((None,) * x.ndim))),
|
|
785
|
+
self._params_template,
|
|
786
|
+
)
|
|
787
|
+
self._backend = "jax"
|
|
788
|
+
self._compiled_v_fns: Dict[
|
|
789
|
+
Tuple[Callable[[Tensor], Tensor], str],
|
|
790
|
+
Callable[[Any, Tensor, Tensor], Tensor],
|
|
791
|
+
] = {}
|
|
792
|
+
self._compiled_vg_fns: Dict[
|
|
793
|
+
Tuple[Callable[[Tensor], Tensor], str],
|
|
794
|
+
Callable[[Any, Tensor, Tensor], Tensor],
|
|
795
|
+
] = {}
|
|
796
|
+
|
|
797
|
+
logger.info("Running cotengra pathfinder... (This may take a while)")
|
|
798
|
+
if tree_data is None:
|
|
799
|
+
if params is None:
|
|
800
|
+
raise ValueError("Please provide specific circuit parameters array.")
|
|
801
|
+
if jaxlib.process_index() == 0:
|
|
802
|
+
logger.info("Process 0: Running cotengra pathfinder...")
|
|
803
|
+
tree_data = self._get_tree_data(
|
|
804
|
+
self.nodes_fn, self._params_template, cotengra_options # type: ignore
|
|
805
|
+
)
|
|
806
|
+
|
|
807
|
+
# Step 2: Use the robust helper function to broadcast the tree object.
|
|
808
|
+
# Process 0 sends its computed `tree_object`.
|
|
809
|
+
# Other processes send `None`, but receive the object from process 0.
|
|
810
|
+
|
|
811
|
+
if jaxlib.process_count() > 1:
|
|
812
|
+
# self.tree = broadcast_py_object(tree_object)
|
|
813
|
+
jaxlib.experimental.multihost_utils.sync_global_devices("tree_before")
|
|
814
|
+
logger.info(
|
|
815
|
+
f"Process {jaxlib.process_index()}: Synchronizing contraction path..."
|
|
816
|
+
)
|
|
817
|
+
tree_data = broadcast_py_object(tree_data)
|
|
818
|
+
jaxlib.experimental.multihost_utils.sync_global_devices("tree_after")
|
|
819
|
+
else:
|
|
820
|
+
logger.info("Using pre-computed contraction path.")
|
|
821
|
+
if tree_data is None:
|
|
822
|
+
raise ValueError("Contraction path data is missing.")
|
|
823
|
+
|
|
824
|
+
self.tree = ContractionTree.from_path(
|
|
825
|
+
inputs=tree_data["inputs"],
|
|
826
|
+
output=tree_data["output"],
|
|
827
|
+
size_dict=tree_data["size_dict"],
|
|
828
|
+
path=tree_data["path"],
|
|
829
|
+
)
|
|
830
|
+
|
|
831
|
+
# Restore slicing information
|
|
832
|
+
for ind, _ in tree_data["sliced_inds"].items():
|
|
833
|
+
self.tree.remove_ind_(ind)
|
|
834
|
+
|
|
835
|
+
logger.info(
|
|
836
|
+
f"Process {jaxlib.process_index()}: Contraction path successfully synchronized."
|
|
837
|
+
)
|
|
838
|
+
actual_num_slices = self.tree.nslices
|
|
839
|
+
|
|
840
|
+
self._report_tree_info()
|
|
841
|
+
|
|
842
|
+
slices_per_device = int(np.ceil(actual_num_slices / self.num_devices))
|
|
843
|
+
padded_size = slices_per_device * self.num_devices
|
|
844
|
+
slice_indices = np.arange(actual_num_slices)
|
|
845
|
+
padded_slice_indices = np.full(padded_size, PADDING_VALUE, dtype=np.int32)
|
|
846
|
+
padded_slice_indices[:actual_num_slices] = slice_indices
|
|
847
|
+
|
|
848
|
+
# Reshape for distribution and define the sharding rule
|
|
849
|
+
batched_indices = padded_slice_indices.reshape(
|
|
850
|
+
self.num_devices, slices_per_device
|
|
851
|
+
)
|
|
852
|
+
# Sharding rule: split the first axis (the one for devices) across the 'devices' mesh axis
|
|
853
|
+
self.sharding = NamedSharding(self.mesh, P("devices", None))
|
|
854
|
+
# Place the tensor on devices according to the rule
|
|
855
|
+
self.batched_slice_indices = jaxlib.device_put(batched_indices, self.sharding)
|
|
856
|
+
|
|
857
|
+
# self.batched_slice_indices = backend.convert_to_tensor(
|
|
858
|
+
# padded_slice_indices.reshape(self.num_devices, slices_per_device)
|
|
859
|
+
# )
|
|
860
|
+
print(
|
|
861
|
+
f"Distributing across {self.num_devices} devices. Each device will sequentially process "
|
|
862
|
+
f"up to {slices_per_device} slices."
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
self._compiled_vg_fn = None
|
|
866
|
+
self._compiled_v_fn = None
|
|
867
|
+
|
|
868
|
+
logger.info("Initialization complete.")
|
|
869
|
+
|
|
870
|
+
def _report_tree_info(self) -> None:
|
|
871
|
+
print("\n--- Contraction Path Info ---")
|
|
872
|
+
actual_num_slices = self.tree.nslices
|
|
873
|
+
stats = self.tree.contract_stats()
|
|
874
|
+
print(f"Path found with {actual_num_slices} slices.")
|
|
875
|
+
print(
|
|
876
|
+
f"Arithmetic Intensity (higher is better): {self.tree.arithmetic_intensity():.2f}"
|
|
877
|
+
)
|
|
878
|
+
print("flops (TFlops):", stats["flops"] / 2**40 / self.num_devices)
|
|
879
|
+
print("write (GB):", stats["write"] / 2**27 / actual_num_slices)
|
|
880
|
+
print("size (GB):", stats["size"] / 2**27)
|
|
881
|
+
print("-----------------------------\n")
|
|
882
|
+
|
|
883
|
+
@staticmethod
|
|
884
|
+
def _get_tree_data(
|
|
885
|
+
nodes_fn: Callable[[Tensor], List[Gate]],
|
|
886
|
+
params: Tensor,
|
|
887
|
+
cotengra_options: Optional[Dict[str, Any]] = None,
|
|
888
|
+
) -> Dict[str, Any]:
|
|
889
|
+
global ctg
|
|
890
|
+
|
|
891
|
+
import cotengra as ctg
|
|
892
|
+
|
|
893
|
+
local_cotengra_options = (cotengra_options or {}).copy()
|
|
894
|
+
|
|
895
|
+
nodes = nodes_fn(params)
|
|
896
|
+
tn_info, _ = get_tn_info(nodes)
|
|
897
|
+
default_cotengra_options = {
|
|
898
|
+
"slicing_reconf_opts": {"target_size": 2**28},
|
|
899
|
+
"max_repeats": 128,
|
|
900
|
+
"minimize": "write",
|
|
901
|
+
"parallel": "auto",
|
|
902
|
+
"progbar": True,
|
|
903
|
+
}
|
|
904
|
+
default_cotengra_options.update(local_cotengra_options)
|
|
905
|
+
|
|
906
|
+
opt = ctg.ReusableHyperOptimizer(**default_cotengra_options)
|
|
907
|
+
tree_object = opt.search(*tn_info)
|
|
908
|
+
tree_data = {
|
|
909
|
+
"inputs": tree_object.inputs,
|
|
910
|
+
"output": tree_object.output,
|
|
911
|
+
"size_dict": tree_object.size_dict,
|
|
912
|
+
"path": tree_object.get_path(),
|
|
913
|
+
"sliced_inds": tree_object.sliced_inds,
|
|
914
|
+
}
|
|
915
|
+
return tree_data
|
|
916
|
+
|
|
917
|
+
@staticmethod
|
|
918
|
+
def find_path(
|
|
919
|
+
nodes_fn: Callable[[Tensor], Tensor],
|
|
920
|
+
params: Tensor,
|
|
921
|
+
cotengra_options: Optional[Dict[str, Any]] = None,
|
|
922
|
+
filepath: Optional[str] = None,
|
|
923
|
+
) -> None:
|
|
924
|
+
tree_data = DistributedContractor._get_tree_data(
|
|
925
|
+
nodes_fn, params, cotengra_options
|
|
926
|
+
)
|
|
927
|
+
if filepath is not None:
|
|
928
|
+
with open(filepath, "wb") as f:
|
|
929
|
+
pickle.dump(tree_data, f)
|
|
930
|
+
logger.info(f"Contraction path data successfully saved to '{filepath}'.")
|
|
931
|
+
|
|
932
|
+
@classmethod
|
|
933
|
+
def from_path(
|
|
934
|
+
cls,
|
|
935
|
+
filepath: str,
|
|
936
|
+
nodes_fn: Callable[[Tensor], List[Gate]],
|
|
937
|
+
devices: Optional[List[Any]] = None, # backward compatibility
|
|
938
|
+
mesh: Optional[Any] = None,
|
|
939
|
+
params: Any = None,
|
|
940
|
+
) -> "DistributedContractor":
|
|
941
|
+
with open(filepath, "rb") as f:
|
|
942
|
+
tree_data = pickle.load(f)
|
|
943
|
+
|
|
944
|
+
# Each process loads the file independently. No broadcast is needed.
|
|
945
|
+
# We pass the loaded `tree_data` directly to __init__ to trigger the second workflow.
|
|
946
|
+
return cls(
|
|
947
|
+
nodes_fn=nodes_fn,
|
|
948
|
+
params=params,
|
|
949
|
+
mesh=mesh,
|
|
950
|
+
devices=devices,
|
|
951
|
+
tree_data=tree_data,
|
|
952
|
+
)
|
|
953
|
+
|
|
954
|
+
def _get_single_slice_contraction_fn(
|
|
955
|
+
self, op: Optional[Callable[[Tensor], Tensor]] = None
|
|
956
|
+
) -> Callable[[Any, Tensor, int], Tensor]:
|
|
957
|
+
if op is None:
|
|
958
|
+
op = backend.sum
|
|
959
|
+
|
|
960
|
+
def single_slice_contraction(
|
|
961
|
+
tree: ctg.ContractionTree, params: Tensor, slice_idx: int
|
|
962
|
+
) -> Tensor:
|
|
963
|
+
nodes = self.nodes_fn(params)
|
|
964
|
+
_, standardized_nodes = get_tn_info(nodes)
|
|
965
|
+
input_arrays = [node.tensor for node in standardized_nodes]
|
|
966
|
+
sliced_arrays = tree.slice_arrays(input_arrays, slice_idx)
|
|
967
|
+
result = tree.contract_core(sliced_arrays, backend=self._backend)
|
|
968
|
+
return op(result)
|
|
969
|
+
|
|
970
|
+
return single_slice_contraction
|
|
971
|
+
|
|
972
|
+
def _get_device_sum_vg_fn(
|
|
973
|
+
self,
|
|
974
|
+
op: Optional[Callable[[Tensor], Tensor]] = None,
|
|
975
|
+
output_dtype: Optional[str] = None,
|
|
976
|
+
) -> Callable[[Any, Tensor, Tensor], Tuple[Tensor, Tensor]]:
|
|
977
|
+
post_processing = lambda x: backend.real(backend.sum(x))
|
|
978
|
+
if op is None:
|
|
979
|
+
op = post_processing
|
|
980
|
+
base_fn = self._get_single_slice_contraction_fn(op=op)
|
|
981
|
+
# to ensure the output is real so that can be differentiated
|
|
982
|
+
single_slice_vg_fn = jaxlib.value_and_grad(base_fn, argnums=1)
|
|
983
|
+
|
|
984
|
+
if output_dtype is None:
|
|
985
|
+
output_dtype = rdtypestr
|
|
986
|
+
|
|
987
|
+
def device_sum_fn(
|
|
988
|
+
tree: ctg.ContractionTree, params: Tensor, slice_indices_for_device: Tensor
|
|
989
|
+
) -> Tuple[Tensor, Tensor]:
|
|
990
|
+
def scan_body(
|
|
991
|
+
carry: Tuple[Tensor, Tensor], slice_idx: Tensor
|
|
992
|
+
) -> Tuple[Tuple[Tensor, Tensor], None]:
|
|
993
|
+
acc_value, acc_grads = carry
|
|
994
|
+
|
|
995
|
+
def compute_and_add() -> Tuple[Tensor, Tensor]:
|
|
996
|
+
value_slice, grads_slice = single_slice_vg_fn(
|
|
997
|
+
tree, params, slice_idx
|
|
998
|
+
)
|
|
999
|
+
new_value = acc_value + value_slice
|
|
1000
|
+
new_grads = jaxlib.tree_util.tree_map(
|
|
1001
|
+
jaxlib.numpy.add, acc_grads, grads_slice
|
|
1002
|
+
)
|
|
1003
|
+
return new_value, new_grads
|
|
1004
|
+
|
|
1005
|
+
def do_nothing() -> Tuple[Tensor, Tensor]:
|
|
1006
|
+
return acc_value, acc_grads
|
|
1007
|
+
|
|
1008
|
+
return (
|
|
1009
|
+
jaxlib.lax.cond(
|
|
1010
|
+
slice_idx == PADDING_VALUE, do_nothing, compute_and_add
|
|
1011
|
+
),
|
|
1012
|
+
None,
|
|
1013
|
+
)
|
|
1014
|
+
|
|
1015
|
+
initial_carry = (
|
|
1016
|
+
backend.cast(backend.convert_to_tensor(0.0), dtype=output_dtype),
|
|
1017
|
+
jaxlib.tree_util.tree_map(lambda x: jaxlib.numpy.zeros_like(x), params),
|
|
1018
|
+
)
|
|
1019
|
+
(final_value, final_grads), _ = jaxlib.lax.scan(
|
|
1020
|
+
scan_body, initial_carry, slice_indices_for_device
|
|
1021
|
+
)
|
|
1022
|
+
return final_value, final_grads
|
|
1023
|
+
|
|
1024
|
+
return device_sum_fn
|
|
1025
|
+
|
|
1026
|
+
def _get_device_sum_v_fn(
|
|
1027
|
+
self,
|
|
1028
|
+
op: Optional[Callable[[Tensor], Tensor]] = None,
|
|
1029
|
+
output_dtype: Optional[str] = None,
|
|
1030
|
+
) -> Callable[[Any, Tensor, Tensor], Tensor]:
|
|
1031
|
+
base_fn = self._get_single_slice_contraction_fn(op=op)
|
|
1032
|
+
if output_dtype is None:
|
|
1033
|
+
output_dtype = dtypestr
|
|
1034
|
+
|
|
1035
|
+
def device_sum_fn(
|
|
1036
|
+
tree: ctg.ContractionTree, params: Tensor, slice_indices_for_device: Tensor
|
|
1037
|
+
) -> Tensor:
|
|
1038
|
+
def scan_body(
|
|
1039
|
+
carry_value: Tensor, slice_idx: Tensor
|
|
1040
|
+
) -> Tuple[Tensor, None]:
|
|
1041
|
+
def compute_and_add() -> Tensor:
|
|
1042
|
+
return carry_value + base_fn(tree, params, slice_idx)
|
|
1043
|
+
|
|
1044
|
+
return (
|
|
1045
|
+
jaxlib.lax.cond(
|
|
1046
|
+
slice_idx == PADDING_VALUE, lambda: carry_value, compute_and_add
|
|
1047
|
+
),
|
|
1048
|
+
None,
|
|
1049
|
+
)
|
|
1050
|
+
|
|
1051
|
+
initial_carry = backend.cast(
|
|
1052
|
+
backend.convert_to_tensor(0.0), dtype=output_dtype
|
|
1053
|
+
)
|
|
1054
|
+
final_value, _ = jaxlib.lax.scan(
|
|
1055
|
+
scan_body, initial_carry, slice_indices_for_device
|
|
1056
|
+
)
|
|
1057
|
+
return final_value
|
|
1058
|
+
|
|
1059
|
+
return device_sum_fn
|
|
1060
|
+
|
|
1061
|
+
def _get_or_compile_fn(
|
|
1062
|
+
self,
|
|
1063
|
+
cache: Dict[
|
|
1064
|
+
Tuple[Callable[[Tensor], Tensor], str],
|
|
1065
|
+
Callable[[Any, Tensor, Tensor], Tensor],
|
|
1066
|
+
],
|
|
1067
|
+
fn_getter: Callable[..., Any],
|
|
1068
|
+
op: Optional[Callable[[Tensor], Tensor]],
|
|
1069
|
+
output_dtype: Optional[str],
|
|
1070
|
+
is_grad_fn: bool,
|
|
1071
|
+
) -> Callable[[Any, Tensor, Tensor], Tensor]:
|
|
1072
|
+
"""
|
|
1073
|
+
Gets a compiled pmap-ed function from cache or compiles and caches it.
|
|
1074
|
+
|
|
1075
|
+
The cache key is a tuple of (op, output_dtype). Caution on lambda function!
|
|
1076
|
+
|
|
1077
|
+
Returns:
|
|
1078
|
+
The compiled, pmap-ed JAX function.
|
|
1079
|
+
"""
|
|
1080
|
+
cache_key = (op, output_dtype)
|
|
1081
|
+
if cache_key not in cache:
|
|
1082
|
+
device_fn = fn_getter(op=op, output_dtype=output_dtype)
|
|
1083
|
+
|
|
1084
|
+
def global_aggregated_fn(
|
|
1085
|
+
tree: Any, params: Any, batched_slice_indices: Tensor
|
|
1086
|
+
) -> Any:
|
|
1087
|
+
# Use jax.vmap to apply the per-device function across the sharded data.
|
|
1088
|
+
# vmap maps `device_fn` over the first axis (0) of `batched_slice_indices`.
|
|
1089
|
+
# `tree` and `params` are broadcasted (in_axes=None) to each call.
|
|
1090
|
+
vmapped_device_fn = jaxlib.vmap(
|
|
1091
|
+
device_fn, in_axes=(None, None, 0), out_axes=0
|
|
1092
|
+
)
|
|
1093
|
+
device_results = vmapped_device_fn(tree, params, batched_slice_indices)
|
|
1094
|
+
|
|
1095
|
+
# Now, `device_results` is a sharded PyTree (one result per device).
|
|
1096
|
+
# We aggregate them using jnp.sum, which JAX automatically compiles
|
|
1097
|
+
# into a cross-device AllReduce operation.
|
|
1098
|
+
|
|
1099
|
+
if is_grad_fn:
|
|
1100
|
+
# `device_results` is a (value, grad) tuple of sharded arrays
|
|
1101
|
+
device_values, device_grads = device_results
|
|
1102
|
+
|
|
1103
|
+
# Replace psum with jnp.sum
|
|
1104
|
+
global_value = jaxlib.numpy.sum(device_values, axis=0)
|
|
1105
|
+
global_grad = jaxlib.tree_util.tree_map(
|
|
1106
|
+
lambda g: jaxlib.numpy.sum(g, axis=0), device_grads
|
|
1107
|
+
)
|
|
1108
|
+
return global_value, global_grad
|
|
1109
|
+
else:
|
|
1110
|
+
# `device_results` is just the sharded values
|
|
1111
|
+
return jaxlib.numpy.sum(device_results, axis=0)
|
|
1112
|
+
|
|
1113
|
+
# Compile the global function with jax.jit and specify shardings.
|
|
1114
|
+
# `params` are replicated (available everywhere).
|
|
1115
|
+
|
|
1116
|
+
in_shardings = (self.params_sharding, self.sharding)
|
|
1117
|
+
|
|
1118
|
+
if is_grad_fn:
|
|
1119
|
+
# Returns (value, grad), so out_sharding must be a 2-tuple.
|
|
1120
|
+
# `value` is a replicated scalar -> P()
|
|
1121
|
+
sharding_for_value = NamedSharding(self.mesh, P())
|
|
1122
|
+
# `grad` is a replicated PyTree with the same structure as params.
|
|
1123
|
+
sharding_for_grad = self.params_sharding
|
|
1124
|
+
out_shardings = (sharding_for_value, sharding_for_grad)
|
|
1125
|
+
else:
|
|
1126
|
+
# Returns a single scalar value -> P()
|
|
1127
|
+
out_shardings = NamedSharding(self.mesh, P())
|
|
1128
|
+
|
|
1129
|
+
compiled_fn = jaxlib.jit(
|
|
1130
|
+
global_aggregated_fn,
|
|
1131
|
+
# `tree` is a static argument, its value is compiled into the function.
|
|
1132
|
+
static_argnums=(0,),
|
|
1133
|
+
# Specify how inputs are sharded.
|
|
1134
|
+
in_shardings=in_shardings,
|
|
1135
|
+
# Specify how the output should be sharded.
|
|
1136
|
+
out_shardings=out_shardings,
|
|
1137
|
+
)
|
|
1138
|
+
cache[cache_key] = compiled_fn # type: ignore
|
|
1139
|
+
return cache[cache_key] # type: ignore
|
|
1140
|
+
|
|
1141
|
+
def value_and_grad(
|
|
1142
|
+
self,
|
|
1143
|
+
params: Tensor,
|
|
1144
|
+
# aggregate: bool = True,
|
|
1145
|
+
op: Optional[Callable[[Tensor], Tensor]] = None,
|
|
1146
|
+
output_dtype: Optional[str] = None,
|
|
1147
|
+
) -> Tuple[Tensor, Tensor]:
|
|
1148
|
+
"""
|
|
1149
|
+
Calculates the value and gradient, compiling the pmap function if needed for the first call.
|
|
1150
|
+
|
|
1151
|
+
:param params: Parameters for the `nodes_fn` input
|
|
1152
|
+
:type params: Tensor
|
|
1153
|
+
:param op: Optional post-processing function for the output, defaults to None (corresponding to `backend.real`)
|
|
1154
|
+
op is a cache key, so dont directly pass lambda function for op
|
|
1155
|
+
:type op: Optional[Callable[[Tensor], Tensor]], optional
|
|
1156
|
+
:param output_dtype: dtype str for the output of `nodes_fn`, defaults to None (corresponding to `rdtypestr`)
|
|
1157
|
+
:type output_dtype: Optional[str], optional
|
|
1158
|
+
"""
|
|
1159
|
+
compiled_vg_fn = self._get_or_compile_fn(
|
|
1160
|
+
cache=self._compiled_vg_fns,
|
|
1161
|
+
fn_getter=self._get_device_sum_vg_fn,
|
|
1162
|
+
op=op,
|
|
1163
|
+
output_dtype=output_dtype,
|
|
1164
|
+
is_grad_fn=True,
|
|
1165
|
+
)
|
|
1166
|
+
|
|
1167
|
+
total_value, total_grad = compiled_vg_fn(
|
|
1168
|
+
self.tree, params, self.batched_slice_indices
|
|
1169
|
+
)
|
|
1170
|
+
return total_value, total_grad
|
|
1171
|
+
|
|
1172
|
+
def value(
|
|
1173
|
+
self,
|
|
1174
|
+
params: Tensor,
|
|
1175
|
+
# aggregate: bool = True,
|
|
1176
|
+
op: Optional[Callable[[Tensor], Tensor]] = None,
|
|
1177
|
+
output_dtype: Optional[str] = None,
|
|
1178
|
+
) -> Tensor:
|
|
1179
|
+
"""
|
|
1180
|
+
Calculates the value, compiling the pmap function for the first call.
|
|
1181
|
+
|
|
1182
|
+
:param params: Parameters for the `nodes_fn` input
|
|
1183
|
+
:type params: Tensor
|
|
1184
|
+
:param op: Optional post-processing function for the output, defaults to None (corresponding to identity)
|
|
1185
|
+
op is a cache key, so dont directly pass lambda function for op
|
|
1186
|
+
:type op: Optional[Callable[[Tensor], Tensor]], optional
|
|
1187
|
+
:param output_dtype: dtype str for the output of `nodes_fn`, defaults to None (corresponding to `dtypestr`)
|
|
1188
|
+
:type output_dtype: Optional[str], optional
|
|
1189
|
+
"""
|
|
1190
|
+
compiled_v_fn = self._get_or_compile_fn(
|
|
1191
|
+
cache=self._compiled_v_fns,
|
|
1192
|
+
fn_getter=self._get_device_sum_v_fn,
|
|
1193
|
+
op=op,
|
|
1194
|
+
output_dtype=output_dtype,
|
|
1195
|
+
is_grad_fn=False,
|
|
1196
|
+
)
|
|
1197
|
+
|
|
1198
|
+
total_value = compiled_v_fn(self.tree, params, self.batched_slice_indices)
|
|
1199
|
+
return total_value
|
|
1200
|
+
|
|
1201
|
+
def grad(
|
|
1202
|
+
self,
|
|
1203
|
+
params: Tensor,
|
|
1204
|
+
op: Optional[Callable[[Tensor], Tensor]] = None,
|
|
1205
|
+
output_dtype: Optional[str] = None,
|
|
1206
|
+
) -> Tensor:
|
|
1207
|
+
_, grad = self.value_and_grad(params, op=op, output_dtype=output_dtype)
|
|
1208
|
+
return grad
|