tensorcircuit-nightly 1.2.0.dev20250326__py3-none-any.whl → 1.4.0.dev20251128__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tensorcircuit-nightly might be problematic. Click here for more details.

Files changed (77) hide show
  1. tensorcircuit/__init__.py +5 -1
  2. tensorcircuit/abstractcircuit.py +4 -0
  3. tensorcircuit/analogcircuit.py +413 -0
  4. tensorcircuit/applications/layers.py +1 -1
  5. tensorcircuit/applications/van.py +1 -1
  6. tensorcircuit/backends/abstract_backend.py +312 -5
  7. tensorcircuit/backends/cupy_backend.py +3 -1
  8. tensorcircuit/backends/jax_backend.py +100 -4
  9. tensorcircuit/backends/jax_ops.py +108 -0
  10. tensorcircuit/backends/numpy_backend.py +49 -3
  11. tensorcircuit/backends/pytorch_backend.py +92 -3
  12. tensorcircuit/backends/tensorflow_backend.py +102 -3
  13. tensorcircuit/basecircuit.py +157 -98
  14. tensorcircuit/circuit.py +115 -57
  15. tensorcircuit/cloud/local.py +1 -1
  16. tensorcircuit/cloud/quafu_provider.py +1 -1
  17. tensorcircuit/cloud/tencent.py +1 -1
  18. tensorcircuit/compiler/simple_compiler.py +2 -2
  19. tensorcircuit/cons.py +105 -23
  20. tensorcircuit/densitymatrix.py +16 -11
  21. tensorcircuit/experimental.py +733 -153
  22. tensorcircuit/fgs.py +254 -73
  23. tensorcircuit/gates.py +66 -22
  24. tensorcircuit/interfaces/jax.py +5 -3
  25. tensorcircuit/interfaces/tensortrans.py +6 -2
  26. tensorcircuit/interfaces/torch.py +14 -4
  27. tensorcircuit/keras.py +3 -3
  28. tensorcircuit/mpscircuit.py +154 -65
  29. tensorcircuit/quantum.py +698 -134
  30. tensorcircuit/quditcircuit.py +733 -0
  31. tensorcircuit/quditgates.py +618 -0
  32. tensorcircuit/results/counts.py +131 -18
  33. tensorcircuit/results/readout_mitigation.py +4 -1
  34. tensorcircuit/shadows.py +1 -1
  35. tensorcircuit/simplify.py +3 -1
  36. tensorcircuit/stabilizercircuit.py +29 -17
  37. tensorcircuit/templates/__init__.py +2 -0
  38. tensorcircuit/templates/blocks.py +2 -2
  39. tensorcircuit/templates/hamiltonians.py +174 -0
  40. tensorcircuit/templates/lattice.py +1789 -0
  41. tensorcircuit/timeevol.py +896 -0
  42. tensorcircuit/translation.py +10 -3
  43. tensorcircuit/utils.py +7 -0
  44. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/METADATA +66 -29
  45. tensorcircuit_nightly-1.4.0.dev20251128.dist-info/RECORD +96 -0
  46. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/WHEEL +1 -1
  47. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/top_level.txt +0 -1
  48. tensorcircuit_nightly-1.2.0.dev20250326.dist-info/RECORD +0 -118
  49. tests/__init__.py +0 -0
  50. tests/conftest.py +0 -67
  51. tests/test_backends.py +0 -1035
  52. tests/test_calibrating.py +0 -149
  53. tests/test_channels.py +0 -409
  54. tests/test_circuit.py +0 -1699
  55. tests/test_cloud.py +0 -219
  56. tests/test_compiler.py +0 -147
  57. tests/test_dmcircuit.py +0 -555
  58. tests/test_ensemble.py +0 -72
  59. tests/test_fgs.py +0 -310
  60. tests/test_gates.py +0 -156
  61. tests/test_interfaces.py +0 -562
  62. tests/test_keras.py +0 -160
  63. tests/test_miscs.py +0 -282
  64. tests/test_mpscircuit.py +0 -341
  65. tests/test_noisemodel.py +0 -156
  66. tests/test_qaoa.py +0 -86
  67. tests/test_qem.py +0 -152
  68. tests/test_quantum.py +0 -549
  69. tests/test_quantum_attr.py +0 -42
  70. tests/test_results.py +0 -380
  71. tests/test_shadows.py +0 -160
  72. tests/test_simplify.py +0 -46
  73. tests/test_stabilizer.py +0 -217
  74. tests/test_templates.py +0 -218
  75. tests/test_torchnn.py +0 -99
  76. tests/test_van.py +0 -102
  77. {tensorcircuit_nightly-1.2.0.dev20250326.dist-info → tensorcircuit_nightly-1.4.0.dev20251128.dist-info}/licenses/LICENSE +0 -0
@@ -2,17 +2,30 @@
2
2
  Experimental features
