qadence 1.10.3__tar.gz → 1.11.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {qadence-1.10.3 → qadence-1.11.0}/PKG-INFO +14 -11
- {qadence-1.10.3 → qadence-1.11.0}/README.md +9 -6
- {qadence-1.10.3 → qadence-1.11.0}/mkdocs.yml +9 -3
- {qadence-1.10.3 → qadence-1.11.0}/pyproject.toml +5 -5
- {qadence-1.10.3 → qadence-1.11.0}/qadence/blocks/block_to_tensor.py +21 -24
- {qadence-1.10.3 → qadence-1.11.0}/qadence/constructors/__init__.py +7 -1
- {qadence-1.10.3 → qadence-1.11.0}/qadence/constructors/hamiltonians.py +96 -9
- {qadence-1.10.3 → qadence-1.11.0}/qadence/mitigations/analog_zne.py +6 -2
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/__init__.py +2 -2
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/callbacks/callback.py +80 -50
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/callbacks/callbackmanager.py +3 -2
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/callbacks/writer_registry.py +3 -2
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/config.py +66 -5
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/constructors.py +9 -62
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/data.py +4 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/models.py +69 -4
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/optimize_step.py +1 -2
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/train_utils/__init__.py +3 -1
- qadence-1.11.0/qadence/ml_tools/train_utils/accelerator.py +480 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/train_utils/config_manager.py +7 -7
- qadence-1.11.0/qadence/ml_tools/train_utils/distribution.py +209 -0
- qadence-1.11.0/qadence/ml_tools/train_utils/execution.py +421 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/trainer.py +188 -100
- {qadence-1.10.3 → qadence-1.11.0}/qadence/types.py +7 -11
- {qadence-1.10.3 → qadence-1.11.0}/qadence/utils.py +45 -0
- {qadence-1.10.3 → qadence-1.11.0}/renovate.json +0 -1
- {qadence-1.10.3 → qadence-1.11.0}/.coveragerc +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/.github/ISSUE_TEMPLATE/bug-report.yml +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/.github/ISSUE_TEMPLATE/config.yml +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/.github/ISSUE_TEMPLATE/new-feature.yml +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/.github/workflows/build_docs.yml +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/.github/workflows/lint.yml +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/.github/workflows/test_all.yml +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/.github/workflows/test_examples.yml +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/.github/workflows/test_fast.yml +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/.gitignore +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/.pre-commit-config.yaml +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/LICENSE +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/MANIFEST.in +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/__init__.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/analog/__init__.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/analog/addressing.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/analog/constants.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/analog/device.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/analog/hamiltonian_terms.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/analog/parse_analog.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backend.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/__init__.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/api.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/gpsr.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/horqrux/__init__.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/horqrux/backend.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/horqrux/config.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/horqrux/convert_ops.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/jax_utils.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/pulser/__init__.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/pulser/backend.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/pulser/channels.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/pulser/cloud.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/pulser/config.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/pulser/convert_ops.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/pulser/devices.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/pulser/pulses.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/pulser/waveforms.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/pyqtorch/__init__.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/pyqtorch/backend.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/pyqtorch/config.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/pyqtorch/convert_ops.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/backends/utils.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/blocks/__init__.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/blocks/abstract.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/blocks/analog.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/blocks/composite.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/blocks/embedding.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/blocks/manipulate.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/blocks/matrix.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/blocks/primitive.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/blocks/utils.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/circuit.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/constructors/ala.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/constructors/daqc/__init__.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/constructors/daqc/daqc.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/constructors/daqc/gen_parser.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/constructors/daqc/utils.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/constructors/feature_maps.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/constructors/hea.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/constructors/iia.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/constructors/qft.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/constructors/rydberg_feature_maps.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/constructors/rydberg_hea.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/constructors/utils.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/decompose.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/divergences.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/draw/__init__.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/draw/assets/dark/measurement.png +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/draw/assets/dark/measurement.svg +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/draw/assets/light/measurement.png +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/draw/assets/light/measurement.svg +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/draw/themes.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/draw/utils.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/draw/vizbackend.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/engines/__init__.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/engines/differentiable_backend.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/engines/jax/__init__.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/engines/jax/differentiable_backend.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/engines/jax/differentiable_expectation.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/engines/torch/__init__.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/engines/torch/differentiable_backend.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/engines/torch/differentiable_expectation.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/exceptions/__init__.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/exceptions/exceptions.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/execution.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/extensions.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/libs.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/log_config.yaml +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/logger.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/measurements/__init__.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/measurements/protocols.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/measurements/samples.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/measurements/shadow.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/measurements/tomography.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/measurements/utils.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/mitigations/__init__.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/mitigations/protocols.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/mitigations/readout.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/callbacks/__init__.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/callbacks/saveload.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/information/__init__.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/information/information_content.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/loss/__init__.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/loss/loss.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/parameters.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/stages.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/tensors.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/train_utils/base_trainer.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/ml_tools/utils.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/model.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/noise/__init__.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/noise/protocols.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/operations/__init__.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/operations/analog.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/operations/control_ops.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/operations/ham_evo.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/operations/parametric.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/operations/primitive.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/overlap.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/parameters.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/pasqal_cloud_connection.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/protocols.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/py.typed +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/qubit_support.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/register.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/serial_expr_grammar.peg +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/serialization.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/states.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/transpile/__init__.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/transpile/apply_fn.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/transpile/block.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/transpile/circuit.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/transpile/digitalize.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/transpile/flatten.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/transpile/invert.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/transpile/noise.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/qadence/transpile/transpile.py +0 -0
- {qadence-1.10.3 → qadence-1.11.0}/setup.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: qadence
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.11.0
|
4
4
|
Summary: Pasqal interface for circuit-based quantum computing SDKs
|
5
5
|
Author-email: Aleksander Wennersteen <aleksander.wennersteen@pasqal.com>, Gert-Jan Both <gert-jan.both@pasqal.com>, Niklas Heim <niklas.heim@pasqal.com>, Mario Dagrada <mario.dagrada@pasqal.com>, Vincent Elfving <vincent.elfving@pasqal.com>, Dominik Seitz <dominik.seitz@pasqal.com>, Roland Guichard <roland.guichard@pasqal.com>, "Joao P. Moutinho" <joao.moutinho@pasqal.com>, Vytautas Abramavicius <vytautas.abramavicius@pasqal.com>, Gergana Velikova <gergana.velikova@pasqal.com>, Eduardo Maschio <eduardo.maschio@pasqal.com>, Smit Chaudhary <smit.chaudhary@pasqal.com>, Ignacio Fernández Graña <ignacio.fernandez-grana@pasqal.com>, Charles Moussa <charles.moussa@pasqal.com>, Giorgio Tosti Balducci <giorgio.tosti-balducci@pasqal.com>, Daniele Cucurachi <daniele.cucurachi@pasqal.com>, Pim Venderbosch <pim.venderbosch@pasqal.com>, Manu Lahariya <manu.lahariya@pasqal.com>
|
6
6
|
License: Apache 2.0
|
@@ -23,7 +23,7 @@ Requires-Dist: nevergrad
|
|
23
23
|
Requires-Dist: numpy
|
24
24
|
Requires-Dist: openfermion
|
25
25
|
Requires-Dist: pasqal-cloud
|
26
|
-
Requires-Dist: pyqtorch==1.7.
|
26
|
+
Requires-Dist: pyqtorch==1.7.1
|
27
27
|
Requires-Dist: pyyaml
|
28
28
|
Requires-Dist: rich
|
29
29
|
Requires-Dist: scipy
|
@@ -55,9 +55,9 @@ Requires-Dist: mlflow; extra == 'mlflow'
|
|
55
55
|
Provides-Extra: protocols
|
56
56
|
Requires-Dist: qadence-protocols; extra == 'protocols'
|
57
57
|
Provides-Extra: pulser
|
58
|
-
Requires-Dist: pasqal-cloud==0.
|
59
|
-
Requires-Dist: pulser-core==1.
|
60
|
-
Requires-Dist: pulser-simulation==1.
|
58
|
+
Requires-Dist: pasqal-cloud==0.13.0; extra == 'pulser'
|
59
|
+
Requires-Dist: pulser-core==1.3.0; extra == 'pulser'
|
60
|
+
Requires-Dist: pulser-simulation==1.3.0; extra == 'pulser'
|
61
61
|
Provides-Extra: visualization
|
62
62
|
Requires-Dist: graphviz; extra == 'visualization'
|
63
63
|
Description-Content-Type: text/markdown
|
@@ -202,12 +202,15 @@ Users also report problems running Hatch on Windows, we suggest using WSL2.
|
|
202
202
|
If you use Qadence for a publication, we kindly ask you to cite our work using the following BibTex entry:
|
203
203
|
|
204
204
|
```latex
|
205
|
-
@article{
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
205
|
+
@article{qadence2025,
|
206
|
+
author = {Seitz, Dominik and Heim, Niklas and Moutinho, João and Guichard, Roland and Abramavicius, Vytautas and Wennersteen, Aleksander and Both, Gert-Jan and Quelle, Anton and Groot, Caroline and Velikova, Gergana and Elfving, Vincent and Dagrada, Mario},
|
207
|
+
year = {2025},
|
208
|
+
month = {01},
|
209
|
+
pages = {1-14},
|
210
|
+
title = {Qadence: a differentiable interface for digital and analog programs},
|
211
|
+
volume = {PP},
|
212
|
+
journal = {IEEE Software},
|
213
|
+
doi = {10.1109/MS.2025.3536607}
|
211
214
|
}
|
212
215
|
```
|
213
216
|
|
@@ -138,12 +138,15 @@ Users also report problems running Hatch on Windows, we suggest using WSL2.
|
|
138
138
|
If you use Qadence for a publication, we kindly ask you to cite our work using the following BibTex entry:
|
139
139
|
|
140
140
|
```latex
|
141
|
-
@article{
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
141
|
+
@article{qadence2025,
|
142
|
+
author = {Seitz, Dominik and Heim, Niklas and Moutinho, João and Guichard, Roland and Abramavicius, Vytautas and Wennersteen, Aleksander and Both, Gert-Jan and Quelle, Anton and Groot, Caroline and Velikova, Gergana and Elfving, Vincent and Dagrada, Mario},
|
143
|
+
year = {2025},
|
144
|
+
month = {01},
|
145
|
+
pages = {1-14},
|
146
|
+
title = {Qadence: a differentiable interface for digital and analog programs},
|
147
|
+
volume = {PP},
|
148
|
+
journal = {IEEE Software},
|
149
|
+
doi = {10.1109/MS.2025.3536607}
|
147
150
|
}
|
148
151
|
```
|
149
152
|
|
@@ -43,14 +43,20 @@ nav:
|
|
43
43
|
|
44
44
|
- Variational quantum algorithms:
|
45
45
|
- tutorials/qml/index.md
|
46
|
-
- Training: tutorials/qml/ml_tools/trainer.md
|
47
|
-
- Training Callbacks: tutorials/qml/ml_tools/callbacks.md
|
48
|
-
- Data and Configurations: tutorials/qml/ml_tools/data_and_config.md
|
49
46
|
- Configuring a QNN: tutorials/qml/config_qnn.md
|
50
47
|
- Quantum circuit learning: tutorials/qml/qcl.md
|
51
48
|
- Solving MaxCut with QAOA: tutorials/qml/qaoa.md
|
52
49
|
- Solving a 1D ODE: tutorials/qml/dqc_1d.md
|
53
50
|
|
51
|
+
- ML Tools:
|
52
|
+
- tutorials/qml/ml_tools/intro.md
|
53
|
+
- Training: tutorials/qml/ml_tools/trainer.md
|
54
|
+
- Data and Configurations: tutorials/qml/ml_tools/data_and_config.md
|
55
|
+
- Training Callbacks: tutorials/qml/ml_tools/callbacks.md
|
56
|
+
- Accelerator: tutorials/qml/ml_tools/accelerator_doc.md
|
57
|
+
- CPU Training: tutorials/qml/ml_tools/CPU.md
|
58
|
+
- GPU Training: tutorials/qml/ml_tools/GPU.md
|
59
|
+
|
54
60
|
- Advanced Tutorials:
|
55
61
|
- tutorials/advanced_tutorials/index.md
|
56
62
|
- Quantum circuits differentiation: tutorials/advanced_tutorials/differentiability.md
|
@@ -28,7 +28,7 @@ authors = [
|
|
28
28
|
]
|
29
29
|
requires-python = ">=3.9"
|
30
30
|
license = { text = "Apache 2.0" }
|
31
|
-
version = "1.
|
31
|
+
version = "1.11.0"
|
32
32
|
classifiers = [
|
33
33
|
"License :: OSI Approved :: Apache Software License",
|
34
34
|
"Programming Language :: Python",
|
@@ -52,7 +52,7 @@ dependencies = [
|
|
52
52
|
"jsonschema",
|
53
53
|
"nevergrad",
|
54
54
|
"scipy",
|
55
|
-
"pyqtorch==1.7.
|
55
|
+
"pyqtorch==1.7.1",
|
56
56
|
"pyyaml",
|
57
57
|
"matplotlib",
|
58
58
|
"Arpeggio==2.0.2",
|
@@ -65,9 +65,9 @@ allow-ambiguous-features = true
|
|
65
65
|
|
66
66
|
[project.optional-dependencies]
|
67
67
|
pulser = [
|
68
|
-
"pulser-core==1.
|
69
|
-
"pulser-simulation==1.
|
70
|
-
"pasqal-cloud==0.
|
68
|
+
"pulser-core==1.3.0",
|
69
|
+
"pulser-simulation==1.3.0",
|
70
|
+
"pasqal-cloud==0.13.0",
|
71
71
|
]
|
72
72
|
visualization = [
|
73
73
|
"graphviz",
|
@@ -82,18 +82,20 @@ def _fill_identities(
|
|
82
82
|
full_qubit_support = tuple(sorted(full_qubit_support))
|
83
83
|
qubit_support = tuple(sorted(qubit_support))
|
84
84
|
block_mat = block_mat.to(device)
|
85
|
-
|
85
|
+
identity_mat = IMAT.to(device)
|
86
86
|
if diag_only:
|
87
|
-
|
87
|
+
block_mat = torch.diag(block_mat.squeeze(0))
|
88
|
+
identity_mat = torch.diag(identity_mat.squeeze(0))
|
89
|
+
mat = identity_mat if qubit_support[0] != full_qubit_support[0] else block_mat
|
88
90
|
for i in full_qubit_support[1:]:
|
89
91
|
if i == qubit_support[0]:
|
90
|
-
other =
|
92
|
+
other = block_mat
|
91
93
|
if endianness == Endianness.LITTLE:
|
92
94
|
mat = torch.kron(other, mat)
|
93
95
|
else:
|
94
96
|
mat = torch.kron(mat.contiguous(), other.contiguous())
|
95
97
|
elif i not in qubit_support:
|
96
|
-
other =
|
98
|
+
other = identity_mat
|
97
99
|
if endianness == Endianness.LITTLE:
|
98
100
|
mat = torch.kron(other.contiguous(), mat.contiguous())
|
99
101
|
else:
|
@@ -264,13 +266,12 @@ def _gate_parameters(b: AbstractBlock, values: dict[str, torch.Tensor]) -> tuple
|
|
264
266
|
|
265
267
|
def block_to_diagonal(
|
266
268
|
block: AbstractBlock,
|
269
|
+
values: dict[str, TNumber | torch.Tensor] = dict(),
|
267
270
|
qubit_support: tuple | list | None = None,
|
268
|
-
use_full_support: bool =
|
271
|
+
use_full_support: bool = False,
|
269
272
|
endianness: Endianness = Endianness.BIG,
|
270
273
|
device: torch.device = None,
|
271
274
|
) -> torch.Tensor:
|
272
|
-
if block.is_parametric:
|
273
|
-
raise TypeError("Sparse observables cant be parametric.")
|
274
275
|
if not block._is_diag_pauli:
|
275
276
|
raise TypeError("Sparse observables can only be used on paulis which are diagonal.")
|
276
277
|
if qubit_support is None:
|
@@ -282,17 +283,16 @@ def block_to_diagonal(
|
|
282
283
|
if isinstance(block, (ChainBlock, KronBlock)):
|
283
284
|
v = torch.ones(2**nqubits, dtype=torch.cdouble)
|
284
285
|
for b in block.blocks:
|
285
|
-
v *= block_to_diagonal(b, qubit_support)
|
286
|
+
v *= block_to_diagonal(b, values, qubit_support, device=device)
|
286
287
|
if isinstance(block, AddBlock):
|
287
288
|
t = torch.zeros(2**nqubits, dtype=torch.cdouble)
|
288
289
|
for b in block.blocks:
|
289
|
-
t += block_to_diagonal(b, qubit_support)
|
290
|
+
t += block_to_diagonal(b, values, qubit_support, device=device)
|
290
291
|
v = t
|
291
292
|
elif isinstance(block, ScaleBlock):
|
292
|
-
_s = evaluate(block.scale,
|
293
|
-
_s = _s.detach() # type: ignore[union-attr]
|
294
|
-
v = _s * block_to_diagonal(block.block, qubit_support)
|
295
|
-
|
293
|
+
_s = evaluate(block.scale, values, as_torch=True) # type: ignore[attr-defined]
|
294
|
+
_s = _s.detach().squeeze(0) # type: ignore[union-attr]
|
295
|
+
v = _s * block_to_diagonal(block.block, values, qubit_support, device=device)
|
296
296
|
elif isinstance(block, PrimitiveBlock):
|
297
297
|
v = _fill_identities(
|
298
298
|
OPERATIONS_DICT[block.name],
|
@@ -300,6 +300,7 @@ def block_to_diagonal(
|
|
300
300
|
qubit_support, # type: ignore [arg-type]
|
301
301
|
diag_only=True,
|
302
302
|
endianness=endianness,
|
303
|
+
device=device,
|
303
304
|
)
|
304
305
|
return v
|
305
306
|
|
@@ -309,7 +310,7 @@ def block_to_tensor(
|
|
309
310
|
block: AbstractBlock,
|
310
311
|
values: dict[str, TNumber | torch.Tensor] = {},
|
311
312
|
qubit_support: tuple | None = None,
|
312
|
-
use_full_support: bool =
|
313
|
+
use_full_support: bool = False,
|
313
314
|
tensor_type: TensorType = TensorType.DENSE,
|
314
315
|
endianness: Endianness = Endianness.BIG,
|
315
316
|
device: torch.device = None,
|
@@ -339,18 +340,14 @@ def block_to_tensor(
|
|
339
340
|
print(block_to_tensor(obs, tensor_type="SparseDiagonal"))
|
340
341
|
```
|
341
342
|
"""
|
343
|
+
from qadence.blocks import embedding
|
342
344
|
|
343
|
-
|
344
|
-
|
345
|
-
# as observables only do the matmul of the size of the qubit support.
|
346
|
-
|
345
|
+
(ps, embed) = embedding(block)
|
346
|
+
values = embed(ps, values)
|
347
347
|
if tensor_type == TensorType.DENSE:
|
348
|
-
from qadence.blocks import embedding
|
349
|
-
|
350
|
-
(ps, embed) = embedding(block)
|
351
348
|
return _block_to_tensor_embedded(
|
352
349
|
block,
|
353
|
-
|
350
|
+
values,
|
354
351
|
qubit_support,
|
355
352
|
use_full_support,
|
356
353
|
endianness=endianness,
|
@@ -358,7 +355,7 @@ def block_to_tensor(
|
|
358
355
|
)
|
359
356
|
|
360
357
|
elif tensor_type == TensorType.SPARSEDIAGONAL:
|
361
|
-
t = block_to_diagonal(block, endianness=endianness)
|
358
|
+
t = block_to_diagonal(block, values, endianness=endianness)
|
362
359
|
indices, values, size = torch.nonzero(t), t[t != 0], len(t)
|
363
360
|
indices = torch.stack((indices.flatten(), indices.flatten()))
|
364
361
|
return torch.sparse_coo_tensor(indices, values, (size, size))
|
@@ -369,7 +366,7 @@ def _block_to_tensor_embedded(
|
|
369
366
|
block: AbstractBlock,
|
370
367
|
values: dict[str, TNumber | torch.Tensor] = {},
|
371
368
|
qubit_support: tuple | None = None,
|
372
|
-
use_full_support: bool =
|
369
|
+
use_full_support: bool = False,
|
373
370
|
endianness: Endianness = Endianness.BIG,
|
374
371
|
device: torch.device = None,
|
375
372
|
) -> torch.Tensor:
|
@@ -17,6 +17,9 @@ from .hamiltonians import (
|
|
17
17
|
ObservableConfig,
|
18
18
|
total_magnetization,
|
19
19
|
zz_hamiltonian,
|
20
|
+
total_magnetization_config,
|
21
|
+
zz_hamiltonian_config,
|
22
|
+
ising_hamiltonian_config,
|
20
23
|
)
|
21
24
|
|
22
25
|
from .rydberg_hea import rydberg_hea, rydberg_hea_layer
|
@@ -34,9 +37,12 @@ __all__ = [
|
|
34
37
|
"iia",
|
35
38
|
"hamiltonian_factory",
|
36
39
|
"ising_hamiltonian",
|
37
|
-
"ObservableConfig",
|
38
40
|
"total_magnetization",
|
39
41
|
"zz_hamiltonian",
|
42
|
+
"ObservableConfig",
|
43
|
+
"total_magnetization_config",
|
44
|
+
"zz_hamiltonian_config",
|
45
|
+
"ising_hamiltonian_config",
|
40
46
|
"qft",
|
41
47
|
"daqc_transform",
|
42
48
|
"rydberg_hea",
|
@@ -7,11 +7,12 @@ from typing import Callable, List, Type, Union
|
|
7
7
|
import numpy as np
|
8
8
|
from torch import Tensor, double, ones, rand
|
9
9
|
from typing_extensions import Any
|
10
|
+
from qadence.parameters import Parameter
|
10
11
|
|
11
12
|
from qadence.blocks import AbstractBlock, add, block_is_qubit_hamiltonian
|
12
|
-
from qadence.operations import N, X, Y, Z
|
13
|
+
from qadence.operations import N, X, Y, Z, H
|
13
14
|
from qadence.register import Register
|
14
|
-
from qadence.types import Interaction,
|
15
|
+
from qadence.types import Interaction, TArray, TParameter
|
15
16
|
|
16
17
|
logger = getLogger(__name__)
|
17
18
|
|
@@ -239,7 +240,30 @@ def is_numeric(x: Any) -> bool:
|
|
239
240
|
|
240
241
|
@dataclass
|
241
242
|
class ObservableConfig:
|
242
|
-
|
243
|
+
"""ObservableConfig is a configuration class for defining the parameters of an observable Hamiltonian."""
|
244
|
+
|
245
|
+
interaction: Interaction | Callable | None = None
|
246
|
+
"""
|
247
|
+
The type of interaction.
|
248
|
+
|
249
|
+
Available options from the Interaction enum are:
|
250
|
+
- Interaction.ZZ
|
251
|
+
- Interaction.NN
|
252
|
+
- Interaction.XY
|
253
|
+
- Interaction.XYZ
|
254
|
+
|
255
|
+
Alternatively, a custom interaction function can be defined.
|
256
|
+
Example:
|
257
|
+
|
258
|
+
def custom_int(i: int, j: int):
|
259
|
+
return X(i) @ X(j) + Y(i) @ Y(j)
|
260
|
+
|
261
|
+
n_qubits = 2
|
262
|
+
|
263
|
+
observable_config = ObservableConfig(interaction=custom_int, scale = 1.0, shift = 0.0)
|
264
|
+
observable = create_observable(register=4, config=observable_config)
|
265
|
+
"""
|
266
|
+
detuning: TDetuning | None = None
|
243
267
|
"""
|
244
268
|
Single qubit detuning of the observable Hamiltonian.
|
245
269
|
|
@@ -249,8 +273,6 @@ class ObservableConfig:
|
|
249
273
|
"""The scale by which to multiply the output of the observable."""
|
250
274
|
shift: TParameter = 0.0
|
251
275
|
"""The shift to add to the output of the observable."""
|
252
|
-
transformation_type: ObservableTransform = ObservableTransform.NONE # type: ignore[assignment]
|
253
|
-
"""The type of transformation."""
|
254
276
|
trainable_transform: bool | None = None
|
255
277
|
"""
|
256
278
|
Whether to have a trainable transformation on the output of the observable.
|
@@ -261,8 +283,73 @@ class ObservableConfig:
|
|
261
283
|
"""
|
262
284
|
|
263
285
|
def __post_init__(self) -> None:
|
286
|
+
if self.interaction is None and self.detuning is None:
|
287
|
+
raise ValueError(
|
288
|
+
"Please provide an interaction and/or detuning for the Observable Hamiltonian."
|
289
|
+
)
|
290
|
+
|
264
291
|
if is_numeric(self.scale) and is_numeric(self.shift):
|
265
|
-
assert (
|
266
|
-
|
267
|
-
|
268
|
-
|
292
|
+
assert self.trainable_transform is None, (
|
293
|
+
"If scale and shift are numbers, trainable_transform must be None."
|
294
|
+
f"But got: {self.trainable_transform}"
|
295
|
+
)
|
296
|
+
|
297
|
+
# trasform the scale and shift into parameters
|
298
|
+
if self.trainable_transform is not None:
|
299
|
+
self.shift = Parameter(name=self.shift, trainable=self.trainable_transform)
|
300
|
+
self.scale = Parameter(name=self.scale, trainable=self.trainable_transform)
|
301
|
+
else:
|
302
|
+
self.shift = Parameter(self.shift)
|
303
|
+
self.scale = Parameter(self.scale)
|
304
|
+
|
305
|
+
|
306
|
+
def total_magnetization_config(
|
307
|
+
scale: TParameter = 1.0,
|
308
|
+
shift: TParameter = 0.0,
|
309
|
+
trainable_transform: bool | None = None,
|
310
|
+
) -> ObservableConfig:
|
311
|
+
return ObservableConfig(
|
312
|
+
detuning=Z,
|
313
|
+
scale=scale,
|
314
|
+
shift=shift,
|
315
|
+
trainable_transform=trainable_transform,
|
316
|
+
)
|
317
|
+
|
318
|
+
|
319
|
+
def zz_hamiltonian_config(
|
320
|
+
scale: TParameter = 1.0,
|
321
|
+
shift: TParameter = 0.0,
|
322
|
+
trainable_transform: bool | None = None,
|
323
|
+
) -> ObservableConfig:
|
324
|
+
return ObservableConfig(
|
325
|
+
interaction=Interaction.ZZ,
|
326
|
+
detuning=Z,
|
327
|
+
scale=scale,
|
328
|
+
shift=shift,
|
329
|
+
trainable_transform=trainable_transform,
|
330
|
+
)
|
331
|
+
|
332
|
+
|
333
|
+
def ising_hamiltonian_config(
|
334
|
+
scale: TParameter = 1.0,
|
335
|
+
shift: TParameter = 0.0,
|
336
|
+
trainable_transform: bool | None = None,
|
337
|
+
) -> ObservableConfig:
|
338
|
+
|
339
|
+
def ZZ_Z_hamiltonian(i: int, j: int) -> AbstractBlock:
|
340
|
+
result = Z(i) @ Z(j)
|
341
|
+
|
342
|
+
if i == 0:
|
343
|
+
result += Z(j)
|
344
|
+
elif i == 1 and j == 2:
|
345
|
+
result += Z(0)
|
346
|
+
|
347
|
+
return result
|
348
|
+
|
349
|
+
return ObservableConfig(
|
350
|
+
interaction=ZZ_Z_hamiltonian,
|
351
|
+
detuning=Z,
|
352
|
+
scale=scale,
|
353
|
+
shift=shift,
|
354
|
+
trainable_transform=trainable_transform,
|
355
|
+
)
|
@@ -92,7 +92,9 @@ def pulse_experiment(
|
|
92
92
|
)
|
93
93
|
# Convert observable to Numpy types compatible with QuTip simulations.
|
94
94
|
# Matrices are flipped to match QuTip conventions.
|
95
|
-
converted_observable = [
|
95
|
+
converted_observable = [
|
96
|
+
np.flip(block_to_tensor(obs, use_full_support=True).numpy()) for obs in observable
|
97
|
+
]
|
96
98
|
# Create ZNE datasets by looping over batches.
|
97
99
|
for observable in converted_observable:
|
98
100
|
# Get expectation values at the end of the time serie [0,t]
|
@@ -130,7 +132,9 @@ def noise_level_experiment(
|
|
130
132
|
)
|
131
133
|
# Convert observable to Numpy types compatible with QuTip simulations.
|
132
134
|
# Matrices are flipped to match QuTip conventions.
|
133
|
-
converted_observable = [
|
135
|
+
converted_observable = [
|
136
|
+
np.flip(block_to_tensor(obs, use_full_support=True).numpy()) for obs in observable
|
137
|
+
]
|
134
138
|
# Create ZNE datasets by looping over batches.
|
135
139
|
for observable in converted_observable:
|
136
140
|
# Get expectation values at the end of the time serie [0,t]
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from .callbacks.saveload import load_checkpoint, load_model, write_checkpoint
|
4
4
|
from .config import AnsatzConfig, FeatureMapConfig, TrainConfig
|
5
|
-
from .constructors import create_ansatz, create_fm_blocks,
|
5
|
+
from .constructors import create_ansatz, create_fm_blocks, create_observable
|
6
6
|
from .data import DictDataLoader, InfiniteTensorDataset, OptimizeResult, to_dataloader
|
7
7
|
from .information import InformationContent
|
8
8
|
from .models import QNN
|
@@ -19,7 +19,7 @@ __all__ = [
|
|
19
19
|
"DictDataLoader",
|
20
20
|
"FeatureMapConfig",
|
21
21
|
"load_checkpoint",
|
22
|
-
"
|
22
|
+
"create_observable",
|
23
23
|
"QNN",
|
24
24
|
"TrainConfig",
|
25
25
|
"OptimizeResult",
|
@@ -95,14 +95,36 @@ class Callback:
|
|
95
95
|
self.callback: CallbackFunction | None = callback
|
96
96
|
self.on: str | TrainingStage = on
|
97
97
|
self.called_every: int = called_every
|
98
|
-
self.callback_condition =
|
98
|
+
self.callback_condition = (
|
99
|
+
callback_condition if callback_condition else Callback.default_callback
|
100
|
+
)
|
99
101
|
|
100
102
|
if isinstance(modify_optimize_result, dict):
|
101
|
-
self.modify_optimize_result = (
|
102
|
-
|
103
|
+
self.modify_optimize_result = lambda opt_res: Callback.modify_opt_res_dict(
|
104
|
+
opt_res, modify_optimize_result
|
103
105
|
)
|
104
106
|
else:
|
105
|
-
self.modify_optimize_result =
|
107
|
+
self.modify_optimize_result = (
|
108
|
+
modify_optimize_result
|
109
|
+
if modify_optimize_result
|
110
|
+
else Callback.modify_opt_res_default
|
111
|
+
)
|
112
|
+
|
113
|
+
@staticmethod
|
114
|
+
def default_callback(_: Any) -> bool:
|
115
|
+
return True
|
116
|
+
|
117
|
+
@staticmethod
|
118
|
+
def modify_opt_res_dict(
|
119
|
+
opt_res: OptimizeResult,
|
120
|
+
modify_optimize_result: dict[str, Any] = {},
|
121
|
+
) -> OptimizeResult:
|
122
|
+
opt_res.extra.update(modify_optimize_result)
|
123
|
+
return opt_res
|
124
|
+
|
125
|
+
@staticmethod
|
126
|
+
def modify_opt_res_default(opt_res: OptimizeResult) -> OptimizeResult:
|
127
|
+
return opt_res
|
106
128
|
|
107
129
|
@property
|
108
130
|
def on(self) -> TrainingStage | str:
|
@@ -261,8 +283,9 @@ class WriteMetrics(Callback):
|
|
261
283
|
config (TrainConfig): The configuration object.
|
262
284
|
writer (BaseWriter ): The writer object for logging.
|
263
285
|
"""
|
264
|
-
|
265
|
-
|
286
|
+
if trainer.accelerator.rank == 0:
|
287
|
+
opt_result = trainer.opt_result
|
288
|
+
writer.write(opt_result.iteration, opt_result.metrics)
|
266
289
|
|
267
290
|
|
268
291
|
class PlotMetrics(Callback):
|
@@ -299,9 +322,10 @@ class PlotMetrics(Callback):
|
|
299
322
|
config (TrainConfig): The configuration object.
|
300
323
|
writer (BaseWriter ): The writer object for logging.
|
301
324
|
"""
|
302
|
-
|
303
|
-
|
304
|
-
|
325
|
+
if trainer.accelerator.rank == 0:
|
326
|
+
opt_result = trainer.opt_result
|
327
|
+
plotting_functions = config.plotting_functions
|
328
|
+
writer.plot(trainer.model, opt_result.iteration, plotting_functions)
|
305
329
|
|
306
330
|
|
307
331
|
class LogHyperparameters(Callback):
|
@@ -338,8 +362,9 @@ class LogHyperparameters(Callback):
|
|
338
362
|
config (TrainConfig): The configuration object.
|
339
363
|
writer (BaseWriter ): The writer object for logging.
|
340
364
|
"""
|
341
|
-
|
342
|
-
|
365
|
+
if trainer.accelerator.rank == 0:
|
366
|
+
hyperparams = config.hyperparams
|
367
|
+
writer.log_hyperparams(hyperparams)
|
343
368
|
|
344
369
|
|
345
370
|
class SaveCheckpoint(Callback):
|
@@ -376,11 +401,12 @@ class SaveCheckpoint(Callback):
|
|
376
401
|
config (TrainConfig): The configuration object.
|
377
402
|
writer (BaseWriter ): The writer object for logging.
|
378
403
|
"""
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
404
|
+
if trainer.accelerator.rank == 0:
|
405
|
+
folder = config.log_folder
|
406
|
+
model = trainer.model
|
407
|
+
optimizer = trainer.optimizer
|
408
|
+
opt_result = trainer.opt_result
|
409
|
+
write_checkpoint(folder, model, optimizer, opt_result.iteration)
|
384
410
|
|
385
411
|
|
386
412
|
class SaveBestCheckpoint(SaveCheckpoint):
|
@@ -404,17 +430,18 @@ class SaveBestCheckpoint(SaveCheckpoint):
|
|
404
430
|
config (TrainConfig): The configuration object.
|
405
431
|
writer (BaseWriter ): The writer object for logging.
|
406
432
|
"""
|
407
|
-
|
408
|
-
if config.validation_criterion and config.validation_criterion(
|
409
|
-
opt_result.loss, self.best_loss, config.val_epsilon
|
410
|
-
):
|
411
|
-
self.best_loss = opt_result.loss
|
412
|
-
|
413
|
-
folder = config.log_folder
|
414
|
-
model = trainer.model
|
415
|
-
optimizer = trainer.optimizer
|
433
|
+
if trainer.accelerator.rank == 0:
|
416
434
|
opt_result = trainer.opt_result
|
417
|
-
|
435
|
+
if config.validation_criterion and config.validation_criterion(
|
436
|
+
opt_result.loss, self.best_loss, config.val_epsilon
|
437
|
+
):
|
438
|
+
self.best_loss = opt_result.loss
|
439
|
+
|
440
|
+
folder = config.log_folder
|
441
|
+
model = trainer.model
|
442
|
+
optimizer = trainer.optimizer
|
443
|
+
opt_result = trainer.opt_result
|
444
|
+
write_checkpoint(folder, model, optimizer, "best")
|
418
445
|
|
419
446
|
|
420
447
|
class LoadCheckpoint(Callback):
|
@@ -431,11 +458,12 @@ class LoadCheckpoint(Callback):
|
|
431
458
|
Returns:
|
432
459
|
Any: The result of loading the checkpoint.
|
433
460
|
"""
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
461
|
+
if trainer.accelerator.rank == 0:
|
462
|
+
folder = config.log_folder
|
463
|
+
model = trainer.model
|
464
|
+
optimizer = trainer.optimizer
|
465
|
+
device = trainer.accelerator.execution.log_device
|
466
|
+
return load_checkpoint(folder, model, optimizer, device=device)
|
439
467
|
|
440
468
|
|
441
469
|
class LogModelTracker(Callback):
|
@@ -449,10 +477,11 @@ class LogModelTracker(Callback):
|
|
449
477
|
config (TrainConfig): The configuration object.
|
450
478
|
writer (BaseWriter ): The writer object for logging.
|
451
479
|
"""
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
480
|
+
if trainer.accelerator.rank == 0:
|
481
|
+
model = trainer.model
|
482
|
+
writer.log_model(
|
483
|
+
model, trainer.train_dataloader, trainer.val_dataloader, trainer.test_dataloader
|
484
|
+
)
|
456
485
|
|
457
486
|
|
458
487
|
class LRSchedulerStepDecay(Callback):
|
@@ -713,7 +742,7 @@ class EarlyStopping(Callback):
|
|
713
742
|
f"EarlyStopping: No improvement in '{self.monitor}' for {self.patience} epochs. "
|
714
743
|
"Stopping training."
|
715
744
|
)
|
716
|
-
trainer.
|
745
|
+
trainer._stop_training.fill_(1)
|
717
746
|
|
718
747
|
|
719
748
|
class GradientMonitoring(Callback):
|
@@ -759,17 +788,18 @@ class GradientMonitoring(Callback):
|
|
759
788
|
config (TrainConfig): The configuration object.
|
760
789
|
writer (BaseWriter): The writer object for logging.
|
761
790
|
"""
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
791
|
+
if trainer.accelerator.rank == 0:
|
792
|
+
gradient_stats = {}
|
793
|
+
for name, param in trainer.model.named_parameters():
|
794
|
+
if param.grad is not None:
|
795
|
+
grad = param.grad
|
796
|
+
gradient_stats.update(
|
797
|
+
{
|
798
|
+
name + "_mean": grad.mean().item(),
|
799
|
+
name + "_std": grad.std().item(),
|
800
|
+
name + "_max": grad.max().item(),
|
801
|
+
name + "_min": grad.min().item(),
|
802
|
+
}
|
803
|
+
)
|
804
|
+
|
805
|
+
writer.write(trainer.opt_result.iteration, gradient_stats)
|
@@ -201,7 +201,8 @@ class CallbacksManager:
|
|
201
201
|
logger.debug(f"Loaded model and optimizer from {self.config.log_folder}")
|
202
202
|
|
203
203
|
# Setup writer
|
204
|
-
|
204
|
+
if trainer.accelerator.rank == 0:
|
205
|
+
self.writer.open(self.config, iteration=trainer.global_step)
|
205
206
|
|
206
207
|
def end_training(self, trainer: Any) -> None:
|
207
208
|
"""
|
@@ -210,5 +211,5 @@ class CallbacksManager:
|
|
210
211
|
Args:
|
211
212
|
trainer (Any): The training object managing the training process.
|
212
213
|
"""
|
213
|
-
if self.writer:
|
214
|
+
if trainer.accelerator.rank == 0 and self.writer:
|
214
215
|
self.writer.close()
|