autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -1,395 +0,0 @@
|
|
1
|
-
import contextlib
|
2
|
-
import logging
|
3
|
-
import math
|
4
|
-
import sys
|
5
|
-
import warnings
|
6
|
-
from typing import Dict, List, Optional, Tuple, Union
|
7
|
-
|
8
|
-
import torch
|
9
|
-
from lightning.pytorch.accelerators import find_usable_cuda_devices
|
10
|
-
from torch import nn
|
11
|
-
|
12
|
-
from autogluon.common.utils.resource_utils import ResourceManager
|
13
|
-
|
14
|
-
from ..constants import DDP_STRATEGIES, OBJECT_DETECTION, OCR
|
15
|
-
|
16
|
-
logger = logging.getLogger(__name__)
|
17
|
-
|
18
|
-
|
19
|
-
def is_interactive_env():
|
20
|
-
"""
|
21
|
-
Return whether the current process is running under the interactive mode.
|
22
|
-
Check also https://stackoverflow.com/a/64523765
|
23
|
-
"""
|
24
|
-
return hasattr(sys, "ps1")
|
25
|
-
|
26
|
-
|
27
|
-
def is_interactive_strategy(strategy: str):
|
28
|
-
if isinstance(strategy, str) and strategy:
|
29
|
-
return strategy.startswith(("ddp_fork", "ddp_notebook"))
|
30
|
-
else:
|
31
|
-
return False
|
32
|
-
|
33
|
-
|
34
|
-
def compute_num_gpus(config_num_gpus: Union[int, float, List], accelerator: str):
|
35
|
-
"""
|
36
|
-
Compute the gpu number to initialize the lightning trainer.
|
37
|
-
|
38
|
-
Parameters
|
39
|
-
----------
|
40
|
-
config_num_gpus
|
41
|
-
The gpu number provided by config.
|
42
|
-
accelerator
|
43
|
-
# "cpu", "gpu", or "auto".
|
44
|
-
|
45
|
-
Returns
|
46
|
-
-------
|
47
|
-
A valid gpu number for the current environment and config.
|
48
|
-
"""
|
49
|
-
if isinstance(accelerator, str) and accelerator.lower() not in ["gpu", "auto"]:
|
50
|
-
return 0
|
51
|
-
|
52
|
-
config_num_gpus = (
|
53
|
-
math.floor(config_num_gpus) if isinstance(config_num_gpus, (int, float)) else len(config_num_gpus)
|
54
|
-
)
|
55
|
-
detected_num_gpus = ResourceManager.get_gpu_count_torch()
|
56
|
-
|
57
|
-
if config_num_gpus < 0: # In case config_num_gpus is -1, meaning using all gpus.
|
58
|
-
num_gpus = detected_num_gpus
|
59
|
-
else:
|
60
|
-
num_gpus = min(config_num_gpus, detected_num_gpus)
|
61
|
-
if detected_num_gpus < config_num_gpus:
|
62
|
-
warnings.warn(
|
63
|
-
f"Using the detected GPU number {detected_num_gpus}, "
|
64
|
-
f"smaller than the GPU number {config_num_gpus} in the config.",
|
65
|
-
UserWarning,
|
66
|
-
)
|
67
|
-
|
68
|
-
return num_gpus
|
69
|
-
|
70
|
-
|
71
|
-
def convert_to_torch_precision(precision: Union[int, str]):
|
72
|
-
"""
|
73
|
-
Convert a precision integer or string to the corresponding torch precision.
|
74
|
-
|
75
|
-
Parameters
|
76
|
-
----------
|
77
|
-
precision
|
78
|
-
a precision integer or string from the config.
|
79
|
-
|
80
|
-
Returns
|
81
|
-
-------
|
82
|
-
A torch precision object.
|
83
|
-
"""
|
84
|
-
precision_mapping = {
|
85
|
-
16: torch.half,
|
86
|
-
"16": torch.half,
|
87
|
-
"16-mixed": torch.half,
|
88
|
-
"16-true": torch.half,
|
89
|
-
"bf16": torch.bfloat16,
|
90
|
-
"bf16-mixed": torch.bfloat16,
|
91
|
-
"bf16-true": torch.bfloat16,
|
92
|
-
32: torch.float32,
|
93
|
-
"32": torch.float32,
|
94
|
-
"32-true": torch.float32,
|
95
|
-
64: torch.float64,
|
96
|
-
"64": torch.float64,
|
97
|
-
"64-true": torch.float64,
|
98
|
-
}
|
99
|
-
|
100
|
-
if precision in precision_mapping:
|
101
|
-
precision = precision_mapping[precision]
|
102
|
-
else:
|
103
|
-
raise ValueError(f"Unknown precision: {precision}")
|
104
|
-
|
105
|
-
return precision
|
106
|
-
|
107
|
-
|
108
|
-
def infer_precision(
|
109
|
-
num_gpus: int, precision: Union[int, str], as_torch: Optional[bool] = False, cpu_only_warning: bool = True
|
110
|
-
):
|
111
|
-
"""
|
112
|
-
Infer the proper precision based on the environment setup and the provided precision.
|
113
|
-
|
114
|
-
Parameters
|
115
|
-
----------
|
116
|
-
num_gpus
|
117
|
-
GPU number.
|
118
|
-
precision
|
119
|
-
The precision provided in config.
|
120
|
-
as_torch
|
121
|
-
Whether to convert the precision to the Pytorch format.
|
122
|
-
cpu_only_warning
|
123
|
-
Whether to turn on warning if the instance has only CPU.
|
124
|
-
|
125
|
-
Returns
|
126
|
-
-------
|
127
|
-
The inferred precision.
|
128
|
-
"""
|
129
|
-
if num_gpus == 0: # CPU only prediction
|
130
|
-
if cpu_only_warning:
|
131
|
-
warnings.warn(
|
132
|
-
"Only CPU is detected in the instance. "
|
133
|
-
"This may result in slow speed for MultiModalPredictor. "
|
134
|
-
"Consider using an instance with GPU support.",
|
135
|
-
UserWarning,
|
136
|
-
)
|
137
|
-
precision = 32 # Force to use fp32 for training since 16-mixed is not available in CPU
|
138
|
-
else:
|
139
|
-
if isinstance(precision, str) and "bf16" in precision and not torch.cuda.is_bf16_supported():
|
140
|
-
warnings.warn(
|
141
|
-
f"{precision} is not supported by the GPU device / cuda version. "
|
142
|
-
"Consider using GPU devices with versions after Amphere or upgrading cuda to be >=11.0. "
|
143
|
-
f"MultiModalPredictor is switching precision from {precision} to 32.",
|
144
|
-
UserWarning,
|
145
|
-
)
|
146
|
-
precision = 32
|
147
|
-
|
148
|
-
if as_torch:
|
149
|
-
precision = convert_to_torch_precision(precision=precision)
|
150
|
-
|
151
|
-
return precision
|
152
|
-
|
153
|
-
|
154
|
-
def move_to_device(obj: Union[torch.Tensor, nn.Module, Dict, List, Tuple], device: torch.device):
|
155
|
-
"""
|
156
|
-
Move an object to the given device.
|
157
|
-
|
158
|
-
Parameters
|
159
|
-
----------
|
160
|
-
obj
|
161
|
-
An object, which can be a tensor, a module, a dict, or a list.
|
162
|
-
device
|
163
|
-
A Pytorch device instance.
|
164
|
-
|
165
|
-
Returns
|
166
|
-
-------
|
167
|
-
The object on the device.
|
168
|
-
"""
|
169
|
-
if not isinstance(device, torch.device):
|
170
|
-
raise ValueError(f"Invalid device: {device}. Ensure the device type is `torch.device`.")
|
171
|
-
|
172
|
-
if torch.is_tensor(obj) or isinstance(obj, nn.Module):
|
173
|
-
return obj.to(device)
|
174
|
-
elif isinstance(obj, dict):
|
175
|
-
res = {}
|
176
|
-
for k, v in obj.items():
|
177
|
-
res[k] = move_to_device(v, device)
|
178
|
-
return res
|
179
|
-
elif isinstance(obj, list) or isinstance(obj, tuple):
|
180
|
-
res = []
|
181
|
-
for v in obj:
|
182
|
-
res.append(move_to_device(v, device))
|
183
|
-
return res
|
184
|
-
elif isinstance(obj, (int, float, str)):
|
185
|
-
return obj
|
186
|
-
else:
|
187
|
-
raise TypeError(
|
188
|
-
f"Invalid type {type(obj)} for move_to_device. "
|
189
|
-
f"Make sure the object is one of these: a Pytorch tensor, a Pytorch module, "
|
190
|
-
f"a dict or list of tensors or modules."
|
191
|
-
)
|
192
|
-
|
193
|
-
|
194
|
-
def compute_inference_batch_size(
|
195
|
-
per_gpu_batch_size: int,
|
196
|
-
eval_batch_size_ratio: Union[int, float],
|
197
|
-
per_gpu_batch_size_evaluation: int,
|
198
|
-
num_gpus: int,
|
199
|
-
strategy: str,
|
200
|
-
):
|
201
|
-
"""
|
202
|
-
Compute the batch size for inference.
|
203
|
-
|
204
|
-
Parameters
|
205
|
-
----------
|
206
|
-
per_gpu_batch_size
|
207
|
-
Per gpu batch size from the config.
|
208
|
-
eval_batch_size_ratio
|
209
|
-
per_gpu_batch_size_evaluation = per_gpu_batch_size * eval_batch_size_ratio.
|
210
|
-
per_gpu_batch_size_evaluation
|
211
|
-
Per gpu evaluation batch size from the config.
|
212
|
-
num_gpus
|
213
|
-
Number of GPUs.
|
214
|
-
strategy
|
215
|
-
A pytorch lightning strategy.
|
216
|
-
|
217
|
-
Returns
|
218
|
-
-------
|
219
|
-
Batch size for inference.
|
220
|
-
"""
|
221
|
-
if per_gpu_batch_size_evaluation:
|
222
|
-
batch_size = per_gpu_batch_size_evaluation
|
223
|
-
else:
|
224
|
-
batch_size = per_gpu_batch_size * eval_batch_size_ratio
|
225
|
-
|
226
|
-
if num_gpus > 1 and strategy == "dp":
|
227
|
-
# If using 'dp', the per_gpu_batch_size would be split by all GPUs.
|
228
|
-
# So, we need to use the GPU number as a multiplier to compute the batch size.
|
229
|
-
batch_size = batch_size * num_gpus
|
230
|
-
|
231
|
-
return batch_size
|
232
|
-
|
233
|
-
|
234
|
-
@contextlib.contextmanager
|
235
|
-
def double_precision_context():
|
236
|
-
"""
|
237
|
-
Double precision context manager.
|
238
|
-
"""
|
239
|
-
default_dtype = torch.get_default_dtype()
|
240
|
-
torch.set_default_dtype(torch.float64)
|
241
|
-
yield
|
242
|
-
torch.set_default_dtype(default_dtype)
|
243
|
-
|
244
|
-
|
245
|
-
def get_precision_context(precision: Union[int, str], device_type: Optional[str] = None):
|
246
|
-
"""
|
247
|
-
Choose the proper context manager based on the precision.
|
248
|
-
|
249
|
-
Parameters
|
250
|
-
----------
|
251
|
-
precision
|
252
|
-
The precision.
|
253
|
-
device_type
|
254
|
-
gpu or cpu.
|
255
|
-
|
256
|
-
Returns
|
257
|
-
-------
|
258
|
-
A precision context manager.
|
259
|
-
"""
|
260
|
-
precision = convert_to_torch_precision(precision=precision)
|
261
|
-
|
262
|
-
if precision in [torch.half, torch.float16, torch.bfloat16]:
|
263
|
-
return torch.autocast(device_type=device_type, dtype=precision)
|
264
|
-
if precision == torch.float32:
|
265
|
-
assert torch.get_default_dtype() == torch.float32
|
266
|
-
return contextlib.nullcontext()
|
267
|
-
elif precision == torch.float64:
|
268
|
-
return double_precision_context()
|
269
|
-
else:
|
270
|
-
raise ValueError(f"Unknown precision: {precision}")
|
271
|
-
|
272
|
-
|
273
|
-
def check_if_packages_installed(problem_type: str = None, package_names: List[str] = None):
|
274
|
-
"""
|
275
|
-
Check if necessary packages are installed for some problem types.
|
276
|
-
Raise an error if an package can't be imported.
|
277
|
-
|
278
|
-
Parameters
|
279
|
-
----------
|
280
|
-
problem_type
|
281
|
-
Problem type
|
282
|
-
"""
|
283
|
-
if problem_type:
|
284
|
-
problem_type = problem_type.lower()
|
285
|
-
if any(p in problem_type for p in [OBJECT_DETECTION, OCR]):
|
286
|
-
try:
|
287
|
-
with warnings.catch_warnings():
|
288
|
-
warnings.simplefilter("ignore")
|
289
|
-
import mmcv
|
290
|
-
except ImportError as e:
|
291
|
-
raise ValueError(
|
292
|
-
f"Encountered error while importing mmcv: {e}. {_get_mmlab_installation_guide('mmcv')}"
|
293
|
-
)
|
294
|
-
|
295
|
-
try:
|
296
|
-
import mmdet
|
297
|
-
except ImportError as e:
|
298
|
-
raise ValueError(
|
299
|
-
f"Encountered error while importing mmdet: {e}. {_get_mmlab_installation_guide('mmdet')}"
|
300
|
-
)
|
301
|
-
|
302
|
-
if OCR in problem_type:
|
303
|
-
try:
|
304
|
-
import mmocr
|
305
|
-
except ImportError as e:
|
306
|
-
raise ValueError(
|
307
|
-
f'Encountered error while importing mmocr: {e}. Try to install mmocr: pip install "mmocr<1.0".'
|
308
|
-
)
|
309
|
-
if package_names:
|
310
|
-
for package_name in package_names:
|
311
|
-
if package_name == "mmcv":
|
312
|
-
try:
|
313
|
-
with warnings.catch_warnings():
|
314
|
-
warnings.simplefilter("ignore")
|
315
|
-
import mmcv
|
316
|
-
from mmcv import ConfigDict
|
317
|
-
from mmcv.runner import load_checkpoint
|
318
|
-
from mmcv.transforms import Compose
|
319
|
-
except ImportError as e:
|
320
|
-
f"Encountered error while importing {package_name}: {e}. {_get_mmlab_installation_guide(package_name)}"
|
321
|
-
elif package_name == "mmdet":
|
322
|
-
try:
|
323
|
-
import mmdet
|
324
|
-
from mmdet.datasets.transforms import ImageToTensor
|
325
|
-
from mmdet.registry import MODELS
|
326
|
-
except ImportError as e:
|
327
|
-
f"Encountered error while importing {package_name}: {e}. {_get_mmlab_installation_guide(package_name)}"
|
328
|
-
elif package_name == "mmengine":
|
329
|
-
try:
|
330
|
-
import mmengine
|
331
|
-
from mmengine.dataset import pseudo_collate as collate
|
332
|
-
from mmengine.runner import load_checkpoint
|
333
|
-
except ImportError as e:
|
334
|
-
warnings.warn(e)
|
335
|
-
raise ValueError(
|
336
|
-
f"Encountered error while importing {package_name}: {e}. {_get_mmlab_installation_guide(package_name)}"
|
337
|
-
)
|
338
|
-
else:
|
339
|
-
raise ValueError(f"package_name {package_name} is not required.")
|
340
|
-
|
341
|
-
|
342
|
-
def get_available_devices(num_gpus: int, auto_select_gpus: bool):
|
343
|
-
"""
|
344
|
-
Get the available devices.
|
345
|
-
|
346
|
-
Parameters
|
347
|
-
----------
|
348
|
-
num_gpus
|
349
|
-
Number of GPUs.
|
350
|
-
auto_select_gpus
|
351
|
-
Whether to pick GPU indices that are "accessible". See here: https://github.com/Lightning-AI/lightning/blob/accd2b9e61063ba3c683764043030545ed87c71f/src/lightning/fabric/accelerators/cuda.py#L79
|
352
|
-
|
353
|
-
Returns
|
354
|
-
-------
|
355
|
-
The available devices.
|
356
|
-
"""
|
357
|
-
if num_gpus > 0:
|
358
|
-
if auto_select_gpus:
|
359
|
-
if is_interactive_env():
|
360
|
-
devices = list(range(num_gpus))
|
361
|
-
else:
|
362
|
-
devices = find_usable_cuda_devices(num_gpus)
|
363
|
-
else:
|
364
|
-
devices = num_gpus
|
365
|
-
else:
|
366
|
-
devices = "auto"
|
367
|
-
|
368
|
-
return devices
|
369
|
-
|
370
|
-
|
371
|
-
def _get_mmlab_installation_guide(package_name):
|
372
|
-
if package_name == "mmdet":
|
373
|
-
err_msg = 'Please install MMDetection by: pip install "mmdet==3.2.0"'
|
374
|
-
elif package_name == "mmcv":
|
375
|
-
err_msg = 'Please install MMCV by: mim install "mmcv==2.1.0"'
|
376
|
-
elif package_name == "mmengine":
|
377
|
-
err_msg = "Please install MMEngine by: mim install mmengine"
|
378
|
-
else:
|
379
|
-
raise ValueError("Available package_name are: mmdet, mmcv, mmengine.")
|
380
|
-
|
381
|
-
err_msg += " Pytorch version larger than 2.1 is not supported yet. To use Autogluon for object detection, please downgrade PyTorch version to <=2.1."
|
382
|
-
|
383
|
-
return err_msg
|
384
|
-
|
385
|
-
|
386
|
-
def run_ddp_only_once(num_gpus: int, strategy: str):
|
387
|
-
if strategy in DDP_STRATEGIES:
|
388
|
-
global FIRST_DDP_RUN # Use the global variable to make sure it is tracked per process
|
389
|
-
if "FIRST_DDP_RUN" in globals() and not FIRST_DDP_RUN:
|
390
|
-
# not the first time running DDP, set number of devices to 1 (use single GPU)
|
391
|
-
return min(1, num_gpus), "auto"
|
392
|
-
else:
|
393
|
-
if num_gpus > 1:
|
394
|
-
FIRST_DDP_RUN = False # run DDP for the first time, disable the following runs
|
395
|
-
return num_gpus, strategy
|