qadence 1.7.1__py3-none-any.whl → 1.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,6 +8,7 @@ import pyqtorch as pyq
8
8
  import sympy
9
9
  import torch
10
10
  from pyqtorch.apply import apply_operator
11
+ from pyqtorch.embed import Embedding
11
12
  from pyqtorch.matrices import _dagger
12
13
  from pyqtorch.time_dependent.sesolve import sesolve
13
14
  from pyqtorch.utils import is_diag
@@ -24,7 +25,7 @@ from torch import (
24
25
  )
25
26
  from torch import device as torch_device
26
27
  from torch import dtype as torch_dtype
27
- from torch.nn import Module
28
+ from torch.nn import Module, ParameterDict
28
29
 
29
30
  from qadence.backends.utils import (
30
31
  finitediff,
@@ -200,7 +201,12 @@ class PyQHamiltonianEvolution(Module):
200
201
  elif isinstance(block.generator, Tensor):
201
202
  m = block.generator.to(dtype=cdouble)
202
203
  hmat = block_to_tensor(
203
- MatrixBlock(m, qubit_support=block.qubit_support),
204
+ MatrixBlock(
205
+ m,
206
+ qubit_support=block.qubit_support,
207
+ check_unitary=False,
208
+ check_hermitian=True,
209
+ ),
204
210
  qubit_support=self.qubit_support,
205
211
  use_full_support=False,
206
212
  )
@@ -313,7 +319,8 @@ class PyQHamiltonianEvolution(Module):
313
319
  def forward(
314
320
  self,
315
321
  state: Tensor,
316
- values: dict[str, Tensor],
322
+ values: dict[str, Tensor] | ParameterDict = dict(),
323
+ embedding: Embedding | None = None,
317
324
  ) -> Tensor:
318
325
  if getattr(self.block.generator, "is_time_dependent", False): # type: ignore [union-attr]
319
326
 
qadence/blocks/matrix.py CHANGED
@@ -60,7 +60,13 @@ class MatrixBlock(PrimitiveBlock):
60
60
  name = "MatrixBlock"
61
61
  matrix: torch.Tensor
62
62
 
63
- def __init__(self, matrix: torch.Tensor | np.ndarray, qubit_support: tuple[int, ...]) -> None:
63
+ def __init__(
64
+ self,
65
+ matrix: torch.Tensor | np.ndarray,
66
+ qubit_support: tuple[int, ...],
67
+ check_unitary: bool = True,
68
+ check_hermitian: bool = False,
69
+ ) -> None:
64
70
  if isinstance(matrix, np.ndarray):
65
71
  matrix = torch.tensor(matrix)
66
72
  if matrix.ndim == 3 and matrix.size(0) == 1:
@@ -69,10 +75,12 @@ class MatrixBlock(PrimitiveBlock):
69
75
  raise TypeError("Please provide a 2D matrix.")
70
76
  if not self.is_square(matrix):
71
77
  raise TypeError("Please provide a square matrix.")
72
- if not self.is_hermitian(matrix):
73
- logger.warning("Provided matrix is not hermitian.")
74
- if not self.is_unitary(matrix):
75
- logger.warning("Provided matrix is not unitary.")
78
+ if check_hermitian:
79
+ if not self.is_hermitian(matrix):
80
+ logger.warning("Provided matrix is not hermitian.")
81
+ if check_unitary:
82
+ if not self.is_unitary(matrix):
83
+ logger.warning("Provided matrix is not unitary.")
76
84
  self.matrix = matrix.clone()
77
85
  super().__init__(qubit_support)
78
86
 
@@ -151,8 +151,9 @@ class DifferentiableExpectation:
151
151
  return (
152
152
  AdjointExpectation.apply(
153
153
  self.circuit.native,
154
- self.observable[0].native, # Currently, adjoint only supports a single observable.
155
154
  self.state,
155
+ self.observable[0].native, # Currently, adjoint only supports a single observable.
156
+ None,
156
157
  self.param_values.keys(),
157
158
  *self.param_values.values(),
158
159
  )
qadence/extensions.py CHANGED
@@ -31,9 +31,14 @@ def import_backend(backend_name: str | BackendName) -> Backend:
31
31
  backend: Backend
32
32
  try:
33
33
  module = importlib.import_module(module_path)
34
- backend = getattr(module, "Backend")
35
34
  except (ModuleNotFoundError, ImportError) as e:
36
- raise type(e)
35
+ # If backend is not in Qadence, search in extensions.
36
+ module_path = f"qadence_extensions.backends.{backend_name}.backend"
37
+ try:
38
+ module = importlib.import_module(module_path)
39
+ except (ModuleNotFoundError, ImportError) as e:
40
+ raise type(e)
41
+ backend = getattr(module, "Backend")
37
42
  return backend
38
43
 
39
44
 
@@ -5,15 +5,25 @@ import os
5
5
  from dataclasses import dataclass, field, fields
6
6
  from logging import getLogger
7
7
  from pathlib import Path
8
- from typing import Callable, Optional, Type
8
+ from typing import Callable, Type
9
+ from uuid import uuid4
9
10
 
10
11
  from sympy import Basic
12
+ from torch import Tensor
11
13
 
12
14
  from qadence.blocks.analog import AnalogBlock
13
15
  from qadence.blocks.primitive import ParametricBlock
14
16
  from qadence.operations import RX, AnalogRX
15
17
  from qadence.parameters import Parameter
16
- from qadence.types import AnsatzType, BasisSet, MultivariateStrategy, ReuploadScaling, Strategy
18
+ from qadence.types import (
19
+ AnsatzType,
20
+ BasisSet,
21
+ ExperimentTrackingTool,
22
+ LoggablePlotFunction,
23
+ MultivariateStrategy,
24
+ ReuploadScaling,
25
+ Strategy,
26
+ )
17
27
 
18
28
  logger = getLogger(__file__)
19
29
 
@@ -37,10 +47,14 @@ class TrainConfig:
37
47
  print_every: int = 1000
38
48
  """Print loss/metrics."""
39
49
  write_every: int = 50
40
- """Write tensorboard logs."""
50
+ """Write loss and metrics with the tracking tool."""
41
51
  checkpoint_every: int = 5000
42
52
  """Write model/optimizer checkpoint."""
43
- folder: Optional[Path] = None
53
+ plot_every: int = 5000
54
+ """Write figures."""
55
+ log_model: bool = False
56
+ """Logs a serialised version of the model."""
57
+ folder: Path | None = None
44
58
  """Checkpoint/tensorboard logs folder."""
45
59
  create_subfolder_per_run: bool = False
46
60
  """Checkpoint/tensorboard logs stored in subfolder with name `<timestamp>_<PID>`.
@@ -59,14 +73,38 @@ class TrainConfig:
59
73
 
60
74
  validation loss across previous iterations.
61
75
  """
62
- validation_criterion: Optional[Callable] = None
76
+ validation_criterion: Callable | None = None
63
77
  """A boolean function which evaluates a given validation metric is satisfied."""
64
- trainstop_criterion: Optional[Callable] = None
78
+ trainstop_criterion: Callable | None = None
65
79
  """A boolean function which evaluates a given training stopping metric is satisfied."""
66
80
  batch_size: int = 1
67
81
  """The batch_size to use when passing a list/tuple of torch.Tensors."""
68
82
  verbose: bool = True
69
83
  """Whether or not to print out metrics values during training."""
84
+ tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD
85
+ """The tracking tool of choice."""
86
+ hyperparams: dict = field(default_factory=dict)
87
+ """Hyperparameters to track."""
88
+ plotting_functions: tuple[LoggablePlotFunction, ...] = field(default_factory=tuple) # type: ignore
89
+ """Functions for in-train plotting."""
90
+
91
+ # tensorboard only allows for certain types as hyperparameters
92
+ _tb_allowed_hyperparams_types: tuple = field(
93
+ default=(int, float, str, bool, Tensor), init=False, repr=False
94
+ )
95
+
96
+ def _filter_tb_hyperparams(self) -> None:
97
+ keys_to_remove = [
98
+ key
99
+ for key, value in self.hyperparams.items()
100
+ if not isinstance(value, TrainConfig._tb_allowed_hyperparams_types)
101
+ ]
102
+ if keys_to_remove:
103
+ logger.warning(
104
+ f"Tensorboard cannot log the following hyperparameters: {keys_to_remove}."
105
+ )
106
+ for key in keys_to_remove:
107
+ self.hyperparams.pop(key)
70
108
 
71
109
  def __post_init__(self) -> None:
72
110
  if self.folder:
@@ -81,6 +119,64 @@ class TrainConfig:
81
119
  self.trainstop_criterion = lambda x: x <= self.max_iter
82
120
  if self.validation_criterion is None:
83
121
  self.validation_criterion = lambda *x: False
122
+ if self.hyperparams and self.tracking_tool == ExperimentTrackingTool.TENSORBOARD:
123
+ self._filter_tb_hyperparams()
124
+ if self.tracking_tool == ExperimentTrackingTool.MLFLOW:
125
+ self._mlflow_config = MLFlowConfig()
126
+ if self.plotting_functions and self.tracking_tool != ExperimentTrackingTool.MLFLOW:
127
+ logger.warning("In-training plots are only available with mlflow tracking.")
128
+ if not self.plotting_functions and self.tracking_tool == ExperimentTrackingTool.MLFLOW:
129
+ logger.warning("Tracking with mlflow, but no plotting functions provided.")
130
+
131
+ @property
132
+ def mlflow_config(self) -> MLFlowConfig:
133
+ if self.tracking_tool == ExperimentTrackingTool.MLFLOW:
134
+ return self._mlflow_config
135
+ else:
136
+ raise AttributeError(
137
+ "mlflow_config is available only for with the mlflow tracking tool."
138
+ )
139
+
140
+
141
+ class MLFlowConfig:
142
+ """
143
+ Configuration for mlflow tracking.
144
+
145
+ Example:
146
+
147
+ export MLFLOW_TRACKING_URI=tracking_uri
148
+ export MLFLOW_EXPERIMENT=experiment_name
149
+ export MLFLOW_RUN_NAME=run_name
150
+ """
151
+
152
+ def __init__(self) -> None:
153
+ import mlflow
154
+
155
+ self.tracking_uri: str = os.getenv("MLFLOW_TRACKING_URI", "")
156
+ """The URI of the mlflow tracking server.
157
+
158
+ An empty string, or a local file path, prefixed with file:/.
159
+ Data is stored locally at the provided file (or ./mlruns if empty).
160
+ """
161
+
162
+ self.experiment_name: str = os.getenv("MLFLOW_EXPERIMENT", str(uuid4()))
163
+ """The name of the experiment.
164
+
165
+ If None or empty, a new experiment is created with a random UUID.
166
+ """
167
+
168
+ self.run_name: str = os.getenv("MLFLOW_RUN_NAME", str(uuid4()))
169
+ """The name of the run."""
170
+
171
+ mlflow.set_tracking_uri(self.tracking_uri)
172
+
173
+ # activate existing or create experiment
174
+ exp_filter_string = f"name = '{self.experiment_name}'"
175
+ if not mlflow.search_experiments(filter_string=exp_filter_string):
176
+ mlflow.create_experiment(name=self.experiment_name)
177
+
178
+ self.experiment = mlflow.set_experiment(self.experiment_name)
179
+ self.run = mlflow.start_run(run_name=self.run_name, nested=False)
84
180
 
85
181
 
86
182
  @dataclass
@@ -1,7 +1,23 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from logging import getLogger
4
+ from typing import Any, Callable, Union
5
+
6
+ from matplotlib.figure import Figure
7
+ from mlflow.models import infer_signature
8
+ from torch import Tensor
9
+ from torch.nn import Module
10
+ from torch.utils.data import DataLoader
3
11
  from torch.utils.tensorboard import SummaryWriter
4
12
 
13
+ from qadence.ml_tools.data import DictDataLoader
14
+ from qadence.types import ExperimentTrackingTool
15
+
16
+ logger = getLogger(__name__)
17
+
18
+ PlottingFunction = Callable[[Module, int], tuple[str, Figure]]
19
+ InputData = Union[Tensor, dict[str, Tensor]]
20
+
5
21
 
6
22
  def print_metrics(loss: float | None, metrics: dict, iteration: int) -> None:
7
23
  msg = " ".join(
@@ -20,5 +36,110 @@ def write_tensorboard(
20
36
  writer.add_scalar(key, arg, iteration)
21
37
 
22
38
 
23
- def log_hyperparams(writer: SummaryWriter, hyperparams: dict, metrics: dict) -> None:
39
+ def log_hyperparams_tensorboard(writer: SummaryWriter, hyperparams: dict, metrics: dict) -> None:
24
40
  writer.add_hparams(hyperparams, metrics)
41
+
42
+
43
+ def plot_tensorboard(
44
+ writer: SummaryWriter,
45
+ model: Module,
46
+ iteration: int,
47
+ plotting_functions: tuple[PlottingFunction],
48
+ ) -> None:
49
+ for pf in plotting_functions:
50
+ descr, fig = pf(model, iteration)
51
+ writer.add_figure(descr, fig, global_step=iteration)
52
+
53
+
54
+ def log_model_tensorboard(
55
+ writer: SummaryWriter,
56
+ model: Module,
57
+ dataloader: Union[None, DataLoader, DictDataLoader],
58
+ ) -> None:
59
+ logger.warning("Model logging is not supported by tensorboard. No model will be logged.")
60
+
61
+
62
+ def write_mlflow(writer: Any, loss: float | None, metrics: dict, iteration: int) -> None:
63
+ writer.log_metrics({"loss": float(loss)}, step=iteration) # type: ignore
64
+ writer.log_metrics(metrics, step=iteration) # logs the single metrics
65
+
66
+
67
+ def log_hyperparams_mlflow(writer: Any, hyperparams: dict, metrics: dict) -> None:
68
+ writer.log_params(hyperparams) # type: ignore
69
+
70
+
71
+ def plot_mlflow(
72
+ writer: Any,
73
+ model: Module,
74
+ iteration: int,
75
+ plotting_functions: tuple[PlottingFunction],
76
+ ) -> None:
77
+ for pf in plotting_functions:
78
+ descr, fig = pf(model, iteration)
79
+ writer.log_figure(fig, descr)
80
+
81
+
82
+ def log_model_mlflow(
83
+ writer: Any, model: Module, dataloader: DataLoader | DictDataLoader | None
84
+ ) -> None:
85
+ if dataloader is not None:
86
+ xs: InputData
87
+ xs, *_ = next(iter(dataloader))
88
+ preds = model(xs)
89
+ if isinstance(xs, Tensor):
90
+ xs = xs.numpy()
91
+ preds = preds.detach().numpy()
92
+ elif isinstance(xs, dict):
93
+ for key, val in xs.items():
94
+ xs[key] = val.numpy()
95
+ for key, val in preds.items():
96
+ preds[key] = val.detach.numpy()
97
+ signature = infer_signature(xs, preds)
98
+ else:
99
+ signature = None
100
+ writer.pytorch.log_model(model, artifact_path="model", signature=signature)
101
+
102
+
103
+ TRACKER_MAPPING: dict[ExperimentTrackingTool, Callable[..., None]] = {
104
+ ExperimentTrackingTool.TENSORBOARD: write_tensorboard,
105
+ ExperimentTrackingTool.MLFLOW: write_mlflow,
106
+ }
107
+
108
+ LOGGER_MAPPING: dict[ExperimentTrackingTool, Callable[..., None]] = {
109
+ ExperimentTrackingTool.TENSORBOARD: log_hyperparams_tensorboard,
110
+ ExperimentTrackingTool.MLFLOW: log_hyperparams_mlflow,
111
+ }
112
+
113
+ PLOTTER_MAPPING: dict[ExperimentTrackingTool, Callable[..., None]] = {
114
+ ExperimentTrackingTool.TENSORBOARD: plot_tensorboard,
115
+ ExperimentTrackingTool.MLFLOW: plot_mlflow,
116
+ }
117
+
118
+ MODEL_LOGGER_MAPPING: dict[ExperimentTrackingTool, Callable[..., None]] = {
119
+ ExperimentTrackingTool.TENSORBOARD: log_model_tensorboard,
120
+ ExperimentTrackingTool.MLFLOW: log_model_mlflow,
121
+ }
122
+
123
+
124
+ def write_tracker(
125
+ *args: Any, tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD
126
+ ) -> None:
127
+ return TRACKER_MAPPING[tracking_tool](*args)
128
+
129
+
130
+ def log_tracker(
131
+ *args: Any, tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD
132
+ ) -> None:
133
+ return LOGGER_MAPPING[tracking_tool](*args)
134
+
135
+
136
+ def plot_tracker(
137
+ *args: Any, tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD
138
+ ) -> None:
139
+ return PLOTTER_MAPPING[tracking_tool](*args)
140
+
141
+
142
+ def log_model_tracker(
143
+ *args: Any, tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD
144
+ ) -> None:
145
+ return MODEL_LOGGER_MAPPING[tracking_tool](*args)
@@ -1,10 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import importlib
3
4
  import math
4
5
  from logging import getLogger
5
6
  from typing import Callable, Union
6
7
 
7
- from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn
8
+ from rich.progress import (
9
+ BarColumn,
10
+ Progress,
11
+ TaskProgressColumn,
12
+ TextColumn,
13
+ TimeRemainingColumn,
14
+ )
8
15
  from torch import complex128, float32, float64
9
16
  from torch import device as torch_device
10
17
  from torch import dtype as torch_dtype
@@ -16,8 +23,15 @@ from torch.utils.tensorboard import SummaryWriter
16
23
  from qadence.ml_tools.config import TrainConfig
17
24
  from qadence.ml_tools.data import DictDataLoader, data_to_device
18
25
  from qadence.ml_tools.optimize_step import optimize_step
19
- from qadence.ml_tools.printing import print_metrics, write_tensorboard
26
+ from qadence.ml_tools.printing import (
27
+ log_model_tracker,
28
+ log_tracker,
29
+ plot_tracker,
30
+ print_metrics,
31
+ write_tracker,
32
+ )
20
33
  from qadence.ml_tools.saveload import load_checkpoint, write_checkpoint
34
+ from qadence.types import ExperimentTrackingTool
21
35
 
22
36
  logger = getLogger(__name__)
23
37
 
@@ -30,7 +44,6 @@ def train(
30
44
  loss_fn: Callable,
31
45
  device: torch_device = None,
32
46
  optimize_step: Callable = optimize_step,
33
- write_tensorboard: Callable = write_tensorboard,
34
47
  dtype: torch_dtype = None,
35
48
  ) -> tuple[Module, Optimizer]:
36
49
  """Runs the training loop with gradient-based optimizer.
@@ -48,15 +61,11 @@ def train(
48
61
  the model
49
62
  optimizer: The optimizer to use.
50
63
  config: `TrainConfig` with additional training options.
51
- loss_fn: Loss function returning (loss: float, metrics: dict[str, float])
64
+ loss_fn: Loss function returning (loss: float, metrics: dict[str, float], ...)
52
65
  device: String defining device to train on, pass 'cuda' for GPU.
53
66
  optimize_step: Customizable optimization callback which is called at every iteration.=
54
67
  The function must have the signature `optimize_step(model,
55
68
  optimizer, loss_fn, xs, device="cpu")`.
56
- write_tensorboard: Customizable tensorboard logging callback which is
57
- called every `config.write_every` iterations. The function must have
58
- the signature `write_tensorboard(writer, loss, metrics, iteration)`
59
- (see the example below).
60
69
  dtype: The dtype to use for the data.
61
70
 
62
71
  Example:
@@ -122,8 +131,11 @@ def train(
122
131
  model = model.module.to(device=device, dtype=dtype)
123
132
  else:
124
133
  model = model.to(device=device, dtype=dtype)
125
- # initialize tensorboard
126
- writer = SummaryWriter(config.folder, purge_step=init_iter)
134
+ # initialize tracking tool
135
+ if config.tracking_tool == ExperimentTrackingTool.TENSORBOARD:
136
+ writer = SummaryWriter(config.folder, purge_step=init_iter)
137
+ else:
138
+ writer = importlib.import_module("mlflow")
127
139
 
128
140
  perform_val = isinstance(config.val_every, int)
129
141
  if perform_val:
@@ -166,7 +178,7 @@ def train(
166
178
  best_val_loss, metrics = loss_fn(model, xs_to_device)
167
179
 
168
180
  metrics["val_loss"] = best_val_loss
169
- write_tensorboard(writer, None, metrics, init_iter)
181
+ write_tracker(writer, None, metrics, init_iter, tracking_tool=config.tracking_tool)
170
182
 
171
183
  if config.folder:
172
184
  if config.checkpoint_best_only:
@@ -174,6 +186,14 @@ def train(
174
186
  else:
175
187
  write_checkpoint(config.folder, model, optimizer, init_iter)
176
188
 
189
+ plot_tracker(
190
+ writer,
191
+ model,
192
+ init_iter,
193
+ config.plotting_functions,
194
+ tracking_tool=config.tracking_tool,
195
+ )
196
+
177
197
  except KeyboardInterrupt:
178
198
  logger.info("Terminating training gracefully after the current iteration.")
179
199
 
@@ -218,19 +238,31 @@ def train(
218
238
  print_metrics(loss, metrics, iteration - 1)
219
239
 
220
240
  if iteration % config.write_every == 0:
221
- write_tensorboard(writer, loss, metrics, iteration - 1)
241
+ write_tracker(
242
+ writer, loss, metrics, iteration, tracking_tool=config.tracking_tool
243
+ )
222
244
 
245
+ if iteration % config.plot_every == 0:
246
+ plot_tracker(
247
+ writer,
248
+ model,
249
+ iteration,
250
+ config.plotting_functions,
251
+ tracking_tool=config.tracking_tool,
252
+ )
223
253
  if perform_val:
224
254
  if iteration % config.val_every == 0:
225
255
  xs = next(dl_iter_val)
226
256
  xs_to_device = data_to_device(xs, device=device, dtype=data_dtype)
227
- val_loss, _ = loss_fn(model, xs_to_device)
257
+ val_loss, *_ = loss_fn(model, xs_to_device)
228
258
  if config.validation_criterion(val_loss, best_val_loss, config.val_epsilon): # type: ignore[misc]
229
259
  best_val_loss = val_loss
230
260
  if config.folder and config.checkpoint_best_only:
231
261
  write_checkpoint(config.folder, model, optimizer, iteration="best")
232
262
  metrics["val_loss"] = val_loss
233
- write_tensorboard(writer, None, metrics, iteration)
263
+ write_tracker(
264
+ writer, loss, metrics, iteration, tracking_tool=config.tracking_tool
265
+ )
234
266
 
235
267
  if config.folder:
236
268
  if iteration % config.checkpoint_every == 0 and not config.checkpoint_best_only:
@@ -245,17 +277,32 @@ def train(
245
277
  try:
246
278
  xs = next(dl_iter) if dataloader is not None else None # type: ignore[arg-type]
247
279
  xs_to_device = data_to_device(xs, device=device, dtype=data_dtype)
248
- loss, metrics = loss_fn(model, xs_to_device)
280
+ loss, metrics, *_ = loss_fn(model, xs_to_device)
281
+ if dataloader is None:
282
+ loss = loss.item()
249
283
  if iteration % config.print_every == 0 and config.verbose:
250
284
  print_metrics(loss, metrics, iteration)
251
285
 
252
286
  except KeyboardInterrupt:
253
287
  logger.info("Terminating training gracefully after the current iteration.")
254
288
 
255
- # Final printing, writing and checkpointing
289
+ # Final checkpointing and writing
256
290
  if config.folder and not config.checkpoint_best_only:
257
291
  write_checkpoint(config.folder, model, optimizer, iteration)
258
- write_tensorboard(writer, loss, metrics, iteration)
259
- writer.close()
292
+ write_tracker(writer, loss, metrics, iteration, tracking_tool=config.tracking_tool)
293
+
294
+ # writing hyperparameters
295
+ if config.hyperparams:
296
+ log_tracker(writer, config.hyperparams, metrics, tracking_tool=config.tracking_tool)
297
+
298
+ # logging the model
299
+ if config.log_model:
300
+ log_model_tracker(writer, model, dataloader, tracking_tool=config.tracking_tool)
301
+
302
+ # close tracker
303
+ if config.tracking_tool == ExperimentTrackingTool.TENSORBOARD:
304
+ writer.close()
305
+ elif config.tracking_tool == ExperimentTrackingTool.MLFLOW:
306
+ writer.end_run()
260
307
 
261
308
  return model, optimizer
@@ -1,11 +1,18 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import importlib
3
4
  from logging import getLogger
4
5
  from typing import Callable
5
6
 
6
7
  import nevergrad as ng
7
8
  from nevergrad.optimization.base import Optimizer as NGOptimizer
8
- from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn
9
+ from rich.progress import (
10
+ BarColumn,
11
+ Progress,
12
+ TaskProgressColumn,
13
+ TextColumn,
14
+ TimeRemainingColumn,
15
+ )
9
16
  from torch import Tensor
10
17
  from torch.nn import Module
11
18
  from torch.utils.data import DataLoader
@@ -14,9 +21,16 @@ from torch.utils.tensorboard import SummaryWriter
14
21
  from qadence.ml_tools.config import TrainConfig
15
22
  from qadence.ml_tools.data import DictDataLoader
16
23
  from qadence.ml_tools.parameters import get_parameters, set_parameters
17
- from qadence.ml_tools.printing import print_metrics, write_tensorboard
24
+ from qadence.ml_tools.printing import (
25
+ log_model_tracker,
26
+ log_tracker,
27
+ plot_tracker,
28
+ print_metrics,
29
+ write_tracker,
30
+ )
18
31
  from qadence.ml_tools.saveload import load_checkpoint, write_checkpoint
19
32
  from qadence.ml_tools.tensors import promote_to_tensor
33
+ from qadence.types import ExperimentTrackingTool
20
34
 
21
35
  logger = getLogger(__name__)
22
36
 
@@ -42,6 +56,7 @@ def train(
42
56
  dataloader: Dataloader constructed via `dictdataloader`
43
57
  optimizer: The optimizer to use taken from the Nevergrad library. If this is not
44
58
  the case the function will raise an AssertionError
59
+ config: `TrainConfig` with additional training options.
45
60
  loss_fn: Loss function returning (loss: float, metrics: dict[str, float])
46
61
  """
47
62
  init_iter = 0
@@ -63,8 +78,11 @@ def train(
63
78
  # TODO: support also Scipy optimizers
64
79
  assert isinstance(optimizer, NGOptimizer), "Use only optimizers from the Nevergrad library"
65
80
 
66
- # initialize tensorboard
67
- writer = SummaryWriter(config.folder, purge_step=init_iter)
81
+ # initialize tracking tool
82
+ if config.tracking_tool == ExperimentTrackingTool.TENSORBOARD:
83
+ writer = SummaryWriter(config.folder, purge_step=init_iter)
84
+ else:
85
+ writer = importlib.import_module("mlflow")
68
86
 
69
87
  # set optimizer configuration and initial parameters
70
88
  optimizer.budget = config.max_iter
@@ -100,7 +118,16 @@ def train(
100
118
  print_metrics(loss, metrics, iteration)
101
119
 
102
120
  if iteration % config.write_every == 0:
103
- write_tensorboard(writer, loss, metrics, iteration)
121
+ write_tracker(writer, loss, metrics, iteration, tracking_tool=config.tracking_tool)
122
+
123
+ if iteration % config.plot_every == 0:
124
+ plot_tracker(
125
+ writer,
126
+ model,
127
+ iteration,
128
+ config.plotting_functions,
129
+ tracking_tool=config.tracking_tool,
130
+ )
104
131
 
105
132
  if config.folder:
106
133
  if iteration % config.checkpoint_every == 0:
@@ -109,10 +136,22 @@ def train(
109
136
  if iteration >= init_iter + config.max_iter:
110
137
  break
111
138
 
112
- ## Final writing and stuff
139
+ # writing hyperparameters
140
+ if config.hyperparams:
141
+ log_tracker(writer, config.hyperparams, metrics, tracking_tool=config.tracking_tool)
142
+
143
+ if config.log_model:
144
+ log_model_tracker(writer, model, dataloader, tracking_tool=config.tracking_tool)
145
+
146
+ # Final writing and checkpointing
113
147
  if config.folder:
114
148
  write_checkpoint(config.folder, model, optimizer, iteration)
115
- write_tensorboard(writer, loss, metrics, iteration)
116
- writer.close()
149
+ write_tracker(writer, loss, metrics, iteration, tracking_tool=config.tracking_tool)
150
+
151
+ # close tracker
152
+ if config.tracking_tool == ExperimentTrackingTool.TENSORBOARD:
153
+ writer.close()
154
+ elif config.tracking_tool == ExperimentTrackingTool.MLFLOW:
155
+ writer.end_run()
117
156
 
118
157
  return model, optimizer
qadence/types.py CHANGED
@@ -6,9 +6,11 @@ from typing import Callable, Iterable, Tuple, Union
6
6
 
7
7
  import numpy as np
8
8
  import sympy
9
+ from matplotlib.figure import Figure
9
10
  from numpy.typing import ArrayLike
10
11
  from pyqtorch.utils import SolverType
11
12
  from torch import Tensor, pi
13
+ from torch.nn import Module
12
14
 
13
15
  TNumber = Union[int, float, complex, np.int64, np.float64]
14
16
  """Union of python and numpy numeric types."""
@@ -445,3 +447,13 @@ class ObservableTransform:
445
447
  """Use the given values as min and max."""
446
448
  NONE = "none"
447
449
  """No transformation."""
450
+
451
+
452
+ class ExperimentTrackingTool(StrEnum):
453
+ TENSORBOARD = "tensorboard"
454
+ """Use the tensorboard experiment tracker."""
455
+ MLFLOW = "mlflow"
456
+ """Use the ml-flow experiment tracker."""
457
+
458
+
459
+ LoggablePlotFunction = Callable[[Module, int], tuple[str, Figure]]
@@ -1,8 +1,8 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: qadence
3
- Version: 1.7.1
3
+ Version: 1.7.3
4
4
  Summary: Pasqal interface for circuit-based quantum computing SDKs
5
- Author-email: Aleksander Wennersteen <aleksander.wennersteen@pasqal.com>, Gert-Jan Both <gert-jan.both@pasqal.com>, Niklas Heim <niklas.heim@pasqal.com>, Mario Dagrada <mario.dagrada@pasqal.com>, Vincent Elfving <vincent.elfving@pasqal.com>, Dominik Seitz <dominik.seitz@pasqal.com>, Roland Guichard <roland.guichard@pasqal.com>, "Joao P. Moutinho" <joao.moutinho@pasqal.com>, Vytautas Abramavicius <vytautas.abramavicius@pasqal.com>, Gergana Velikova <gergana.velikova@pasqal.com>, Eduardo Maschio <eduardo.maschio@pasqal.com>, Smit Chaudhary <smit.chaudhary@pasqal.com>, Ignacio Fernández Graña <ignacio.fernandez-grana@pasqal.com>, Charles Moussa <charles.moussa@pasqal.com>
5
+ Author-email: Aleksander Wennersteen <aleksander.wennersteen@pasqal.com>, Gert-Jan Both <gert-jan.both@pasqal.com>, Niklas Heim <niklas.heim@pasqal.com>, Mario Dagrada <mario.dagrada@pasqal.com>, Vincent Elfving <vincent.elfving@pasqal.com>, Dominik Seitz <dominik.seitz@pasqal.com>, Roland Guichard <roland.guichard@pasqal.com>, "Joao P. Moutinho" <joao.moutinho@pasqal.com>, Vytautas Abramavicius <vytautas.abramavicius@pasqal.com>, Gergana Velikova <gergana.velikova@pasqal.com>, Eduardo Maschio <eduardo.maschio@pasqal.com>, Smit Chaudhary <smit.chaudhary@pasqal.com>, Ignacio Fernández Graña <ignacio.fernandez-grana@pasqal.com>, Charles Moussa <charles.moussa@pasqal.com>, Giorgio Tosti Balducci <giorgio.tosti-balducci@pasqal.com>
6
6
  License: Apache 2.0
7
7
  License-File: LICENSE
8
8
  Classifier: License :: OSI Approved :: Apache Software License
@@ -22,7 +22,7 @@ Requires-Dist: matplotlib
22
22
  Requires-Dist: nevergrad
23
23
  Requires-Dist: numpy
24
24
  Requires-Dist: openfermion
25
- Requires-Dist: pyqtorch==1.2.5
25
+ Requires-Dist: pyqtorch==1.3.2
26
26
  Requires-Dist: pyyaml
27
27
  Requires-Dist: rich
28
28
  Requires-Dist: scipy
@@ -33,6 +33,7 @@ Requires-Dist: torch
33
33
  Provides-Extra: all
34
34
  Requires-Dist: braket; extra == 'all'
35
35
  Requires-Dist: libs; extra == 'all'
36
+ Requires-Dist: mlflow; extra == 'all'
36
37
  Requires-Dist: protocols; extra == 'all'
37
38
  Requires-Dist: pulser; extra == 'all'
38
39
  Requires-Dist: visualization; extra == 'all'
@@ -51,6 +52,8 @@ Requires-Dist: optax; extra == 'horqrux'
51
52
  Requires-Dist: sympy2jax; extra == 'horqrux'
52
53
  Provides-Extra: libs
53
54
  Requires-Dist: qadence-libs; extra == 'libs'
55
+ Provides-Extra: mlflow
56
+ Requires-Dist: mlflow; extra == 'mlflow'
54
57
  Provides-Extra: protocols
55
58
  Requires-Dist: qadence-protocols; extra == 'protocols'
56
59
  Provides-Extra: pulser
@@ -4,7 +4,7 @@ qadence/circuit.py,sha256=3lQdjj_srxgk6f5M3eh3kE-Qdov4FA9TZxZZb0E1_mI,6966
4
4
  qadence/decompose.py,sha256=C4LYia_GcC9Rx3QO0ZLWTI9dN63a8WTEAXO0ZVQWuiE,5221
5
5
  qadence/divergences.py,sha256=JhpELhWSnuDvQxa9hJp_DE3EQg2Ban-Ta0mHZ_fVrHg,1832
6
6
  qadence/execution.py,sha256=JNvN8RVxbbysm5CzS9fdp5LpyVaDpk84h-BkC-S0Wj8,9587
7
- qadence/extensions.py,sha256=_RfP1572Ijb3ZinsRR8dI6DJXyLnrC4_H46COKaah2Q,5476
7
+ qadence/extensions.py,sha256=J4bNYX8SXNDnxiyL9k0pvIX8Cycy1vb5Dford6vFAnI,5741
8
8
  qadence/libs.py,sha256=HetkKO8TCTlVCViQdVQJvxwBekrhd-y_iMox4UJMY1M,410
9
9
  qadence/log_config.yaml,sha256=WwSpxqMSXgPJ7wO_wh46UnFzXdgX9NVA4MbN3TcJFyE,485
10
10
  qadence/logger.py,sha256=Hb76pK3VyQjVjJb4_NqFlOJgjYJVa8t7DHJFlzOM86M,407
@@ -18,7 +18,7 @@ qadence/register.py,sha256=mwmvS6PcTY0F9cIhTUXG3NT73FIagfMCwVqYa4DrQrk,13001
18
18
  qadence/serial_expr_grammar.peg,sha256=z5ytL7do9kO8o4h-V5GrsDuLdso0KsRcMuIYURFfmAY,328
19
19
  qadence/serialization.py,sha256=qEET6Gu9u2aSibPve3bJrqDzK2_gO3RPDJjt4ZY8GbE,15596
20
20
  qadence/states.py,sha256=5QIOBBYs8e2uLFiMa8iMYZ-MvWIFEqkZAjNYx0SyYPI,14843
21
- qadence/types.py,sha256=6Dw8Ibtn0ZKRVg5DcW3O5129LIwYT-LLT2V3XjxFQek,10729
21
+ qadence/types.py,sha256=Uex0set3DqJZb9Frkh9FvC5cp2ypp2U_A4e_y2XRDTo,11054
22
22
  qadence/utils.py,sha256=zb2j7wURfy8kazaS84r4t35vAeDpo4Tpm4HbmPH-kFA,9865
23
23
  qadence/analog/__init__.py,sha256=BCyS9R4KUjzUXN0Ax3b0eMo8ZAuSkGoJQVtZ4_pvAFs,279
24
24
  qadence/analog/addressing.py,sha256=fu5-xW9lquEbagApNp23S_ET1kl0iDtZUrIYSVNmw9s,6435
@@ -51,7 +51,7 @@ qadence/backends/pulser/waveforms.py,sha256=0uz95b7rUaUUtN0tuHBZmJ0H6UBmfHST_59o
51
51
  qadence/backends/pyqtorch/__init__.py,sha256=0OdVy6cq0oQggV48LO1WXdaZuSkDkz7OYNEPIkNAmfk,140
52
52
  qadence/backends/pyqtorch/backend.py,sha256=5ChkSD3D5totCMxwbzC31yIeOBxw6QEM5CS6qnq1Jqw,9287
53
53
  qadence/backends/pyqtorch/config.py,sha256=jK-if0OF6L_inP-oZhWI4-b8wcrOiK8-EVv3NYDOfBM,2056
54
- qadence/backends/pyqtorch/convert_ops.py,sha256=8xw43wWETIuD2N-5g-EZVeUYYOMnTWuIePfqqpQ3Mbg,14897
54
+ qadence/backends/pyqtorch/convert_ops.py,sha256=lrc2_l-xsMY0fBAtbYsStB_TwIiG27Bur_tVLpaaNDA,15160
55
55
  qadence/blocks/__init__.py,sha256=H6jEA_CptkE-eoB4UfSbUiDszbxxhZwECV_TgoZWXoU,960
56
56
  qadence/blocks/abstract.py,sha256=QFwKPagbTrn3V4c2DHpBd-QL_mVIUXfbvyBLUdD6zw4,12023
57
57
  qadence/blocks/analog.py,sha256=ymnnlSVoW1XL05ZvnnHCqRTHuOXIEY_7E9M0PNKJZy4,10812
@@ -59,7 +59,7 @@ qadence/blocks/block_to_tensor.py,sha256=Sg7YGKUoPUUHKvyB8Khztrk7UYnV5SD451_3I00
59
59
  qadence/blocks/composite.py,sha256=z_lXRBVnh-DdvfZdv6T0ZEmVhlU76zBt72P_FGGa-PQ,8897
60
60
  qadence/blocks/embedding.py,sha256=XC3_U4Dqi9jvU1TVbl2bZQJAGNL8Sww89eKOVLCsfiQ,6752
61
61
  qadence/blocks/manipulate.py,sha256=kPmzej7mnWFoqTJA2CkGulT7hcPha0GGPARC8rjZltg,2387
62
- qadence/blocks/matrix.py,sha256=k6CC3zN2i6nd7_9S9u4fJAxy9wfkM287945GpArwOhY,3771
62
+ qadence/blocks/matrix.py,sha256=r1JqyD3kzspLq3aYbdVCNbWpjUJXmUx2c34luIo8Mcc,3947
63
63
  qadence/blocks/primitive.py,sha256=RoEA9_VCI_8o4yg_pMe5T38z3LD6IFz9qlCiF3iHmOo,16631
64
64
  qadence/blocks/utils.py,sha256=iCJDi6HTYYaQQCoP3cdIKeCDuy8KQCxctrHN5QWXV-M,16349
65
65
  qadence/constructors/__init__.py,sha256=oEGuILUB8qEbSeaKV9Q-Tk-DAVx-U0wqn8VoSztVueg,984
@@ -90,7 +90,7 @@ qadence/engines/jax/differentiable_backend.py,sha256=W5rDA8wb-ECnFWoLj4dVugF9v1l
90
90
  qadence/engines/jax/differentiable_expectation.py,sha256=XBYHT1XKRuZfKxTcNy8KJpSDPt-2PR4ZCanImCPI9OI,3677
91
91
  qadence/engines/torch/__init__.py,sha256=iZFdD32ot0B0CVyC-f5dVViOBnqoalxa6M9Lj4WQuPE,160
92
92
  qadence/engines/torch/differentiable_backend.py,sha256=AWthwvKE8pCOih4dZ3tXxQX4W1ps9mBcvo7n4V9V24Y,3553
93
- qadence/engines/torch/differentiable_expectation.py,sha256=kAYl23Xq9MwkLm0MzmiIES-qUXc2zQsl9TuIF_c-qTE,9599
93
+ qadence/engines/torch/differentiable_expectation.py,sha256=ojSsde_5PtpebYYBWYmU5Jj76m7pfDki2uqChWfepts,9621
94
94
  qadence/exceptions/__init__.py,sha256=BU6vWrI9mshzr1aTPm1Ticr_o_42GjTrWI4OZXhThsI,203
95
95
  qadence/exceptions/exceptions.py,sha256=4j_VJpx2sZ2Mir5BJUWu4nwb131FY1ygO4q8-XlyfRc,190
96
96
  qadence/measurements/__init__.py,sha256=RIjG9tVJMqhNzyj7maZI250Um0KgHl2PizDcKJag-JU,161
@@ -104,17 +104,17 @@ qadence/mitigations/analog_zne.py,sha256=g0QkjSdF-N9Dv2N8Oza4sylnjUMid5ea-4NCT9T
104
104
  qadence/mitigations/protocols.py,sha256=Jq9MyLujfTyWmc7XVUGYVRUkJT1MmZw-GgmWpVjmX2Y,1608
105
105
  qadence/mitigations/readout.py,sha256=HPfYmdjRlieUdOBMZTghFK4DRWfveM4KkDkEI0bMI0E,6262
106
106
  qadence/ml_tools/__init__.py,sha256=HP4xjldkUZ9_WbZEDgpl31qoP9st5SBbC-DjI5pkx3k,1054
107
- qadence/ml_tools/config.py,sha256=c3vvQiNXlNoJnOuMFfqAd5fVkmpa7EzCN_ztvPV1jBU,14152
107
+ qadence/ml_tools/config.py,sha256=g-hFaVoG57p0elde0giSEP5_XTvyPquDg49xGOtj6gA,17686
108
108
  qadence/ml_tools/constructors.py,sha256=cE510DqCKBe4tImH90qHawEbXU-mlQuW9Wh15lUON6Q,27293
109
109
  qadence/ml_tools/data.py,sha256=8ZUFjhQSp94w7icX7RzM2J39Yo7P_T-AgjcThBc8miI,4283
110
110
  qadence/ml_tools/models.py,sha256=SjwAPbSl9zn9YqfmwqHc2lIXCkIpwG_ysz4jieRh7W0,16996
111
111
  qadence/ml_tools/optimize_step.py,sha256=ATXWmAqybJVK3QmAaDqVXB5mxjTo2MIi_e0a5WSPFpc,1800
112
112
  qadence/ml_tools/parameters.py,sha256=gew2Kq_5-RgRpaTvs8eauVhgo0sTqqDQEV6WHFEiLGM,1301
113
- qadence/ml_tools/printing.py,sha256=Mzdhmm-gPclhYL0NPN2zwJ19-kKQ4PDwFkefIOYEmzU,745
113
+ qadence/ml_tools/printing.py,sha256=aNJdUdh6UkNFIvgOFVFNXpHc_ilJSezPGgl7it4o7Q4,4549
114
114
  qadence/ml_tools/saveload.py,sha256=jeYG7Y1ime0P06SMWOiCgWlci-xHdEPrAARfM-awDH8,5798
115
115
  qadence/ml_tools/tensors.py,sha256=xZ9ZRzOqEaMgLUGWQf1najDmL6iLuN1ojCGVFs1Tm94,1337
116
- qadence/ml_tools/train_grad.py,sha256=VcBNr7g9tfjnlyG7XwQVHtEivnki34LJZcgJuRxyHhs,10924
117
- qadence/ml_tools/train_no_grad.py,sha256=PrOfPwu6C-YqfFxnRkbeyOQzqSyjRrx4AZZd6C-1xRw,4705
116
+ qadence/ml_tools/train_grad.py,sha256=tf162ZfK07NZeqmTDvA92kkojPxX8s2nwBP_VM2qSvw,12190
117
+ qadence/ml_tools/train_no_grad.py,sha256=4AvJok882Vq6EVE1NNLicd0aww6Fku0laZm05dcILaU,5868
118
118
  qadence/ml_tools/utils.py,sha256=PW8FyoV0mG_DtN1U8njTDV5qxZ0EK4mnFwMAsLBArfk,1410
119
119
  qadence/noise/__init__.py,sha256=r0nR8uEZeB1M9pI2UisjWq0bjw50fPFfVGzIMev923g,147
120
120
  qadence/noise/protocols.py,sha256=-aZ06JvMnpxCeT5v5lI_RNPOLbb9Ju1Pi1AB6uAXxVE,1653
@@ -133,7 +133,7 @@ qadence/transpile/digitalize.py,sha256=iWRwYAYQsD2INHj0HNbGJriv_3fRCuBW1nDBrwtKS
133
133
  qadence/transpile/flatten.py,sha256=EdhSG5WyF56nbnxINNLqrHgY84MRM1YFjT3fR4aph5Q,3427
134
134
  qadence/transpile/invert.py,sha256=KAefHTG2AWr39aengVhXrzCtJPhrZC-ZnL6vYvmbnY0,4867
135
135
  qadence/transpile/transpile.py,sha256=6MRRkk1OS279L1fwUQjazA6qlfpbd-T_EJMKT8hAhOU,2721
136
- qadence-1.7.1.dist-info/METADATA,sha256=Cd8hZ0vygRkiSbvu57RDAdL3JNvleJe8E334CMr0WXw,9774
137
- qadence-1.7.1.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
138
- qadence-1.7.1.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
139
- qadence-1.7.1.dist-info/RECORD,,
136
+ qadence-1.7.3.dist-info/METADATA,sha256=d3TUjDh_-Ragvuhicd9BWMr4orwEwNgjAXRnGWk_ruQ,9936
137
+ qadence-1.7.3.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
138
+ qadence-1.7.3.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
139
+ qadence-1.7.3.dist-info/RECORD,,