3
3
  """
4
4
 
5
+ # pylint: disable=unused-import
6
+
5
7
  from functools import partial
6
- from typing import Any, Callable, Optional, Tuple, Sequence, Union
8
+ import logging
9
+ from typing import Any, Callable, Dict, Optional, Tuple, List, Sequence, Union
10
+ import pickle
11
+ import uuid
12
+ import time
13
+ import os
7
14
 
8
15
  import numpy as np
9
16
 
10
- from .cons import backend, dtypestr, contractor, rdtypestr
17
+ from .cons import backend, dtypestr, rdtypestr, get_tn_info
11
18
  from .gates import Gate
19
+ from .timeevol import hamiltonian_evol, evol_global, evol_local
20
+
21
+
22
+ # for backward compatibility
12
23
 
13
24
  Tensor = Any
14
25
  Circuit = Any
15
26
 
27
+ logger = logging.getLogger(__name__)
28
+
16
29
 
17
30
  def adaptive_vmap(
18
31
  f: Callable[..., Any],
@@ -432,157 +445,6 @@ def finite_difference_differentiator(
432
445
  return tf_function # type: ignore
433
446
 
434
447
 
435
- def hamiltonian_evol(
436
- tlist: Tensor,
437
- h: Tensor,
438
- psi0: Tensor,
439
- callback: Optional[Callable[..., Any]] = None,
440
- ) -> Tensor:
441
- """
442
- Fast implementation of time independent Hamiltonian evolution using eigendecomposition.
443
- By default, performs imaginary time evolution.
444
-
445
- :param tlist: Time points for evolution
446
- :type tlist: Tensor
447
- :param h: Time-independent Hamiltonian matrix
448
- :type h: Tensor
449
- :param psi0: Initial state vector
450
- :type psi0: Tensor
451
- :param callback: Optional function to process state at each time point
452
- :type callback: Optional[Callable[..., Any]], optional
453
- :return: Evolution results at each time point. If callback is None, returns state vectors;
454
- otherwise returns callback results
455
- :rtype: Tensor
456
-
457
- :Example:
458
-
459
- >>> import tensorcircuit as tc
460
- >>> import numpy as np
461
- >>> # Define a simple 2-qubit Hamiltonian
462
- >>> h = tc.array_to_tensor([
463
- ... [1.0, 0.0, 0.0, 0.0],
464
- ... [0.0, -1.0, 2.0, 0.0],
465
- ... [0.0, 2.0, -1.0, 0.0],
466
- ... [0.0, 0.0, 0.0, 1.0]
467
- ... ])
468
- >>> # Initial state |00⟩
469
- >>> psi0 = tc.array_to_tensor([1.0, 0.0, 0.0, 0.0])
470
- >>> # Evolution times
471
- >>> times = tc.array_to_tensor([0.0, 0.5, 1.0])
472
- >>> # Evolve and get states
473
- >>> states = tc.experimental.hamiltonian_evol(times, h, psi0)
474
- >>> print(states.shape) # (3, 4)
475
-
476
-
477
- Note:
478
- 1. The Hamiltonian must be time-independent
479
- 2. For time-dependent Hamiltonians, use ``evol_local`` or ``evol_global`` instead
480
- 3. The evolution is performed in imaginary time by default (factor -t in exponential)
481
- 4. The state is automatically normalized at each time point
482
- """
483
- es, u = backend.eigh(h)
484
- utpsi0 = backend.reshape(
485
- backend.transpose(u) @ backend.reshape(psi0, [-1, 1]), [-1]
486
- )
487
-
488
- @backend.jit
489
- def _evol(t: Tensor) -> Tensor:
490
- ebetah_utpsi0 = backend.exp(-t * es) * utpsi0
491
- psi_exact = backend.conj(u) @ backend.reshape(ebetah_utpsi0, [-1, 1])
492
- psi_exact = backend.reshape(psi_exact, [-1])
493
- psi_exact = psi_exact / backend.norm(psi_exact)
494
- if callback is None:
495
- return psi_exact
496
- return callback(psi_exact)
497
-
498
- return backend.stack([_evol(t) for t in tlist])
499
-
500
-
501
- def evol_local(
502
- c: Circuit,
503
- index: Sequence[int],
504
- h_fun: Callable[..., Tensor],
505
- t: float,
506
- *args: Any,
507
- **solver_kws: Any
508
- ) -> Circuit:
509
- """
510
- ode evolution of time dependent Hamiltonian on circuit of given indices
511
- [only jax backend support for now]
512
-
513
- :param c: _description_
514
- :type c: Circuit
515
- :param index: qubit sites to evolve
516
- :type index: Sequence[int]
517
- :param h_fun: h_fun should return a dense Hamiltonian matrix
518
- with input arguments time and *args
519
- :type h_fun: Callable[..., Tensor]
520
- :param t: evolution time
521
- :type t: float
522
- :return: _description_
523
- :rtype: Circuit
524
- """
525
- from jax.experimental.ode import odeint
526
-
527
- s = c.state()
528
- n = c._nqubits
529
- l = len(index)
530
-
531
- def f(y: Tensor, t: Tensor, *args: Any) -> Tensor:
532
- y = backend.reshape2(y)
533
- y = Gate(y)
534
- h = -1.0j * h_fun(t, *args)
535
- h = backend.reshape2(h)
536
- h = Gate(h)
537
- edges = []
538
- for i in range(n):
539
- if i not in index:
540
- edges.append(y[i])
541
- else:
542
- j = index.index(i)
543
- edges.append(h[j])
544
- h[j + l] ^ y[i]
545
- y = contractor([y, h], output_edge_order=edges)
546
- return backend.reshape(y.tensor, [-1])
547
-
548
- ts = backend.stack([0.0, t])
549
- ts = backend.cast(ts, dtype=rdtypestr)
550
- s1 = odeint(f, s, ts, *args, **solver_kws)
551
- return type(c)(n, inputs=s1[-1])
552
-
553
-
554
- def evol_global(
555
- c: Circuit, h_fun: Callable[..., Tensor], t: float, *args: Any, **solver_kws: Any
556
- ) -> Circuit:
557
- """
558
- ode evolution of time dependent Hamiltonian on circuit of all qubits
559
- [only jax backend support for now]
560
-
561
- :param c: _description_
562
- :type c: Circuit
563
- :param h_fun: h_fun should return a **SPARSE** Hamiltonian matrix
564
- with input arguments time and *args
565
- :type h_fun: Callable[..., Tensor]
566
- :param t: _description_
567
- :type t: float
568
- :return: _description_
569
- :rtype: Circuit
570
- """
571
- from jax.experimental.ode import odeint
572
-
573
- s = c.state()
574
- n = c._nqubits
575
-
576
- def f(y: Tensor, t: Tensor, *args: Any) -> Tensor:
577
- h = -1.0j * h_fun(t, *args)
578
- return backend.sparse_dense_matmul(h, y)
579
-
580
- ts = backend.stack([0.0, t])
581
- ts = backend.cast(ts, dtype=rdtypestr)
582
- s1 = odeint(f, s, ts, *args, **solver_kws)
583
- return type(c)(n, inputs=s1[-1])
584
-
585
-
586
448
  def jax_jitted_function_save(filename: str, f: Callable[..., Any], *args: Any) -> None:
587
449
  """
