qadence 1.6.3__tar.gz → 1.7.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {qadence-1.6.3 → qadence-1.7.1}/PKG-INFO +7 -6
- {qadence-1.6.3 → qadence-1.7.1}/mkdocs.yml +4 -1
- {qadence-1.6.3 → qadence-1.7.1}/pyproject.toml +67 -65
- {qadence-1.6.3 → qadence-1.7.1}/qadence/__init__.py +2 -2
- qadence-1.7.1/qadence/backends/api.py +67 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/gpsr.py +1 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/pyqtorch/backend.py +1 -2
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/pyqtorch/config.py +5 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/pyqtorch/convert_ops.py +83 -10
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/utils.py +62 -7
- {qadence-1.6.3 → qadence-1.7.1}/qadence/blocks/abstract.py +7 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/blocks/embedding.py +17 -12
- {qadence-1.6.3 → qadence-1.7.1}/qadence/blocks/matrix.py +1 -1
- {qadence-1.6.3 → qadence-1.7.1}/qadence/blocks/primitive.py +1 -1
- {qadence-1.6.3 → qadence-1.7.1}/qadence/constructors/__init__.py +2 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/constructors/hamiltonians.py +38 -1
- {qadence-1.6.3 → qadence-1.7.1}/qadence/draw/utils.py +1 -1
- {qadence-1.6.3 → qadence-1.7.1}/qadence/execution.py +11 -3
- qadence-1.7.1/qadence/extensions.py +141 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/ml_tools/__init__.py +11 -3
- qadence-1.7.1/qadence/ml_tools/config.py +353 -0
- qadence-1.7.1/qadence/ml_tools/constructors.py +796 -0
- qadence-1.6.3/qadence/models/qnn.py → qadence-1.7.1/qadence/ml_tools/models.py +217 -40
- {qadence-1.6.3 → qadence-1.7.1}/qadence/ml_tools/printing.py +5 -2
- {qadence-1.6.3 → qadence-1.7.1}/qadence/ml_tools/saveload.py +42 -18
- {qadence-1.6.3 → qadence-1.7.1}/qadence/ml_tools/train_grad.py +48 -14
- {qadence-1.6.3 → qadence-1.7.1}/qadence/ml_tools/utils.py +2 -8
- qadence-1.6.3/qadence/models/quantum_model.py → qadence-1.7.1/qadence/model.py +178 -10
- {qadence-1.6.3 → qadence-1.7.1}/qadence/operations/ham_evo.py +10 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/overlap.py +1 -1
- {qadence-1.6.3 → qadence-1.7.1}/qadence/parameters.py +10 -1
- {qadence-1.6.3 → qadence-1.7.1}/qadence/register.py +98 -22
- {qadence-1.6.3 → qadence-1.7.1}/qadence/serialization.py +6 -6
- {qadence-1.6.3 → qadence-1.7.1}/qadence/types.py +44 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/utils.py +2 -8
- qadence-1.6.3/qadence/backends/api.py +0 -80
- qadence-1.6.3/qadence/extensions.py +0 -115
- qadence-1.6.3/qadence/finitediff.py +0 -47
- qadence-1.6.3/qadence/ml_tools/config.py +0 -72
- qadence-1.6.3/qadence/ml_tools/models.py +0 -320
- qadence-1.6.3/qadence/models/__init__.py +0 -7
- {qadence-1.6.3 → qadence-1.7.1}/.coveragerc +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/.github/dependabot.yml +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/.github/workflows/build_docs.yml +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/.github/workflows/dependabot.yml +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/.github/workflows/lint.yml +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/.github/workflows/test_all.yml +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/.github/workflows/test_examples.yml +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/.github/workflows/test_fast.yml +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/.gitignore +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/.pre-commit-config.yaml +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/LICENSE +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/MANIFEST.in +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/README.md +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/analog/__init__.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/analog/addressing.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/analog/constants.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/analog/device.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/analog/hamiltonian_terms.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/analog/parse_analog.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backend.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/__init__.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/braket/__init__.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/braket/backend.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/braket/config.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/braket/convert_ops.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/horqrux/__init__.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/horqrux/backend.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/horqrux/config.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/horqrux/convert_ops.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/jax_utils.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/pulser/__init__.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/pulser/backend.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/pulser/channels.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/pulser/cloud.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/pulser/config.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/pulser/convert_ops.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/pulser/devices.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/pulser/pulses.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/pulser/waveforms.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/backends/pyqtorch/__init__.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/blocks/__init__.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/blocks/analog.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/blocks/block_to_tensor.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/blocks/composite.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/blocks/manipulate.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/blocks/utils.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/circuit.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/constructors/ansatze.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/constructors/daqc/__init__.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/constructors/daqc/daqc.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/constructors/daqc/gen_parser.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/constructors/daqc/utils.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/constructors/feature_maps.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/constructors/iia.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/constructors/qft.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/constructors/rydberg_feature_maps.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/constructors/rydberg_hea.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/constructors/utils.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/decompose.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/divergences.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/draw/__init__.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/draw/assets/dark/measurement.png +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/draw/assets/dark/measurement.svg +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/draw/assets/light/measurement.png +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/draw/assets/light/measurement.svg +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/draw/themes.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/draw/vizbackend.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/engines/__init__.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/engines/differentiable_backend.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/engines/jax/__init__.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/engines/jax/differentiable_backend.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/engines/jax/differentiable_expectation.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/engines/torch/__init__.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/engines/torch/differentiable_backend.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/engines/torch/differentiable_expectation.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/exceptions/__init__.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/exceptions/exceptions.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/libs.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/log_config.yaml +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/logger.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/measurements/__init__.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/measurements/protocols.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/measurements/samples.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/measurements/shadow.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/measurements/tomography.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/measurements/utils.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/mitigations/__init__.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/mitigations/analog_zne.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/mitigations/protocols.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/mitigations/readout.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/ml_tools/data.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/ml_tools/optimize_step.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/ml_tools/parameters.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/ml_tools/tensors.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/ml_tools/train_no_grad.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/noise/__init__.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/noise/protocols.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/noise/readout.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/operations/__init__.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/operations/analog.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/operations/control_ops.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/operations/parametric.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/operations/primitive.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/protocols.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/py.typed +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/qubit_support.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/serial_expr_grammar.peg +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/states.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/transpile/__init__.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/transpile/apply_fn.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/transpile/block.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/transpile/circuit.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/transpile/digitalize.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/transpile/flatten.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/transpile/invert.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/qadence/transpile/transpile.py +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/renovate.json +0 -0
- {qadence-1.6.3 → qadence-1.7.1}/setup.py +0 -0
@@ -1,8 +1,8 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: qadence
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.7.1
|
4
4
|
Summary: Pasqal interface for circuit-based quantum computing SDKs
|
5
|
-
Author-email: Aleksander Wennersteen <aleksander.wennersteen@pasqal.com>, Gert-Jan Both <gert-jan.both@pasqal.com>, Niklas Heim <niklas.heim@pasqal.com>, Mario Dagrada <mario.dagrada@pasqal.com>, Vincent Elfving <vincent.elfving@pasqal.com>, Dominik Seitz <dominik.seitz@pasqal.com>, Roland Guichard <roland.guichard@pasqal.com>, "Joao P. Moutinho" <joao.moutinho@pasqal.com>, Vytautas Abramavicius <vytautas.abramavicius@pasqal.com>, Gergana Velikova <gergana.velikova@pasqal.com>, Eduardo Maschio <eduardo.maschio@pasqal.com>
|
5
|
+
Author-email: Aleksander Wennersteen <aleksander.wennersteen@pasqal.com>, Gert-Jan Both <gert-jan.both@pasqal.com>, Niklas Heim <niklas.heim@pasqal.com>, Mario Dagrada <mario.dagrada@pasqal.com>, Vincent Elfving <vincent.elfving@pasqal.com>, Dominik Seitz <dominik.seitz@pasqal.com>, Roland Guichard <roland.guichard@pasqal.com>, "Joao P. Moutinho" <joao.moutinho@pasqal.com>, Vytautas Abramavicius <vytautas.abramavicius@pasqal.com>, Gergana Velikova <gergana.velikova@pasqal.com>, Eduardo Maschio <eduardo.maschio@pasqal.com>, Smit Chaudhary <smit.chaudhary@pasqal.com>, Ignacio Fernández Graña <ignacio.fernandez-grana@pasqal.com>, Charles Moussa <charles.moussa@pasqal.com>
|
6
6
|
License: Apache 2.0
|
7
7
|
License-File: LICENSE
|
8
8
|
Classifier: License :: OSI Approved :: Apache Software License
|
@@ -22,10 +22,11 @@ Requires-Dist: matplotlib
|
|
22
22
|
Requires-Dist: nevergrad
|
23
23
|
Requires-Dist: numpy
|
24
24
|
Requires-Dist: openfermion
|
25
|
-
Requires-Dist: pyqtorch==1.2.
|
25
|
+
Requires-Dist: pyqtorch==1.2.5
|
26
26
|
Requires-Dist: pyyaml
|
27
27
|
Requires-Dist: rich
|
28
28
|
Requires-Dist: scipy
|
29
|
+
Requires-Dist: sympy<1.13
|
29
30
|
Requires-Dist: sympytorch>=0.1.2
|
30
31
|
Requires-Dist: tensorboard>=2.12.0
|
31
32
|
Requires-Dist: torch
|
@@ -53,9 +54,9 @@ Requires-Dist: qadence-libs; extra == 'libs'
|
|
53
54
|
Provides-Extra: protocols
|
54
55
|
Requires-Dist: qadence-protocols; extra == 'protocols'
|
55
56
|
Provides-Extra: pulser
|
56
|
-
Requires-Dist: pasqal-cloud==0.
|
57
|
-
Requires-Dist: pulser-core==0.
|
58
|
-
Requires-Dist: pulser-simulation==0.
|
57
|
+
Requires-Dist: pasqal-cloud==0.11.1; extra == 'pulser'
|
58
|
+
Requires-Dist: pulser-core==0.19.0; extra == 'pulser'
|
59
|
+
Requires-Dist: pulser-simulation==0.19.0; extra == 'pulser'
|
59
60
|
Provides-Extra: visualization
|
60
61
|
Requires-Dist: graphviz; extra == 'visualization'
|
61
62
|
Description-Content-Type: text/markdown
|
@@ -16,6 +16,7 @@ nav:
|
|
16
16
|
- Contents:
|
17
17
|
- Block system: content/block_system.md
|
18
18
|
- Parametric programs: content/parameters.md
|
19
|
+
- Time-dependent generators: content/time_dependent.md
|
19
20
|
- Quantum models: content/quantummodels.md
|
20
21
|
- Quantum registers: content/register.md
|
21
22
|
- State initialization: content/state_init.md
|
@@ -32,15 +33,17 @@ nav:
|
|
32
33
|
- Digital-analog quantum computing:
|
33
34
|
- tutorials/digital_analog_qc/index.md
|
34
35
|
- Basic operations on neutral-atoms: tutorials/digital_analog_qc/analog-basics.md
|
35
|
-
- Fitting a
|
36
|
+
- Fitting a function with analog blocks: tutorials/digital_analog_qc/analog-blocks-qcl.md
|
36
37
|
- Restricted local addressability: tutorials/digital_analog_qc/semi-local-addressing.md
|
37
38
|
- Pulse-level programming with Pulser: tutorials/digital_analog_qc/pulser-basic.md
|
39
|
+
- Fitting a function with a Hamiltonian ansatz: tutorials/digital_analog_qc/digital-analog-qcl.md
|
38
40
|
- Solve a QUBO problem: tutorials/digital_analog_qc/analog-qubo.md
|
39
41
|
- CNOT with interacting qubits: tutorials/digital_analog_qc/daqc-cnot.md
|
40
42
|
|
41
43
|
- Variational quantum algorithms:
|
42
44
|
- tutorials/qml/index.md
|
43
45
|
- Training tools: tutorials/qml/ml_tools.md
|
46
|
+
- Configuring a QNN: tutorials/qml/config_qnn.md
|
44
47
|
- Quantum circuit learning: tutorials/qml/qcl.md
|
45
48
|
- Solving MaxCut with QAOA: tutorials/qml/qaoa.md
|
46
49
|
- Solving a 1D ODE: tutorials/qml/dqc_1d.md
|
@@ -14,15 +14,18 @@ authors = [
|
|
14
14
|
{ name = "Vincent Elfving", email = "vincent.elfving@pasqal.com" },
|
15
15
|
{ name = "Dominik Seitz", email = "dominik.seitz@pasqal.com" },
|
16
16
|
{ name = "Roland Guichard", email = "roland.guichard@pasqal.com" },
|
17
|
-
{ name = "Joao P. Moutinho", email = "joao.moutinho@pasqal.com"},
|
17
|
+
{ name = "Joao P. Moutinho", email = "joao.moutinho@pasqal.com" },
|
18
18
|
{ name = "Vytautas Abramavicius", email = "vytautas.abramavicius@pasqal.com" },
|
19
19
|
{ name = "Gergana Velikova", email = "gergana.velikova@pasqal.com" },
|
20
20
|
{ name = "Eduardo Maschio", email = "eduardo.maschio@pasqal.com" },
|
21
|
+
{ name = "Smit Chaudhary", email = "smit.chaudhary@pasqal.com" },
|
22
|
+
{ name = "Ignacio Fernández Graña", email = "ignacio.fernandez-grana@pasqal.com" },
|
23
|
+
{ name = "Charles Moussa", email = "charles.moussa@pasqal.com" },
|
21
24
|
]
|
22
25
|
requires-python = ">=3.9"
|
23
|
-
license = {text = "Apache 2.0"}
|
24
|
-
version = "1.
|
25
|
-
classifiers=[
|
26
|
+
license = { text = "Apache 2.0" }
|
27
|
+
version = "1.7.1"
|
28
|
+
classifiers = [
|
26
29
|
"License :: OSI Approved :: Apache Software License",
|
27
30
|
"Programming Language :: Python",
|
28
31
|
"Programming Language :: Python :: 3",
|
@@ -37,6 +40,7 @@ dependencies = [
|
|
37
40
|
"numpy",
|
38
41
|
"torch",
|
39
42
|
"openfermion",
|
43
|
+
"sympy<1.13",
|
40
44
|
"sympytorch>=0.1.2",
|
41
45
|
"rich",
|
42
46
|
"tensorboard>=2.12.0",
|
@@ -44,7 +48,7 @@ dependencies = [
|
|
44
48
|
"jsonschema",
|
45
49
|
"nevergrad",
|
46
50
|
"scipy",
|
47
|
-
"pyqtorch==1.2.
|
51
|
+
"pyqtorch==1.2.5",
|
48
52
|
"pyyaml",
|
49
53
|
"matplotlib",
|
50
54
|
"Arpeggio==2.0.2",
|
@@ -55,13 +59,17 @@ allow-direct-references = true
|
|
55
59
|
allow-ambiguous-features = true
|
56
60
|
|
57
61
|
[project.optional-dependencies]
|
58
|
-
pulser = [
|
62
|
+
pulser = [
|
63
|
+
"pulser-core==0.19.0",
|
64
|
+
"pulser-simulation==0.19.0",
|
65
|
+
"pasqal-cloud==0.11.1",
|
66
|
+
]
|
59
67
|
braket = ["amazon-braket-sdk<1.71.2"]
|
60
68
|
visualization = [
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
69
|
+
"graphviz",
|
70
|
+
# FIXME: will be needed once we support latex labels
|
71
|
+
# "latex2svg @ git+https://github.com/Moonbase59/latex2svg.git#egg=latex2svg",
|
72
|
+
# "scour",
|
65
73
|
]
|
66
74
|
horqrux = [
|
67
75
|
"horqrux==0.6.0",
|
@@ -70,35 +78,31 @@ horqrux = [
|
|
70
78
|
"optax",
|
71
79
|
"jaxopt",
|
72
80
|
"einops",
|
73
|
-
"sympy2jax"
|
81
|
+
"sympy2jax",
|
82
|
+
]
|
74
83
|
protocols = ["qadence-protocols"]
|
75
84
|
libs = ["qadence-libs"]
|
76
85
|
dlprof = ["nvidia-pyindex", "nvidia-dlprof[pytorch]"]
|
77
|
-
all = [
|
78
|
-
|
79
|
-
"braket",
|
80
|
-
"visualization",
|
81
|
-
"protocols",
|
82
|
-
"libs",
|
83
|
-
]
|
86
|
+
all = ["pulser", "braket", "visualization", "protocols", "libs"]
|
87
|
+
|
84
88
|
|
85
89
|
[tool.hatch.envs.default]
|
86
90
|
dependencies = [
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
91
|
+
"flaky",
|
92
|
+
"hypothesis",
|
93
|
+
"pytest",
|
94
|
+
"pytest-cov",
|
95
|
+
"pytest-mypy",
|
96
|
+
"pytest-xdist",
|
97
|
+
"types-PyYAML",
|
98
|
+
"ipykernel",
|
99
|
+
"pre-commit",
|
100
|
+
"black",
|
101
|
+
"isort",
|
102
|
+
"ruff",
|
103
|
+
"pydocstringformatter",
|
100
104
|
]
|
101
|
-
features = ["pulser", "braket","visualization", "horqrux"]
|
105
|
+
features = ["pulser", "braket", "visualization", "horqrux"]
|
102
106
|
|
103
107
|
[tool.hatch.envs.default.scripts]
|
104
108
|
test = "pytest -n auto --cov-report lcov --cov-config=pyproject.toml --cov=qadence --cov=tests --ignore=./tests/test_examples.py {args}"
|
@@ -108,34 +112,32 @@ test-docs = "mkdocs build --clean --strict"
|
|
108
112
|
test-all = "pytest -n auto {args} && mkdocs build --clean --strict"
|
109
113
|
|
110
114
|
[tool.pytest.ini_options]
|
111
|
-
markers = [
|
112
|
-
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
113
|
-
]
|
115
|
+
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
|
114
116
|
testpaths = ["tests"]
|
115
117
|
addopts = """-vvv"""
|
116
118
|
xfail_strict = true
|
117
119
|
filterwarnings = [
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
120
|
+
"ignore:Call to deprecated create function FieldDescriptor",
|
121
|
+
"ignore:Call to deprecated create function Descriptor",
|
122
|
+
"ignore:Call to deprecated create function EnumDescriptor",
|
123
|
+
"ignore:Call to deprecated create function EnumValueDescriptor",
|
124
|
+
"ignore:Call to deprecated create function FileDescriptor",
|
125
|
+
"ignore:Call to deprecated create function OneofDescriptor",
|
126
|
+
"ignore:distutils Version classes are deprecated.",
|
127
|
+
"ignore::DeprecationWarning",
|
126
128
|
]
|
127
129
|
|
128
130
|
|
129
131
|
[tool.hatch.envs.docs]
|
130
132
|
dependencies = [
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
133
|
+
"mkdocs",
|
134
|
+
"mkdocs-material",
|
135
|
+
"mkdocstrings",
|
136
|
+
"mkdocstrings-python",
|
137
|
+
"mkdocs-section-index",
|
138
|
+
"mkdocs-exclude",
|
139
|
+
"markdown-exec",
|
140
|
+
"mike",
|
139
141
|
]
|
140
142
|
features = ["pulser", "braket", "horqrux", "visualization"]
|
141
143
|
|
@@ -151,12 +153,12 @@ features = ["all"]
|
|
151
153
|
|
152
154
|
[tool.hatch.build.targets.sdist]
|
153
155
|
exclude = [
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
156
|
+
"/.gitignore",
|
157
|
+
"/.gitlab-ci-yml",
|
158
|
+
"/.pre-commit-config.yml",
|
159
|
+
"/tests",
|
160
|
+
"/docs",
|
161
|
+
"/examples",
|
160
162
|
]
|
161
163
|
|
162
164
|
[tool.hatch.build.targets.wheel]
|
@@ -167,15 +169,11 @@ branch = true
|
|
167
169
|
parallel = true
|
168
170
|
|
169
171
|
[tool.coverage.report]
|
170
|
-
exclude_lines = [
|
171
|
-
"no cov",
|
172
|
-
"if __name__ == .__main__.:",
|
173
|
-
"if TYPE_CHECKING:",
|
174
|
-
]
|
172
|
+
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]
|
175
173
|
|
176
174
|
[tool.ruff]
|
177
175
|
select = ["E", "F", "I", "Q"]
|
178
|
-
extend-ignore = ["F841","F403"]
|
176
|
+
extend-ignore = ["F841", "F403"]
|
179
177
|
line-length = 100
|
180
178
|
|
181
179
|
[tool.ruff.isort]
|
@@ -183,8 +181,12 @@ required-imports = ["from __future__ import annotations"]
|
|
183
181
|
|
184
182
|
[tool.ruff.per-file-ignores]
|
185
183
|
"__init__.py" = ["F401", "E402"]
|
186
|
-
"qadence/operations/primitive.py" = [
|
187
|
-
"
|
184
|
+
"qadence/operations/primitive.py" = [
|
185
|
+
"E742",
|
186
|
+
] # Avoid ambiguous class name warning for identity.
|
187
|
+
"qadence/backends/horqrux/convert_ops.py" = [
|
188
|
+
"E741",
|
189
|
+
] # Avoid ambiguous class name warning for 0.
|
188
190
|
"examples/*" = ["E402"] # Allow torch seed to be set before qadence imports
|
189
191
|
|
190
192
|
[tool.ruff.mccabe]
|
@@ -49,7 +49,7 @@ from .exceptions import *
|
|
49
49
|
from .execution import *
|
50
50
|
from .measurements import *
|
51
51
|
from .ml_tools import *
|
52
|
-
from .
|
52
|
+
from .model import *
|
53
53
|
from .noise import *
|
54
54
|
from .operations import *
|
55
55
|
from .overlap import *
|
@@ -82,7 +82,7 @@ list_of_submodules = [
|
|
82
82
|
".execution",
|
83
83
|
".measurements",
|
84
84
|
".ml_tools",
|
85
|
-
".
|
85
|
+
".model",
|
86
86
|
".operations",
|
87
87
|
".overlap",
|
88
88
|
".parameters",
|
@@ -0,0 +1,67 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from qadence.backend import Backend, BackendConfiguration
|
4
|
+
from qadence.engines.differentiable_backend import DifferentiableBackend
|
5
|
+
from qadence.extensions import (
|
6
|
+
import_backend,
|
7
|
+
import_config,
|
8
|
+
import_engine,
|
9
|
+
set_backend_config,
|
10
|
+
)
|
11
|
+
from qadence.logger import get_logger
|
12
|
+
from qadence.types import BackendName, DiffMode
|
13
|
+
|
14
|
+
__all__ = ["backend_factory", "config_factory"]
|
15
|
+
logger = get_logger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
def backend_factory(
|
19
|
+
backend: BackendName | str,
|
20
|
+
diff_mode: DiffMode | str | None = None,
|
21
|
+
configuration: BackendConfiguration | dict | None = None,
|
22
|
+
) -> Backend | DifferentiableBackend:
|
23
|
+
backend_inst: Backend | DifferentiableBackend
|
24
|
+
try:
|
25
|
+
BackendCls = import_backend(backend)
|
26
|
+
default_config = BackendCls.default_configuration()
|
27
|
+
if configuration is None:
|
28
|
+
configuration = default_config
|
29
|
+
elif isinstance(configuration, dict):
|
30
|
+
configuration = config_factory(backend, configuration)
|
31
|
+
else:
|
32
|
+
# NOTE: types have to match exactly, hence we use `type`
|
33
|
+
if not isinstance(configuration, type(BackendCls.default_configuration())):
|
34
|
+
expected_cfg = BackendCls.default_configuration()
|
35
|
+
raise ValueError(
|
36
|
+
f"Given config class '{type(configuration)}' does not match the backend",
|
37
|
+
f" class: '{BackendCls}'. Expected: '{type(expected_cfg)}.'",
|
38
|
+
)
|
39
|
+
|
40
|
+
# Instantiate the backend
|
41
|
+
backend_inst = BackendCls( # type: ignore[operator]
|
42
|
+
config=configuration
|
43
|
+
if configuration is not None
|
44
|
+
else BackendCls.default_configuration()
|
45
|
+
)
|
46
|
+
set_backend_config(backend_inst, diff_mode)
|
47
|
+
# Wrap the quantum Backend in a DifferentiableBackend if a diff_mode is passed.
|
48
|
+
if diff_mode is not None:
|
49
|
+
diff_backend_cls = import_engine(backend_inst.engine)
|
50
|
+
backend_inst = diff_backend_cls(backend=backend_inst, diff_mode=DiffMode(diff_mode)) # type: ignore[operator]
|
51
|
+
return backend_inst
|
52
|
+
except Exception as e:
|
53
|
+
msg = f"The requested backend '{backend}' is either not installed\
|
54
|
+
or could not be imported due to {e}."
|
55
|
+
logger.error(msg)
|
56
|
+
raise Exception(msg)
|
57
|
+
# Set backend configurations which depend on the differentiation mode
|
58
|
+
|
59
|
+
|
60
|
+
def config_factory(backend_name: BackendName | str, config: dict) -> BackendConfiguration:
|
61
|
+
cfg: BackendConfiguration
|
62
|
+
try:
|
63
|
+
BackendConfigCls = import_config(backend_name)
|
64
|
+
cfg = BackendConfigCls(**config) # type: ignore[operator]
|
65
|
+
except Exception as e:
|
66
|
+
logger.debug(f"Unable to import config for backend {backend_name} due to {e}.")
|
67
|
+
return cfg
|
@@ -12,6 +12,7 @@ from torch import Tensor
|
|
12
12
|
from qadence.backend import Backend as BackendInterface
|
13
13
|
from qadence.backend import ConvertedCircuit, ConvertedObservable
|
14
14
|
from qadence.backends.utils import (
|
15
|
+
infer_batchsize,
|
15
16
|
pyqify,
|
16
17
|
to_list_of_dicts,
|
17
18
|
unpyqify,
|
@@ -31,7 +32,6 @@ from qadence.transpile import (
|
|
31
32
|
transpile,
|
32
33
|
)
|
33
34
|
from qadence.types import BackendName, Endianness, Engine
|
34
|
-
from qadence.utils import infer_batchsize
|
35
35
|
|
36
36
|
from .config import Configuration, default_passes
|
37
37
|
from .convert_ops import convert_block
|
@@ -165,7 +165,6 @@ class Backend(BackendInterface):
|
|
165
165
|
"Looping expectation does not make sense with batched initial state. "
|
166
166
|
"Define your initial state with `batch_size=1`"
|
167
167
|
)
|
168
|
-
|
169
168
|
list_expvals = []
|
170
169
|
observables = observable if isinstance(observable, list) else [observable]
|
171
170
|
for vals in to_list_of_dicts(param_values):
|
@@ -4,6 +4,8 @@ from dataclasses import dataclass
|
|
4
4
|
from logging import getLogger
|
5
5
|
from typing import Callable
|
6
6
|
|
7
|
+
from pyqtorch.utils import SolverType
|
8
|
+
|
7
9
|
from qadence.analog import add_background_hamiltonian
|
8
10
|
from qadence.backend import BackendConfiguration
|
9
11
|
from qadence.transpile import (
|
@@ -41,6 +43,9 @@ class Configuration(BackendConfiguration):
|
|
41
43
|
algo_hevo: AlgoHEvo = AlgoHEvo.EXP
|
42
44
|
"""Determine which kind of Hamiltonian evolution algorithm to use."""
|
43
45
|
|
46
|
+
ode_solver: SolverType = SolverType.DP5_SE
|
47
|
+
"""Determine which ODE solver to use for time-dependent blocks."""
|
48
|
+
|
44
49
|
n_steps_hevo: int = 100
|
45
50
|
"""Default number of steps for the Hamiltonian evolution."""
|
46
51
|
|
@@ -6,8 +6,10 @@ from typing import Any, Sequence, Tuple
|
|
6
6
|
|
7
7
|
import pyqtorch as pyq
|
8
8
|
import sympy
|
9
|
+
import torch
|
9
10
|
from pyqtorch.apply import apply_operator
|
10
11
|
from pyqtorch.matrices import _dagger
|
12
|
+
from pyqtorch.time_dependent.sesolve import sesolve
|
11
13
|
from pyqtorch.utils import is_diag
|
12
14
|
from torch import (
|
13
15
|
Tensor,
|
@@ -26,6 +28,8 @@ from torch.nn import Module
|
|
26
28
|
|
27
29
|
from qadence.backends.utils import (
|
28
30
|
finitediff,
|
31
|
+
pyqify,
|
32
|
+
unpyqify,
|
29
33
|
)
|
30
34
|
from qadence.blocks import (
|
31
35
|
AbstractBlock,
|
@@ -38,8 +42,12 @@ from qadence.blocks import (
|
|
38
42
|
ScaleBlock,
|
39
43
|
TimeEvolutionBlock,
|
40
44
|
)
|
41
|
-
from qadence.blocks.block_to_tensor import
|
45
|
+
from qadence.blocks.block_to_tensor import (
|
46
|
+
_block_to_tensor_embedded,
|
47
|
+
block_to_tensor,
|
48
|
+
)
|
42
49
|
from qadence.blocks.primitive import ProjectorBlock
|
50
|
+
from qadence.blocks.utils import parameters
|
43
51
|
from qadence.operations import (
|
44
52
|
U,
|
45
53
|
multi_qubit_gateset,
|
@@ -177,6 +185,7 @@ class PyQHamiltonianEvolution(Module):
|
|
177
185
|
self.param_names = config.get_param_name(block)
|
178
186
|
self.block = block
|
179
187
|
self.hmat: Tensor
|
188
|
+
self.config = config
|
180
189
|
|
181
190
|
if isinstance(block.generator, AbstractBlock) and not block.generator.is_parametric:
|
182
191
|
hmat = block_to_tensor(
|
@@ -253,7 +262,8 @@ class PyQHamiltonianEvolution(Module):
|
|
253
262
|
"""Approximate jacobian of the evolved operator with respect to time evolution."""
|
254
263
|
return finitediff(
|
255
264
|
lambda t: self._unitary(time_evolution=t, hamiltonian=self._hamiltonian(self, values)),
|
256
|
-
values[self.param_names[0]],
|
265
|
+
values[self.param_names[0]].reshape(-1, 1),
|
266
|
+
(0,),
|
257
267
|
)
|
258
268
|
|
259
269
|
def jacobian_generator(self, values: dict[str, Tensor]) -> Tensor:
|
@@ -280,25 +290,88 @@ class PyQHamiltonianEvolution(Module):
|
|
280
290
|
lambda v: self._unitary(
|
281
291
|
time_evolution=self._time_evolution(values), hamiltonian=_generator(v)
|
282
292
|
),
|
283
|
-
values[self.param_names[1]],
|
293
|
+
values[self.param_names[1]].reshape(-1, 1),
|
294
|
+
(0,),
|
284
295
|
)
|
285
296
|
|
286
297
|
def dagger(self, values: dict[str, Tensor]) -> Tensor:
|
287
298
|
"""Dagger of the evolved operator given the current parameter values."""
|
288
299
|
return _dagger(self.unitary(values))
|
289
300
|
|
301
|
+
def _get_time_parameter(self) -> str:
|
302
|
+
# get unique time parameters
|
303
|
+
unique_time_params = set()
|
304
|
+
for p in parameters(self.block.generator): # type: ignore [arg-type]
|
305
|
+
if getattr(p, "is_time", False):
|
306
|
+
unique_time_params.add(str(p))
|
307
|
+
|
308
|
+
if len(unique_time_params) > 1:
|
309
|
+
raise Exception("Only a single time parameter is supported.")
|
310
|
+
|
311
|
+
return unique_time_params.pop()
|
312
|
+
|
290
313
|
def forward(
|
291
314
|
self,
|
292
315
|
state: Tensor,
|
293
316
|
values: dict[str, Tensor],
|
294
317
|
) -> Tensor:
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
318
|
+
if getattr(self.block.generator, "is_time_dependent", False): # type: ignore [union-attr]
|
319
|
+
|
320
|
+
def Ht(t: Tensor | float) -> Tensor:
|
321
|
+
# values dict has to change with new value of t
|
322
|
+
# initial value of a feature parameter inside generator block
|
323
|
+
# has to be inferred here
|
324
|
+
new_vals = dict()
|
325
|
+
for str_expr, val in values.items():
|
326
|
+
expr = sympy.sympify(str_expr)
|
327
|
+
t_symb = sympy.Symbol(self._get_time_parameter())
|
328
|
+
free_symbols = expr.free_symbols
|
329
|
+
if t_symb in free_symbols:
|
330
|
+
# create substitution list for time and feature params
|
331
|
+
subs_list = [(t_symb, t)]
|
332
|
+
|
333
|
+
if len(free_symbols) > 1:
|
334
|
+
# get feature param symbols
|
335
|
+
feat_symbols = free_symbols.difference(set([t_symb]))
|
336
|
+
|
337
|
+
# get feature param values
|
338
|
+
feat_vals = values["orig_param_values"]
|
339
|
+
|
340
|
+
# update substitution list with feature param values
|
341
|
+
for fs in feat_symbols:
|
342
|
+
subs_list.append((fs, feat_vals[str(fs)]))
|
343
|
+
|
344
|
+
# evaluate expression with new time param value
|
345
|
+
new_vals[str_expr] = torch.tensor(float(expr.subs(subs_list)))
|
346
|
+
else:
|
347
|
+
# expression doesn't contain time parameter - copy it as is
|
348
|
+
new_vals[str_expr] = val
|
349
|
+
|
350
|
+
# get matrix form of generator
|
351
|
+
hmat = _block_to_tensor_embedded(
|
352
|
+
self.block.generator, # type: ignore[arg-type]
|
353
|
+
values=new_vals,
|
354
|
+
qubit_support=self.qubit_support,
|
355
|
+
use_full_support=False,
|
356
|
+
device=self.device,
|
357
|
+
).squeeze(0)
|
358
|
+
|
359
|
+
return hmat
|
360
|
+
|
361
|
+
tsave = torch.linspace(0, self.block.duration, self.config.n_steps_hevo) # type: ignore [attr-defined]
|
362
|
+
result = pyqify(
|
363
|
+
sesolve(Ht, unpyqify(state).T[:, 0:1], tsave, self.config.ode_solver).states[-1].T
|
364
|
+
)
|
365
|
+
else:
|
366
|
+
result = apply_operator(
|
367
|
+
state,
|
368
|
+
self.unitary(values),
|
369
|
+
self.qubit_support,
|
370
|
+
self.n_qubits,
|
371
|
+
self.batch_size,
|
372
|
+
)
|
373
|
+
|
374
|
+
return result
|
302
375
|
|
303
376
|
@property
|
304
377
|
def device(self) -> torch_device:
|
@@ -20,8 +20,8 @@ from torch import (
|
|
20
20
|
rand,
|
21
21
|
)
|
22
22
|
|
23
|
-
from qadence.types import ParamDictType
|
24
|
-
from qadence.utils import
|
23
|
+
from qadence.types import Endianness, ParamDictType
|
24
|
+
from qadence.utils import int_to_basis, is_qadence_shape
|
25
25
|
|
26
26
|
FINITE_DIFF_EPS = 1e-06
|
27
27
|
# Dict of NumPy dtype -> torch dtype (when the correspondence exists)
|
@@ -98,10 +98,11 @@ def to_list_of_dicts(param_values: ParamDictType) -> list[ParamDictType]:
|
|
98
98
|
if not param_values:
|
99
99
|
return [param_values]
|
100
100
|
|
101
|
-
max_batch_size = max(p.size()[0] for p in param_values.values())
|
101
|
+
max_batch_size = max(p.size()[0] for p in param_values.values() if isinstance(p, Tensor))
|
102
102
|
batched_values = {
|
103
103
|
k: (v if v.size()[0] == max_batch_size else v.repeat(max_batch_size, 1))
|
104
104
|
for k, v in param_values.items()
|
105
|
+
if isinstance(v, Tensor)
|
105
106
|
}
|
106
107
|
|
107
108
|
return [{k: v[i] for k, v in batched_values.items()} for i in range(max_batch_size)]
|
@@ -143,17 +144,71 @@ def validate_state(state: Tensor, n_qubits: int) -> None:
|
|
143
144
|
)
|
144
145
|
|
145
146
|
|
146
|
-
def infer_batchsize(param_values:
|
147
|
+
def infer_batchsize(param_values: dict[str, Tensor] = None) -> int:
|
147
148
|
"""Infer the batch_size through the length of the parameter tensors."""
|
148
|
-
|
149
|
+
try:
|
150
|
+
return (
|
151
|
+
max(
|
152
|
+
[
|
153
|
+
len(tensor_or_dict)
|
154
|
+
for tensor_or_dict in param_values.values()
|
155
|
+
if isinstance(tensor_or_dict, Tensor)
|
156
|
+
]
|
157
|
+
)
|
158
|
+
if param_values
|
159
|
+
else 1
|
160
|
+
)
|
161
|
+
except Exception:
|
162
|
+
return 1
|
149
163
|
|
150
164
|
|
151
165
|
# The following functions can be used to compute potentially higher order gradients using pyqtorch's
|
152
166
|
# native 'jacobian' methods.
|
153
167
|
|
154
168
|
|
155
|
-
def finitediff(
|
156
|
-
|
169
|
+
def finitediff(
|
170
|
+
f: Callable,
|
171
|
+
x: Tensor,
|
172
|
+
derivative_indices: tuple[int, ...],
|
173
|
+
eps: float = None,
|
174
|
+
) -> Tensor:
|
175
|
+
"""
|
176
|
+
Compute the finite difference of a function at a point.
|
177
|
+
|
178
|
+
Args:
|
179
|
+
f: The function to differentiate.
|
180
|
+
x: Input of size `(batch_size, input_size)`.
|
181
|
+
derivative_indices: Which *input* to differentiate (i.e. which variable x[:,i])
|
182
|
+
eps: finite difference spacing (uses `torch.finfo(x.dtype).eps ** (1 / (2 + order))`
|
183
|
+
as default)
|
184
|
+
|
185
|
+
Returns:
|
186
|
+
(Tensor): The finite difference of the function at the point `x`.
|
187
|
+
"""
|
188
|
+
|
189
|
+
if eps is None:
|
190
|
+
order = len(derivative_indices)
|
191
|
+
eps = torch.finfo(x.dtype).eps ** (1 / (2 + order))
|
192
|
+
|
193
|
+
# compute derivative direction vector(s)
|
194
|
+
eps = torch.as_tensor(eps, dtype=x.dtype)
|
195
|
+
_eps = 1 / eps # type: ignore[operator]
|
196
|
+
ev = torch.zeros_like(x)
|
197
|
+
i = derivative_indices[0]
|
198
|
+
ev[:, i] += eps
|
199
|
+
|
200
|
+
# recursive finite differencing for higher order than 3 / mixed derivatives
|
201
|
+
if len(derivative_indices) > 3 or len(set(derivative_indices)) > 1:
|
202
|
+
di = derivative_indices[1:]
|
203
|
+
return (finitediff(f, x + ev, di) - finitediff(f, x - ev, di)) * _eps / 2
|
204
|
+
elif len(derivative_indices) == 3:
|
205
|
+
return (f(x + 2 * ev) - 2 * f(x + ev) + 2 * f(x - ev) - f(x - 2 * ev)) * _eps**3 / 2
|
206
|
+
elif len(derivative_indices) == 2:
|
207
|
+
return (f(x + ev) + f(x - ev) - 2 * f(x)) * _eps**2
|
208
|
+
elif len(derivative_indices) == 1:
|
209
|
+
return (f(x + ev) - f(x - ev)) * _eps / 2
|
210
|
+
else:
|
211
|
+
raise ValueError("If you see this error there is a bug in the `finitediff` function.")
|
157
212
|
|
158
213
|
|
159
214
|
def finitediff_sampling(
|