tensorcircuit-nightly 1.0.2.dev20250108__py3-none-any.whl → 1.4.0.dev20251103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

Files changed (76) hide show
  1. tensorcircuit/__init__.py +18 -2
  2. tensorcircuit/about.py +46 -0
  3. tensorcircuit/abstractcircuit.py +4 -0
  4. tensorcircuit/analogcircuit.py +413 -0
  5. tensorcircuit/applications/layers.py +1 -1
  6. tensorcircuit/applications/van.py +1 -1
  7. tensorcircuit/backends/abstract_backend.py +320 -7
  8. tensorcircuit/backends/cupy_backend.py +3 -1
  9. tensorcircuit/backends/jax_backend.py +102 -4
  10. tensorcircuit/backends/jax_ops.py +110 -1
  11. tensorcircuit/backends/numpy_backend.py +49 -3
  12. tensorcircuit/backends/pytorch_backend.py +92 -3
  13. tensorcircuit/backends/tensorflow_backend.py +102 -3
  14. tensorcircuit/basecircuit.py +157 -98
  15. tensorcircuit/circuit.py +115 -57
  16. tensorcircuit/cloud/local.py +1 -1
  17. tensorcircuit/cloud/quafu_provider.py +1 -1
  18. tensorcircuit/cloud/tencent.py +1 -1
  19. tensorcircuit/compiler/simple_compiler.py +2 -2
  20. tensorcircuit/cons.py +142 -21
  21. tensorcircuit/densitymatrix.py +43 -14
  22. tensorcircuit/experimental.py +387 -129
  23. tensorcircuit/fgs.py +282 -81
  24. tensorcircuit/gates.py +66 -22
  25. tensorcircuit/interfaces/__init__.py +1 -3
  26. tensorcircuit/interfaces/jax.py +189 -0
  27. tensorcircuit/keras.py +3 -3
  28. tensorcircuit/mpscircuit.py +154 -65
  29. tensorcircuit/quantum.py +868 -152
  30. tensorcircuit/quditcircuit.py +733 -0
  31. tensorcircuit/quditgates.py +618 -0
  32. tensorcircuit/results/counts.py +147 -20
  33. tensorcircuit/results/readout_mitigation.py +4 -1
  34. tensorcircuit/shadows.py +1 -1
  35. tensorcircuit/simplify.py +3 -1
  36. tensorcircuit/stabilizercircuit.py +479 -0
  37. tensorcircuit/templates/__init__.py +2 -0
  38. tensorcircuit/templates/blocks.py +2 -2
  39. tensorcircuit/templates/hamiltonians.py +174 -0
  40. tensorcircuit/templates/lattice.py +1789 -0
  41. tensorcircuit/timeevol.py +896 -0
  42. tensorcircuit/translation.py +10 -3
  43. tensorcircuit/utils.py +7 -0
  44. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/METADATA +73 -23
  45. tensorcircuit_nightly-1.4.0.dev20251103.dist-info/RECORD +96 -0
  46. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/WHEEL +1 -1
  47. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info}/top_level.txt +0 -1
  48. tensorcircuit_nightly-1.0.2.dev20250108.dist-info/RECORD +0 -115
  49. tests/__init__.py +0 -0
  50. tests/conftest.py +0 -67
  51. tests/test_backends.py +0 -1031
  52. tests/test_calibrating.py +0 -149
  53. tests/test_channels.py +0 -365
  54. tests/test_circuit.py +0 -1699
  55. tests/test_cloud.py +0 -219
  56. tests/test_compiler.py +0 -147
  57. tests/test_dmcircuit.py +0 -555
  58. tests/test_ensemble.py +0 -72
  59. tests/test_fgs.py +0 -310
  60. tests/test_gates.py +0 -156
  61. tests/test_interfaces.py +0 -429
  62. tests/test_keras.py +0 -160
  63. tests/test_miscs.py +0 -277
  64. tests/test_mpscircuit.py +0 -341
  65. tests/test_noisemodel.py +0 -156
  66. tests/test_qaoa.py +0 -86
  67. tests/test_qem.py +0 -152
  68. tests/test_quantum.py +0 -526
  69. tests/test_quantum_attr.py +0 -42
  70. tests/test_results.py +0 -347
  71. tests/test_shadows.py +0 -160
  72. tests/test_simplify.py +0 -46
  73. tests/test_templates.py +0 -218
  74. tests/test_torchnn.py +0 -99
  75. tests/test_van.py +0 -102
  76. {tensorcircuit_nightly-1.0.2.dev20250108.dist-info → tensorcircuit_nightly-1.4.0.dev20251103.dist-info/licenses}/LICENSE +0 -0
