tensorcircuit-nightly 1.4.0.dev20251010__tar.gz → 1.4.0.dev20251107__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tensorcircuit-nightly might be problematic. Click here for more details.
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/CHANGELOG.md +8 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/PKG-INFO +1 -1
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/__init__.py +1 -1
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/experimental.py +184 -57
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/interfaces/tensortrans.py +6 -2
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/interfaces/torch.py +14 -4
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit_nightly.egg-info/PKG-INFO +1 -1
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/HISTORY.md +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/LICENSE +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/MANIFEST.in +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/README.md +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/README_cn.md +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/pyproject.toml +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/setup.cfg +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/setup.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/about.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/abstractcircuit.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/analogcircuit.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/applications/__init__.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/applications/ai/__init__.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/applications/ai/ensemble.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/applications/dqas.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/applications/finance/__init__.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/applications/finance/portfolio.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/applications/graphdata.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/applications/layers.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/applications/optimization.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/applications/physics/__init__.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/applications/physics/baseline.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/applications/physics/fss.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/applications/utils.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/applications/vags.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/applications/van.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/applications/vqes.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/asciiart.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/backends/__init__.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/backends/abstract_backend.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/backends/backend_factory.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/backends/cupy_backend.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/backends/jax_backend.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/backends/jax_ops.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/backends/numpy_backend.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/backends/pytorch_backend.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/backends/pytorch_ops.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/backends/tensorflow_backend.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/backends/tf_ops.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/basecircuit.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/channels.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/circuit.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/cloud/__init__.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/cloud/abstraction.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/cloud/apis.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/cloud/config.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/cloud/local.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/cloud/quafu_provider.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/cloud/tencent.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/cloud/utils.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/cloud/wrapper.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/compiler/__init__.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/compiler/composed_compiler.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/compiler/qiskit_compiler.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/compiler/simple_compiler.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/cons.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/densitymatrix.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/fgs.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/gates.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/interfaces/__init__.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/interfaces/jax.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/interfaces/numpy.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/interfaces/scipy.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/interfaces/tensorflow.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/keras.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/mps_base.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/mpscircuit.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/noisemodel.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/quantum.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/quditcircuit.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/quditgates.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/results/__init__.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/results/counts.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/results/qem/__init__.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/results/qem/benchmark_circuits.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/results/qem/qem_methods.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/results/readout_mitigation.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/shadows.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/simplify.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/stabilizercircuit.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/templates/__init__.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/templates/ansatz.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/templates/blocks.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/templates/chems.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/templates/conversions.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/templates/dataset.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/templates/graphs.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/templates/hamiltonians.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/templates/lattice.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/templates/measurements.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/timeevol.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/torchnn.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/translation.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/utils.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit/vis.py +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit_nightly.egg-info/SOURCES.txt +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit_nightly.egg-info/dependency_links.txt +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit_nightly.egg-info/requires.txt +0 -0
- {tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/tensorcircuit_nightly.egg-info/top_level.txt +0 -0
{tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/CHANGELOG.md
RENAMED
|
@@ -12,6 +12,14 @@
|
|
|
12
12
|
|
|
13
13
|
- Add `su4` as a generic parameterized two-qubit gates.
|
|
14
14
|
|
|
15
|
+
- Add multi controller jax support for distrubuted contraction.
|
|
16
|
+
|
|
17
|
+
### Fixed
|
|
18
|
+
|
|
19
|
+
- Fix the breaking logic change in jax from dlpack API, dlcapsule -> tensor.
|
|
20
|
+
|
|
21
|
+
- Better torch interface for dlpack translation.
|
|
22
|
+
|
|
15
23
|
## v1.4.0
|
|
16
24
|
|
|
17
25
|
### Added
|
{tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/PKG-INFO
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tensorcircuit-nightly
|
|
3
|
-
Version: 1.4.0.
|
|
3
|
+
Version: 1.4.0.dev20251107
|
|
4
4
|
Summary: High performance unified quantum computing framework for the NISQ era
|
|
5
5
|
Author-email: TensorCircuit Authors <znfesnpbh@gmail.com>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -489,6 +489,62 @@ jax_func_load = jax_jitted_function_load
|
|
|
489
489
|
PADDING_VALUE = -1
|
|
490
490
|
jaxlib: Any
|
|
491
491
|
ctg: Any
|
|
492
|
+
Mesh: Any
|
|
493
|
+
NamedSharding: Any
|
|
494
|
+
P: Any
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
def broadcast_py_object(obj: Any) -> Any:
|
|
498
|
+
"""
|
|
499
|
+
Broadcast a picklable Python object from process 0 to all other processes
|
|
500
|
+
within jax ditribution system.
|
|
501
|
+
|
|
502
|
+
This function uses a two-step broadcast: first the size, then the data.
|
|
503
|
+
This is necessary because `broadcast_one_to_all` requires the same
|
|
504
|
+
shaped array on all hosts.
|
|
505
|
+
|
|
506
|
+
:param obj: The Python object to broadcast. It must be picklable.
|
|
507
|
+
This object should exist on process 0 and can be None on others.
|
|
508
|
+
|
|
509
|
+
:return: The broadcasted object, now present on all processes.
|
|
510
|
+
"""
|
|
511
|
+
import jax as jaxlib
|
|
512
|
+
import pickle
|
|
513
|
+
from jax.experimental import multihost_utils
|
|
514
|
+
|
|
515
|
+
# Serialize to bytes on process 0, empty bytes on others
|
|
516
|
+
if jaxlib.process_index() == 0:
|
|
517
|
+
if obj is None:
|
|
518
|
+
raise ValueError("Object to broadcast from process 0 cannot be None.")
|
|
519
|
+
data = pickle.dumps(obj)
|
|
520
|
+
else:
|
|
521
|
+
data = b""
|
|
522
|
+
|
|
523
|
+
# Step 1: Broadcast the length of the serialized data.
|
|
524
|
+
# We send a single-element int32 array.
|
|
525
|
+
length = np.array([len(data)], dtype=np.int32)
|
|
526
|
+
length = multihost_utils.broadcast_one_to_all(length)
|
|
527
|
+
length = int(length[0]) # type: ignore
|
|
528
|
+
|
|
529
|
+
# Step 2: Broadcast the actual data.
|
|
530
|
+
# Convert byte string to a uint8 array for broadcasting.
|
|
531
|
+
send_arr = np.frombuffer(data, dtype=np.uint8)
|
|
532
|
+
|
|
533
|
+
# Pad the array on the source process if necessary, although it's unlikely
|
|
534
|
+
# to be smaller than `length`. More importantly, other processes create an
|
|
535
|
+
# empty buffer which must be padded to the correct receiving size.
|
|
536
|
+
if send_arr.size < length:
|
|
537
|
+
send_arr = np.pad(send_arr, (0, length - send_arr.size), mode="constant") # type: ignore
|
|
538
|
+
|
|
539
|
+
# Broadcast the uint8 array. Process 0 sends, others receive into `send_arr`.
|
|
540
|
+
received_arr = multihost_utils.broadcast_one_to_all(send_arr)
|
|
541
|
+
|
|
542
|
+
# Step 3: Reconstruct the object from the received bytes.
|
|
543
|
+
# Convert the NumPy array back to bytes, truncate any padding, and unpickle.
|
|
544
|
+
received_data = received_arr[:length].tobytes()
|
|
545
|
+
if jaxlib.process_index() == 0:
|
|
546
|
+
logger.info(f"Broadcasted object {obj}")
|
|
547
|
+
return pickle.loads(received_data)
|
|
492
548
|
|
|
493
549
|
|
|
494
550
|
class DistributedContractor:
|
|
@@ -513,8 +569,10 @@ class DistributedContractor:
|
|
|
513
569
|
:type params: Tensor
|
|
514
570
|
:param cotengra_options: Configuration options passed to the cotengra optimizer. Defaults to None
|
|
515
571
|
:type cotengra_options: Optional[Dict[str, Any]], optional
|
|
516
|
-
:param devices: List of devices to use. If None, uses all available
|
|
572
|
+
:param devices: List of devices to use. If None, uses all available devices
|
|
517
573
|
:type devices: Optional[List[Any]], optional
|
|
574
|
+
:param mesh: Mesh object to use for distributed computation. If None, uses all available devices
|
|
575
|
+
:type mesh: Optional[Any], optional
|
|
518
576
|
"""
|
|
519
577
|
|
|
520
578
|
def __init__(
|
|
@@ -522,23 +580,28 @@ class DistributedContractor:
|
|
|
522
580
|
nodes_fn: Callable[[Tensor], List[Gate]],
|
|
523
581
|
params: Tensor,
|
|
524
582
|
cotengra_options: Optional[Dict[str, Any]] = None,
|
|
525
|
-
devices: Optional[List[Any]] = None,
|
|
583
|
+
devices: Optional[List[Any]] = None, # backward compatibility
|
|
584
|
+
mesh: Optional[Any] = None,
|
|
526
585
|
) -> None:
|
|
527
586
|
global jaxlib
|
|
528
587
|
global ctg
|
|
588
|
+
global Mesh
|
|
589
|
+
global NamedSharding
|
|
590
|
+
global P
|
|
529
591
|
|
|
530
592
|
logger.info("Initializing DistributedContractor...")
|
|
531
593
|
import cotengra as ctg
|
|
532
594
|
import jax as jaxlib
|
|
595
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
|
|
533
596
|
|
|
534
597
|
self.nodes_fn = nodes_fn
|
|
535
|
-
if
|
|
536
|
-
self.
|
|
537
|
-
|
|
538
|
-
|
|
598
|
+
if mesh is not None:
|
|
599
|
+
self.mesh = mesh
|
|
600
|
+
elif devices is not None:
|
|
601
|
+
self.mesh = Mesh(devices, axis_names=("devices",))
|
|
539
602
|
else:
|
|
540
|
-
self.
|
|
541
|
-
|
|
603
|
+
self.mesh = Mesh(jaxlib.devices(), axis_names=("devices",))
|
|
604
|
+
self.num_devices = len(self.mesh.devices)
|
|
542
605
|
|
|
543
606
|
if self.num_devices <= 1:
|
|
544
607
|
logger.info("DistributedContractor is running on a single device.")
|
|
@@ -555,20 +618,39 @@ class DistributedContractor:
|
|
|
555
618
|
] = {}
|
|
556
619
|
|
|
557
620
|
logger.info("Running cotengra pathfinder... (This may take a while)")
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
"progbar"
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
621
|
+
tree_object = None
|
|
622
|
+
if jaxlib.process_index() == 0:
|
|
623
|
+
logger.info("Process 0: Running cotengra pathfinder...")
|
|
624
|
+
|
|
625
|
+
local_cotengra_options = (cotengra_options or {}).copy()
|
|
626
|
+
local_cotengra_options["progbar"] = True
|
|
627
|
+
|
|
628
|
+
nodes = self.nodes_fn(self._params_template)
|
|
629
|
+
tn_info, _ = get_tn_info(nodes)
|
|
630
|
+
default_cotengra_options = {
|
|
631
|
+
"slicing_reconf_opts": {"target_size": 2**28},
|
|
632
|
+
"max_repeats": 128,
|
|
633
|
+
"minimize": "write",
|
|
634
|
+
"parallel": "auto",
|
|
635
|
+
}
|
|
636
|
+
default_cotengra_options.update(local_cotengra_options)
|
|
637
|
+
|
|
638
|
+
opt = ctg.ReusableHyperOptimizer(**default_cotengra_options)
|
|
639
|
+
tree_object = opt.search(*tn_info)
|
|
640
|
+
|
|
641
|
+
# Step 2: Use the robust helper function to broadcast the tree object.
|
|
642
|
+
# Process 0 sends its computed `tree_object`.
|
|
643
|
+
# Other processes send `None`, but receive the object from process 0.
|
|
644
|
+
logger.info(
|
|
645
|
+
f"Process {jaxlib.process_index()}: Synchronizing contraction path..."
|
|
646
|
+
)
|
|
647
|
+
if jaxlib.process_count() > 1:
|
|
648
|
+
self.tree = broadcast_py_object(tree_object)
|
|
649
|
+
else:
|
|
650
|
+
self.tree = tree_object
|
|
651
|
+
logger.info(
|
|
652
|
+
f"Process {jaxlib.process_index()}: Contraction path successfully synchronized."
|
|
653
|
+
)
|
|
572
654
|
actual_num_slices = self.tree.nslices
|
|
573
655
|
|
|
574
656
|
print("\n--- Contraction Path Info ---")
|
|
@@ -587,9 +669,19 @@ class DistributedContractor:
|
|
|
587
669
|
slice_indices = np.arange(actual_num_slices)
|
|
588
670
|
padded_slice_indices = np.full(padded_size, PADDING_VALUE, dtype=np.int32)
|
|
589
671
|
padded_slice_indices[:actual_num_slices] = slice_indices
|
|
590
|
-
|
|
591
|
-
|
|
672
|
+
|
|
673
|
+
# Reshape for distribution and define the sharding rule
|
|
674
|
+
batched_indices = padded_slice_indices.reshape(
|
|
675
|
+
self.num_devices, slices_per_device
|
|
592
676
|
)
|
|
677
|
+
# Sharding rule: split the first axis (the one for devices) across the 'devices' mesh axis
|
|
678
|
+
self.sharding = NamedSharding(self.mesh, P("devices", None))
|
|
679
|
+
# Place the tensor on devices according to the rule
|
|
680
|
+
self.batched_slice_indices = jaxlib.device_put(batched_indices, self.sharding)
|
|
681
|
+
|
|
682
|
+
# self.batched_slice_indices = backend.convert_to_tensor(
|
|
683
|
+
# padded_slice_indices.reshape(self.num_devices, slices_per_device)
|
|
684
|
+
# )
|
|
593
685
|
print(
|
|
594
686
|
f"Distributing across {self.num_devices} devices. Each device will sequentially process "
|
|
595
687
|
f"up to {slices_per_device} slices."
|
|
@@ -716,6 +808,7 @@ class DistributedContractor:
|
|
|
716
808
|
fn_getter: Callable[..., Any],
|
|
717
809
|
op: Optional[Callable[[Tensor], Tensor]],
|
|
718
810
|
output_dtype: Optional[str],
|
|
811
|
+
is_grad_fn: bool,
|
|
719
812
|
) -> Callable[[Any, Tensor, Tensor], Tensor]:
|
|
720
813
|
"""
|
|
721
814
|
Gets a compiled pmap-ed function from cache or compiles and caches it.
|
|
@@ -728,15 +821,64 @@ class DistributedContractor:
|
|
|
728
821
|
cache_key = (op, output_dtype)
|
|
729
822
|
if cache_key not in cache:
|
|
730
823
|
device_fn = fn_getter(op=op, output_dtype=output_dtype)
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
824
|
+
|
|
825
|
+
def global_aggregated_fn(
|
|
826
|
+
tree: Any, params: Any, batched_slice_indices: Tensor
|
|
827
|
+
) -> Any:
|
|
828
|
+
# Use jax.vmap to apply the per-device function across the sharded data.
|
|
829
|
+
# vmap maps `device_fn` over the first axis (0) of `batched_slice_indices`.
|
|
830
|
+
# `tree` and `params` are broadcasted (in_axes=None) to each call.
|
|
831
|
+
vmapped_device_fn = jaxlib.vmap(
|
|
832
|
+
device_fn, in_axes=(None, None, 0), out_axes=0
|
|
833
|
+
)
|
|
834
|
+
device_results = vmapped_device_fn(tree, params, batched_slice_indices)
|
|
835
|
+
|
|
836
|
+
# Now, `device_results` is a sharded PyTree (one result per device).
|
|
837
|
+
# We aggregate them using jnp.sum, which JAX automatically compiles
|
|
838
|
+
# into a cross-device AllReduce operation.
|
|
839
|
+
|
|
840
|
+
if is_grad_fn:
|
|
841
|
+
# `device_results` is a (value, grad) tuple of sharded arrays
|
|
842
|
+
device_values, device_grads = device_results
|
|
843
|
+
|
|
844
|
+
# Replace psum with jnp.sum
|
|
845
|
+
global_value = jaxlib.numpy.sum(device_values, axis=0)
|
|
846
|
+
global_grad = jaxlib.tree_util.tree_map(
|
|
847
|
+
lambda g: jaxlib.numpy.sum(g, axis=0), device_grads
|
|
848
|
+
)
|
|
849
|
+
return global_value, global_grad
|
|
850
|
+
else:
|
|
851
|
+
# `device_results` is just the sharded values
|
|
852
|
+
return jaxlib.numpy.sum(device_results, axis=0)
|
|
853
|
+
|
|
854
|
+
# Compile the global function with jax.jit and specify shardings.
|
|
855
|
+
# `params` are replicated (available everywhere).
|
|
856
|
+
params_sharding = jaxlib.tree_util.tree_map(
|
|
857
|
+
lambda x: NamedSharding(self.mesh, P(*((None,) * x.ndim))),
|
|
858
|
+
self._params_template,
|
|
859
|
+
)
|
|
860
|
+
|
|
861
|
+
in_shardings = (params_sharding, self.sharding)
|
|
862
|
+
|
|
863
|
+
if is_grad_fn:
|
|
864
|
+
# Returns (value, grad), so out_sharding must be a 2-tuple.
|
|
865
|
+
# `value` is a replicated scalar -> P()
|
|
866
|
+
sharding_for_value = NamedSharding(self.mesh, P())
|
|
867
|
+
# `grad` is a replicated PyTree with the same structure as params.
|
|
868
|
+
sharding_for_grad = params_sharding
|
|
869
|
+
out_shardings = (sharding_for_value, sharding_for_grad)
|
|
870
|
+
else:
|
|
871
|
+
# Returns a single scalar value -> P()
|
|
872
|
+
out_shardings = NamedSharding(self.mesh, P())
|
|
873
|
+
|
|
874
|
+
compiled_fn = jaxlib.jit(
|
|
875
|
+
global_aggregated_fn,
|
|
876
|
+
# `tree` is a static argument, its value is compiled into the function.
|
|
877
|
+
static_argnums=(0,),
|
|
878
|
+
# Specify how inputs are sharded.
|
|
879
|
+
in_shardings=in_shardings,
|
|
880
|
+
# Specify how the output should be sharded.
|
|
881
|
+
out_shardings=out_shardings,
|
|
740
882
|
)
|
|
741
883
|
cache[cache_key] = compiled_fn # type: ignore
|
|
742
884
|
return cache[cache_key] # type: ignore
|
|
@@ -744,7 +886,7 @@ class DistributedContractor:
|
|
|
744
886
|
def value_and_grad(
|
|
745
887
|
self,
|
|
746
888
|
params: Tensor,
|
|
747
|
-
aggregate: bool = True,
|
|
889
|
+
# aggregate: bool = True,
|
|
748
890
|
op: Optional[Callable[[Tensor], Tensor]] = None,
|
|
749
891
|
output_dtype: Optional[str] = None,
|
|
750
892
|
) -> Tuple[Tensor, Tensor]:
|
|
@@ -753,8 +895,6 @@ class DistributedContractor:
|
|
|
753
895
|
|
|
754
896
|
:param params: Parameters for the `nodes_fn` input
|
|
755
897
|
:type params: Tensor
|
|
756
|
-
:param aggregate: Whether to aggregate (sum) the results across devices, defaults to True
|
|
757
|
-
:type aggregate: bool, optional
|
|
758
898
|
:param op: Optional post-processing function for the output, defaults to None (corresponding to `backend.real`)
|
|
759
899
|
op is a cache key, so dont directly pass lambda function for op
|
|
760
900
|
:type op: Optional[Callable[[Tensor], Tensor]], optional
|
|
@@ -766,24 +906,18 @@ class DistributedContractor:
|
|
|
766
906
|
fn_getter=self._get_device_sum_vg_fn,
|
|
767
907
|
op=op,
|
|
768
908
|
output_dtype=output_dtype,
|
|
909
|
+
is_grad_fn=True,
|
|
769
910
|
)
|
|
770
911
|
|
|
771
|
-
|
|
912
|
+
total_value, total_grad = compiled_vg_fn(
|
|
772
913
|
self.tree, params, self.batched_slice_indices
|
|
773
914
|
)
|
|
774
|
-
|
|
775
|
-
if aggregate:
|
|
776
|
-
total_value = backend.sum(device_values)
|
|
777
|
-
total_grad = jaxlib.tree_util.tree_map(
|
|
778
|
-
lambda x: backend.sum(x, axis=0), device_grads
|
|
779
|
-
)
|
|
780
|
-
return total_value, total_grad
|
|
781
|
-
return device_values, device_grads
|
|
915
|
+
return total_value, total_grad
|
|
782
916
|
|
|
783
917
|
def value(
|
|
784
918
|
self,
|
|
785
919
|
params: Tensor,
|
|
786
|
-
aggregate: bool = True,
|
|
920
|
+
# aggregate: bool = True,
|
|
787
921
|
op: Optional[Callable[[Tensor], Tensor]] = None,
|
|
788
922
|
output_dtype: Optional[str] = None,
|
|
789
923
|
) -> Tensor:
|
|
@@ -792,8 +926,6 @@ class DistributedContractor:
|
|
|
792
926
|
|
|
793
927
|
:param params: Parameters for the `nodes_fn` input
|
|
794
928
|
:type params: Tensor
|
|
795
|
-
:param aggregate: Whether to aggregate (sum) the results across devices, defaults to True
|
|
796
|
-
:type aggregate: bool, optional
|
|
797
929
|
:param op: Optional post-processing function for the output, defaults to None (corresponding to identity)
|
|
798
930
|
op is a cache key, so dont directly pass lambda function for op
|
|
799
931
|
:type op: Optional[Callable[[Tensor], Tensor]], optional
|
|
@@ -805,22 +937,17 @@ class DistributedContractor:
|
|
|
805
937
|
fn_getter=self._get_device_sum_v_fn,
|
|
806
938
|
op=op,
|
|
807
939
|
output_dtype=output_dtype,
|
|
940
|
+
is_grad_fn=False,
|
|
808
941
|
)
|
|
809
942
|
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
if aggregate:
|
|
813
|
-
return backend.sum(device_values)
|
|
814
|
-
return device_values
|
|
943
|
+
total_value = compiled_v_fn(self.tree, params, self.batched_slice_indices)
|
|
944
|
+
return total_value
|
|
815
945
|
|
|
816
946
|
def grad(
|
|
817
947
|
self,
|
|
818
948
|
params: Tensor,
|
|
819
|
-
aggregate: bool = True,
|
|
820
949
|
op: Optional[Callable[[Tensor], Tensor]] = None,
|
|
821
950
|
output_dtype: Optional[str] = None,
|
|
822
951
|
) -> Tensor:
|
|
823
|
-
_, grad = self.value_and_grad(
|
|
824
|
-
params, aggregate=aggregate, op=op, output_dtype=output_dtype
|
|
825
|
-
)
|
|
952
|
+
_, grad = self.value_and_grad(params, op=op, output_dtype=output_dtype)
|
|
826
953
|
return grad
|
|
@@ -132,13 +132,17 @@ def general_args_to_backend(
|
|
|
132
132
|
target_backend = backend
|
|
133
133
|
elif isinstance(target_backend, str):
|
|
134
134
|
target_backend = get_backend(target_backend)
|
|
135
|
+
try:
|
|
136
|
+
t = backend.tree_map(target_backend.from_dlpack, caps)
|
|
137
|
+
except TypeError:
|
|
138
|
+
t = backend.tree_map(target_backend.from_dlpack, args)
|
|
139
|
+
|
|
135
140
|
if dtype is None:
|
|
136
|
-
return
|
|
141
|
+
return t
|
|
137
142
|
if isinstance(dtype, str):
|
|
138
143
|
leaves, treedef = backend.tree_flatten(args)
|
|
139
144
|
dtype = [dtype for _ in range(len(leaves))]
|
|
140
145
|
dtype = backend.tree_unflatten(treedef, dtype)
|
|
141
|
-
t = backend.tree_map(target_backend.from_dlpack, caps)
|
|
142
146
|
t = backend.tree_map(target_backend.cast, t, dtype)
|
|
143
147
|
return t
|
|
144
148
|
|
|
@@ -69,12 +69,14 @@ def torch_interface(
|
|
|
69
69
|
@staticmethod
|
|
70
70
|
def forward(ctx: Any, *x: Any) -> Any: # type: ignore
|
|
71
71
|
# ctx.xdtype = [xi.dtype for xi in x]
|
|
72
|
-
ctx.
|
|
72
|
+
ctx.save_for_backward(*x)
|
|
73
|
+
x_detached = backend.tree_map(lambda s: s.detach(), x)
|
|
74
|
+
ctx.xdtype = backend.tree_map(lambda s: s.dtype, x_detached)
|
|
73
75
|
# (x, )
|
|
74
76
|
if len(ctx.xdtype) == 1:
|
|
75
77
|
ctx.xdtype = ctx.xdtype[0]
|
|
76
|
-
ctx.device = (backend.tree_flatten(
|
|
77
|
-
x = general_args_to_backend(
|
|
78
|
+
ctx.device = (backend.tree_flatten(x_detached)[0][0]).device
|
|
79
|
+
x = general_args_to_backend(x_detached, enable_dlpack=enable_dlpack)
|
|
78
80
|
y = fun(*x)
|
|
79
81
|
ctx.ydtype = backend.tree_map(lambda s: s.dtype, y)
|
|
80
82
|
if len(x) == 1:
|
|
@@ -88,6 +90,9 @@ def torch_interface(
|
|
|
88
90
|
|
|
89
91
|
@staticmethod
|
|
90
92
|
def backward(ctx: Any, *grad_y: Any) -> Any:
|
|
93
|
+
x = ctx.saved_tensors
|
|
94
|
+
x_detached = backend.tree_map(lambda s: s.detach(), x)
|
|
95
|
+
x_backend = general_args_to_backend(x_detached, enable_dlpack=enable_dlpack)
|
|
91
96
|
if len(grad_y) == 1:
|
|
92
97
|
grad_y = grad_y[0]
|
|
93
98
|
grad_y = backend.tree_map(lambda s: s.contiguous(), grad_y)
|
|
@@ -96,7 +101,12 @@ def torch_interface(
|
|
|
96
101
|
)
|
|
97
102
|
# grad_y = general_args_to_numpy(grad_y)
|
|
98
103
|
# grad_y = numpy_args_to_backend(grad_y, dtype=ctx.ydtype) # backend.dtype
|
|
99
|
-
|
|
104
|
+
if len(x_backend) == 1:
|
|
105
|
+
x_backend_for_vjp = x_backend[0]
|
|
106
|
+
else:
|
|
107
|
+
x_backend_for_vjp = x_backend
|
|
108
|
+
|
|
109
|
+
_, g = vjp_fun(x_backend_for_vjp, grad_y)
|
|
100
110
|
# a redundency due to current vjp API
|
|
101
111
|
|
|
102
112
|
r = general_args_to_backend(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tensorcircuit-nightly
|
|
3
|
-
Version: 1.4.0.
|
|
3
|
+
Version: 1.4.0.dev20251107
|
|
4
4
|
Summary: High performance unified quantum computing framework for the NISQ era
|
|
5
5
|
Author-email: TensorCircuit Authors <znfesnpbh@gmail.com>
|
|
6
6
|
License-Expression: Apache-2.0
|
{tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/HISTORY.md
RENAMED
|
File without changes
|
|
File without changes
|
{tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/MANIFEST.in
RENAMED
|
File without changes
|
{tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/README.md
RENAMED
|
File without changes
|
{tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/README_cn.md
RENAMED
|
File without changes
|
{tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/pyproject.toml
RENAMED
|
File without changes
|
{tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/setup.cfg
RENAMED
|
File without changes
|
{tensorcircuit_nightly-1.4.0.dev20251010 → tensorcircuit_nightly-1.4.0.dev20251107}/setup.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|