qadence 1.10.3__py3-none-any.whl → 1.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- qadence/blocks/block_to_tensor.py +21 -24
- qadence/constructors/__init__.py +7 -1
- qadence/constructors/hamiltonians.py +105 -9
- qadence/draw/utils.py +2 -1
- qadence/mitigations/analog_zne.py +6 -2
- qadence/ml_tools/__init__.py +2 -2
- qadence/ml_tools/callbacks/callback.py +80 -50
- qadence/ml_tools/callbacks/callbackmanager.py +3 -2
- qadence/ml_tools/callbacks/writer_registry.py +3 -2
- qadence/ml_tools/config.py +66 -5
- qadence/ml_tools/constructors.py +15 -63
- qadence/ml_tools/data.py +4 -0
- qadence/ml_tools/models.py +64 -4
- qadence/ml_tools/optimize_step.py +1 -2
- qadence/ml_tools/train_utils/__init__.py +3 -1
- qadence/ml_tools/train_utils/accelerator.py +480 -0
- qadence/ml_tools/train_utils/config_manager.py +7 -7
- qadence/ml_tools/train_utils/distribution.py +209 -0
- qadence/ml_tools/train_utils/execution.py +421 -0
- qadence/ml_tools/trainer.py +179 -99
- qadence/model.py +23 -0
- qadence/register.py +5 -1
- qadence/types.py +7 -11
- qadence/utils.py +45 -0
- {qadence-1.10.3.dist-info → qadence-1.11.1.dist-info}/METADATA +14 -11
- {qadence-1.10.3.dist-info → qadence-1.11.1.dist-info}/RECORD +28 -25
- {qadence-1.10.3.dist-info → qadence-1.11.1.dist-info}/WHEEL +0 -0
- {qadence-1.10.3.dist-info → qadence-1.11.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,480 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from logging import getLogger
|
3
|
+
from typing import Any, Callable
|
4
|
+
import functools
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import torch.distributed as dist
|
8
|
+
import torch.multiprocessing as mp
|
9
|
+
from torch import dtype as torch_dtype
|
10
|
+
from torch import nn, optim
|
11
|
+
from torch.utils.data import DataLoader
|
12
|
+
from torch.utils.data.distributed import DistributedSampler
|
13
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
14
|
+
|
15
|
+
from qadence.ml_tools.train_utils.distribution import Distributor
|
16
|
+
from qadence.ml_tools.data import data_to_device, InfiniteTensorDataset, DictDataLoader
|
17
|
+
from qadence.types import ExecutionType
|
18
|
+
|
19
|
+
logger = getLogger("ml_tools")
|
20
|
+
|
21
|
+
|
22
|
+
class Accelerator(Distributor):
|
23
|
+
"""
|
24
|
+
A class for handling distributed training.
|
25
|
+
|
26
|
+
This class extends `Distributor` to manage distributed training using PyTorch's
|
27
|
+
`torch.distributed` API. It supports spawning multiple processes and wrapping models with
|
28
|
+
`DistributedDataParallel` (DDP) when required.
|
29
|
+
|
30
|
+
This class is provides head level method - distribute() - which wraps a function at a head process level,
|
31
|
+
before launching `nprocs` processes as required. Furthermore, it provides processes level methods,
|
32
|
+
such as prepare(), and prepare_batch() which can be run inside each process for correct movement and
|
33
|
+
preparation of model, optimizers and datasets.
|
34
|
+
|
35
|
+
Inherited Attributes:
|
36
|
+
nprocs (int): Number of processes to launch for distributed training.
|
37
|
+
execution (BaseExecution): Detected execution instance for process launch (e.g., "torchrun","default").
|
38
|
+
execution_type (ExecutionType): Type of execution used.
|
39
|
+
rank (int): Global rank of the process (to be set during environment setup).
|
40
|
+
world_size (int): Total number of processes (to be set during environment setup).
|
41
|
+
local_rank (int | None): Local rank on the node (to be set during environment setup).
|
42
|
+
master_addr (str): Master node address (to be set during environment setup).
|
43
|
+
master_port (str): Master node port (to be set during environment setup).
|
44
|
+
node_rank (int): Rank of the node on the cluster setup.
|
45
|
+
|
46
|
+
NOTE: There are three different indicators for number of processes executed.
|
47
|
+
- 1. self._config_nprocs: Number of processes specified by the user.
|
48
|
+
Provided in the initilization of the Accelerator. (acc = Accelerator(nprocs = 2))
|
49
|
+
- 2. self.nprocs: Number of processes defined at the head level.
|
50
|
+
- When accelerator is used to spawn processes (e.g., In case default, python execution),
|
51
|
+
nprocs = _config_nprocs.
|
52
|
+
- When an external elastic method is used to spawn processes (e.g., In case of torchrun),
|
53
|
+
nprocs = 1. This is because the external launcher already spawns multiple processes,
|
54
|
+
and the accelerator __init__ is called from each process.
|
55
|
+
- 3. self.world_size: Number of processes actually executed.
|
56
|
+
"""
|
57
|
+
|
58
|
+
# -----------------------------------------------------------------------------
|
59
|
+
# HEAD level methods
|
60
|
+
# -----------------------------------------------------------------------------
|
61
|
+
def __init__(
|
62
|
+
self,
|
63
|
+
nprocs: int = 1,
|
64
|
+
compute_setup: str = "auto",
|
65
|
+
log_setup: str = "cpu",
|
66
|
+
backend: str = "gloo",
|
67
|
+
dtype: torch_dtype | None = None,
|
68
|
+
) -> None:
|
69
|
+
"""
|
70
|
+
Initializes the Accelerator class.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
nprocs (int): Number of processes to launch. Default is 1.
|
74
|
+
compute_setup (str): Compute device setup; options are "auto" (default), "gpu", or "cpu".
|
75
|
+
- "auto": Uses GPU if available, otherwise CPU.
|
76
|
+
- "gpu": Forces GPU usage, raising an error if no CUDA device is available.
|
77
|
+
- "cpu": Forces CPU usage.
|
78
|
+
log_setup (str): Logging device setup; options are "auto", "cpu" (default).
|
79
|
+
- "auto": Uses same device to log as used for computation.
|
80
|
+
- "cpu": Forces CPU logging.
|
81
|
+
backend (str): The backend for distributed communication. Default is "gloo".
|
82
|
+
dtype (torch.dtype | None): Data type for controlling numerical precision. Default is None.
|
83
|
+
"""
|
84
|
+
super().__init__(nprocs, compute_setup, log_setup, backend, dtype)
|
85
|
+
|
86
|
+
# Default values
|
87
|
+
self.rank = 0
|
88
|
+
self.local_rank = 0
|
89
|
+
self.world_size = self.execution.get_world_size(0, self.nprocs)
|
90
|
+
|
91
|
+
def distribute(self, fun: Callable) -> Callable:
|
92
|
+
"""
|
93
|
+
Decorator to distribute the fit function across multiple processes.
|
94
|
+
|
95
|
+
This function is generic and can work with other methods as well.
|
96
|
+
Weather it is bound or unbound.
|
97
|
+
|
98
|
+
When applied to a function (typically a fit function), this decorator
|
99
|
+
will execute the function in a distributed fashion using torch.multiprocessing.
|
100
|
+
The number of processes used is determined by `self.nprocs`,
|
101
|
+
and if multiple nodes are involved (`self.num_nodes > 1`), the process count is
|
102
|
+
adjusted accordingly. In single process mode (`self.nporcs` is 1), the function
|
103
|
+
is executed directly in the current process.
|
104
|
+
|
105
|
+
After execution, the decorator returns the model stored in `instance.model`.
|
106
|
+
|
107
|
+
Parameters:
|
108
|
+
fun (callable): The function to be decorated. This function usually implements
|
109
|
+
a model fitting or training routine.
|
110
|
+
|
111
|
+
Returns:
|
112
|
+
callable: The wrapped function. When called, it will execute in distributed mode
|
113
|
+
(if configured) and return the value of `instance.model`.
|
114
|
+
"""
|
115
|
+
|
116
|
+
@functools.wraps(fun)
|
117
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
118
|
+
|
119
|
+
# Get the original picklable function
|
120
|
+
# for the case of bound class method
|
121
|
+
# as well as a function
|
122
|
+
if self.is_class_method(fun, args):
|
123
|
+
instance = args[0]
|
124
|
+
method_name = fun.__name__
|
125
|
+
method = getattr(instance, method_name)
|
126
|
+
args = args[1:]
|
127
|
+
self._spawn_method(instance, method, args, kwargs)
|
128
|
+
else:
|
129
|
+
instance = None
|
130
|
+
# method_name = fun.__name__
|
131
|
+
# module = inspect.getmodule(fun)
|
132
|
+
# method = getattr(module, method_name) if module else fun
|
133
|
+
self._spawn_method(instance, fun, args, kwargs)
|
134
|
+
|
135
|
+
if instance and hasattr(instance, "accelerator"):
|
136
|
+
instance.accelerator.finalize()
|
137
|
+
else:
|
138
|
+
self.finalize()
|
139
|
+
|
140
|
+
# TODO: Return the original returns from fun
|
141
|
+
# Currently it only returns the model and optimizer
|
142
|
+
# similar to the fit method.
|
143
|
+
try:
|
144
|
+
return instance.model, instance.optimizer
|
145
|
+
except Exception:
|
146
|
+
return
|
147
|
+
|
148
|
+
return wrapper
|
149
|
+
|
150
|
+
def worker(self, rank: int, instance: Any, fun: Callable, args: tuple, kwargs: dict) -> None:
|
151
|
+
"""
|
152
|
+
Worker function to be executed in each spawned process.
|
153
|
+
|
154
|
+
This function is called in every subprocess created by torch.multiprocessing (via mp.spawn).
|
155
|
+
It performs the following tasks:
|
156
|
+
1. Sets up the accelerator for the given process rank. This typically involves configuring
|
157
|
+
the GPU or other hardware resources for distributed training.
|
158
|
+
2. If the retrieved method has been decorated (i.e. it has a '__wrapped__' attribute),
|
159
|
+
the original, unwrapped function is invoked with the given arguments. Otherwise,
|
160
|
+
the method is called directly.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
rank (int): The rank (or identifier) of the spawned process.
|
164
|
+
instance (object): The object (Trainer) that contains the method to execute.
|
165
|
+
This object is expected to have an `accelerator` attribute with a `setup_process(rank)` method.
|
166
|
+
This argument is optional, in case it is None, the fun will be called independently.
|
167
|
+
fun (Callable): The function of the method on the instance to be executed.
|
168
|
+
args (tuple): Positional arguments to pass to the target method.
|
169
|
+
kwargs (dict): Keyword arguments to pass to the target method.
|
170
|
+
"""
|
171
|
+
# Setup the accelerator for the given process rank (e.g., configuring GPU)
|
172
|
+
if instance and instance.accelerator:
|
173
|
+
instance.accelerator.setup_process(rank)
|
174
|
+
else:
|
175
|
+
self.setup_process(rank)
|
176
|
+
|
177
|
+
if hasattr(fun, "__wrapped__"):
|
178
|
+
# Explicitly get the original (unbound) method, passing in the instance.
|
179
|
+
# We need to call the original method in case so that MP spawn does not
|
180
|
+
# create multiple processes. (To Avoid infinite loop)
|
181
|
+
fun = fun.__wrapped__ # Unwrap if decorated
|
182
|
+
fun(instance, *args, **kwargs) if instance else fun(*args, **kwargs)
|
183
|
+
else:
|
184
|
+
fun(*args, **kwargs)
|
185
|
+
|
186
|
+
def is_class_method(self, fun: Callable, args: Any) -> bool:
|
187
|
+
"""
|
188
|
+
Determines if `fun` is a class method or a standalone function.
|
189
|
+
|
190
|
+
Frist argument of the args should be:
|
191
|
+
- An object and has __dict__: making it a class
|
192
|
+
- Has a method named fun: making it a class that has this method.
|
193
|
+
|
194
|
+
Args:
|
195
|
+
fun (Callable): The function being checked.
|
196
|
+
args (tuple): The arguments passed to the function.
|
197
|
+
|
198
|
+
Returns:
|
199
|
+
bool: True if `fun` is a class method, False otherwise.
|
200
|
+
"""
|
201
|
+
return (
|
202
|
+
bool(args)
|
203
|
+
and isinstance(args[0], object)
|
204
|
+
and hasattr(args[0], "__dict__")
|
205
|
+
and hasattr(args[0], fun.__name__)
|
206
|
+
)
|
207
|
+
|
208
|
+
def _spawn_method(self, instance: Any, method: Callable, args: Any, kwargs: Any) -> None:
|
209
|
+
"""
|
210
|
+
This method spawns the required numbers of processes.
|
211
|
+
|
212
|
+
- if execution is `default`, it will spawn `nproc` processes across all nodes
|
213
|
+
- if execution is `otherwise`, it will run a single process.
|
214
|
+
|
215
|
+
Args:
|
216
|
+
instance (object): The object (Trainer) that contains the method to execute.
|
217
|
+
This object is expected to have an `accelerator` attribute with a `setup_process(rank)` method.
|
218
|
+
This argument is optional, in case it is None, the fun will be called independently.
|
219
|
+
method (Callable): The function of the method on the instance to be executed.
|
220
|
+
args (tuple): Positional arguments to pass to the target method.
|
221
|
+
kwargs (dict): Keyword arguments to pass to the target method.
|
222
|
+
"""
|
223
|
+
|
224
|
+
if self.execution_type == ExecutionType.DEFAULT and self.world_size > 1:
|
225
|
+
# Spawn multiple processes that will run the worker function.
|
226
|
+
nprocs = self.nprocs
|
227
|
+
if self.execution.num_nodes > 1:
|
228
|
+
nprocs //= self.execution.num_nodes
|
229
|
+
mp.spawn(
|
230
|
+
self.worker,
|
231
|
+
args=(instance, method, args, kwargs),
|
232
|
+
nprocs=int(nprocs),
|
233
|
+
join=True,
|
234
|
+
)
|
235
|
+
else:
|
236
|
+
# In single process mode, call the worker with rank 0.
|
237
|
+
self.worker(0, instance, method, args, kwargs)
|
238
|
+
|
239
|
+
# -----------------------------------------------------------------------------
|
240
|
+
# PROCESS level methods
|
241
|
+
# -----------------------------------------------------------------------------
|
242
|
+
def prepare(self, *args: Any) -> tuple[Any, ...]:
|
243
|
+
"""
|
244
|
+
Prepares models, optimizers, and dataloaders for distributed training.
|
245
|
+
|
246
|
+
This method iterates over the provided objects and:
|
247
|
+
- Moves models to the specified device (e.g., GPU or CPU) and casts them to the
|
248
|
+
desired precision (specified by `self.dtype`). It then wraps models in
|
249
|
+
DistributedDataParallel (DDP) if more than one device is used.
|
250
|
+
- Passes through optimizers unchanged.
|
251
|
+
- For dataloaders, it adjusts them to use a distributed sampler (if applicable)
|
252
|
+
by calling a helper method. Note that only the sampler is prepared; moving the
|
253
|
+
actual batch data to the device is handled separately during training.
|
254
|
+
Please use the `prepare_batch` method to move the batch to correct device/dtype.
|
255
|
+
|
256
|
+
Args:
|
257
|
+
*args (Any): A variable number of objects to be prepared. These can include:
|
258
|
+
- PyTorch models (`nn.Module`)
|
259
|
+
- Optimizers (`optim.Optimizer`)
|
260
|
+
- DataLoaders (or a dictionary-like `DictDataLoader` of dataloaders)
|
261
|
+
|
262
|
+
Returns:
|
263
|
+
tuple[Any, ...]: A tuple containing the prepared objects, where each object has been
|
264
|
+
modified as needed to support distributed training.
|
265
|
+
"""
|
266
|
+
prepared: list = []
|
267
|
+
for obj in args:
|
268
|
+
if obj is None:
|
269
|
+
prepared.append(None)
|
270
|
+
elif isinstance(obj, nn.Module):
|
271
|
+
prepared.append(self._prepare_model(obj))
|
272
|
+
elif isinstance(obj, optim.Optimizer):
|
273
|
+
prepared.append(self._prepare_optimizer(obj))
|
274
|
+
elif isinstance(obj, (DataLoader, DictDataLoader)):
|
275
|
+
prepared.append(self._prepare_data(obj))
|
276
|
+
else:
|
277
|
+
prepared.append(obj)
|
278
|
+
return tuple(prepared)
|
279
|
+
|
280
|
+
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
281
|
+
"""
|
282
|
+
Moves the model to the desired device and casts it to the specified dtype.
|
283
|
+
|
284
|
+
In a distributed setting, if more than one device is used (i.e., self.world_size > 1),
|
285
|
+
the model is wrapped in DistributedDataParallel (DDP) to handle gradient synchronization
|
286
|
+
across devices.
|
287
|
+
|
288
|
+
Args:
|
289
|
+
model (nn.Module): The PyTorch model to prepare.
|
290
|
+
|
291
|
+
Returns:
|
292
|
+
nn.Module: The model moved to the correct device (and wrapped in DDP if applicable).
|
293
|
+
"""
|
294
|
+
model = model.to(device=self.execution.device, dtype=self.execution.dtype)
|
295
|
+
|
296
|
+
# If using distributed training with more than one device:
|
297
|
+
if self.world_size > 1:
|
298
|
+
if self.execution.device.startswith("cuda"):
|
299
|
+
# For GPU-based training: wrap the model with DDP and specify the local GPU.
|
300
|
+
model = DDP(model, device_ids=[self.local_rank])
|
301
|
+
else:
|
302
|
+
# For CPU-based or other environments:
|
303
|
+
if not self.local_rank:
|
304
|
+
model = DDP(model)
|
305
|
+
|
306
|
+
return model
|
307
|
+
|
308
|
+
def _prepare_optimizer(self, optimizer: optim.Optimizer) -> optim.Optimizer:
|
309
|
+
"""
|
310
|
+
Passes through the optimizer without modification.
|
311
|
+
|
312
|
+
Args:
|
313
|
+
optimizer (optim.Optimizer): The optimizer to prepare.
|
314
|
+
|
315
|
+
Returns:
|
316
|
+
optim.Optimizer: The unmodified optimizer.
|
317
|
+
"""
|
318
|
+
# Optimizers are not device-specific in this context, so no action is needed.
|
319
|
+
return optimizer
|
320
|
+
|
321
|
+
def _prepare_data(self, dataloader: DataLoader | DictDataLoader) -> DataLoader | DictDataLoader:
|
322
|
+
"""
|
323
|
+
Adjusts DataLoader(s) for distributed training.
|
324
|
+
|
325
|
+
Args:
|
326
|
+
dataloader (Union[DataLoader, DictDataLoader]): The dataloader or dictionary of dataloaders to prepare.
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
Union[DataLoader, DictDataLoader]: The prepared dataloader(s) with the correct distributed
|
330
|
+
sampling setup.
|
331
|
+
"""
|
332
|
+
if isinstance(dataloader, DictDataLoader):
|
333
|
+
# If the input is a DictDataLoader, prepare each contained DataLoader.
|
334
|
+
prepared_dataloaders = {
|
335
|
+
key: self._prepare_dataloader(dl) for key, dl in dataloader.dataloaders.items()
|
336
|
+
}
|
337
|
+
return DictDataLoader(prepared_dataloaders)
|
338
|
+
else:
|
339
|
+
# For a single DataLoader, prepare it directly.
|
340
|
+
return self._prepare_dataloader(dataloader)
|
341
|
+
|
342
|
+
def _prepare_dataloader(self, dataloader: DataLoader) -> DataLoader:
|
343
|
+
"""
|
344
|
+
Prepares a single DataLoader for distributed training.
|
345
|
+
|
346
|
+
When training in a distributed setting (i.e., when `self.world_size > 1`), data must be
|
347
|
+
divided among multiple processes. This is achieved by creating a
|
348
|
+
DistributedSampler that splits the dataset into distinct subsets for each process.
|
349
|
+
|
350
|
+
This method does the following:
|
351
|
+
- If distributed training is enabled:
|
352
|
+
- Checks if the dataset is not an instance of `InfiniteTensorDataset`.
|
353
|
+
- If so, creates a `DistributedSampler` for the dataset using the total number
|
354
|
+
of replicas (`self.world_size`) and the current process's rank (`self.local_rank`).
|
355
|
+
- Otherwise (i.e., for infinite datasets), no sampler is set (sampler remains `None`).
|
356
|
+
- Returns a new DataLoader configured with:
|
357
|
+
- The same dataset and batch size as the original.
|
358
|
+
- The distributed sampler (if applicable).
|
359
|
+
- The number of workers and pin_memory settings retrieved from the original DataLoader.
|
360
|
+
- If not in a distributed setting (i.e., `self.world_size <= 1`), returns the original DataLoader unmodified.
|
361
|
+
|
362
|
+
Args:
|
363
|
+
dataloader (DataLoader): The original DataLoader instance that loads the dataset.
|
364
|
+
|
365
|
+
Returns:
|
366
|
+
DataLoader: A new DataLoader prepared for distributed training if in a multi-process environment;
|
367
|
+
otherwise, the original DataLoader is returned.
|
368
|
+
"""
|
369
|
+
if self.world_size > 1:
|
370
|
+
if not isinstance(dataloader.dataset, InfiniteTensorDataset):
|
371
|
+
# If the dataset is not an infinite dataset, create a DistributedSampler.
|
372
|
+
sampler = DistributedSampler(
|
373
|
+
dataloader.dataset, num_replicas=self.world_size, rank=self.local_rank
|
374
|
+
)
|
375
|
+
else:
|
376
|
+
# For infinite datasets, we do not use a sampler since the dataset
|
377
|
+
# is designed to loop indefinitely.
|
378
|
+
sampler = None
|
379
|
+
|
380
|
+
return DataLoader(
|
381
|
+
dataloader.dataset, # Use the same dataset as the original.
|
382
|
+
batch_size=dataloader.batch_size, # Maintain the same batch size.
|
383
|
+
sampler=sampler, # Use the created DistributedSampler (or None).
|
384
|
+
num_workers=getattr(dataloader, "num_workers", 0),
|
385
|
+
pin_memory=getattr(dataloader, "pin_memory", False),
|
386
|
+
)
|
387
|
+
return dataloader
|
388
|
+
|
389
|
+
def prepare_batch(self, batch: dict | list | tuple | torch.Tensor | None) -> Any:
|
390
|
+
"""
|
391
|
+
Moves a batch of data to the target device and casts it to the desired data dtype.
|
392
|
+
|
393
|
+
This method is typically called within the optimization step of your training loop.
|
394
|
+
It supports various batch formats:
|
395
|
+
- If the batch is a dictionary, each value is moved individually.
|
396
|
+
- If the batch is a tuple or list, each element is processed and returned as a tuple.
|
397
|
+
- Otherwise, the batch is processed directly.
|
398
|
+
|
399
|
+
Args:
|
400
|
+
batch (Any): The batch of data to move to the device. This can be a dict, tuple, list,
|
401
|
+
or any type compatible with `data_to_device`.
|
402
|
+
|
403
|
+
Returns:
|
404
|
+
Any: The batch with all elements moved to `self.device` and cast to `self.data_dtype`.
|
405
|
+
"""
|
406
|
+
if batch is None:
|
407
|
+
return None
|
408
|
+
|
409
|
+
if isinstance(batch, dict):
|
410
|
+
return {
|
411
|
+
key: data_to_device(
|
412
|
+
value, device=self.execution.device, dtype=self.execution.data_dtype
|
413
|
+
)
|
414
|
+
for key, value in batch.items()
|
415
|
+
}
|
416
|
+
elif isinstance(batch, (tuple, list)):
|
417
|
+
return tuple(
|
418
|
+
data_to_device(x, device=self.execution.device, dtype=self.execution.data_dtype)
|
419
|
+
for x in batch
|
420
|
+
)
|
421
|
+
elif isinstance(batch, torch.Tensor):
|
422
|
+
return data_to_device(
|
423
|
+
batch, device=self.execution.device, dtype=self.execution.data_dtype
|
424
|
+
)
|
425
|
+
return
|
426
|
+
|
427
|
+
def all_reduce_dict(
|
428
|
+
self, d: dict[str, torch.Tensor], op: str = "mean"
|
429
|
+
) -> dict[str, torch.Tensor]:
|
430
|
+
"""
|
431
|
+
Performs an all-reduce operation on a dictionary of tensors, averaging values across all processes.
|
432
|
+
|
433
|
+
Args:
|
434
|
+
d (dict[str, torch.Tensor]): A dictionary where values are tensors to be reduced across processes.
|
435
|
+
op (str): Operation method to all_reduce with. Available options include `sum`, `avg`, and `max`.
|
436
|
+
Defaults to `avg`
|
437
|
+
|
438
|
+
Returns:
|
439
|
+
dict[str, torch.Tensor]: A dictionary with the reduced tensors, averaged over the world size.
|
440
|
+
"""
|
441
|
+
if dist.is_initialized():
|
442
|
+
world_size = dist.get_world_size()
|
443
|
+
reduced: dict[str, torch.Tensor] = {}
|
444
|
+
for key, tensor in d.items():
|
445
|
+
if not isinstance(tensor, torch.Tensor):
|
446
|
+
tensor = torch.tensor(
|
447
|
+
tensor, device=self.execution.device, dtype=self.execution.data_dtype
|
448
|
+
)
|
449
|
+
tensor = tensor.detach().clone()
|
450
|
+
if op == "max":
|
451
|
+
dist.all_reduce(tensor, op=dist.ReduceOp.MAX)
|
452
|
+
elif op == "sum":
|
453
|
+
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
454
|
+
else:
|
455
|
+
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
456
|
+
tensor /= world_size
|
457
|
+
reduced[key] = tensor
|
458
|
+
return reduced
|
459
|
+
else:
|
460
|
+
return d
|
461
|
+
|
462
|
+
def broadcast(self, obj: Any, src: int) -> Any:
|
463
|
+
"""
|
464
|
+
Broadcasts an object from the source process to all processes.
|
465
|
+
|
466
|
+
On non-source processes, this value is ignored.
|
467
|
+
|
468
|
+
Args:
|
469
|
+
obj (Any): The object to broadcast on the source process.
|
470
|
+
src (int): The source process rank.
|
471
|
+
|
472
|
+
Returns:
|
473
|
+
Any : The broadcasted object from the source process.
|
474
|
+
"""
|
475
|
+
if dist.is_initialized():
|
476
|
+
obj_list = [obj] if self.rank == src else [None]
|
477
|
+
dist.broadcast_object_list(obj_list, src=src)
|
478
|
+
return obj_list[0]
|
479
|
+
else:
|
480
|
+
return obj
|
@@ -41,10 +41,10 @@ class ConfigManager:
|
|
41
41
|
handling hyperparameters, deriving additional parameters,
|
42
42
|
and logging warnings.
|
43
43
|
"""
|
44
|
+
self._log_warnings()
|
44
45
|
self._initialize_folder()
|
45
46
|
self._handle_hyperparams()
|
46
47
|
self._setup_additional_configuration()
|
47
|
-
self._log_warnings()
|
48
48
|
|
49
49
|
def _initialize_folder(self) -> None:
|
50
50
|
"""
|
@@ -78,19 +78,19 @@ class ConfigManager:
|
|
78
78
|
log_folder = root_folder_path / self.config._subfolders[-1]
|
79
79
|
else:
|
80
80
|
if self.config._subfolders:
|
81
|
-
|
82
|
-
|
83
|
-
else:
|
84
|
-
log_folder = Path(self.config.log_folder)
|
81
|
+
# self.config.log_folder is an old subfolder.
|
82
|
+
log_folder = Path(self.config.log_folder)
|
85
83
|
else:
|
86
84
|
if self.config.log_folder == Path("./"):
|
85
|
+
# A subfolder must be created (no specific folder given to config).
|
87
86
|
self._add_subfolder()
|
88
87
|
log_folder = root_folder_path / self.config._subfolders[-1]
|
89
88
|
else:
|
89
|
+
# The folder is one and fully specified by the user.
|
90
90
|
log_folder = Path(self.config.log_folder)
|
91
91
|
|
92
92
|
log_folder.mkdir(parents=True, exist_ok=True)
|
93
|
-
return
|
93
|
+
return log_folder
|
94
94
|
|
95
95
|
def _add_subfolder(self) -> None:
|
96
96
|
"""
|
@@ -147,7 +147,7 @@ class ConfigManager:
|
|
147
147
|
Sets the stopping criterion if it is not already defined.
|
148
148
|
"""
|
149
149
|
if self.config.trainstop_criterion is None:
|
150
|
-
|
150
|
+
return
|
151
151
|
|
152
152
|
def _log_warnings(self) -> None:
|
153
153
|
"""
|