qadence 1.10.3__py3-none-any.whl → 1.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,480 @@
1
+ from __future__ import annotations
2
+ from logging import getLogger
3
+ from typing import Any, Callable
4
+ import functools
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+ import torch.multiprocessing as mp
9
+ from torch import dtype as torch_dtype
10
+ from torch import nn, optim
11
+ from torch.utils.data import DataLoader
12
+ from torch.utils.data.distributed import DistributedSampler
13
+ from torch.nn.parallel import DistributedDataParallel as DDP
14
+
15
+ from qadence.ml_tools.train_utils.distribution import Distributor
16
+ from qadence.ml_tools.data import data_to_device, InfiniteTensorDataset, DictDataLoader
17
+ from qadence.types import ExecutionType
18
+
19
+ logger = getLogger("ml_tools")
20
+
21
+
22
+ class Accelerator(Distributor):
23
+ """
24
+ A class for handling distributed training.
25
+
26
+ This class extends `Distributor` to manage distributed training using PyTorch's
27
+ `torch.distributed` API. It supports spawning multiple processes and wrapping models with
28
+ `DistributedDataParallel` (DDP) when required.
29
+
30
+ This class is provides head level method - distribute() - which wraps a function at a head process level,
31
+ before launching `nprocs` processes as required. Furthermore, it provides processes level methods,
32
+ such as prepare(), and prepare_batch() which can be run inside each process for correct movement and
33
+ preparation of model, optimizers and datasets.
34
+
35
+ Inherited Attributes:
36
+ nprocs (int): Number of processes to launch for distributed training.
37
+ execution (BaseExecution): Detected execution instance for process launch (e.g., "torchrun","default").
38
+ execution_type (ExecutionType): Type of execution used.
39
+ rank (int): Global rank of the process (to be set during environment setup).
40
+ world_size (int): Total number of processes (to be set during environment setup).
41
+ local_rank (int | None): Local rank on the node (to be set during environment setup).
42
+ master_addr (str): Master node address (to be set during environment setup).
43
+ master_port (str): Master node port (to be set during environment setup).
44
+ node_rank (int): Rank of the node on the cluster setup.
45
+
46
+ NOTE: There are three different indicators for number of processes executed.
47
+ - 1. self._config_nprocs: Number of processes specified by the user.
48
+ Provided in the initilization of the Accelerator. (acc = Accelerator(nprocs = 2))
49
+ - 2. self.nprocs: Number of processes defined at the head level.
50
+ - When accelerator is used to spawn processes (e.g., In case default, python execution),
51
+ nprocs = _config_nprocs.
52
+ - When an external elastic method is used to spawn processes (e.g., In case of torchrun),
53
+ nprocs = 1. This is because the external launcher already spawns multiple processes,
54
+ and the accelerator __init__ is called from each process.
55
+ - 3. self.world_size: Number of processes actually executed.
56
+ """
57
+
58
+ # -----------------------------------------------------------------------------
59
+ # HEAD level methods
60
+ # -----------------------------------------------------------------------------
61
+ def __init__(
62
+ self,
63
+ nprocs: int = 1,
64
+ compute_setup: str = "auto",
65
+ log_setup: str = "cpu",
66
+ backend: str = "gloo",
67
+ dtype: torch_dtype | None = None,
68
+ ) -> None:
69
+ """
70
+ Initializes the Accelerator class.
71
+
72
+ Args:
73
+ nprocs (int): Number of processes to launch. Default is 1.
74
+ compute_setup (str): Compute device setup; options are "auto" (default), "gpu", or "cpu".
75
+ - "auto": Uses GPU if available, otherwise CPU.
76
+ - "gpu": Forces GPU usage, raising an error if no CUDA device is available.
77
+ - "cpu": Forces CPU usage.
78
+ log_setup (str): Logging device setup; options are "auto", "cpu" (default).
79
+ - "auto": Uses same device to log as used for computation.
80
+ - "cpu": Forces CPU logging.
81
+ backend (str): The backend for distributed communication. Default is "gloo".
82
+ dtype (torch.dtype | None): Data type for controlling numerical precision. Default is None.
83
+ """
84
+ super().__init__(nprocs, compute_setup, log_setup, backend, dtype)
85
+
86
+ # Default values
87
+ self.rank = 0
88
+ self.local_rank = 0
89
+ self.world_size = self.execution.get_world_size(0, self.nprocs)
90
+
91
+ def distribute(self, fun: Callable) -> Callable:
92
+ """
93
+ Decorator to distribute the fit function across multiple processes.
94
+
95
+ This function is generic and can work with other methods as well.
96
+ Weather it is bound or unbound.
97
+
98
+ When applied to a function (typically a fit function), this decorator
99
+ will execute the function in a distributed fashion using torch.multiprocessing.
100
+ The number of processes used is determined by `self.nprocs`,
101
+ and if multiple nodes are involved (`self.num_nodes > 1`), the process count is
102
+ adjusted accordingly. In single process mode (`self.nporcs` is 1), the function
103
+ is executed directly in the current process.
104
+
105
+ After execution, the decorator returns the model stored in `instance.model`.
106
+
107
+ Parameters:
108
+ fun (callable): The function to be decorated. This function usually implements
109
+ a model fitting or training routine.
110
+
111
+ Returns:
112
+ callable: The wrapped function. When called, it will execute in distributed mode
113
+ (if configured) and return the value of `instance.model`.
114
+ """
115
+
116
+ @functools.wraps(fun)
117
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
118
+
119
+ # Get the original picklable function
120
+ # for the case of bound class method
121
+ # as well as a function
122
+ if self.is_class_method(fun, args):
123
+ instance = args[0]
124
+ method_name = fun.__name__
125
+ method = getattr(instance, method_name)
126
+ args = args[1:]
127
+ self._spawn_method(instance, method, args, kwargs)
128
+ else:
129
+ instance = None
130
+ # method_name = fun.__name__
131
+ # module = inspect.getmodule(fun)
132
+ # method = getattr(module, method_name) if module else fun
133
+ self._spawn_method(instance, fun, args, kwargs)
134
+
135
+ if instance and hasattr(instance, "accelerator"):
136
+ instance.accelerator.finalize()
137
+ else:
138
+ self.finalize()
139
+
140
+ # TODO: Return the original returns from fun
141
+ # Currently it only returns the model and optimizer
142
+ # similar to the fit method.
143
+ try:
144
+ return instance.model, instance.optimizer
145
+ except Exception:
146
+ return
147
+
148
+ return wrapper
149
+
150
+ def worker(self, rank: int, instance: Any, fun: Callable, args: tuple, kwargs: dict) -> None:
151
+ """
152
+ Worker function to be executed in each spawned process.
153
+
154
+ This function is called in every subprocess created by torch.multiprocessing (via mp.spawn).
155
+ It performs the following tasks:
156
+ 1. Sets up the accelerator for the given process rank. This typically involves configuring
157
+ the GPU or other hardware resources for distributed training.
158
+ 2. If the retrieved method has been decorated (i.e. it has a '__wrapped__' attribute),
159
+ the original, unwrapped function is invoked with the given arguments. Otherwise,
160
+ the method is called directly.
161
+
162
+ Args:
163
+ rank (int): The rank (or identifier) of the spawned process.
164
+ instance (object): The object (Trainer) that contains the method to execute.
165
+ This object is expected to have an `accelerator` attribute with a `setup_process(rank)` method.
166
+ This argument is optional, in case it is None, the fun will be called independently.
167
+ fun (Callable): The function of the method on the instance to be executed.
168
+ args (tuple): Positional arguments to pass to the target method.
169
+ kwargs (dict): Keyword arguments to pass to the target method.
170
+ """
171
+ # Setup the accelerator for the given process rank (e.g., configuring GPU)
172
+ if instance and instance.accelerator:
173
+ instance.accelerator.setup_process(rank)
174
+ else:
175
+ self.setup_process(rank)
176
+
177
+ if hasattr(fun, "__wrapped__"):
178
+ # Explicitly get the original (unbound) method, passing in the instance.
179
+ # We need to call the original method in case so that MP spawn does not
180
+ # create multiple processes. (To Avoid infinite loop)
181
+ fun = fun.__wrapped__ # Unwrap if decorated
182
+ fun(instance, *args, **kwargs) if instance else fun(*args, **kwargs)
183
+ else:
184
+ fun(*args, **kwargs)
185
+
186
+ def is_class_method(self, fun: Callable, args: Any) -> bool:
187
+ """
188
+ Determines if `fun` is a class method or a standalone function.
189
+
190
+ Frist argument of the args should be:
191
+ - An object and has __dict__: making it a class
192
+ - Has a method named fun: making it a class that has this method.
193
+
194
+ Args:
195
+ fun (Callable): The function being checked.
196
+ args (tuple): The arguments passed to the function.
197
+
198
+ Returns:
199
+ bool: True if `fun` is a class method, False otherwise.
200
+ """
201
+ return (
202
+ bool(args)
203
+ and isinstance(args[0], object)
204
+ and hasattr(args[0], "__dict__")
205
+ and hasattr(args[0], fun.__name__)
206
+ )
207
+
208
+ def _spawn_method(self, instance: Any, method: Callable, args: Any, kwargs: Any) -> None:
209
+ """
210
+ This method spawns the required numbers of processes.
211
+
212
+ - if execution is `default`, it will spawn `nproc` processes across all nodes
213
+ - if execution is `otherwise`, it will run a single process.
214
+
215
+ Args:
216
+ instance (object): The object (Trainer) that contains the method to execute.
217
+ This object is expected to have an `accelerator` attribute with a `setup_process(rank)` method.
218
+ This argument is optional, in case it is None, the fun will be called independently.
219
+ method (Callable): The function of the method on the instance to be executed.
220
+ args (tuple): Positional arguments to pass to the target method.
221
+ kwargs (dict): Keyword arguments to pass to the target method.
222
+ """
223
+
224
+ if self.execution_type == ExecutionType.DEFAULT and self.world_size > 1:
225
+ # Spawn multiple processes that will run the worker function.
226
+ nprocs = self.nprocs
227
+ if self.execution.num_nodes > 1:
228
+ nprocs //= self.execution.num_nodes
229
+ mp.spawn(
230
+ self.worker,
231
+ args=(instance, method, args, kwargs),
232
+ nprocs=int(nprocs),
233
+ join=True,
234
+ )
235
+ else:
236
+ # In single process mode, call the worker with rank 0.
237
+ self.worker(0, instance, method, args, kwargs)
238
+
239
+ # -----------------------------------------------------------------------------
240
+ # PROCESS level methods
241
+ # -----------------------------------------------------------------------------
242
+ def prepare(self, *args: Any) -> tuple[Any, ...]:
243
+ """
244
+ Prepares models, optimizers, and dataloaders for distributed training.
245
+
246
+ This method iterates over the provided objects and:
247
+ - Moves models to the specified device (e.g., GPU or CPU) and casts them to the
248
+ desired precision (specified by `self.dtype`). It then wraps models in
249
+ DistributedDataParallel (DDP) if more than one device is used.
250
+ - Passes through optimizers unchanged.
251
+ - For dataloaders, it adjusts them to use a distributed sampler (if applicable)
252
+ by calling a helper method. Note that only the sampler is prepared; moving the
253
+ actual batch data to the device is handled separately during training.
254
+ Please use the `prepare_batch` method to move the batch to correct device/dtype.
255
+
256
+ Args:
257
+ *args (Any): A variable number of objects to be prepared. These can include:
258
+ - PyTorch models (`nn.Module`)
259
+ - Optimizers (`optim.Optimizer`)
260
+ - DataLoaders (or a dictionary-like `DictDataLoader` of dataloaders)
261
+
262
+ Returns:
263
+ tuple[Any, ...]: A tuple containing the prepared objects, where each object has been
264
+ modified as needed to support distributed training.
265
+ """
266
+ prepared: list = []
267
+ for obj in args:
268
+ if obj is None:
269
+ prepared.append(None)
270
+ elif isinstance(obj, nn.Module):
271
+ prepared.append(self._prepare_model(obj))
272
+ elif isinstance(obj, optim.Optimizer):
273
+ prepared.append(self._prepare_optimizer(obj))
274
+ elif isinstance(obj, (DataLoader, DictDataLoader)):
275
+ prepared.append(self._prepare_data(obj))
276
+ else:
277
+ prepared.append(obj)
278
+ return tuple(prepared)
279
+
280
+ def _prepare_model(self, model: nn.Module) -> nn.Module:
281
+ """
282
+ Moves the model to the desired device and casts it to the specified dtype.
283
+
284
+ In a distributed setting, if more than one device is used (i.e., self.world_size > 1),
285
+ the model is wrapped in DistributedDataParallel (DDP) to handle gradient synchronization
286
+ across devices.
287
+
288
+ Args:
289
+ model (nn.Module): The PyTorch model to prepare.
290
+
291
+ Returns:
292
+ nn.Module: The model moved to the correct device (and wrapped in DDP if applicable).
293
+ """
294
+ model = model.to(device=self.execution.device, dtype=self.execution.dtype)
295
+
296
+ # If using distributed training with more than one device:
297
+ if self.world_size > 1:
298
+ if self.execution.device.startswith("cuda"):
299
+ # For GPU-based training: wrap the model with DDP and specify the local GPU.
300
+ model = DDP(model, device_ids=[self.local_rank])
301
+ else:
302
+ # For CPU-based or other environments:
303
+ if not self.local_rank:
304
+ model = DDP(model)
305
+
306
+ return model
307
+
308
+ def _prepare_optimizer(self, optimizer: optim.Optimizer) -> optim.Optimizer:
309
+ """
310
+ Passes through the optimizer without modification.
311
+
312
+ Args:
313
+ optimizer (optim.Optimizer): The optimizer to prepare.
314
+
315
+ Returns:
316
+ optim.Optimizer: The unmodified optimizer.
317
+ """
318
+ # Optimizers are not device-specific in this context, so no action is needed.
319
+ return optimizer
320
+
321
+ def _prepare_data(self, dataloader: DataLoader | DictDataLoader) -> DataLoader | DictDataLoader:
322
+ """
323
+ Adjusts DataLoader(s) for distributed training.
324
+
325
+ Args:
326
+ dataloader (Union[DataLoader, DictDataLoader]): The dataloader or dictionary of dataloaders to prepare.
327
+
328
+ Returns:
329
+ Union[DataLoader, DictDataLoader]: The prepared dataloader(s) with the correct distributed
330
+ sampling setup.
331
+ """
332
+ if isinstance(dataloader, DictDataLoader):
333
+ # If the input is a DictDataLoader, prepare each contained DataLoader.
334
+ prepared_dataloaders = {
335
+ key: self._prepare_dataloader(dl) for key, dl in dataloader.dataloaders.items()
336
+ }
337
+ return DictDataLoader(prepared_dataloaders)
338
+ else:
339
+ # For a single DataLoader, prepare it directly.
340
+ return self._prepare_dataloader(dataloader)
341
+
342
+ def _prepare_dataloader(self, dataloader: DataLoader) -> DataLoader:
343
+ """
344
+ Prepares a single DataLoader for distributed training.
345
+
346
+ When training in a distributed setting (i.e., when `self.world_size > 1`), data must be
347
+ divided among multiple processes. This is achieved by creating a
348
+ DistributedSampler that splits the dataset into distinct subsets for each process.
349
+
350
+ This method does the following:
351
+ - If distributed training is enabled:
352
+ - Checks if the dataset is not an instance of `InfiniteTensorDataset`.
353
+ - If so, creates a `DistributedSampler` for the dataset using the total number
354
+ of replicas (`self.world_size`) and the current process's rank (`self.local_rank`).
355
+ - Otherwise (i.e., for infinite datasets), no sampler is set (sampler remains `None`).
356
+ - Returns a new DataLoader configured with:
357
+ - The same dataset and batch size as the original.
358
+ - The distributed sampler (if applicable).
359
+ - The number of workers and pin_memory settings retrieved from the original DataLoader.
360
+ - If not in a distributed setting (i.e., `self.world_size <= 1`), returns the original DataLoader unmodified.
361
+
362
+ Args:
363
+ dataloader (DataLoader): The original DataLoader instance that loads the dataset.
364
+
365
+ Returns:
366
+ DataLoader: A new DataLoader prepared for distributed training if in a multi-process environment;
367
+ otherwise, the original DataLoader is returned.
368
+ """
369
+ if self.world_size > 1:
370
+ if not isinstance(dataloader.dataset, InfiniteTensorDataset):
371
+ # If the dataset is not an infinite dataset, create a DistributedSampler.
372
+ sampler = DistributedSampler(
373
+ dataloader.dataset, num_replicas=self.world_size, rank=self.local_rank
374
+ )
375
+ else:
376
+ # For infinite datasets, we do not use a sampler since the dataset
377
+ # is designed to loop indefinitely.
378
+ sampler = None
379
+
380
+ return DataLoader(
381
+ dataloader.dataset, # Use the same dataset as the original.
382
+ batch_size=dataloader.batch_size, # Maintain the same batch size.
383
+ sampler=sampler, # Use the created DistributedSampler (or None).
384
+ num_workers=getattr(dataloader, "num_workers", 0),
385
+ pin_memory=getattr(dataloader, "pin_memory", False),
386
+ )
387
+ return dataloader
388
+
389
+ def prepare_batch(self, batch: dict | list | tuple | torch.Tensor | None) -> Any:
390
+ """
391
+ Moves a batch of data to the target device and casts it to the desired data dtype.
392
+
393
+ This method is typically called within the optimization step of your training loop.
394
+ It supports various batch formats:
395
+ - If the batch is a dictionary, each value is moved individually.
396
+ - If the batch is a tuple or list, each element is processed and returned as a tuple.
397
+ - Otherwise, the batch is processed directly.
398
+
399
+ Args:
400
+ batch (Any): The batch of data to move to the device. This can be a dict, tuple, list,
401
+ or any type compatible with `data_to_device`.
402
+
403
+ Returns:
404
+ Any: The batch with all elements moved to `self.device` and cast to `self.data_dtype`.
405
+ """
406
+ if batch is None:
407
+ return None
408
+
409
+ if isinstance(batch, dict):
410
+ return {
411
+ key: data_to_device(
412
+ value, device=self.execution.device, dtype=self.execution.data_dtype
413
+ )
414
+ for key, value in batch.items()
415
+ }
416
+ elif isinstance(batch, (tuple, list)):
417
+ return tuple(
418
+ data_to_device(x, device=self.execution.device, dtype=self.execution.data_dtype)
419
+ for x in batch
420
+ )
421
+ elif isinstance(batch, torch.Tensor):
422
+ return data_to_device(
423
+ batch, device=self.execution.device, dtype=self.execution.data_dtype
424
+ )
425
+ return
426
+
427
+ def all_reduce_dict(
428
+ self, d: dict[str, torch.Tensor], op: str = "mean"
429
+ ) -> dict[str, torch.Tensor]:
430
+ """
431
+ Performs an all-reduce operation on a dictionary of tensors, averaging values across all processes.
432
+
433
+ Args:
434
+ d (dict[str, torch.Tensor]): A dictionary where values are tensors to be reduced across processes.
435
+ op (str): Operation method to all_reduce with. Available options include `sum`, `avg`, and `max`.
436
+ Defaults to `avg`
437
+
438
+ Returns:
439
+ dict[str, torch.Tensor]: A dictionary with the reduced tensors, averaged over the world size.
440
+ """
441
+ if dist.is_initialized():
442
+ world_size = dist.get_world_size()
443
+ reduced: dict[str, torch.Tensor] = {}
444
+ for key, tensor in d.items():
445
+ if not isinstance(tensor, torch.Tensor):
446
+ tensor = torch.tensor(
447
+ tensor, device=self.execution.device, dtype=self.execution.data_dtype
448
+ )
449
+ tensor = tensor.detach().clone()
450
+ if op == "max":
451
+ dist.all_reduce(tensor, op=dist.ReduceOp.MAX)
452
+ elif op == "sum":
453
+ dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
454
+ else:
455
+ dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
456
+ tensor /= world_size
457
+ reduced[key] = tensor
458
+ return reduced
459
+ else:
460
+ return d
461
+
462
+ def broadcast(self, obj: Any, src: int) -> Any:
463
+ """
464
+ Broadcasts an object from the source process to all processes.
465
+
466
+ On non-source processes, this value is ignored.
467
+
468
+ Args:
469
+ obj (Any): The object to broadcast on the source process.
470
+ src (int): The source process rank.
471
+
472
+ Returns:
473
+ Any : The broadcasted object from the source process.
474
+ """
475
+ if dist.is_initialized():
476
+ obj_list = [obj] if self.rank == src else [None]
477
+ dist.broadcast_object_list(obj_list, src=src)
478
+ return obj_list[0]
479
+ else:
480
+ return obj
@@ -41,10 +41,10 @@ class ConfigManager:
41
41
  handling hyperparameters, deriving additional parameters,