@@ -2,17 +2,26 @@
2
2
  Experimental features
3
3
  """
4
4
 
5
+ # pylint: disable=unused-import
6
+
5
7
  from functools import partial
6
- from typing import Any, Callable, Optional, Tuple, Sequence, Union
8
+ import logging
9
+ from typing import Any, Callable, Dict, Optional, Tuple, List, Sequence, Union
7
10
 
8
11
  import numpy as np
9
12
 
10
- from .cons import backend, dtypestr, contractor, rdtypestr
13
+ from .cons import backend, dtypestr, rdtypestr, get_tn_info
11
14
  from .gates import Gate
15
+ from .timeevol import hamiltonian_evol, evol_global, evol_local
16
+
17
+
18
+ # for backward compatibility
12
19
 
13
20
  Tensor = Any
14
21
  Circuit = Any
15
22
 
23
+ logger = logging.getLogger(__name__)
24
+
16
25
 
17
26
  def adaptive_vmap(
18
27
  f: Callable[..., Any],
@@ -20,6 +29,16 @@ def adaptive_vmap(
20
29
  static_argnums: Optional[Union[int, Sequence[int]]] = None,
21
30
  chunk_size: Optional[int] = None,
22
31
  ) -> Callable[..., Any]:
32
+ """
33
+ Vectorized map with adaptive chunking for memory efficiency.
34
+
35
+ :param f: Function to be vectorized
36
+ :param vectorized_argnums: Arguments to be vectorized over
37
+ :param static_argnums: Arguments that remain static during vectorization
38
+ :param chunk_size: Size of chunks for batch processing, None means no chunking
39
+ (naive vmap)
40
+ :return: Vectorized function
41
+ """
23
42
  if chunk_size is None:
24
43
  return backend.vmap(f, vectorized_argnums) # type: ignore
25
44
 
@@ -94,6 +113,29 @@ def qng(
94
113
  postprocess: Optional[str] = "qng",
95
114
  mode: str = "fwd",
96
115
  ) -> Callable[..., Tensor]:
116
+ """
117
+ Compute quantum natural gradient for quantum circuit optimization.
118
+
119
+ :param f: Function that takes parameters and returns quantum state
120
+ :param kernel: Type of kernel to use ("qng" or "dynamics"), the former has the second term
121
+ :param postprocess: Post-processing method ("qng" or None)
122
+ :param mode: Mode of differentiation ("fwd" or "rev")
123
+ :return: Function computing QNG matrix
124
+
125
+ :Example:
126
+
127
+ >>> import tensorcircuit as tc
128
+ >>> def ansatz(params):
129
+ ... c = tc.Circuit(2)
130
+ ... c.rx(0, theta=params[0])
131
+ ... c.ry(1, theta=params[1])
132
+ ... return c.state()
133
+ >>> qng_fn = tc.experimental.qng(ansatz)
134
+ >>> params = tc.array_to_tensor([0.5, 0.8])
135
+ >>> qng_matrix = qng_fn(params)
136
+ >>> print(qng_matrix.shape) # (2, 2)
137
+ """
138
+
97
139
  # for both qng and qng2 calculation, we highly recommended complex-dtype but real valued inputs
98
140
  def wrapper(params: Tensor, **kws: Any) -> Tensor:
99
141
  params = backend.cast(params, dtype=dtypestr) # R->C protection
@@ -399,130 +441,6 @@ def finite_difference_differentiator(
399
441
  return tf_function # type: ignore
400
442
 
401
443
 
402
- def hamiltonian_evol(
403
- tlist: Tensor,
404
- h: Tensor,
405
- psi0: Tensor,
406
- callback: Optional[Callable[..., Any]] = None,
407
- ) -> Tensor:
408
- """
409
- Fast implementation of static full Hamiltonian evolution
410
- (default as imaginary time)
411
-
412
- :param tlist: _description_
413
- :type tlist: Tensor
414
- :param h: _description_
415
- :type h: Tensor
416
- :param psi0: _description_
417
- :type psi0: Tensor
418
- :param callback: _description_, defaults to None
419
- :type callback: Optional[Callable[..., Any]], optional
420
- :return: Tensor
421
- :rtype: result dynamics on ``tlist``
422
- """
423
- es, u = backend.eigh(h)
424
- utpsi0 = backend.reshape(
425
- backend.transpose(u) @ backend.reshape(psi0, [-1, 1]), [-1]
426
- )
427
-
428
- @backend.jit
429
- def _evol(t: Tensor) -> Tensor:
430
- ebetah_utpsi0 = backend.exp(-t * es) * utpsi0
431
- psi_exact = backend.conj(u) @ backend.reshape(ebetah_utpsi0, [-1, 1])
432
- psi_exact = backend.reshape(psi_exact, [-1])
433
- psi_exact = psi_exact / backend.norm(psi_exact)
434
- if callback is None:
435
- return psi_exact
436
- return callback(psi_exact)
437
-
438
- return backend.stack([_evol(t) for t in tlist])
439
-
440
-
441
- def evol_local(
442
- c: Circuit,
443
- index: Sequence[int],
444
- h_fun: Callable[..., Tensor],
445
- t: float,
446
- *args: Any,
447
- **solver_kws: Any
448
- ) -> Circuit:
449
- """
450
- ode evolution of time dependent Hamiltonian on circuit of given indices
451
- [only jax backend support for now]
452
-
453
- :param c: _description_
454
- :type c: Circuit
455
- :param index: qubit sites to evolve
456
- :type index: Sequence[int]
457
- :param h_fun: h_fun should return a dense Hamiltonian matrix
458
- with input arguments time and *args
459
- :type h_fun: Callable[..., Tensor]
460
- :param t: evolution time
461
- :type t: float
462
- :return: _description_
463
- :rtype: Circuit
464
- """
465
- from jax.experimental.ode import odeint
466
-
467
- s = c.state()
468
- n = c._nqubits
469
- l = len(index)
470
-
471
- def f(y: Tensor, t: Tensor, *args: Any) -> Tensor:
472
- y = backend.reshape2(y)
473
- y = Gate(y)
474
- h = -1.0j * h_fun(t, *args)
475
- h = backend.reshape2(h)
476
- h = Gate(h)
477
- edges = []
478
- for i in range(n):
479
- if i not in index:
480
- edges.append(y[i])
481
- else:
482
- j = index.index(i)
483
- edges.append(h[j])
484
- h[j + l] ^ y[i]
485
- y = contractor([y, h], output_edge_order=edges)
486
- return backend.reshape(y.tensor, [-1])
487
-
488
- ts = backend.stack([0.0, t])
489
- ts = backend.cast(ts, dtype=rdtypestr)
490
- s1 = odeint(f, s, ts, *args, **solver_kws)
491
- return type(c)(n, inputs=s1[-1])
492
-
493
-
494
- def evol_global(
495
- c: Circuit, h_fun: Callable[..., Tensor], t: float, *args: Any, **solver_kws: Any
496
- ) -> Circuit:
497
- """
498
- ode evolution of time dependent Hamiltonian on circuit of all qubits
499
- [only jax backend support for now]
500
-
501
- :param c: _description_
502
- :type c: Circuit
503
- :param h_fun: h_fun should return a **SPARSE** Hamiltonian matrix
504
- with input arguments time and *args
505
- :type h_fun: Callable[..., Tensor]
506
- :param t: _description_
507
- :type t: float
508
- :return: _description_
509
- :rtype: Circuit
510
- """
511
- from jax.experimental.ode import odeint
512
-
513
- s = c.state()
514
- n = c._nqubits
515
-
516
- def f(y: Tensor, t: Tensor, *args: Any) -> Tensor:
517
- h = -1.0j * h_fun(t, *args)
518
- return backend.sparse_dense_matmul(h, y)
519
-
520
- ts = backend.stack([0.0, t])
521
- ts = backend.cast(ts, dtype=rdtypestr)
522
- s1 = odeint(f, s, ts, *args, **solver_kws)
523
- return type(c)(n, inputs=s1[-1])
524
-
525
-
526
444
  def jax_jitted_function_save(filename: str, f: Callable[..., Any], *args: Any) -> None:
527
445
  """
