ultralytics 8.3.88__py3-none-any.whl → 8.3.90__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +2 -2
- tests/test_cli.py +13 -11
- tests/test_cuda.py +10 -1
- tests/test_integrations.py +1 -5
- tests/test_python.py +16 -16
- tests/test_solutions.py +9 -9
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +3 -1
- ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
- ultralytics/cfg/models/11/yolo11.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8.yaml +5 -5
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
- ultralytics/data/annotator.py +9 -14
- ultralytics/data/base.py +125 -39
- ultralytics/data/build.py +63 -24
- ultralytics/data/converter.py +34 -33
- ultralytics/data/dataset.py +207 -53
- ultralytics/data/loaders.py +1 -0
- ultralytics/data/split_dota.py +39 -12
- ultralytics/data/utils.py +33 -47
- ultralytics/engine/exporter.py +19 -17
- ultralytics/engine/model.py +69 -90
- ultralytics/engine/predictor.py +106 -21
- ultralytics/engine/trainer.py +32 -23
- ultralytics/engine/tuner.py +31 -38
- ultralytics/engine/validator.py +75 -41
- ultralytics/hub/__init__.py +21 -26
- ultralytics/hub/auth.py +9 -12
- ultralytics/hub/session.py +76 -21
- ultralytics/hub/utils.py +19 -17
- ultralytics/models/fastsam/model.py +23 -17
- ultralytics/models/fastsam/predict.py +36 -16
- ultralytics/models/fastsam/utils.py +5 -5
- ultralytics/models/fastsam/val.py +6 -6
- ultralytics/models/nas/model.py +29 -24
- ultralytics/models/nas/predict.py +14 -11
- ultralytics/models/nas/val.py +11 -13
- ultralytics/models/rtdetr/model.py +20 -11
- ultralytics/models/rtdetr/predict.py +21 -21
- ultralytics/models/rtdetr/train.py +25 -24
- ultralytics/models/rtdetr/val.py +47 -14
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +50 -4
- ultralytics/models/sam/model.py +8 -14
- ultralytics/models/sam/modules/decoders.py +18 -21
- ultralytics/models/sam/modules/encoders.py +25 -46
- ultralytics/models/sam/modules/memory_attention.py +19 -15
- ultralytics/models/sam/modules/sam.py +18 -25
- ultralytics/models/sam/modules/tiny_encoder.py +19 -29
- ultralytics/models/sam/modules/transformer.py +35 -57
- ultralytics/models/sam/modules/utils.py +15 -15
- ultralytics/models/sam/predict.py +0 -3
- ultralytics/models/utils/loss.py +87 -36
- ultralytics/models/utils/ops.py +26 -31
- ultralytics/models/yolo/classify/predict.py +30 -12
- ultralytics/models/yolo/classify/train.py +83 -19
- ultralytics/models/yolo/classify/val.py +45 -23
- ultralytics/models/yolo/detect/predict.py +29 -19
- ultralytics/models/yolo/detect/train.py +90 -23
- ultralytics/models/yolo/detect/val.py +150 -29
- ultralytics/models/yolo/model.py +1 -2
- ultralytics/models/yolo/obb/predict.py +18 -13
- ultralytics/models/yolo/obb/train.py +12 -8
- ultralytics/models/yolo/obb/val.py +35 -22
- ultralytics/models/yolo/pose/predict.py +28 -15
- ultralytics/models/yolo/pose/train.py +21 -8
- ultralytics/models/yolo/pose/val.py +51 -31
- ultralytics/models/yolo/segment/predict.py +27 -16
- ultralytics/models/yolo/segment/train.py +11 -8
- ultralytics/models/yolo/segment/val.py +110 -29
- ultralytics/models/yolo/world/train.py +43 -16
- ultralytics/models/yolo/world/train_world.py +61 -36
- ultralytics/nn/autobackend.py +28 -14
- ultralytics/nn/modules/__init__.py +12 -12
- ultralytics/nn/modules/activation.py +12 -3
- ultralytics/nn/modules/block.py +587 -84
- ultralytics/nn/modules/conv.py +418 -54
- ultralytics/nn/modules/head.py +3 -4
- ultralytics/nn/modules/transformer.py +320 -34
- ultralytics/nn/modules/utils.py +17 -3
- ultralytics/nn/tasks.py +226 -79
- ultralytics/solutions/ai_gym.py +2 -2
- ultralytics/solutions/analytics.py +4 -4
- ultralytics/solutions/heatmap.py +4 -4
- ultralytics/solutions/instance_segmentation.py +10 -4
- ultralytics/solutions/object_blurrer.py +2 -2
- ultralytics/solutions/object_counter.py +2 -2
- ultralytics/solutions/object_cropper.py +2 -2
- ultralytics/solutions/parking_management.py +9 -9
- ultralytics/solutions/queue_management.py +1 -1
- ultralytics/solutions/region_counter.py +2 -2
- ultralytics/solutions/security_alarm.py +7 -7
- ultralytics/solutions/solutions.py +7 -4
- ultralytics/solutions/speed_estimation.py +2 -2
- ultralytics/solutions/streamlit_inference.py +6 -6
- ultralytics/solutions/trackzone.py +9 -2
- ultralytics/solutions/vision_eye.py +4 -4
- ultralytics/trackers/basetrack.py +1 -1
- ultralytics/trackers/bot_sort.py +23 -22
- ultralytics/trackers/byte_tracker.py +4 -4
- ultralytics/trackers/track.py +2 -1
- ultralytics/trackers/utils/gmc.py +26 -27
- ultralytics/trackers/utils/kalman_filter.py +31 -29
- ultralytics/trackers/utils/matching.py +7 -7
- ultralytics/utils/__init__.py +37 -35
- ultralytics/utils/autobatch.py +5 -5
- ultralytics/utils/benchmarks.py +111 -18
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +11 -11
- ultralytics/utils/callbacks/comet.py +35 -22
- ultralytics/utils/callbacks/dvc.py +11 -10
- ultralytics/utils/callbacks/hub.py +8 -8
- ultralytics/utils/callbacks/mlflow.py +1 -1
- ultralytics/utils/callbacks/neptune.py +12 -10
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +6 -6
- ultralytics/utils/callbacks/wb.py +16 -16
- ultralytics/utils/checks.py +139 -68
- ultralytics/utils/dist.py +15 -2
- ultralytics/utils/downloads.py +37 -56
- ultralytics/utils/files.py +12 -13
- ultralytics/utils/instance.py +117 -52
- ultralytics/utils/loss.py +28 -33
- ultralytics/utils/metrics.py +246 -181
- ultralytics/utils/ops.py +65 -61
- ultralytics/utils/patches.py +8 -6
- ultralytics/utils/plotting.py +72 -59
- ultralytics/utils/tal.py +88 -57
- ultralytics/utils/torch_utils.py +202 -64
- ultralytics/utils/triton.py +13 -3
- ultralytics/utils/tuner.py +13 -25
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/METADATA +2 -2
- ultralytics-8.3.90.dist-info/RECORD +250 -0
- ultralytics-8.3.88.dist-info/RECORD +0 -250
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/LICENSE +0 -0
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/top_level.txt +0 -0
ultralytics/engine/predictor.py
CHANGED
@@ -65,28 +65,54 @@ Example:
|
|
65
65
|
|
66
66
|
class BasePredictor:
|
67
67
|
"""
|
68
|
-
BasePredictor.
|
69
|
-
|
70
68
|
A base class for creating predictors.
|
71
69
|
|
70
|
+
This class provides the foundation for prediction functionality, handling model setup, inference,
|
71
|
+
and result processing across various input sources.
|
72
|
+
|
72
73
|
Attributes:
|
73
74
|
args (SimpleNamespace): Configuration for the predictor.
|
74
75
|
save_dir (Path): Directory to save results.
|
75
76
|
done_warmup (bool): Whether the predictor has finished setup.
|
76
|
-
model (nn.Module): Model used for prediction.
|
77
|
+
model (torch.nn.Module): Model used for prediction.
|
77
78
|
data (dict): Data configuration.
|
78
79
|
device (torch.device): Device used for prediction.
|
79
80
|
dataset (Dataset): Dataset used for prediction.
|
80
|
-
vid_writer (dict): Dictionary of {save_path: video_writer
|
81
|
+
vid_writer (dict): Dictionary of {save_path: video_writer} for saving video output.
|
82
|
+
plotted_img (numpy.ndarray): Last plotted image.
|
83
|
+
source_type (SimpleNamespace): Type of input source.
|
84
|
+
seen (int): Number of images processed.
|
85
|
+
windows (List): List of window names for visualization.
|
86
|
+
batch (tuple): Current batch data.
|
87
|
+
results (List): Current batch results.
|
88
|
+
transforms (callable): Image transforms for classification.
|
89
|
+
callbacks (dict): Callback functions for different events.
|
90
|
+
txt_path (Path): Path to save text results.
|
91
|
+
_lock (threading.Lock): Lock for thread-safe inference.
|
92
|
+
|
93
|
+
Methods:
|
94
|
+
preprocess: Prepare input image before inference.
|
95
|
+
inference: Run inference on a given image.
|
96
|
+
postprocess: Process raw predictions into structured results.
|
97
|
+
predict_cli: Run prediction for command line interface.
|
98
|
+
setup_source: Set up input source and inference mode.
|
99
|
+
stream_inference: Stream inference on input source.
|
100
|
+
setup_model: Initialize and configure the model.
|
101
|
+
write_results: Write inference results to files.
|
102
|
+
save_predicted_images: Save prediction visualizations.
|
103
|
+
show: Display results in a window.
|
104
|
+
run_callbacks: Execute registered callbacks for an event.
|
105
|
+
add_callback: Register a new callback function.
|
81
106
|
"""
|
82
107
|
|
83
108
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
84
109
|
"""
|
85
|
-
|
110
|
+
Initialize the BasePredictor class.
|
86
111
|
|
87
112
|
Args:
|
88
|
-
cfg (str
|
89
|
-
overrides (dict
|
113
|
+
cfg (str | dict): Path to a configuration file or a configuration dictionary.
|
114
|
+
overrides (dict | None): Configuration overrides.
|
115
|
+
_callbacks (dict | None): Dictionary of callback functions.
|
90
116
|
"""
|
91
117
|
self.args = get_cfg(cfg, overrides)
|
92
118
|
self.save_dir = get_save_dir(self.args)
|
@@ -120,7 +146,7 @@ class BasePredictor:
|
|
120
146
|
Prepares input image before inference.
|
121
147
|
|
122
148
|
Args:
|
123
|
-
im (torch.Tensor | List(np.ndarray)):
|
149
|
+
im (torch.Tensor | List(np.ndarray)): Images of shape (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
124
150
|
"""
|
125
151
|
not_tensor = not isinstance(im, torch.Tensor)
|
126
152
|
if not_tensor:
|
@@ -136,7 +162,7 @@ class BasePredictor:
|
|
136
162
|
return im
|
137
163
|
|
138
164
|
def inference(self, im, *args, **kwargs):
|
139
|
-
"""
|
165
|
+
"""Run inference on a given image using the specified model and arguments."""
|
140
166
|
visualize = (
|
141
167
|
increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
|
142
168
|
if self.args.visualize and (not self.source_type.tensor)
|
@@ -149,10 +175,10 @@ class BasePredictor:
|
|
149
175
|
Pre-transform input image before inference.
|
150
176
|
|
151
177
|
Args:
|
152
|
-
im (List
|
178
|
+
im (List[np.ndarray]): Images of shape (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
153
179
|
|
154
180
|
Returns:
|
155
|
-
(
|
181
|
+
(List[np.ndarray]): A list of transformed images.
|
156
182
|
"""
|
157
183
|
same_shapes = len({x.shape for x in im}) == 1
|
158
184
|
letterbox = LetterBox(
|
@@ -163,11 +189,24 @@ class BasePredictor:
|
|
163
189
|
return [letterbox(image=x) for x in im]
|
164
190
|
|
165
191
|
def postprocess(self, preds, img, orig_imgs):
|
166
|
-
"""Post-
|
192
|
+
"""Post-process predictions for an image and return them."""
|
167
193
|
return preds
|
168
194
|
|
169
195
|
def __call__(self, source=None, model=None, stream=False, *args, **kwargs):
|
170
|
-
"""
|
196
|
+
"""
|
197
|
+
Perform inference on an image or stream.
|
198
|
+
|
199
|
+
Args:
|
200
|
+
source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor | None):
|
201
|
+
Source for inference.
|
202
|
+
model (str | Path | torch.nn.Module | None): Model for inference.
|
203
|
+
stream (bool): Whether to stream the inference results. If True, returns a generator.
|
204
|
+
*args (Any): Additional arguments for the inference method.
|
205
|
+
**kwargs (Any): Additional keyword arguments for the inference method.
|
206
|
+
|
207
|
+
Returns:
|
208
|
+
(List[ultralytics.engine.results.Results] | generator): Results objects or generator of Results objects.
|
209
|
+
"""
|
171
210
|
self.stream = stream
|
172
211
|
if stream:
|
173
212
|
return self.stream_inference(source, model, *args, **kwargs)
|
@@ -182,6 +221,11 @@ class BasePredictor:
|
|
182
221
|
the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the
|
183
222
|
generator without storing results.
|
184
223
|
|
224
|
+
Args:
|
225
|
+
source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor | None):
|
226
|
+
Source for inference.
|
227
|
+
model (str | Path | torch.nn.Module | None): Model for inference.
|
228
|
+
|
185
229
|
Note:
|
186
230
|
Do not modify this function or remove the generator. The generator ensures that no outputs are
|
187
231
|
accumulated in memory, which is critical for preventing memory issues during long-running predictions.
|
@@ -191,7 +235,13 @@ class BasePredictor:
|
|
191
235
|
pass
|
192
236
|
|
193
237
|
def setup_source(self, source):
|
194
|
-
"""
|
238
|
+
"""
|
239
|
+
Set up source and inference mode.
|
240
|
+
|
241
|
+
Args:
|
242
|
+
source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor):
|
243
|
+
Source for inference.
|
244
|
+
"""
|
195
245
|
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
196
246
|
self.transforms = (
|
197
247
|
getattr(
|
@@ -220,7 +270,19 @@ class BasePredictor:
|
|
220
270
|
|
221
271
|
@smart_inference_mode()
|
222
272
|
def stream_inference(self, source=None, model=None, *args, **kwargs):
|
223
|
-
"""
|
273
|
+
"""
|
274
|
+
Stream real-time inference on camera feed and save results to file.
|
275
|
+
|
276
|
+
Args:
|
277
|
+
source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor | None):
|
278
|
+
Source for inference.
|
279
|
+
model (str | Path | torch.nn.Module | None): Model for inference.
|
280
|
+
*args (Any): Additional arguments for the inference method.
|
281
|
+
**kwargs (Any): Additional keyword arguments for the inference method.
|
282
|
+
|
283
|
+
Yields:
|
284
|
+
(ultralytics.engine.results.Results): Results objects.
|
285
|
+
"""
|
224
286
|
if self.args.verbose:
|
225
287
|
LOGGER.info("")
|
226
288
|
|
@@ -306,7 +368,13 @@ class BasePredictor:
|
|
306
368
|
self.run_callbacks("on_predict_end")
|
307
369
|
|
308
370
|
def setup_model(self, model, verbose=True):
|
309
|
-
"""
|
371
|
+
"""
|
372
|
+
Initialize YOLO model with given parameters and set it to evaluation mode.
|
373
|
+
|
374
|
+
Args:
|
375
|
+
model (str | Path | torch.nn.Module | None): Model to load or use.
|
376
|
+
verbose (bool): Whether to print verbose output.
|
377
|
+
"""
|
310
378
|
self.model = AutoBackend(
|
311
379
|
weights=model or self.args.model,
|
312
380
|
device=select_device(self.args.device, verbose=verbose),
|
@@ -323,7 +391,18 @@ class BasePredictor:
|
|
323
391
|
self.model.eval()
|
324
392
|
|
325
393
|
def write_results(self, i, p, im, s):
|
326
|
-
"""
|
394
|
+
"""
|
395
|
+
Write inference results to a file or directory.
|
396
|
+
|
397
|
+
Args:
|
398
|
+
i (int): Index of the current image in the batch.
|
399
|
+
p (Path): Path to the current image.
|
400
|
+
im (torch.Tensor): Preprocessed image tensor.
|
401
|
+
s (List[str]): List of result strings.
|
402
|
+
|
403
|
+
Returns:
|
404
|
+
(str): String with result information.
|
405
|
+
"""
|
327
406
|
string = "" # print string
|
328
407
|
if len(im.shape) == 3:
|
329
408
|
im = im[None] # expand for batch dim
|
@@ -363,7 +442,13 @@ class BasePredictor:
|
|
363
442
|
return string
|
364
443
|
|
365
444
|
def save_predicted_images(self, save_path="", frame=0):
|
366
|
-
"""
|
445
|
+
"""
|
446
|
+
Save video predictions as mp4 or images as jpg at specified path.
|
447
|
+
|
448
|
+
Args:
|
449
|
+
save_path (str): Path to save the results.
|
450
|
+
frame (int): Frame number for video mode.
|
451
|
+
"""
|
367
452
|
im = self.plotted_img
|
368
453
|
|
369
454
|
# Save videos and streams
|
@@ -391,7 +476,7 @@ class BasePredictor:
|
|
391
476
|
cv2.imwrite(str(Path(save_path).with_suffix(".jpg")), im) # save to JPG for best support
|
392
477
|
|
393
478
|
def show(self, p=""):
|
394
|
-
"""Display an image in a window
|
479
|
+
"""Display an image in a window."""
|
395
480
|
im = self.plotted_img
|
396
481
|
if platform.system() == "Linux" and p not in self.windows:
|
397
482
|
self.windows.append(p)
|
@@ -401,10 +486,10 @@ class BasePredictor:
|
|
401
486
|
cv2.waitKey(300 if self.dataset.mode == "image" else 1) # 1 millisecond
|
402
487
|
|
403
488
|
def run_callbacks(self, event: str):
|
404
|
-
"""
|
489
|
+
"""Run all registered callbacks for a specific event."""
|
405
490
|
for callback in self.callbacks.get(event, []):
|
406
491
|
callback(self)
|
407
492
|
|
408
493
|
def add_callback(self, event: str, func):
|
409
|
-
"""Add callback."""
|
494
|
+
"""Add a callback function for a specific event."""
|
410
495
|
self.callbacks[event].append(func)
|
ultralytics/engine/trainer.py
CHANGED
@@ -87,17 +87,20 @@ class BaseTrainer:
|
|
87
87
|
fitness (float): Current fitness value.
|
88
88
|
loss (float): Current loss value.
|
89
89
|
tloss (float): Total loss value.
|
90
|
-
loss_names (
|
90
|
+
loss_names (List): List of loss names.
|
91
91
|
csv (Path): Path to results CSV file.
|
92
|
+
metrics (Dict): Dictionary of metrics.
|
93
|
+
plots (Dict): Dictionary of plots.
|
92
94
|
"""
|
93
95
|
|
94
96
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
95
97
|
"""
|
96
|
-
|
98
|
+
Initialize the BaseTrainer class.
|
97
99
|
|
98
100
|
Args:
|
99
101
|
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
100
|
-
overrides (
|
102
|
+
overrides (Dict, optional): Configuration overrides. Defaults to None.
|
103
|
+
_callbacks (List, optional): List of callback functions. Defaults to None.
|
101
104
|
"""
|
102
105
|
self.args = get_cfg(cfg, overrides)
|
103
106
|
self.check_resume(overrides)
|
@@ -156,11 +159,11 @@ class BaseTrainer:
|
|
156
159
|
callbacks.add_integration_callbacks(self)
|
157
160
|
|
158
161
|
def add_callback(self, event: str, callback):
|
159
|
-
"""
|
162
|
+
"""Append the given callback to the event's callback list."""
|
160
163
|
self.callbacks[event].append(callback)
|
161
164
|
|
162
165
|
def set_callback(self, event: str, callback):
|
163
|
-
"""
|
166
|
+
"""Override the existing callbacks with the given callback for the specified event."""
|
164
167
|
self.callbacks[event] = [callback]
|
165
168
|
|
166
169
|
def run_callbacks(self, event: str):
|
@@ -216,7 +219,7 @@ class BaseTrainer:
|
|
216
219
|
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
217
220
|
|
218
221
|
def _setup_ddp(self, world_size):
|
219
|
-
"""
|
222
|
+
"""Initialize and set the DistributedDataParallel parameters for training."""
|
220
223
|
torch.cuda.set_device(RANK)
|
221
224
|
self.device = torch.device("cuda", RANK)
|
222
225
|
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
@@ -229,7 +232,7 @@ class BaseTrainer:
|
|
229
232
|
)
|
230
233
|
|
231
234
|
def _setup_train(self, world_size):
|
232
|
-
"""
|
235
|
+
"""Build dataloaders and optimizer on correct rank process."""
|
233
236
|
# Model
|
234
237
|
self.run_callbacks("on_pretrain_routine_start")
|
235
238
|
ckpt = self.setup_model()
|
@@ -317,7 +320,7 @@ class BaseTrainer:
|
|
317
320
|
self.run_callbacks("on_pretrain_routine_end")
|
318
321
|
|
319
322
|
def _do_train(self, world_size=1):
|
320
|
-
"""Train
|
323
|
+
"""Train the model with the specified world size."""
|
321
324
|
if world_size > 1:
|
322
325
|
self._setup_ddp(world_size)
|
323
326
|
self._setup_train(world_size)
|
@@ -477,7 +480,7 @@ class BaseTrainer:
|
|
477
480
|
self.run_callbacks("teardown")
|
478
481
|
|
479
482
|
def auto_batch(self, max_num_obj=0):
|
480
|
-
"""
|
483
|
+
"""Calculate optimal batch size based on model and device memory constraints."""
|
481
484
|
return check_train_batch_size(
|
482
485
|
model=self.model,
|
483
486
|
imgsz=self.args.imgsz,
|
@@ -487,12 +490,12 @@ class BaseTrainer:
|
|
487
490
|
) # returns batch size
|
488
491
|
|
489
492
|
def _get_memory(self, fraction=False):
|
490
|
-
"""Get accelerator memory utilization in GB or fraction."""
|
493
|
+
"""Get accelerator memory utilization in GB or as a fraction of total memory."""
|
491
494
|
memory, total = 0, 0
|
492
495
|
if self.device.type == "mps":
|
493
496
|
memory = torch.mps.driver_allocated_memory()
|
494
497
|
if fraction:
|
495
|
-
|
498
|
+
return __import__("psutil").virtual_memory().percent / 100
|
496
499
|
elif self.device.type == "cpu":
|
497
500
|
pass
|
498
501
|
else:
|
@@ -502,7 +505,7 @@ class BaseTrainer:
|
|
502
505
|
return ((memory / total) if total > 0 else 0) if fraction else (memory / 2**30)
|
503
506
|
|
504
507
|
def _clear_memory(self):
|
505
|
-
"""Clear accelerator memory
|
508
|
+
"""Clear accelerator memory by calling garbage collector and emptying cache."""
|
506
509
|
gc.collect()
|
507
510
|
if self.device.type == "mps":
|
508
511
|
torch.mps.empty_cache()
|
@@ -512,7 +515,7 @@ class BaseTrainer:
|
|
512
515
|
torch.cuda.empty_cache()
|
513
516
|
|
514
517
|
def read_results_csv(self):
|
515
|
-
"""Read results.csv into a
|
518
|
+
"""Read results.csv into a dictionary using pandas."""
|
516
519
|
import pandas as pd # scope for faster 'import ultralytics'
|
517
520
|
|
518
521
|
return pd.read_csv(self.csv).to_dict(orient="list")
|
@@ -554,9 +557,10 @@ class BaseTrainer:
|
|
554
557
|
|
555
558
|
def get_dataset(self):
|
556
559
|
"""
|
557
|
-
Get train
|
560
|
+
Get train and validation datasets from data dictionary.
|
558
561
|
|
559
|
-
Returns
|
562
|
+
Returns:
|
563
|
+
(tuple): A tuple containing the training and validation/test datasets.
|
560
564
|
"""
|
561
565
|
try:
|
562
566
|
if self.args.task == "classify":
|
@@ -580,7 +584,12 @@ class BaseTrainer:
|
|
580
584
|
return data["train"], data.get("val") or data.get("test")
|
581
585
|
|
582
586
|
def setup_model(self):
|
583
|
-
"""
|
587
|
+
"""
|
588
|
+
Load, create, or download model for any task.
|
589
|
+
|
590
|
+
Returns:
|
591
|
+
(dict): Optional checkpoint to resume training from.
|
592
|
+
"""
|
584
593
|
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
585
594
|
return
|
586
595
|
|
@@ -610,9 +619,10 @@ class BaseTrainer:
|
|
610
619
|
|
611
620
|
def validate(self):
|
612
621
|
"""
|
613
|
-
|
622
|
+
Run validation on test set using self.validator.
|
614
623
|
|
615
|
-
|
624
|
+
Returns:
|
625
|
+
(tuple): A tuple containing metrics dictionary and fitness score.
|
616
626
|
"""
|
617
627
|
metrics = self.validator(self)
|
618
628
|
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
@@ -646,7 +656,7 @@ class BaseTrainer:
|
|
646
656
|
return {"loss": loss_items} if loss_items is not None else ["loss"]
|
647
657
|
|
648
658
|
def set_model_attributes(self):
|
649
|
-
"""
|
659
|
+
"""Set or update model parameters before training."""
|
650
660
|
self.model.names = self.data["names"]
|
651
661
|
|
652
662
|
def build_targets(self, preds, targets):
|
@@ -667,7 +677,7 @@ class BaseTrainer:
|
|
667
677
|
pass
|
668
678
|
|
669
679
|
def save_metrics(self, metrics):
|
670
|
-
"""
|
680
|
+
"""Save training metrics to a CSV file."""
|
671
681
|
keys, vals = list(metrics.keys()), list(metrics.values())
|
672
682
|
n = len(metrics) + 2 # number of cols
|
673
683
|
s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n") # header
|
@@ -685,7 +695,7 @@ class BaseTrainer:
|
|
685
695
|
self.plots[path] = {"data": data, "timestamp": time.time()}
|
686
696
|
|
687
697
|
def final_eval(self):
|
688
|
-
"""
|
698
|
+
"""Perform final evaluation and validation for object detection YOLO model."""
|
689
699
|
ckpt = {}
|
690
700
|
for f in self.last, self.best:
|
691
701
|
if f.exists():
|
@@ -769,8 +779,7 @@ class BaseTrainer:
|
|
769
779
|
|
770
780
|
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
771
781
|
"""
|
772
|
-
|
773
|
-
weight decay, and number of iterations.
|
782
|
+
Construct an optimizer for the given model.
|
774
783
|
|
775
784
|
Args:
|
776
785
|
model (torch.nn.Module): The model for which to build an optimizer.
|
ultralytics/engine/tuner.py
CHANGED
@@ -7,14 +7,11 @@ Hyperparameter tuning is the process of systematically searching for the optimal
|
|
7
7
|
that yield the best model performance. This is particularly crucial in deep learning models like YOLO,
|
8
8
|
where small changes in hyperparameters can lead to significant differences in model accuracy and efficiency.
|
9
9
|
|
10
|
-
|
10
|
+
Examples:
|
11
11
|
Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
model = YOLO("yolo11n.pt")
|
16
|
-
model.tune(data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False)
|
17
|
-
```
|
12
|
+
>>> from ultralytics import YOLO
|
13
|
+
>>> model = YOLO("yolo11n.pt")
|
14
|
+
>>> model.tune(data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False)
|
18
15
|
"""
|
19
16
|
|
20
17
|
import random
|
@@ -32,39 +29,33 @@ from ultralytics.utils.plotting import plot_tune_results
|
|
32
29
|
|
33
30
|
class Tuner:
|
34
31
|
"""
|
35
|
-
|
32
|
+
A class for hyperparameter tuning of YOLO models.
|
36
33
|
|
37
|
-
The class evolves YOLO model hyperparameters over a given number of iterations
|
38
|
-
|
34
|
+
The class evolves YOLO model hyperparameters over a given number of iterations by mutating them according to the
|
35
|
+
search space and retraining the model to evaluate their performance.
|
39
36
|
|
40
37
|
Attributes:
|
41
|
-
space (
|
38
|
+
space (Dict): Hyperparameter search space containing bounds and scaling factors for mutation.
|
42
39
|
tune_dir (Path): Directory where evolution logs and results will be saved.
|
43
40
|
tune_csv (Path): Path to the CSV file where evolution logs are saved.
|
41
|
+
args (Dict): Configuration arguments for the tuning process.
|
42
|
+
callbacks (List): Callback functions to be executed during tuning.
|
43
|
+
prefix (str): Prefix string for logging messages.
|
44
44
|
|
45
45
|
Methods:
|
46
|
-
_mutate
|
47
|
-
|
48
|
-
|
49
|
-
__call__():
|
50
|
-
Executes the hyperparameter evolution across multiple iterations.
|
46
|
+
_mutate: Mutates the given hyperparameters within the specified bounds.
|
47
|
+
__call__: Executes the hyperparameter evolution across multiple iterations.
|
51
48
|
|
52
|
-
|
49
|
+
Examples:
|
53
50
|
Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
```
|
51
|
+
>>> from ultralytics import YOLO
|
52
|
+
>>> model = YOLO("yolo11n.pt")
|
53
|
+
>>> model.tune(
|
54
|
+
... data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False
|
55
|
+
... )
|
60
56
|
|
61
57
|
Tune with custom search space.
|
62
|
-
|
63
|
-
from ultralytics import YOLO
|
64
|
-
|
65
|
-
model = YOLO("yolo11n.pt")
|
66
|
-
model.tune(space={key1: val1, key2: val2}) # custom search space dictionary
|
67
|
-
```
|
58
|
+
>>> model.tune(space={key1: val1, key2: val2}) # custom search space dictionary
|
68
59
|
"""
|
69
60
|
|
70
61
|
def __init__(self, args=DEFAULT_CFG, _callbacks=None):
|
@@ -72,7 +63,8 @@ class Tuner:
|
|
72
63
|
Initialize the Tuner with configurations.
|
73
64
|
|
74
65
|
Args:
|
75
|
-
args (
|
66
|
+
args (Dict): Configuration for hyperparameter evolution.
|
67
|
+
_callbacks (List, optional): Callback functions to be executed during tuning.
|
76
68
|
"""
|
77
69
|
self.space = args.pop("space", None) or { # key: (min, max, gain(optional))
|
78
70
|
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
|
@@ -114,7 +106,7 @@ class Tuner:
|
|
114
106
|
|
115
107
|
def _mutate(self, parent="single", n=5, mutation=0.8, sigma=0.2):
|
116
108
|
"""
|
117
|
-
|
109
|
+
Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
|
118
110
|
|
119
111
|
Args:
|
120
112
|
parent (str): Parent selection method: 'single' or 'weighted'.
|
@@ -123,7 +115,7 @@ class Tuner:
|
|
123
115
|
sigma (float): Standard deviation for Gaussian random number generator.
|
124
116
|
|
125
117
|
Returns:
|
126
|
-
(
|
118
|
+
(Dict): A dictionary containing mutated hyperparameters.
|
127
119
|
"""
|
128
120
|
if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate
|
129
121
|
# Select parent(s)
|
@@ -160,22 +152,23 @@ class Tuner:
|
|
160
152
|
|
161
153
|
def __call__(self, model=None, iterations=10, cleanup=True):
|
162
154
|
"""
|
163
|
-
|
155
|
+
Execute the hyperparameter evolution process when the Tuner instance is called.
|
164
156
|
|
165
157
|
This method iterates through the number of iterations, performing the following steps in each iteration:
|
158
|
+
|
166
159
|
1. Load the existing hyperparameters or initialize new ones.
|
167
160
|
2. Mutate the hyperparameters using the `mutate` method.
|
168
161
|
3. Train a YOLO model with the mutated hyperparameters.
|
169
162
|
4. Log the fitness score and mutated hyperparameters to a CSV file.
|
170
163
|
|
171
164
|
Args:
|
172
|
-
|
173
|
-
|
174
|
-
|
165
|
+
model (Model): A pre-initialized YOLO model to be used for training.
|
166
|
+
iterations (int): The number of generations to run the evolution for.
|
167
|
+
cleanup (bool): Whether to delete iteration weights to reduce storage space used during tuning.
|
175
168
|
|
176
169
|
Note:
|
177
|
-
|
178
|
-
|
170
|
+
The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores.
|
171
|
+
Ensure this path is set correctly in the Tuner instance.
|
179
172
|
"""
|
180
173
|
t0 = time.time()
|
181
174
|
best_save_dir, best_metrics = None, None
|