evograd-diff 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evograd/__init__.py +67 -0
- evograd/algorithms/__init__.py +138 -0
- evograd/algorithms/cmaes.py +1365 -0
- evograd/algorithms/de.py +895 -0
- evograd/algorithms/ga.py +532 -0
- evograd/algorithms/pso.py +648 -0
- evograd/algorithms/shade.py +1165 -0
- evograd/benchmarks/functions/__init__.py +229 -0
- evograd/benchmarks/functions/base.py +217 -0
- evograd/benchmarks/functions/cec2017/__init__.py +250 -0
- evograd/benchmarks/functions/cec2017/basic.py +413 -0
- evograd/benchmarks/functions/cec2017/composition.py +580 -0
- evograd/benchmarks/functions/cec2017/data.pkl +0 -0
- evograd/benchmarks/functions/cec2017/data.py +350 -0
- evograd/benchmarks/functions/cec2017/hybrid.py +406 -0
- evograd/benchmarks/functions/cec2017/simple.py +326 -0
- evograd/benchmarks/functions/classical.py +649 -0
- evograd/benchmarks/functions/smoothed_funnel.py +476 -0
- evograd/benchmarks/functions/transforms.py +463 -0
- evograd/benchmarks/run_benchmark_functions.py +1208 -0
- evograd/core/__init__.py +73 -0
- evograd/core/algorithm.py +778 -0
- evograd/core/maximize.py +269 -0
- evograd/core/minimize.py +740 -0
- evograd/core/problem.py +444 -0
- evograd/core/result.py +571 -0
- evograd/core/termination.py +602 -0
- evograd/operators/__init__.py +178 -0
- evograd/operators/crossover.py +1117 -0
- evograd/operators/mutation.py +1098 -0
- evograd/operators/relaxations.py +175 -0
- evograd/operators/repair.py +601 -0
- evograd/operators/sampling.py +577 -0
- evograd/operators/selection.py +981 -0
- evograd/operators/survival.py +1000 -0
- evograd/tests/__init__.py +11 -0
- evograd/tests/run_all.py +78 -0
- evograd/tests/test_core.py +528 -0
- evograd/tests/test_ga.py +572 -0
- evograd/tests/test_operators.py +662 -0
- evograd/tests/test_per_individual.py +326 -0
- evograd/tests/test_utils.py +328 -0
- evograd/utils/__init__.py +97 -0
- evograd/utils/callbacks.py +926 -0
- evograd/utils/device.py +502 -0
- evograd/utils/duplicates.py +421 -0
- evograd_diff-0.1.0.dist-info/METADATA +439 -0
- evograd_diff-0.1.0.dist-info/RECORD +50 -0
- evograd_diff-0.1.0.dist-info/WHEEL +4 -0
- evograd_diff-0.1.0.dist-info/licenses/LICENSE +201 -0
evograd/utils/device.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Device utilities for EvoGrad.
|
|
3
|
+
|
|
4
|
+
This module provides functions for:
|
|
5
|
+
- Automatic device detection (CUDA > MPS > CPU)
|
|
6
|
+
- Tensor conversion with device/dtype handling
|
|
7
|
+
- Device-aware operations
|
|
8
|
+
|
|
9
|
+
Design Goals
|
|
10
|
+
------------
|
|
11
|
+
- Seamless GPU acceleration when available
|
|
12
|
+
- Consistent API across different hardware backends
|
|
13
|
+
- Safe tensor conversion from various input types
|
|
14
|
+
|
|
15
|
+
Example
|
|
16
|
+
-------
|
|
17
|
+
>>> from evograd.utils.device import get_device, ensure_tensor
|
|
18
|
+
>>> device = get_device() # Automatically selects best device
|
|
19
|
+
>>> x = ensure_tensor([1.0, 2.0, 3.0], device=device)
|
|
20
|
+
>>> x.device
|
|
21
|
+
device(type='cuda', index=0) # or 'mps' or 'cpu'
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
from typing import Any, List, Optional, Sequence, Union
|
|
27
|
+
|
|
28
|
+
import numpy as np
|
|
29
|
+
import torch
|
|
30
|
+
|
|
31
|
+
# Type alias for values that can be converted to tensors
|
|
32
|
+
TensorLike = Union[
|
|
33
|
+
torch.Tensor,
|
|
34
|
+
np.ndarray,
|
|
35
|
+
float,
|
|
36
|
+
int,
|
|
37
|
+
List[float],
|
|
38
|
+
List[int],
|
|
39
|
+
Sequence[float],
|
|
40
|
+
Sequence[int],
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_device(
|
|
45
|
+
preference: Optional[str] = None,
|
|
46
|
+
fallback: str = "cpu",
|
|
47
|
+
) -> torch.device:
|
|
48
|
+
"""Get the best available device for computation.
|
|
49
|
+
|
|
50
|
+
Priority order (when preference is None):
|
|
51
|
+
1. CUDA (NVIDIA GPU)
|
|
52
|
+
2. MPS (Apple Silicon)
|
|
53
|
+
3. CPU
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
preference:
|
|
58
|
+
Explicit device preference. If specified and available, this device
|
|
59
|
+
is returned. Options: "cuda", "mps", "cpu", or a specific device
|
|
60
|
+
string like "cuda:0".
|
|
61
|
+
fallback:
|
|
62
|
+
Device to use if the preferred device is not available.
|
|
63
|
+
Default is "cpu".
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
torch.device
|
|
68
|
+
The selected device.
|
|
69
|
+
|
|
70
|
+
Examples
|
|
71
|
+
--------
|
|
72
|
+
>>> device = get_device() # Auto-detect best device
|
|
73
|
+
>>> device = get_device("cuda") # Prefer CUDA, fall back to CPU
|
|
74
|
+
>>> device = get_device("cuda:1") # Specific GPU
|
|
75
|
+
"""
|
|
76
|
+
if preference is not None:
|
|
77
|
+
# User specified a preference
|
|
78
|
+
pref_lower = preference.lower()
|
|
79
|
+
|
|
80
|
+
if pref_lower.startswith("cuda"):
|
|
81
|
+
if torch.cuda.is_available():
|
|
82
|
+
return torch.device(preference)
|
|
83
|
+
else:
|
|
84
|
+
return torch.device(fallback)
|
|
85
|
+
|
|
86
|
+
elif pref_lower == "mps":
|
|
87
|
+
if torch.backends.mps.is_available():
|
|
88
|
+
return torch.device("mps")
|
|
89
|
+
else:
|
|
90
|
+
return torch.device(fallback)
|
|
91
|
+
|
|
92
|
+
elif pref_lower == "cpu":
|
|
93
|
+
return torch.device("cpu")
|
|
94
|
+
|
|
95
|
+
else:
|
|
96
|
+
# Try to use the preference as-is
|
|
97
|
+
try:
|
|
98
|
+
return torch.device(preference)
|
|
99
|
+
except RuntimeError:
|
|
100
|
+
return torch.device(fallback)
|
|
101
|
+
|
|
102
|
+
# Auto-detect best available device
|
|
103
|
+
if torch.cuda.is_available():
|
|
104
|
+
return torch.device("cuda")
|
|
105
|
+
elif torch.backends.mps.is_available():
|
|
106
|
+
return torch.device("mps")
|
|
107
|
+
else:
|
|
108
|
+
return torch.device("cpu")
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def get_default_dtype() -> torch.dtype:
|
|
112
|
+
"""Get the default floating-point dtype for EvoGrad.
|
|
113
|
+
|
|
114
|
+
Returns
|
|
115
|
+
-------
|
|
116
|
+
torch.dtype
|
|
117
|
+
Default is torch.float32, which offers a good balance between
|
|
118
|
+
precision and performance for evolutionary computation.
|
|
119
|
+
"""
|
|
120
|
+
return torch.float32
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def ensure_tensor(
|
|
124
|
+
value: TensorLike,
|
|
125
|
+
dim: Optional[int] = None,
|
|
126
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
127
|
+
dtype: Optional[torch.dtype] = None,
|
|
128
|
+
copy: bool = False,
|
|
129
|
+
) -> torch.Tensor:
|
|
130
|
+
"""Convert a value to a tensor with specified properties.
|
|
131
|
+
|
|
132
|
+
This function handles various input types and ensures the result
|
|
133
|
+
has the correct device, dtype, and optionally shape.
|
|
134
|
+
|
|
135
|
+
Parameters
|
|
136
|
+
----------
|
|
137
|
+
value:
|
|
138
|
+
Input value to convert. Can be:
|
|
139
|
+
- A scalar (int, float)
|
|
140
|
+
- A list or sequence of numbers
|
|
141
|
+
- A numpy array
|
|
142
|
+
- An existing torch tensor
|
|
143
|
+
dim:
|
|
144
|
+
If provided, the tensor is broadcast/repeated to have this length
|
|
145
|
+
along the first (and only) dimension. Only valid for 1D outputs.
|
|
146
|
+
If value is a scalar, it is repeated `dim` times.
|
|
147
|
+
If value is already a 1D tensor of length `dim`, it is unchanged.
|
|
148
|
+
If value is a 1D tensor of length 1, it is repeated `dim` times.
|
|
149
|
+
device:
|
|
150
|
+
Target device. If None, uses the input tensor's device (if it's
|
|
151
|
+
already a tensor) or the default device.
|
|
152
|
+
dtype:
|
|
153
|
+
Target dtype. If None, uses float32 for floating-point values
|
|
154
|
+
or the input tensor's dtype.
|
|
155
|
+
copy:
|
|
156
|
+
If True, always create a new tensor even if the input already
|
|
157
|
+
satisfies all requirements.
|
|
158
|
+
|
|
159
|
+
Returns
|
|
160
|
+
-------
|
|
161
|
+
torch.Tensor
|
|
162
|
+
The converted tensor.
|
|
163
|
+
|
|
164
|
+
Raises
|
|
165
|
+
------
|
|
166
|
+
ValueError
|
|
167
|
+
If `dim` is specified but the input cannot be broadcast to that shape.
|
|
168
|
+
|
|
169
|
+
Examples
|
|
170
|
+
--------
|
|
171
|
+
>>> # Scalar to tensor
|
|
172
|
+
>>> ensure_tensor(3.14)
|
|
173
|
+
tensor(3.1400)
|
|
174
|
+
|
|
175
|
+
>>> # Scalar broadcast to dimension
|
|
176
|
+
>>> ensure_tensor(-100.0, dim=10)
|
|
177
|
+
tensor([-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.])
|
|
178
|
+
|
|
179
|
+
>>> # List to tensor with device
|
|
180
|
+
>>> ensure_tensor([1, 2, 3], device="cuda")
|
|
181
|
+
tensor([1., 2., 3.], device='cuda:0')
|
|
182
|
+
|
|
183
|
+
>>> # Numpy array conversion
|
|
184
|
+
>>> import numpy as np
|
|
185
|
+
>>> ensure_tensor(np.array([1.0, 2.0]))
|
|
186
|
+
tensor([1., 2.])
|
|
187
|
+
"""
|
|
188
|
+
# Determine target dtype
|
|
189
|
+
if dtype is None:
|
|
190
|
+
if isinstance(value, torch.Tensor):
|
|
191
|
+
dtype = value.dtype
|
|
192
|
+
else:
|
|
193
|
+
dtype = get_default_dtype()
|
|
194
|
+
|
|
195
|
+
# Determine target device
|
|
196
|
+
if device is None:
|
|
197
|
+
if isinstance(value, torch.Tensor):
|
|
198
|
+
device = value.device
|
|
199
|
+
else:
|
|
200
|
+
device = get_device()
|
|
201
|
+
elif isinstance(device, str):
|
|
202
|
+
device = torch.device(device)
|
|
203
|
+
|
|
204
|
+
# Convert to tensor
|
|
205
|
+
if isinstance(value, torch.Tensor):
|
|
206
|
+
tensor = value
|
|
207
|
+
needs_conversion = (
|
|
208
|
+
copy
|
|
209
|
+
or tensor.device != device
|
|
210
|
+
or tensor.dtype != dtype
|
|
211
|
+
)
|
|
212
|
+
if needs_conversion:
|
|
213
|
+
tensor = tensor.to(device=device, dtype=dtype)
|
|
214
|
+
if copy and tensor.data_ptr() == value.data_ptr():
|
|
215
|
+
tensor = tensor.clone()
|
|
216
|
+
elif isinstance(value, np.ndarray):
|
|
217
|
+
tensor = torch.from_numpy(value).to(device=device, dtype=dtype)
|
|
218
|
+
elif isinstance(value, (int, float)):
|
|
219
|
+
tensor = torch.tensor(value, device=device, dtype=dtype)
|
|
220
|
+
elif isinstance(value, (list, tuple)):
|
|
221
|
+
tensor = torch.tensor(value, device=device, dtype=dtype)
|
|
222
|
+
else:
|
|
223
|
+
# Try generic conversion
|
|
224
|
+
try:
|
|
225
|
+
tensor = torch.as_tensor(value, device=device, dtype=dtype)
|
|
226
|
+
except (TypeError, ValueError) as e:
|
|
227
|
+
raise TypeError(
|
|
228
|
+
f"Cannot convert {type(value).__name__} to tensor: {e}"
|
|
229
|
+
) from e
|
|
230
|
+
|
|
231
|
+
# Handle dimension broadcasting
|
|
232
|
+
if dim is not None:
|
|
233
|
+
dim = int(dim)
|
|
234
|
+
if dim < 1:
|
|
235
|
+
raise ValueError(f"dim must be >= 1, got {dim}")
|
|
236
|
+
|
|
237
|
+
if tensor.ndim == 0:
|
|
238
|
+
# Scalar: repeat to create 1D tensor
|
|
239
|
+
tensor = tensor.expand(dim).clone()
|
|
240
|
+
elif tensor.ndim == 1:
|
|
241
|
+
if tensor.shape[0] == dim:
|
|
242
|
+
# Already correct size
|
|
243
|
+
pass
|
|
244
|
+
elif tensor.shape[0] == 1:
|
|
245
|
+
# Single element: broadcast
|
|
246
|
+
tensor = tensor.expand(dim).clone()
|
|
247
|
+
else:
|
|
248
|
+
raise ValueError(
|
|
249
|
+
f"Cannot broadcast tensor of shape {tuple(tensor.shape)} "
|
|
250
|
+
f"to dim={dim}. Expected shape ({dim},) or (1,)."
|
|
251
|
+
)
|
|
252
|
+
else:
|
|
253
|
+
raise ValueError(
|
|
254
|
+
f"Cannot broadcast {tensor.ndim}D tensor to 1D. "
|
|
255
|
+
f"Got shape {tuple(tensor.shape)}."
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
return tensor
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def ensure_bounds(
|
|
262
|
+
lower: TensorLike,
|
|
263
|
+
upper: TensorLike,
|
|
264
|
+
dim: int,
|
|
265
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
266
|
+
dtype: Optional[torch.dtype] = None,
|
|
267
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
268
|
+
"""Convert and validate lower/upper bounds.
|
|
269
|
+
|
|
270
|
+
Parameters
|
|
271
|
+
----------
|
|
272
|
+
lower:
|
|
273
|
+
Lower bounds. Scalar or 1D tensor of length `dim`.
|
|
274
|
+
upper:
|
|
275
|
+
Upper bounds. Scalar or 1D tensor of length `dim`.
|
|
276
|
+
dim:
|
|
277
|
+
Number of dimensions (variables).
|
|
278
|
+
device:
|
|
279
|
+
Target device.
|
|
280
|
+
dtype:
|
|
281
|
+
Target dtype.
|
|
282
|
+
|
|
283
|
+
Returns
|
|
284
|
+
-------
|
|
285
|
+
tuple[torch.Tensor, torch.Tensor]
|
|
286
|
+
Tuple of (lower, upper) tensors, each of shape (dim,).
|
|
287
|
+
|
|
288
|
+
Raises
|
|
289
|
+
------
|
|
290
|
+
ValueError
|
|
291
|
+
If bounds have incompatible shapes or if lower > upper for any dimension.
|
|
292
|
+
|
|
293
|
+
Examples
|
|
294
|
+
--------
|
|
295
|
+
>>> lb, ub = ensure_bounds(-100.0, 100.0, dim=10)
|
|
296
|
+
>>> lb.shape, ub.shape
|
|
297
|
+
(torch.Size([10]), torch.Size([10]))
|
|
298
|
+
|
|
299
|
+
>>> lb, ub = ensure_bounds([-1, -2, -3], [1, 2, 3], dim=3)
|
|
300
|
+
"""
|
|
301
|
+
lb = ensure_tensor(lower, dim=dim, device=device, dtype=dtype)
|
|
302
|
+
ub = ensure_tensor(upper, dim=dim, device=device, dtype=dtype)
|
|
303
|
+
|
|
304
|
+
# Validate bounds
|
|
305
|
+
if torch.any(lb > ub):
|
|
306
|
+
violations = (lb > ub).nonzero(as_tuple=False).view(-1)
|
|
307
|
+
raise ValueError(
|
|
308
|
+
f"Lower bounds must be <= upper bounds. "
|
|
309
|
+
f"Violations at indices: {violations.tolist()}"
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
return lb, ub
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def to_device(
|
|
316
|
+
*tensors: torch.Tensor,
|
|
317
|
+
device: Union[str, torch.device],
|
|
318
|
+
) -> tuple[torch.Tensor, ...]:
|
|
319
|
+
"""Move multiple tensors to a device.
|
|
320
|
+
|
|
321
|
+
Parameters
|
|
322
|
+
----------
|
|
323
|
+
*tensors:
|
|
324
|
+
Tensors to move.
|
|
325
|
+
device:
|
|
326
|
+
Target device.
|
|
327
|
+
|
|
328
|
+
Returns
|
|
329
|
+
-------
|
|
330
|
+
tuple[torch.Tensor, ...]
|
|
331
|
+
Tuple of tensors on the target device.
|
|
332
|
+
|
|
333
|
+
Examples
|
|
334
|
+
--------
|
|
335
|
+
>>> x, y, z = to_device(x, y, z, device="cuda")
|
|
336
|
+
"""
|
|
337
|
+
if isinstance(device, str):
|
|
338
|
+
device = torch.device(device)
|
|
339
|
+
|
|
340
|
+
return tuple(t.to(device) for t in tensors)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def sync_device() -> None:
|
|
344
|
+
"""Synchronize the current device (useful for timing).
|
|
345
|
+
|
|
346
|
+
For CUDA devices, this calls torch.cuda.synchronize().
|
|
347
|
+
For other devices, this is a no-op.
|
|
348
|
+
"""
|
|
349
|
+
if torch.cuda.is_available():
|
|
350
|
+
torch.cuda.synchronize()
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def get_memory_info(device: Optional[Union[str, torch.device]] = None) -> dict[str, int]:
|
|
354
|
+
"""Get memory information for a device.
|
|
355
|
+
|
|
356
|
+
Parameters
|
|
357
|
+
----------
|
|
358
|
+
device:
|
|
359
|
+
Device to query. If None, uses the default device.
|
|
360
|
+
|
|
361
|
+
Returns
|
|
362
|
+
-------
|
|
363
|
+
dict[str, int]
|
|
364
|
+
Dictionary with keys:
|
|
365
|
+
- "allocated": Currently allocated memory (bytes)
|
|
366
|
+
- "reserved": Currently reserved memory (bytes)
|
|
367
|
+
- "max_allocated": Peak allocated memory (bytes)
|
|
368
|
+
For CPU and MPS, returns empty dict or partial info.
|
|
369
|
+
|
|
370
|
+
Examples
|
|
371
|
+
--------
|
|
372
|
+
>>> info = get_memory_info("cuda")
|
|
373
|
+
>>> print(f"Allocated: {info['allocated'] / 1e9:.2f} GB")
|
|
374
|
+
"""
|
|
375
|
+
if device is None:
|
|
376
|
+
device = get_device()
|
|
377
|
+
elif isinstance(device, str):
|
|
378
|
+
device = torch.device(device)
|
|
379
|
+
|
|
380
|
+
if device.type == "cuda":
|
|
381
|
+
return {
|
|
382
|
+
"allocated": torch.cuda.memory_allocated(device),
|
|
383
|
+
"reserved": torch.cuda.memory_reserved(device),
|
|
384
|
+
"max_allocated": torch.cuda.max_memory_allocated(device),
|
|
385
|
+
}
|
|
386
|
+
elif device.type == "mps":
|
|
387
|
+
# MPS has limited memory introspection
|
|
388
|
+
try:
|
|
389
|
+
return {
|
|
390
|
+
"allocated": torch.mps.current_allocated_memory(),
|
|
391
|
+
}
|
|
392
|
+
except AttributeError:
|
|
393
|
+
return {}
|
|
394
|
+
else:
|
|
395
|
+
return {}
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def set_seed(
|
|
399
|
+
seed: int,
|
|
400
|
+
deterministic: bool = False,
|
|
401
|
+
) -> None:
|
|
402
|
+
"""Set random seeds for reproducibility.
|
|
403
|
+
|
|
404
|
+
Parameters
|
|
405
|
+
----------
|
|
406
|
+
seed:
|
|
407
|
+
Random seed value.
|
|
408
|
+
deterministic:
|
|
409
|
+
If True, enables deterministic algorithms in PyTorch.
|
|
410
|
+
This may reduce performance but ensures reproducibility.
|
|
411
|
+
|
|
412
|
+
Notes
|
|
413
|
+
-----
|
|
414
|
+
This sets seeds for:
|
|
415
|
+
- Python's random module
|
|
416
|
+
- NumPy
|
|
417
|
+
- PyTorch (CPU and all CUDA devices)
|
|
418
|
+
"""
|
|
419
|
+
import random
|
|
420
|
+
|
|
421
|
+
random.seed(seed)
|
|
422
|
+
np.random.seed(seed)
|
|
423
|
+
torch.manual_seed(seed)
|
|
424
|
+
|
|
425
|
+
if torch.cuda.is_available():
|
|
426
|
+
torch.cuda.manual_seed_all(seed)
|
|
427
|
+
|
|
428
|
+
if deterministic:
|
|
429
|
+
torch.backends.cudnn.deterministic = True
|
|
430
|
+
torch.backends.cudnn.benchmark = False
|
|
431
|
+
# PyTorch 1.8+
|
|
432
|
+
if hasattr(torch, "use_deterministic_algorithms"):
|
|
433
|
+
torch.use_deterministic_algorithms(True)
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
class DeviceContext:
|
|
437
|
+
"""Context manager for temporary device switching.
|
|
438
|
+
|
|
439
|
+
This is useful when you need to perform operations on a specific
|
|
440
|
+
device and want to ensure cleanup.
|
|
441
|
+
|
|
442
|
+
Parameters
|
|
443
|
+
----------
|
|
444
|
+
device:
|
|
445
|
+
Device to use within the context.
|
|
446
|
+
|
|
447
|
+
Examples
|
|
448
|
+
--------
|
|
449
|
+
>>> with DeviceContext("cuda"):
|
|
450
|
+
... x = torch.randn(100, 100) # Created on CUDA
|
|
451
|
+
... result = x @ x.T
|
|
452
|
+
"""
|
|
453
|
+
|
|
454
|
+
def __init__(self, device: Union[str, torch.device]) -> None:
|
|
455
|
+
if isinstance(device, str):
|
|
456
|
+
device = torch.device(device)
|
|
457
|
+
self.device = device
|
|
458
|
+
self._previous_device: Optional[int] = None
|
|
459
|
+
|
|
460
|
+
def __enter__(self) -> torch.device:
|
|
461
|
+
# Store current default device (if CUDA)
|
|
462
|
+
if torch.cuda.is_available() and self.device.type == "cuda":
|
|
463
|
+
self._previous_device = torch.cuda.current_device()
|
|
464
|
+
if self.device.index is not None:
|
|
465
|
+
torch.cuda.set_device(self.device.index)
|
|
466
|
+
return self.device
|
|
467
|
+
|
|
468
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
|
|
469
|
+
# Restore previous device
|
|
470
|
+
if self._previous_device is not None and torch.cuda.is_available():
|
|
471
|
+
torch.cuda.set_device(self._previous_device)
|
|
472
|
+
return False # Don't suppress exceptions
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
# ---------------------------------------------------------------------------
|
|
476
|
+
# Module-level convenience
|
|
477
|
+
# ---------------------------------------------------------------------------
|
|
478
|
+
|
|
479
|
+
# Default device (lazily initialized)
|
|
480
|
+
_default_device: Optional[torch.device] = None
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def default_device() -> torch.device:
|
|
484
|
+
"""Get the module's default device (cached).
|
|
485
|
+
|
|
486
|
+
This is initialized once on first call and reused. Use `get_device()`
|
|
487
|
+
if you need fresh detection or want to specify preferences.
|
|
488
|
+
"""
|
|
489
|
+
global _default_device
|
|
490
|
+
if _default_device is None:
|
|
491
|
+
_default_device = get_device()
|
|
492
|
+
return _default_device
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def reset_default_device() -> None:
|
|
496
|
+
"""Reset the cached default device.
|
|
497
|
+
|
|
498
|
+
Call this if the hardware configuration has changed and you want
|
|
499
|
+
to re-detect the best device.
|
|
500
|
+
"""
|
|
501
|
+
global _default_device
|
|
502
|
+
_default_device = None
|