588
450
  save a jitted jax function as a file
@@ -626,3 +488,721 @@ def jax_jitted_function_load(filename: str) -> Callable[..., Any]:
626
488
 
627
489
 
628
490
  jax_func_load = jax_jitted_function_load
491
+
492
+
493
+ PADDING_VALUE = -1
494
+ jaxlib: Any
495
+ ctg: Any
496
+ Mesh: Any
497
+ NamedSharding: Any
498
+ P: Any
499
+
500
+
501
+ def broadcast_py_object(obj: Any, shared_dir: Optional[str] = None) -> Any:
502
+ """
503
+ Broadcast a picklable Python object from process 0 to all other processes,
504
+ with fallback mechanism from gRPC to file system based approach.
505
+
506
+ This function first attempts to use gRPC-based broadcast. If that fails due to
507
+ pickling issues, it falls back to a file system based approach that is more robust.
508
+
509
+ :param obj: The Python object to broadcast. It must be picklable.
510
+ This object should exist on process 0 and can be None on others.
511
+ :type obj: Any
512
+ :param shared_dir: Directory path for shared file system broadcast fallback.
513
+ If None, uses current directory. Only used in fallback mode.
514
+ :type shared_dir: Optional[str], optional
515
+ :return: The broadcasted object, now present on all processes.
516
+ :rtype: Any
517
+ """
518
+ import jax
519
+ from jax.experimental import multihost_utils
520
+
521
+ try:
522
+ result = broadcast_py_object_jax(obj)
523
+ return result
524
+
525
+ except pickle.UnpicklingError as e:
526
+ # This block is executed if any process fails during the gRPC attempt.
527
+
528
+ multihost_utils.sync_global_devices("grpc_broadcast_failed_fallback_sync")
529
+
530
+ if jax.process_index() == 0:
531
+ border = "=" * 80
532
+ logger.warning(
533
+ "\n%s\nJAX gRPC broadcast failed with error: %s\n"
534
+ "--> Falling back to robust Shared File System broadcast method.\n%s",
535
+ border,
536
+ e,
537
+ border,
538
+ )
539
+
540
+ return broadcast_py_object_fs(obj, shared_dir)
541
+
542
+
543
+ def broadcast_py_object_jax(obj: Any) -> Any:
544
+ """
545
+ Broadcast a picklable Python object from process 0 to all other processes
546
+ within jax ditribution system.
547
+
548
+ This function uses a two-step broadcast: first the size, then the data.
549
+ This is necessary because `broadcast_one_to_all` requires the same
550
+ shaped array on all hosts.
551
+
552
+ :param obj: The Python object to broadcast. It must be picklable.
553
+ This object should exist on process 0 and can be None on others.
554
+
555
+ :return: The broadcasted object, now present on all processes.
556
+ """
557
+ import jax as jaxlib
558
+ import pickle
559
+ from jax.experimental import multihost_utils
560
+
561
+ # Serialize to bytes on process 0, empty bytes on others
562
+ if jaxlib.process_index() == 0:
563
+ if obj is None:
564
+ raise ValueError("Object to broadcast from process 0 cannot be None.")
565
+ data = pickle.dumps(obj)
566
+ logger.info(
567
+ f"--- Size of object to be broadcast: {len(data) / 1024**2:.3f} MB ---"
568
+ )
569
+
570
+ else:
571
+ data = b""
572
+
573
+ # Step 1: Broadcast the length of the serialized data.
574
+ # We send a single-element int32 array.
575
+ length = np.array([len(data)], dtype=np.int32)
576
+ length = multihost_utils.broadcast_one_to_all(length)
577
+
578
+ length = int(length[0]) # type: ignore
579
+
580
+ # Step 2: Broadcast the actual data.
581
+ # Convert byte string to a uint8 array for broadcasting.
582
+ send_arr_uint8 = np.frombuffer(data, dtype=np.uint8)
583
+ padded_length = (length + 3) // 4 * 4
584
+ if send_arr_uint8.size < padded_length:
585
+ send_arr_uint8 = np.pad( # type: ignore
586
+ send_arr_uint8, (0, padded_length - send_arr_uint8.size), mode="constant"
587
+ )
588
+ send_arr_int32 = send_arr_uint8.astype(np.int32)
589
+ # send_arr_int32 = jaxlib.numpy.array(send_arr_int32, dtype=np.int32)
590
+ send_arr_int32 = jaxlib.device_put(send_arr_int32)
591
+
592
+ jaxlib.experimental.multihost_utils.sync_global_devices("bulk_before")
593
+
594
+ received_arr = multihost_utils.broadcast_one_to_all(send_arr_int32)
595
+
596
+ received_arr = np.array(received_arr)
597
+ received_arr_uint8 = received_arr.astype(np.uint8)
598
+
599
+ # Step 3: Reconstruct the object from the received bytes.
600
+ # Convert the NumPy array back to bytes, truncate any padding, and unpickle.
601
+ received_data = received_arr_uint8[:length].tobytes()
602
+ # if jaxlib.process_index() == 0:
603
+ # logger.info(f"Broadcasted object {obj}")
604
+ return pickle.loads(received_data)
605
+
606
+
607
+ def broadcast_py_object_fs(
608
+ obj: Any, shared_dir: Optional[str] = None, timeout_seconds: int = 300
609
+ ) -> Any:
610
+ """
611
+ Broadcast a picklable Python object from process 0 to all other processes
612
+ using a shared file system approach.
613
+
614
+ This is a fallback method when gRPC-based broadcast fails. It uses UUID-based
615
+ file communication to share objects between processes through a shared file system.
616
+
617
+ :param obj: The Python object to broadcast. Must be picklable.
618
+ Should exist on process 0, can be None on others.
619
+ :type obj: Any
620
+ :param shared_dir: Directory path for shared file system communication.
621
+ If None, uses current directory.
622
+ :type shared_dir: Optional[str], optional
623
+ :param timeout_seconds: Maximum time to wait for file operations before timing out.
624
+ Defaults to 300 seconds.
625
+ :type timeout_seconds: int, optional
626
+ :return: The broadcasted object, now present on all processes.
627
+ :rtype: Any
628
+ """
629
+ # to_avoid very subtle bugs for broadcast tree_data on A800 clusters
630
+ import jax
631
+ from jax.experimental import multihost_utils
632
+
633
+ if shared_dir is None:
634
+ shared_dir = "."
635
+ if jax.process_index() == 0:
636
+ os.makedirs(shared_dir, exist_ok=True)
637
+
638
+ id_comm_path = os.path.join(shared_dir, f".broadcast_temp_12318")
639
+ transfer_id = ""
640
+
641
+ if jax.process_index() == 0:
642
+ transfer_id = str(uuid.uuid4())
643
+ # print(f"[Process 0] Generated unique transfer ID: {transfer_id}", flush=True)
644
+ with open(id_comm_path, "w") as f:
645
+ f.write(transfer_id)
646
+
647
+ multihost_utils.sync_global_devices("fs_broadcast_id_written")
648
+
649
+ if jax.process_index() != 0:
650
+ start_time = time.time()
651
+ while not os.path.exists(id_comm_path):
652
+ time.sleep(0.1)
653
+ if time.time() - start_time > timeout_seconds:
654
+ raise TimeoutError(
655
+ f"Process {jax.process_index()} timed out waiting for ID file: {id_comm_path}"
656
+ )
657
+ with open(id_comm_path, "r") as f:
658
+ transfer_id = f.read()
659
+
660
+ multihost_utils.sync_global_devices("fs_broadcast_id_read")
661
+ if jax.process_index() == 0:
662
+ try:
663
+ os.remove(id_comm_path)
664
+ except OSError:
665
+ pass # 如果文件已被其他进程快速清理,忽略错误
666
+
667
+ # 定义本次传输使用的数据文件和标志文件路径
668
+ data_path = os.path.join(shared_dir, f"{transfer_id}.data")
669
+ done_path = os.path.join(shared_dir, f"{transfer_id}.done")
670
+
671
+ result_obj = None
672
+
673
+ if jax.process_index() == 0:
674
+ if obj is None:
675
+ raise ValueError("None cannot be broadcasted.")
676
+
677
+ # print(f"[Process 0] Pickling object...", flush=True)
678
+ pickled_data = pickle.dumps(obj)
679
+ logger.info(
680
+ f"[Process 0] Writing {len(pickled_data) / 1024**2:.3f} MB to {data_path}"
681
+ )
682
+ with open(data_path, "wb") as f:
683
+ f.write(pickled_data)
684
+
685
+ with open(done_path, "w") as f:
686
+ pass
687
+ logger.info(f"[Process 0] Write complete.")
688
+ result_obj = obj
689
+ else:
690
+ # print(f"[Process {jax.process_index()}] Waiting for done file: {done_path}", flush=True)
691
+ start_time = time.time()
692
+ while not os.path.exists(done_path):
693
+ time.sleep(0.1)
694
+ if time.time() - start_time > timeout_seconds:
695
+ raise TimeoutError(
696
+ f"Process {jax.process_index()} timed out waiting for done file: {done_path}"
697
+ )
698
+
699
+ # print(f"[Process {jax.process_index()}] Done file found. Reading data from {data_path}", flush=True)
700
+ with open(data_path, "rb") as f:
701
+ pickled_data = f.read()
702
+
703
+ result_obj = pickle.loads(pickled_data)
704
+ logger.info(f"[Process {jax.process_index()}] Object successfully loaded.")
705
+
706
+ multihost_utils.sync_global_devices("fs_broadcast_read_complete")
707
+
708
+ if jax.process_index() == 0:
709
+ try:
710
+ os.remove(data_path)
711
+ os.remove(done_path)
712
+ # print(f"[Process 0] Cleaned up temporary files for transfer {transfer_id}.", flush=True)
713
+ except OSError as e:
714
+ logger.info(
715
+ f"[Process 0]: Failed to clean up temporary files: {e}",
716
+ )
717
+
718
+ return result_obj
719
+
720
+
721
+ class DistributedContractor:
722
+ """
723
+ A distributed tensor network contractor that parallelizes computations across multiple devices.
724
+
725
+ This class uses cotengra to find optimal contraction paths and distributes the computational
726
+ load across multiple devices (e.g., GPUs) for efficient tensor network calculations.
727
+ Particularly useful for large-scale quantum circuit simulations and variational quantum algorithms.
728
+
729
+ Example:
730
+ >>> def nodes_fn(params):
731
+ ... c = tc.Circuit(4)
732
+ ... c.rx(0, theta=params[0])
733
+ ... return c.expectation_before([tc.gates.z(), [0]], reuse=False)
734
+ >>> dc = DistributedContractor(nodes_fn, params)
735
+ >>> value, grad = dc.value_and_grad(params)
736
+
737
+ :param nodes_fn: Function that takes parameters and returns a list of tensor network nodes
738
+ :type nodes_fn: Callable[[Tensor], List[Gate]]
739
+ :param params: Initial parameters used to determine the tensor network structure
740
+ :type params: Tensor
741
+ :param cotengra_options: Configuration options passed to the cotengra optimizer. Defaults to None
742
+ :type cotengra_options: Optional[Dict[str, Any]], optional
743
+ :param devices: List of devices to use. If None, uses all available devices
744
+ :type devices: Optional[List[Any]], optional
745
+ :param mesh: Mesh object to use for distributed computation. If None, uses all available devices
746
+ :type mesh: Optional[Any], optional
747
+ """
748
+
749
+ def __init__(
750
+ self,
751
+ nodes_fn: Callable[[Tensor], List[Gate]],
752
+ params: Tensor,
753
+ cotengra_options: Optional[Dict[str, Any]] = None,
754
+ devices: Optional[List[Any]] = None, # backward compatibility
755
+ mesh: Optional[Any] = None,
756
+ tree_data: Optional[Dict[str, Any]] = None,
757
+ ) -> None:
758
+ global jaxlib
759
+ global ctg
760
+ global Mesh
761
+ global NamedSharding
762
+ global P
763
+
764
+ logger.info("Initializing DistributedContractor...")
765
+ import cotengra as ctg
766
+ from cotengra import ContractionTree
767
+ import jax as jaxlib
768
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
769
+
770
+ self.nodes_fn = nodes_fn
771
+ if mesh is not None:
772
+ self.mesh = mesh
773
+ elif devices is not None:
774
+ self.mesh = Mesh(devices, axis_names=("devices",))
775
+ else:
776
+ self.mesh = Mesh(jaxlib.devices(), axis_names=("devices",))
777
+ self.num_devices = len(self.mesh.devices)
778
+
779
+ if self.num_devices <= 1:
780
+ logger.info("DistributedContractor is running on a single device.")
781
+
782
+ self._params_template = params
783
+ self.params_sharding = jaxlib.tree_util.tree_map(
784
+ lambda x: NamedSharding(self.mesh, P(*((None,) * x.ndim))),
785
+ self._params_template,
786
+ )
787
+ self._backend = "jax"
788
+ self._compiled_v_fns: Dict[
789
+ Tuple[Callable[[Tensor], Tensor], str],
790
+ Callable[[Any, Tensor, Tensor], Tensor],
791
+ ] = {}
792
+ self._compiled_vg_fns: Dict[
793
+ Tuple[Callable[[Tensor], Tensor], str],
794
+ Callable[[Any, Tensor, Tensor], Tensor],
795
+ ] = {}
796
+
797
+ logger.info("Running cotengra pathfinder... (This may take a while)")
798
+ if tree_data is None:
799
+ if params is None:
800
+ raise ValueError("Please provide specific circuit parameters array.")
801
+ if jaxlib.process_index() == 0:
802
+ logger.info("Process 0: Running cotengra pathfinder...")
803
+ tree_data = self._get_tree_data(
804
+ self.nodes_fn, self._params_template, cotengra_options # type: ignore
805
+ )
806
+
807
+ # Step 2: Use the robust helper function to broadcast the tree object.
808
+ # Process 0 sends its computed `tree_object`.
809
+ # Other processes send `None`, but receive the object from process 0.
810
+
811
+ if jaxlib.process_count() > 1:
812
+ # self.tree = broadcast_py_object(tree_object)
813
+ jaxlib.experimental.multihost_utils.sync_global_devices("tree_before")
814
+ logger.info(
815
+ f"Process {jaxlib.process_index()}: Synchronizing contraction path..."
816
+ )
817
+ tree_data = broadcast_py_object(tree_data)
818
+ jaxlib.experimental.multihost_utils.sync_global_devices("tree_after")
819
+ else:
820
+ logger.info("Using pre-computed contraction path.")
821
+ if tree_data is None:
822
+ raise ValueError("Contraction path data is missing.")
823
+
824
+ self.tree = ContractionTree.from_path(
825
+ inputs=tree_data["inputs"],
826
+ output=tree_data["output"],
827
+ size_dict=tree_data["size_dict"],
828
+ path=tree_data["path"],
829
+ )
830
+
831
+ # Restore slicing information
832
+ for ind, _ in tree_data["sliced_inds"].items():
833
+ self.tree.remove_ind_(ind)
834
+
835
+ logger.info(
836
+ f"Process {jaxlib.process_index()}: Contraction path successfully synchronized."
837
+ )
838
+ actual_num_slices = self.tree.nslices
839
+
840
+ self._report_tree_info()
841
+
842
+ slices_per_device = int(np.ceil(actual_num_slices / self.num_devices))
843
+ padded_size = slices_per_device * self.num_devices
844
+ slice_indices = np.arange(actual_num_slices)
845
+ padded_slice_indices = np.full(padded_size, PADDING_VALUE, dtype=np.int32)
846
+ padded_slice_indices[:actual_num_slices] = slice_indices
847
+
848
+ # Reshape for distribution and define the sharding rule
849
+ batched_indices = padded_slice_indices.reshape(
850
+ self.num_devices, slices_per_device
851
+ )
852
+ # Sharding rule: split the first axis (the one for devices) across the 'devices' mesh axis
853
+ self.sharding = NamedSharding(self.mesh, P("devices", None))
854
+ # Place the tensor on devices according to the rule
855
+ self.batched_slice_indices = jaxlib.device_put(batched_indices, self.sharding)
856
+
857
+ # self.batched_slice_indices = backend.convert_to_tensor(
858
+ # padded_slice_indices.reshape(self.num_devices, slices_per_device)
859
+ # )
860
+ print(
861
+ f"Distributing across {self.num_devices} devices. Each device will sequentially process "
862
+ f"up to {slices_per_device} slices."
863
+ )
864
+
865
+ self._compiled_vg_fn = None
866
+ self._compiled_v_fn = None
867
+
868
+ logger.info("Initialization complete.")
869
+
870
+ def _report_tree_info(self) -> None:
871
+ print("\n--- Contraction Path Info ---")
872
+ actual_num_slices = self.tree.nslices
873
+ stats = self.tree.contract_stats()
874
+ print(f"Path found with {actual_num_slices} slices.")
875
+ print(
876
+ f"Arithmetic Intensity (higher is better): {self.tree.arithmetic_intensity():.2f}"
877
+ )
878
+ print("flops (TFlops):", stats["flops"] / 2**40 / self.num_devices)
879
+ print("write (GB):", stats["write"] / 2**27 / actual_num_slices)
880
+ print("size (GB):", stats["size"] / 2**27)
881
+ print("-----------------------------\n")
882
+
883
+ @staticmethod
884
+ def _get_tree_data(
885
+ nodes_fn: Callable[[Tensor], List[Gate]],
886
+ params: Tensor,
887
+ cotengra_options: Optional[Dict[str, Any]] = None,
888
+ ) -> Dict[str, Any]:
889
+ global ctg
890
+
891
+ import cotengra as ctg
892
+
893
+ local_cotengra_options = (cotengra_options or {}).copy()
894
+
895
+ nodes = nodes_fn(params)
896
+ tn_info, _ = get_tn_info(nodes)
897
+ default_cotengra_options = {
898
+ "slicing_reconf_opts": {"target_size": 2**28},
899
+ "max_repeats": 128,
900
+ "minimize": "write",
901
+ "parallel": "auto",
902
+ "progbar": True,
903
+ }
904
+ default_cotengra_options.update(local_cotengra_options)
905
+
906
+ opt = ctg.ReusableHyperOptimizer(**default_cotengra_options)
907
+ tree_object = opt.search(*tn_info)
908
+ tree_data = {
909
+ "inputs": tree_object.inputs,
910
+ "output": tree_object.output,
911
+ "size_dict": tree_object.size_dict,
912
+ "path": tree_object.get_path(),
913
+ "sliced_inds": tree_object.sliced_inds,
914
+ }
915
+ return tree_data
916
+
917
+ @staticmethod
918
+ def find_path(
919
+ nodes_fn: Callable[[Tensor], Tensor],
920
+ params: Tensor,
921
+ cotengra_options: Optional[Dict[str, Any]] = None,
922
+ filepath: Optional[str] = None,
923
+ ) -> None:
924
+ tree_data = DistributedContractor._get_tree_data(
925
+ nodes_fn, params, cotengra_options
926
+ )
927
+ if filepath is not None:
928
+ with open(filepath, "wb") as f:
929
+ pickle.dump(tree_data, f)
930
+ logger.info(f"Contraction path data successfully saved to '{filepath}'.")
931
+
932
+ @classmethod
933
+ def from_path(
934
+ cls,
935
+ filepath: str,
936
+ nodes_fn: Callable[[Tensor], List[Gate]],
937
+ devices: Optional[List[Any]] = None, # backward compatibility
938
+ mesh: Optional[Any] = None,
939
+ params: Any = None,
940
+ ) -> "DistributedContractor":
941
+ with open(filepath, "rb") as f:
942
+ tree_data = pickle.load(f)
943
+
944
+ # Each process loads the file independently. No broadcast is needed.
945
+ # We pass the loaded `tree_data` directly to __init__ to trigger the second workflow.
946
+ return cls(
947
+ nodes_fn=nodes_fn,
948
+ params=params,
949
+ mesh=mesh,
950
+ devices=devices,
951
+ tree_data=tree_data,
952
+ )
953
+
954
+ def _get_single_slice_contraction_fn(
955
+ self, op: Optional[Callable[[Tensor], Tensor]] = None
956
+ ) -> Callable[[Any, Tensor, int], Tensor]:
957
+ if op is None:
958
+ op = backend.sum
959
+
960
+ def single_slice_contraction(
961
+ tree: ctg.ContractionTree, params: Tensor, slice_idx: int
962
+ ) -> Tensor:
963
+ nodes = self.nodes_fn(params)
964
+ _, standardized_nodes = get_tn_info(nodes)
965
+ input_arrays = [node.tensor for node in standardized_nodes]
966
+ sliced_arrays = tree.slice_arrays(input_arrays, slice_idx)
967
+ result = tree.contract_core(sliced_arrays, backend=self._backend)
968
+ return op(result)
969
+
970
+ return single_slice_contraction
971
+
972
+ def _get_device_sum_vg_fn(
973
+ self,
974
+ op: Optional[Callable[[Tensor], Tensor]] = None,
975
+ output_dtype: Optional[str] = None,
976
+ ) -> Callable[[Any, Tensor, Tensor], Tuple[Tensor, Tensor]]:
977
+ post_processing = lambda x: backend.real(backend.sum(x))
978
+ if op is None:
979
+ op = post_processing
980
+ base_fn = self._get_single_slice_contraction_fn(op=op)
981
+ # to ensure the output is real so that can be differentiated
982
+ single_slice_vg_fn = jaxlib.value_and_grad(base_fn, argnums=1)
983
+
984
+ if output_dtype is None:
985
+ output_dtype = rdtypestr
986
+
987
+ def device_sum_fn(
988
+ tree: ctg.ContractionTree, params: Tensor, slice_indices_for_device: Tensor
989
+ ) -> Tuple[Tensor, Tensor]:
990
+ def scan_body(
991
+ carry: Tuple[Tensor, Tensor], slice_idx: Tensor
992
+ ) -> Tuple[Tuple[Tensor, Tensor], None]:
993
+ acc_value, acc_grads = carry
994
+
995
+ def compute_and_add() -> Tuple[Tensor, Tensor]:
996
+ value_slice, grads_slice = single_slice_vg_fn(
997
+ tree, params, slice_idx
998
+ )
999
+ new_value = acc_value + value_slice
1000
+ new_grads = jaxlib.tree_util.tree_map(
1001
+ jaxlib.numpy.add, acc_grads, grads_slice
1002
+ )
1003
+ return new_value, new_grads
1004
+
1005
+ def do_nothing() -> Tuple[Tensor, Tensor]:
1006
+ return acc_value, acc_grads
1007
+
1008
+ return (
1009
+ jaxlib.lax.cond(
1010
+ slice_idx == PADDING_VALUE, do_nothing, compute_and_add
1011
+ ),
1012
+ None,
1013
+ )
1014
+
1015
+ initial_carry = (
1016
+ backend.cast(backend.convert_to_tensor(0.0), dtype=output_dtype),
1017
+ jaxlib.tree_util.tree_map(lambda x: jaxlib.numpy.zeros_like(x), params),
1018
+ )
1019
+ (final_value, final_grads), _ = jaxlib.lax.scan(
1020
+ scan_body, initial_carry, slice_indices_for_device
1021
+ )
1022
+ return final_value, final_grads
1023
+
1024
+ return device_sum_fn
1025
+
1026
+ def _get_device_sum_v_fn(
1027
+ self,
1028
+ op: Optional[Callable[[Tensor], Tensor]] = None,
1029
+ output_dtype: Optional[str] = None,
1030
+ ) -> Callable[[Any, Tensor, Tensor], Tensor]:
1031
+ base_fn = self._get_single_slice_contraction_fn(op=op)
1032
+ if output_dtype is None:
1033
+ output_dtype = dtypestr
1034
+
1035
+ def device_sum_fn(
1036
+ tree: ctg.ContractionTree, params: Tensor, slice_indices_for_device: Tensor
1037
+ ) -> Tensor:
1038
+ def scan_body(
1039
+ carry_value: Tensor, slice_idx: Tensor
1040
+ ) -> Tuple[Tensor, None]:
1041
+ def compute_and_add() -> Tensor:
1042
+ return carry_value + base_fn(tree, params, slice_idx)
1043
+
1044
+ return (
1045
+ jaxlib.lax.cond(
1046
+ slice_idx == PADDING_VALUE, lambda: carry_value, compute_and_add
1047
+ ),
1048
+ None,
1049
+ )
1050
+
1051
+ initial_carry = backend.cast(
1052
+ backend.convert_to_tensor(0.0), dtype=output_dtype
1053
+ )
1054
+ final_value, _ = jaxlib.lax.scan(
1055
+ scan_body, initial_carry, slice_indices_for_device
1056
+ )
1057
+ return final_value
1058
+
1059
+ return device_sum_fn
1060
+
1061
+ def _get_or_compile_fn(
1062
+ self,
1063
+ cache: Dict[
1064
+ Tuple[Callable[[Tensor], Tensor], str],
1065
+ Callable[[Any, Tensor, Tensor], Tensor],
1066
+ ],
1067
+ fn_getter: Callable[..., Any],
1068
+ op: Optional[Callable[[Tensor], Tensor]],
1069
+ output_dtype: Optional[str],
1070
+ is_grad_fn: bool,
1071
+ ) -> Callable[[Any, Tensor, Tensor], Tensor]:
1072
+ """
1073
+ Gets a compiled pmap-ed function from cache or compiles and caches it.
1074
+
1075
+ The cache key is a tuple of (op, output_dtype). Caution on lambda function!
1076
+
1077
+ Returns:
1078
+ The compiled, pmap-ed JAX function.
1079
+ """
1080
+ cache_key = (op, output_dtype)
1081
+ if cache_key not in cache:
1082
+ device_fn = fn_getter(op=op, output_dtype=output_dtype)
1083
+
1084
+ def global_aggregated_fn(
1085
+ tree: Any, params: Any, batched_slice_indices: Tensor
1086
+ ) -> Any:
1087
+ # Use jax.vmap to apply the per-device function across the sharded data.
1088
+ # vmap maps `device_fn` over the first axis (0) of `batched_slice_indices`.
1089
+ # `tree` and `params` are broadcasted (in_axes=None) to each call.
1090
+ vmapped_device_fn = jaxlib.vmap(
1091
+ device_fn, in_axes=(None, None, 0), out_axes=0
1092
+ )
1093
+ device_results = vmapped_device_fn(tree, params, batched_slice_indices)
1094
+
1095
+ # Now, `device_results` is a sharded PyTree (one result per device).
1096
+ # We aggregate them using jnp.sum, which JAX automatically compiles
1097
+ # into a cross-device AllReduce operation.
1098
+
1099
+ if is_grad_fn:
1100
+ # `device_results` is a (value, grad) tuple of sharded arrays
1101
+ device_values, device_grads = device_results
1102
+
1103
+ # Replace psum with jnp.sum
1104
+ global_value = jaxlib.numpy.sum(device_values, axis=0)
1105
+ global_grad = jaxlib.tree_util.tree_map(
1106
+ lambda g: jaxlib.numpy.sum(g, axis=0), device_grads
1107
+ )
1108
+ return global_value, global_grad
1109
+ else:
1110
+ # `device_results` is just the sharded values
1111
+ return jaxlib.numpy.sum(device_results, axis=0)
1112
+
1113
+ # Compile the global function with jax.jit and specify shardings.
1114
+ # `params` are replicated (available everywhere).
1115
+
1116
+ in_shardings = (self.params_sharding, self.sharding)
1117
+
1118
+ if is_grad_fn:
1119
+ # Returns (value, grad), so out_sharding must be a 2-tuple.
1120
+ # `value` is a replicated scalar -> P()
1121
+ sharding_for_value = NamedSharding(self.mesh, P())
1122
+ # `grad` is a replicated PyTree with the same structure as params.
1123
+ sharding_for_grad = self.params_sharding
1124
+ out_shardings = (sharding_for_value, sharding_for_grad)
1125
+ else:
1126
+ # Returns a single scalar value -> P()
1127
+ out_shardings = NamedSharding(self.mesh, P())
1128
+
1129
+ compiled_fn = jaxlib.jit(
1130
+ global_aggregated_fn,
1131
+ # `tree` is a static argument, its value is compiled into the function.
1132
+ static_argnums=(0,),
1133
+ # Specify how inputs are sharded.
1134
+ in_shardings=in_shardings,
1135
+ # Specify how the output should be sharded.
1136
+ out_shardings=out_shardings,
1137
+ )
1138
+ cache[cache_key] = compiled_fn # type: ignore
1139
+ return cache[cache_key] # type: ignore
1140
+
1141
+ def value_and_grad(
1142
+ self,
1143
+ params: Tensor,
1144
+ # aggregate: bool = True,
1145
+ op: Optional[Callable[[Tensor], Tensor]] = None,
1146
+ output_dtype: Optional[str] = None,
1147
+ ) -> Tuple[Tensor, Tensor]:
1148
+ """
1149
+ Calculates the value and gradient, compiling the pmap function if needed for the first call.
1150
+
1151
+ :param params: Parameters for the `nodes_fn` input
1152
+ :type params: Tensor
1153
+ :param op: Optional post-processing function for the output, defaults to None (corresponding to `backend.real`)
1154
+ op is a cache key, so dont directly pass lambda function for op
1155
+ :type op: Optional[Callable[[Tensor], Tensor]], optional
1156
+ :param output_dtype: dtype str for the output of `nodes_fn`, defaults to None (corresponding to `rdtypestr`)
1157
+ :type output_dtype: Optional[str], optional
1158
+ """
1159
+ compiled_vg_fn = self._get_or_compile_fn(
1160
+ cache=self._compiled_vg_fns,
1161
+ fn_getter=self._get_device_sum_vg_fn,
1162
+ op=op,
1163
+ output_dtype=output_dtype,
1164
+ is_grad_fn=True,
1165
+ )
1166
+
1167
+ total_value, total_grad = compiled_vg_fn(
1168
+ self.tree, params, self.batched_slice_indices
1169
+ )
1170
+ return total_value, total_grad
1171
+
1172
+ def value(
1173
+ self,
1174
+ params: Tensor,
1175
+ # aggregate: bool = True,
1176
+ op: Optional[Callable[[Tensor], Tensor]] = None,
1177
+ output_dtype: Optional[str] = None,
1178
+ ) -> Tensor:
1179
+ """
1180
+ Calculates the value, compiling the pmap function for the first call.
1181
+
1182
+ :param params: Parameters for the `nodes_fn` input
1183
+ :type params: Tensor
1184
+ :param op: Optional post-processing function for the output, defaults to None (corresponding to identity)
1185
+ op is a cache key, so dont directly pass lambda function for op
1186
+ :type op: Optional[Callable[[Tensor], Tensor]], optional
1187
+ :param output_dtype: dtype str for the output of `nodes_fn`, defaults to None (corresponding to `dtypestr`)
1188
+ :type output_dtype: Optional[str], optional
1189
+ """
1190
+ compiled_v_fn = self._get_or_compile_fn(
1191
+ cache=self._compiled_v_fns,
1192
+ fn_getter=self._get_device_sum_v_fn,
1193
+ op=op,
1194
+ output_dtype=output_dtype,
1195
+ is_grad_fn=False,
1196
+ )
1197
+
1198
+ total_value = compiled_v_fn(self.tree, params, self.batched_slice_indices)
1199
+ return total_value
1200
+
1201
+ def grad(
1202
+ self,
1203
+ params: Tensor,
1204
+ op: Optional[Callable[[Tensor], Tensor]] = None,
1205
+ output_dtype: Optional[str] = None,
1206
+ ) -> Tensor:
1207
+ _, grad = self.value_and_grad(params, op=op, output_dtype=output_dtype)
1208
+ return grad