528
446
  save a jitted jax function as a file
@@ -534,7 +452,7 @@ def jax_jitted_function_save(filename: str, f: Callable[..., Any], *args: Any) -
534
452
  :param args: example function arguments for ``f``
535
453
  """
536
454
 
537
- from jax import export
455
+ from jax import export # type: ignore
538
456
 
539
457
  f_export = export.export(f)(*args) # type: ignore
540
458
  barray = f_export.serialize()
@@ -555,14 +473,354 @@ def jax_jitted_function_load(filename: str) -> Callable[..., Any]:
555
473
  :return: the loaded function
556
474
  :rtype: _type_
557
475
  """
558
- from jax import export
476
+ from jax import export # type: ignore
559
477
 
560
478
  with open(filename, "rb") as f:
561
479
  barray = f.read()
562
480
 
563
481
  f_load = export.deserialize(barray) # type: ignore
564
482
 
565
- return f_load.call
483
+ return f_load.call # type: ignore
566
484
 
567
485
 
568
486
  jax_func_load = jax_jitted_function_load
487
+
488
+
489
+ PADDING_VALUE = -1
490
+ jaxlib: Any
491
+ ctg: Any
492
+
493
+
494
+ class DistributedContractor:
495
+ """
496
+ A distributed tensor network contractor that parallelizes computations across multiple devices.
497
+
498
+ This class uses cotengra to find optimal contraction paths and distributes the computational
499
+ load across multiple devices (e.g., GPUs) for efficient tensor network calculations.
500
+ Particularly useful for large-scale quantum circuit simulations and variational quantum algorithms.
501
+
502
+ Example:
503
+ >>> def nodes_fn(params):
504
+ ... c = tc.Circuit(4)
505
+ ... c.rx(0, theta=params[0])
506
+ ... return c.expectation_before([tc.gates.z(), [0]], reuse=False)
507
+ >>> dc = DistributedContractor(nodes_fn, params)
508
+ >>> value, grad = dc.value_and_grad(params)
509
+
510
+ :param nodes_fn: Function that takes parameters and returns a list of tensor network nodes
511
+ :type nodes_fn: Callable[[Tensor], List[Gate]]
512
+ :param params: Initial parameters used to determine the tensor network structure
513
+ :type params: Tensor
514
+ :param cotengra_options: Configuration options passed to the cotengra optimizer. Defaults to None
515
+ :type cotengra_options: Optional[Dict[str, Any]], optional
516
+ :param devices: List of devices to use. If None, uses all available local devices
517
+ :type devices: Optional[List[Any]], optional
518
+ """
519
+
520
+ def __init__(
521
+ self,
522
+ nodes_fn: Callable[[Tensor], List[Gate]],
523
+ params: Tensor,
524
+ cotengra_options: Optional[Dict[str, Any]] = None,
525
+ devices: Optional[List[Any]] = None,
526
+ ) -> None:
527
+ global jaxlib
528
+ global ctg
529
+
530
+ logger.info("Initializing DistributedContractor...")
531
+ import cotengra as ctg
532
+ import jax as jaxlib
533
+
534
+ self.nodes_fn = nodes_fn
535
+ if devices is None:
536
+ self.num_devices = jaxlib.local_device_count()
537
+ self.devices = jaxlib.local_devices()
538
+ # TODO(@refraction-ray): multi host support
539
+ else:
540
+ self.devices = devices
541
+ self.num_devices = len(devices)
542
+
543
+ if self.num_devices <= 1:
544
+ logger.info("DistributedContractor is running on a single device.")
545
+
546
+ self._params_template = params
547
+ self._backend = "jax"
548
+ self._compiled_v_fns: Dict[
549
+ Tuple[Callable[[Tensor], Tensor], str],
550
+ Callable[[Any, Tensor, Tensor], Tensor],
551
+ ] = {}
552
+ self._compiled_vg_fns: Dict[
553
+ Tuple[Callable[[Tensor], Tensor], str],
554
+ Callable[[Any, Tensor, Tensor], Tensor],
555
+ ] = {}
556
+
557
+ logger.info("Running cotengra pathfinder... (This may take a while)")
558
+ nodes = self.nodes_fn(self._params_template)
559
+ tn_info, _ = get_tn_info(nodes)
560
+ default_cotengra_options = {
561
+ "slicing_reconf_opts": {"target_size": 2**28},
562
+ "max_repeats": 128,
563
+ "progbar": True,
564
+ "minimize": "write",
565
+ "parallel": "auto",
566
+ }
567
+ if cotengra_options:
568
+ default_cotengra_options = cotengra_options
569
+
570
+ opt = ctg.ReusableHyperOptimizer(**default_cotengra_options)
571
+ self.tree = opt.search(*tn_info)
572
+ actual_num_slices = self.tree.nslices
573
+
574
+ print("\n--- Contraction Path Info ---")
575
+ stats = self.tree.contract_stats()
576
+ print(f"Path found with {actual_num_slices} slices.")
577
+ print(
578
+ f"Arithmetic Intensity (higher is better): {self.tree.arithmetic_intensity():.2f}"
579
+ )
580
+ print("flops (TFlops):", stats["flops"] / 2**40 / self.num_devices)
581
+ print("write (GB):", stats["write"] / 2**27 / actual_num_slices)
582
+ print("size (GB):", stats["size"] / 2**27)
583
+ print("-----------------------------\n")
584
+
585
+ slices_per_device = int(np.ceil(actual_num_slices / self.num_devices))
586
+ padded_size = slices_per_device * self.num_devices
587
+ slice_indices = np.arange(actual_num_slices)
588
+ padded_slice_indices = np.full(padded_size, PADDING_VALUE, dtype=np.int32)
589
+ padded_slice_indices[:actual_num_slices] = slice_indices
590
+ self.batched_slice_indices = backend.convert_to_tensor(
591
+ padded_slice_indices.reshape(self.num_devices, slices_per_device)
592
+ )
593
+ print(
594
+ f"Distributing across {self.num_devices} devices. Each device will sequentially process "
595
+ f"up to {slices_per_device} slices."
596
+ )
597
+
598
+ self._compiled_vg_fn = None
599
+ self._compiled_v_fn = None
600
+
601
+ logger.info("Initialization complete.")
602
+
603
+ def _get_single_slice_contraction_fn(
604
+ self, op: Optional[Callable[[Tensor], Tensor]] = None
605
+ ) -> Callable[[Any, Tensor, int], Tensor]:
606
+ if op is None:
607
+ op = backend.sum
608
+
609
+ def single_slice_contraction(
610
+ tree: ctg.ContractionTree, params: Tensor, slice_idx: int
611
+ ) -> Tensor:
612
+ nodes = self.nodes_fn(params)
613
+ _, standardized_nodes = get_tn_info(nodes)
614
+ input_arrays = [node.tensor for node in standardized_nodes]
615
+ sliced_arrays = tree.slice_arrays(input_arrays, slice_idx)
616
+ result = tree.contract_core(sliced_arrays, backend=self._backend)
617
+ return op(result)
618
+
619
+ return single_slice_contraction
620
+
621
+ def _get_device_sum_vg_fn(
622
+ self,
623
+ op: Optional[Callable[[Tensor], Tensor]] = None,
624
+ output_dtype: Optional[str] = None,
625
+ ) -> Callable[[Any, Tensor, Tensor], Tuple[Tensor, Tensor]]:
626
+ post_processing = lambda x: backend.real(backend.sum(x))
627
+ if op is None:
628
+ op = post_processing
629
+ base_fn = self._get_single_slice_contraction_fn(op=op)
630
+ # to ensure the output is real so that can be differentiated
631
+ single_slice_vg_fn = jaxlib.value_and_grad(base_fn, argnums=1)
632
+
633
+ if output_dtype is None:
634
+ output_dtype = rdtypestr
635
+
636
+ def device_sum_fn(
637
+ tree: ctg.ContractionTree, params: Tensor, slice_indices_for_device: Tensor
638
+ ) -> Tuple[Tensor, Tensor]:
639
+ def scan_body(
640
+ carry: Tuple[Tensor, Tensor], slice_idx: Tensor
641
+ ) -> Tuple[Tuple[Tensor, Tensor], None]:
642
+ acc_value, acc_grads = carry
643
+
644
+ def compute_and_add() -> Tuple[Tensor, Tensor]:
645
+ value_slice, grads_slice = single_slice_vg_fn(
646
+ tree, params, slice_idx
647
+ )
648
+ new_value = acc_value + value_slice
649
+ new_grads = jaxlib.tree_util.tree_map(
650
+ jaxlib.numpy.add, acc_grads, grads_slice
651
+ )
652
+ return new_value, new_grads
653
+
654
+ def do_nothing() -> Tuple[Tensor, Tensor]:
655
+ return acc_value, acc_grads
656
+
657
+ return (
658
+ jaxlib.lax.cond(
659
+ slice_idx == PADDING_VALUE, do_nothing, compute_and_add
660
+ ),
661
+ None,
662
+ )
663
+
664
+ initial_carry = (
665
+ backend.cast(backend.convert_to_tensor(0.0), dtype=output_dtype),
666
+ jaxlib.tree_util.tree_map(lambda x: jaxlib.numpy.zeros_like(x), params),
667
+ )
668
+ (final_value, final_grads), _ = jaxlib.lax.scan(
669
+ scan_body, initial_carry, slice_indices_for_device
670
+ )
671
+ return final_value, final_grads
672
+
673
+ return device_sum_fn
674
+
675
+ def _get_device_sum_v_fn(
676
+ self,
677
+ op: Optional[Callable[[Tensor], Tensor]] = None,
678
+ output_dtype: Optional[str] = None,
679
+ ) -> Callable[[Any, Tensor, Tensor], Tensor]:
680
+ base_fn = self._get_single_slice_contraction_fn(op=op)
681
+ if output_dtype is None:
682
+ output_dtype = dtypestr
683
+
684
+ def device_sum_fn(
685
+ tree: ctg.ContractionTree, params: Tensor, slice_indices_for_device: Tensor
686
+ ) -> Tensor:
687
+ def scan_body(
688
+ carry_value: Tensor, slice_idx: Tensor
689
+ ) -> Tuple[Tensor, None]:
690
+ def compute_and_add() -> Tensor:
691
+ return carry_value + base_fn(tree, params, slice_idx)
692
+
693
+ return (
694
+ jaxlib.lax.cond(
695
+ slice_idx == PADDING_VALUE, lambda: carry_value, compute_and_add
696
+ ),
697
+ None,
698
+ )
699
+
700
+ initial_carry = backend.cast(
701
+ backend.convert_to_tensor(0.0), dtype=output_dtype
702
+ )
703
+ final_value, _ = jaxlib.lax.scan(
704
+ scan_body, initial_carry, slice_indices_for_device
705
+ )
706
+ return final_value
707
+
708
+ return device_sum_fn
709
+
710
+ def _get_or_compile_fn(
711
+ self,
712
+ cache: Dict[
713
+ Tuple[Callable[[Tensor], Tensor], str],
714
+ Callable[[Any, Tensor, Tensor], Tensor],
715
+ ],
716
+ fn_getter: Callable[..., Any],
717
+ op: Optional[Callable[[Tensor], Tensor]],
718
+ output_dtype: Optional[str],
719
+ ) -> Callable[[Any, Tensor, Tensor], Tensor]:
720
+ """
721
+ Gets a compiled pmap-ed function from cache or compiles and caches it.
722
+
723
+ The cache key is a tuple of (op, output_dtype). Caution on lambda function!
724
+
725
+ Returns:
726
+ The compiled, pmap-ed JAX function.
727
+ """
728
+ cache_key = (op, output_dtype)
729
+ if cache_key not in cache:
730
+ device_fn = fn_getter(op=op, output_dtype=output_dtype)
731
+ compiled_fn = jaxlib.pmap(
732
+ device_fn,
733
+ in_axes=(
734
+ None,
735
+ None,
736
+ 0,
737
+ ), # tree: broadcast, params: broadcast, indices: map
738
+ static_broadcasted_argnums=(0,), # arg 0 (tree) is a static argument
739
+ devices=self.devices,
740
+ )
741
+ cache[cache_key] = compiled_fn # type: ignore
742
+ return cache[cache_key] # type: ignore
743
+
744
+ def value_and_grad(
745
+ self,
746
+ params: Tensor,
747
+ aggregate: bool = True,
748
+ op: Optional[Callable[[Tensor], Tensor]] = None,
749
+ output_dtype: Optional[str] = None,
750
+ ) -> Tuple[Tensor, Tensor]:
751
+ """
752
+ Calculates the value and gradient, compiling the pmap function if needed for the first call.
753
+
754
+ :param params: Parameters for the `nodes_fn` input
755
+ :type params: Tensor
756
+ :param aggregate: Whether to aggregate (sum) the results across devices, defaults to True
757
+ :type aggregate: bool, optional
758
+ :param op: Optional post-processing function for the output, defaults to None (corresponding to `backend.real`)
759
+ op is a cache key, so dont directly pass lambda function for op
760
+ :type op: Optional[Callable[[Tensor], Tensor]], optional
761
+ :param output_dtype: dtype str for the output of `nodes_fn`, defaults to None (corresponding to `rdtypestr`)
762
+ :type output_dtype: Optional[str], optional
763
+ """
764
+ compiled_vg_fn = self._get_or_compile_fn(
765
+ cache=self._compiled_vg_fns,
766
+ fn_getter=self._get_device_sum_vg_fn,
767
+ op=op,
768
+ output_dtype=output_dtype,
769
+ )
770
+
771
+ device_values, device_grads = compiled_vg_fn(
772
+ self.tree, params, self.batched_slice_indices
773
+ )
774
+
775
+ if aggregate:
776
+ total_value = backend.sum(device_values)
777
+ total_grad = jaxlib.tree_util.tree_map(
778
+ lambda x: backend.sum(x, axis=0), device_grads
779
+ )
780
+ return total_value, total_grad
781
+ return device_values, device_grads
782
+
783
+ def value(
784
+ self,
785
+ params: Tensor,
786
+ aggregate: bool = True,
787
+ op: Optional[Callable[[Tensor], Tensor]] = None,
788
+ output_dtype: Optional[str] = None,
789
+ ) -> Tensor:
790
+ """
791
+ Calculates the value, compiling the pmap function for the first call.
792
+
793
+ :param params: Parameters for the `nodes_fn` input
794
+ :type params: Tensor
795
+ :param aggregate: Whether to aggregate (sum) the results across devices, defaults to True
796
+ :type aggregate: bool, optional
797
+ :param op: Optional post-processing function for the output, defaults to None (corresponding to identity)
798
+ op is a cache key, so dont directly pass lambda function for op
799
+ :type op: Optional[Callable[[Tensor], Tensor]], optional
800
+ :param output_dtype: dtype str for the output of `nodes_fn`, defaults to None (corresponding to `dtypestr`)
801
+ :type output_dtype: Optional[str], optional
802
+ """
803
+ compiled_v_fn = self._get_or_compile_fn(
804
+ cache=self._compiled_v_fns,
805
+ fn_getter=self._get_device_sum_v_fn,
806
+ op=op,
807
+ output_dtype=output_dtype,
808
+ )
809
+
810
+ device_values = compiled_v_fn(self.tree, params, self.batched_slice_indices)
811
+
812
+ if aggregate:
813
+ return backend.sum(device_values)
814
+ return device_values
815
+
816
+ def grad(
817
+ self,
818
+ params: Tensor,
819
+ aggregate: bool = True,
820
+ op: Optional[Callable[[Tensor], Tensor]] = None,
821
+ output_dtype: Optional[str] = None,
822
+ ) -> Tensor:
823
+ _, grad = self.value_and_grad(
824
+ params, aggregate=aggregate, op=op, output_dtype=output_dtype
825
+ )
826
+ return grad