42
42
  and logging warnings.
43
43
  """
44
+ self._log_warnings()
44
45
  self._initialize_folder()
45
46
  self._handle_hyperparams()
46
47
  self._setup_additional_configuration()
47
- self._log_warnings()
48
48
 
49
49
  def _initialize_folder(self) -> None:
50
50
  """
@@ -78,19 +78,19 @@ class ConfigManager:
78
78
  log_folder = root_folder_path / self.config._subfolders[-1]
79
79
  else:
80
80
  if self.config._subfolders:
81
- if self.config.log_folder == root_folder_path / self.config._subfolders[-1]:
82
- log_folder = root_folder_path / self.config._subfolders[-1]
83
- else:
84
- log_folder = Path(self.config.log_folder)
81
+ # self.config.log_folder is an old subfolder.
82
+ log_folder = Path(self.config.log_folder)
85
83
  else:
86
84
  if self.config.log_folder == Path("./"):
85
+ # A subfolder must be created (no specific folder given to config).
87
86
  self._add_subfolder()
88
87
  log_folder = root_folder_path / self.config._subfolders[-1]
89
88
  else:
89
+ # The folder is one and fully specified by the user.
90
90
  log_folder = Path(self.config.log_folder)
91
91
 
92
92
  log_folder.mkdir(parents=True, exist_ok=True)
93
- return Path(log_folder)
93
+ return log_folder
94
94
 
95
95
  def _add_subfolder(self) -> None:
96
96
  """
@@ -147,7 +147,7 @@ class ConfigManager:
147
147
  Sets the stopping criterion if it is not already defined.
148
148
  """
149
149
  if self.config.trainstop_criterion is None:
150
- self.config.trainstop_criterion = lambda x: x <= self.config.max_iter
150
+ return
151
151
 
152
152
  def _log_warnings(self) -> None:
153
153
  """