ultralytics 8.3.89__py3-none-any.whl → 8.3.90__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +2 -2
- tests/test_cli.py +13 -11
- tests/test_cuda.py +10 -1
- tests/test_integrations.py +1 -5
- tests/test_python.py +16 -16
- tests/test_solutions.py +9 -9
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +3 -1
- ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
- ultralytics/cfg/models/11/yolo11.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8.yaml +5 -5
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
- ultralytics/data/annotator.py +9 -14
- ultralytics/data/base.py +118 -30
- ultralytics/data/build.py +63 -24
- ultralytics/data/converter.py +5 -5
- ultralytics/data/dataset.py +207 -53
- ultralytics/data/loaders.py +1 -0
- ultralytics/data/split_dota.py +39 -12
- ultralytics/data/utils.py +13 -19
- ultralytics/engine/exporter.py +19 -17
- ultralytics/engine/model.py +67 -88
- ultralytics/engine/predictor.py +106 -21
- ultralytics/engine/trainer.py +32 -23
- ultralytics/engine/tuner.py +21 -18
- ultralytics/engine/validator.py +75 -41
- ultralytics/hub/__init__.py +12 -13
- ultralytics/hub/auth.py +9 -12
- ultralytics/hub/session.py +76 -21
- ultralytics/hub/utils.py +19 -17
- ultralytics/models/fastsam/model.py +20 -11
- ultralytics/models/fastsam/predict.py +36 -16
- ultralytics/models/fastsam/utils.py +5 -5
- ultralytics/models/fastsam/val.py +6 -6
- ultralytics/models/nas/model.py +22 -11
- ultralytics/models/nas/predict.py +9 -4
- ultralytics/models/nas/val.py +5 -5
- ultralytics/models/rtdetr/model.py +20 -11
- ultralytics/models/rtdetr/predict.py +18 -15
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +42 -6
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +50 -4
- ultralytics/models/sam/model.py +8 -14
- ultralytics/models/sam/modules/decoders.py +18 -21
- ultralytics/models/sam/modules/encoders.py +25 -46
- ultralytics/models/sam/modules/memory_attention.py +19 -15
- ultralytics/models/sam/modules/sam.py +18 -25
- ultralytics/models/sam/modules/tiny_encoder.py +19 -29
- ultralytics/models/sam/modules/transformer.py +35 -57
- ultralytics/models/sam/modules/utils.py +15 -15
- ultralytics/models/sam/predict.py +0 -3
- ultralytics/models/utils/loss.py +87 -36
- ultralytics/models/utils/ops.py +26 -31
- ultralytics/models/yolo/classify/predict.py +24 -3
- ultralytics/models/yolo/classify/train.py +77 -10
- ultralytics/models/yolo/classify/val.py +40 -15
- ultralytics/models/yolo/detect/predict.py +23 -10
- ultralytics/models/yolo/detect/train.py +85 -15
- ultralytics/models/yolo/detect/val.py +145 -21
- ultralytics/models/yolo/model.py +1 -2
- ultralytics/models/yolo/obb/predict.py +12 -4
- ultralytics/models/yolo/obb/train.py +7 -0
- ultralytics/models/yolo/obb/val.py +25 -7
- ultralytics/models/yolo/pose/predict.py +22 -6
- ultralytics/models/yolo/pose/train.py +17 -1
- ultralytics/models/yolo/pose/val.py +46 -21
- ultralytics/models/yolo/segment/predict.py +22 -8
- ultralytics/models/yolo/segment/train.py +6 -0
- ultralytics/models/yolo/segment/val.py +100 -14
- ultralytics/models/yolo/world/train.py +38 -8
- ultralytics/models/yolo/world/train_world.py +39 -10
- ultralytics/nn/autobackend.py +28 -14
- ultralytics/nn/modules/__init__.py +3 -0
- ultralytics/nn/modules/activation.py +12 -3
- ultralytics/nn/modules/block.py +587 -84
- ultralytics/nn/modules/conv.py +418 -54
- ultralytics/nn/modules/head.py +3 -4
- ultralytics/nn/modules/transformer.py +320 -34
- ultralytics/nn/modules/utils.py +17 -3
- ultralytics/nn/tasks.py +221 -69
- ultralytics/solutions/ai_gym.py +2 -2
- ultralytics/solutions/analytics.py +4 -4
- ultralytics/solutions/heatmap.py +4 -4
- ultralytics/solutions/instance_segmentation.py +10 -4
- ultralytics/solutions/object_blurrer.py +2 -2
- ultralytics/solutions/object_counter.py +2 -2
- ultralytics/solutions/object_cropper.py +2 -2
- ultralytics/solutions/parking_management.py +9 -9
- ultralytics/solutions/queue_management.py +1 -1
- ultralytics/solutions/region_counter.py +2 -2
- ultralytics/solutions/security_alarm.py +7 -7
- ultralytics/solutions/solutions.py +7 -4
- ultralytics/solutions/speed_estimation.py +2 -2
- ultralytics/solutions/streamlit_inference.py +6 -6
- ultralytics/solutions/trackzone.py +9 -2
- ultralytics/solutions/vision_eye.py +4 -4
- ultralytics/trackers/basetrack.py +1 -1
- ultralytics/trackers/bot_sort.py +23 -22
- ultralytics/trackers/byte_tracker.py +4 -4
- ultralytics/trackers/track.py +2 -1
- ultralytics/trackers/utils/gmc.py +26 -27
- ultralytics/trackers/utils/kalman_filter.py +31 -29
- ultralytics/trackers/utils/matching.py +7 -7
- ultralytics/utils/__init__.py +32 -27
- ultralytics/utils/autobatch.py +5 -5
- ultralytics/utils/benchmarks.py +111 -18
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +11 -11
- ultralytics/utils/callbacks/comet.py +35 -22
- ultralytics/utils/callbacks/dvc.py +11 -10
- ultralytics/utils/callbacks/hub.py +8 -8
- ultralytics/utils/callbacks/mlflow.py +1 -1
- ultralytics/utils/callbacks/neptune.py +12 -10
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +6 -6
- ultralytics/utils/callbacks/wb.py +16 -16
- ultralytics/utils/checks.py +116 -35
- ultralytics/utils/dist.py +15 -2
- ultralytics/utils/downloads.py +13 -9
- ultralytics/utils/files.py +12 -13
- ultralytics/utils/instance.py +112 -45
- ultralytics/utils/loss.py +28 -33
- ultralytics/utils/metrics.py +246 -181
- ultralytics/utils/ops.py +61 -53
- ultralytics/utils/patches.py +8 -6
- ultralytics/utils/plotting.py +64 -45
- ultralytics/utils/tal.py +88 -57
- ultralytics/utils/torch_utils.py +181 -33
- ultralytics/utils/triton.py +13 -3
- ultralytics/utils/tuner.py +8 -16
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/METADATA +1 -1
- ultralytics-8.3.90.dist-info/RECORD +250 -0
- ultralytics-8.3.89.dist-info/RECORD +0 -250
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/LICENSE +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/top_level.txt +0 -0
ultralytics/engine/validator.py
CHANGED
@@ -41,43 +41,66 @@ from ultralytics.utils.torch_utils import de_parallel, select_device, smart_infe
|
|
41
41
|
|
42
42
|
class BaseValidator:
|
43
43
|
"""
|
44
|
-
BaseValidator.
|
45
|
-
|
46
44
|
A base class for creating validators.
|
47
45
|
|
46
|
+
This class provides the foundation for validation processes, including model evaluation, metric computation, and
|
47
|
+
result visualization.
|
48
|
+
|
48
49
|
Attributes:
|
49
50
|
args (SimpleNamespace): Configuration for the validator.
|
50
51
|
dataloader (DataLoader): Dataloader to use for validation.
|
51
52
|
pbar (tqdm): Progress bar to update during validation.
|
52
53
|
model (nn.Module): Model to validate.
|
53
|
-
data (
|
54
|
+
data (Dict): Data dictionary containing dataset information.
|
54
55
|
device (torch.device): Device to use for validation.
|
55
56
|
batch_i (int): Current batch index.
|
56
57
|
training (bool): Whether the model is in training mode.
|
57
|
-
names (
|
58
|
-
seen:
|
59
|
-
stats:
|
60
|
-
confusion_matrix:
|
61
|
-
nc: Number of classes.
|
62
|
-
iouv
|
63
|
-
jdict (
|
64
|
-
speed (
|
65
|
-
|
58
|
+
names (Dict): Class names mapping.
|
59
|
+
seen (int): Number of images seen so far during validation.
|
60
|
+
stats (Dict): Statistics collected during validation.
|
61
|
+
confusion_matrix: Confusion matrix for classification evaluation.
|
62
|
+
nc (int): Number of classes.
|
63
|
+
iouv (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
|
64
|
+
jdict (List): List to store JSON validation results.
|
65
|
+
speed (Dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
|
66
|
+
batch processing times in milliseconds.
|
66
67
|
save_dir (Path): Directory to save results.
|
67
|
-
plots (
|
68
|
-
callbacks (
|
68
|
+
plots (Dict): Dictionary to store plots for visualization.
|
69
|
+
callbacks (Dict): Dictionary to store various callback functions.
|
70
|
+
|
71
|
+
Methods:
|
72
|
+
__call__: Execute validation process, running inference on dataloader and computing performance metrics.
|
73
|
+
match_predictions: Match predictions to ground truth objects using IoU.
|
74
|
+
add_callback: Append the given callback to the specified event.
|
75
|
+
run_callbacks: Run all callbacks associated with a specified event.
|
76
|
+
get_dataloader: Get data loader from dataset path and batch size.
|
77
|
+
build_dataset: Build dataset from image path.
|
78
|
+
preprocess: Preprocess an input batch.
|
79
|
+
postprocess: Postprocess the predictions.
|
80
|
+
init_metrics: Initialize performance metrics for the YOLO model.
|
81
|
+
update_metrics: Update metrics based on predictions and batch.
|
82
|
+
finalize_metrics: Finalize and return all metrics.
|
83
|
+
get_stats: Return statistics about the model's performance.
|
84
|
+
check_stats: Check statistics.
|
85
|
+
print_results: Print the results of the model's predictions.
|
86
|
+
get_desc: Get description of the YOLO model.
|
87
|
+
on_plot: Register plots (e.g. to be consumed in callbacks).
|
88
|
+
plot_val_samples: Plot validation samples during training.
|
89
|
+
plot_predictions: Plot YOLO model predictions on batch images.
|
90
|
+
pred_to_json: Convert predictions to JSON format.
|
91
|
+
eval_json: Evaluate and return JSON format of prediction statistics.
|
69
92
|
"""
|
70
93
|
|
71
94
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
72
95
|
"""
|
73
|
-
|
96
|
+
Initialize a BaseValidator instance.
|
74
97
|
|
75
98
|
Args:
|
76
|
-
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
|
99
|
+
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
77
100
|
save_dir (Path, optional): Directory to save results.
|
78
|
-
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
79
|
-
args (SimpleNamespace): Configuration for the validator.
|
80
|
-
_callbacks (
|
101
|
+
pbar (tqdm.tqdm, optional): Progress bar for displaying progress.
|
102
|
+
args (SimpleNamespace, optional): Configuration for the validator.
|
103
|
+
_callbacks (Dict, optional): Dictionary to store various callback functions.
|
81
104
|
"""
|
82
105
|
self.args = get_cfg(overrides=args)
|
83
106
|
self.dataloader = dataloader
|
@@ -107,13 +130,22 @@ class BaseValidator:
|
|
107
130
|
|
108
131
|
@smart_inference_mode()
|
109
132
|
def __call__(self, trainer=None, model=None):
|
110
|
-
"""
|
133
|
+
"""
|
134
|
+
Execute validation process, running inference on dataloader and computing performance metrics.
|
135
|
+
|
136
|
+
Args:
|
137
|
+
trainer (object, optional): Trainer object that contains the model to validate.
|
138
|
+
model (nn.Module, optional): Model to validate if not using a trainer.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
stats (dict): Dictionary containing validation statistics.
|
142
|
+
"""
|
111
143
|
self.training = trainer is not None
|
112
144
|
augment = self.args.augment and (not self.training)
|
113
145
|
if self.training:
|
114
146
|
self.device = trainer.device
|
115
147
|
self.data = trainer.data
|
116
|
-
#
|
148
|
+
# Force FP16 val during training
|
117
149
|
self.args.half = self.device.type != "cpu" and trainer.amp
|
118
150
|
model = trainer.ema.ema or trainer.model
|
119
151
|
model = model.half() if self.args.half else model.float()
|
@@ -221,18 +253,20 @@ class BaseValidator:
|
|
221
253
|
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
222
254
|
return stats
|
223
255
|
|
224
|
-
def match_predictions(
|
256
|
+
def match_predictions(
|
257
|
+
self, pred_classes: torch.Tensor, true_classes: torch.Tensor, iou: torch.Tensor, use_scipy: bool = False
|
258
|
+
) -> torch.Tensor:
|
225
259
|
"""
|
226
|
-
|
260
|
+
Match predictions to ground truth objects using IoU.
|
227
261
|
|
228
262
|
Args:
|
229
|
-
pred_classes (torch.Tensor): Predicted class indices of shape(N,).
|
230
|
-
true_classes (torch.Tensor): Target class indices of shape(M,).
|
231
|
-
iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground
|
263
|
+
pred_classes (torch.Tensor): Predicted class indices of shape (N,).
|
264
|
+
true_classes (torch.Tensor): Target class indices of shape (M,).
|
265
|
+
iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground truth.
|
232
266
|
use_scipy (bool): Whether to use scipy for matching (more precise).
|
233
267
|
|
234
268
|
Returns:
|
235
|
-
(torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
|
269
|
+
(torch.Tensor): Correct tensor of shape (N, 10) for 10 IoU thresholds.
|
236
270
|
"""
|
237
271
|
# Dx10 matrix, where D - detections, 10 - IoU thresholds
|
238
272
|
correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
|
@@ -264,11 +298,11 @@ class BaseValidator:
|
|
264
298
|
return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
|
265
299
|
|
266
300
|
def add_callback(self, event: str, callback):
|
267
|
-
"""
|
301
|
+
"""Append the given callback to the specified event."""
|
268
302
|
self.callbacks[event].append(callback)
|
269
303
|
|
270
304
|
def run_callbacks(self, event: str):
|
271
|
-
"""
|
305
|
+
"""Run all callbacks associated with a specified event."""
|
272
306
|
for callback in self.callbacks.get(event, []):
|
273
307
|
callback(self)
|
274
308
|
|
@@ -277,15 +311,15 @@ class BaseValidator:
|
|
277
311
|
raise NotImplementedError("get_dataloader function not implemented for this validator")
|
278
312
|
|
279
313
|
def build_dataset(self, img_path):
|
280
|
-
"""Build dataset."""
|
314
|
+
"""Build dataset from image path."""
|
281
315
|
raise NotImplementedError("build_dataset function not implemented in validator")
|
282
316
|
|
283
317
|
def preprocess(self, batch):
|
284
|
-
"""
|
318
|
+
"""Preprocess an input batch."""
|
285
319
|
return batch
|
286
320
|
|
287
321
|
def postprocess(self, preds):
|
288
|
-
"""
|
322
|
+
"""Postprocess the predictions."""
|
289
323
|
return preds
|
290
324
|
|
291
325
|
def init_metrics(self, model):
|
@@ -293,23 +327,23 @@ class BaseValidator:
|
|
293
327
|
pass
|
294
328
|
|
295
329
|
def update_metrics(self, preds, batch):
|
296
|
-
"""
|
330
|
+
"""Update metrics based on predictions and batch."""
|
297
331
|
pass
|
298
332
|
|
299
333
|
def finalize_metrics(self, *args, **kwargs):
|
300
|
-
"""
|
334
|
+
"""Finalize and return all metrics."""
|
301
335
|
pass
|
302
336
|
|
303
337
|
def get_stats(self):
|
304
|
-
"""
|
338
|
+
"""Return statistics about the model's performance."""
|
305
339
|
return {}
|
306
340
|
|
307
341
|
def check_stats(self, stats):
|
308
|
-
"""
|
342
|
+
"""Check statistics."""
|
309
343
|
pass
|
310
344
|
|
311
345
|
def print_results(self):
|
312
|
-
"""
|
346
|
+
"""Print the results of the model's predictions."""
|
313
347
|
pass
|
314
348
|
|
315
349
|
def get_desc(self):
|
@@ -318,20 +352,20 @@ class BaseValidator:
|
|
318
352
|
|
319
353
|
@property
|
320
354
|
def metric_keys(self):
|
321
|
-
"""
|
355
|
+
"""Return the metric keys used in YOLO training/validation."""
|
322
356
|
return []
|
323
357
|
|
324
358
|
def on_plot(self, name, data=None):
|
325
|
-
"""
|
359
|
+
"""Register plots (e.g. to be consumed in callbacks)."""
|
326
360
|
self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
|
327
361
|
|
328
362
|
# TODO: may need to put these following functions into callback
|
329
363
|
def plot_val_samples(self, batch, ni):
|
330
|
-
"""
|
364
|
+
"""Plot validation samples during training."""
|
331
365
|
pass
|
332
366
|
|
333
367
|
def plot_predictions(self, batch, preds, ni):
|
334
|
-
"""
|
368
|
+
"""Plot YOLO model predictions on batch images."""
|
335
369
|
pass
|
336
370
|
|
337
371
|
def pred_to_json(self, preds, batch):
|
ultralytics/hub/__init__.py
CHANGED
@@ -23,7 +23,7 @@ __all__ = (
|
|
23
23
|
)
|
24
24
|
|
25
25
|
|
26
|
-
def login(api_key: str = None, save=True) -> bool:
|
26
|
+
def login(api_key: str = None, save: bool = True) -> bool:
|
27
27
|
"""
|
28
28
|
Log in to the Ultralytics HUB API using the provided API key.
|
29
29
|
|
@@ -31,8 +31,8 @@ def login(api_key: str = None, save=True) -> bool:
|
|
31
31
|
environment variable if successfully authenticated.
|
32
32
|
|
33
33
|
Args:
|
34
|
-
api_key (str, optional): API key to use for authentication.
|
35
|
-
|
34
|
+
api_key (str, optional): API key to use for authentication. If not provided, it will be retrieved from SETTINGS
|
35
|
+
or HUB_API_KEY environment variable.
|
36
36
|
save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful.
|
37
37
|
|
38
38
|
Returns:
|
@@ -79,7 +79,7 @@ def logout():
|
|
79
79
|
LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo login'.")
|
80
80
|
|
81
81
|
|
82
|
-
def reset_model(model_id=""):
|
82
|
+
def reset_model(model_id: str = ""):
|
83
83
|
"""Reset a trained model to an untrained state."""
|
84
84
|
r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key})
|
85
85
|
if r.status_code == 200:
|
@@ -95,8 +95,8 @@ def export_fmts_hub():
|
|
95
95
|
return list(export_formats()["Argument"][1:]) + ["ultralytics_tflite", "ultralytics_coreml"]
|
96
96
|
|
97
97
|
|
98
|
-
def export_model(model_id="", format="torchscript"):
|
99
|
-
"""Export a model to
|
98
|
+
def export_model(model_id: str = "", format: str = "torchscript"):
|
99
|
+
"""Export a model to the specified format."""
|
100
100
|
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
101
101
|
r = requests.post(
|
102
102
|
f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key}
|
@@ -105,7 +105,7 @@ def export_model(model_id="", format="torchscript"):
|
|
105
105
|
LOGGER.info(f"{PREFIX}{format} export started ✅")
|
106
106
|
|
107
107
|
|
108
|
-
def get_export(model_id="", format="torchscript"):
|
108
|
+
def get_export(model_id: str = "", format: str = "torchscript"):
|
109
109
|
"""Get an exported model dictionary with download URL."""
|
110
110
|
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
111
111
|
r = requests.post(
|
@@ -119,17 +119,12 @@ def get_export(model_id="", format="torchscript"):
|
|
119
119
|
|
120
120
|
def check_dataset(path: str, task: str) -> None:
|
121
121
|
"""
|
122
|
-
|
123
|
-
to the HUB. Usage examples are given below.
|
122
|
+
Check HUB dataset Zip file for errors before upload.
|
124
123
|
|
125
124
|
Args:
|
126
125
|
path (str): Path to data.zip (with data.yaml inside data.zip).
|
127
126
|
task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify', 'obb'.
|
128
127
|
|
129
|
-
Note:
|
130
|
-
Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
|
131
|
-
i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
|
132
|
-
|
133
128
|
Examples:
|
134
129
|
>>> from ultralytics.hub import check_dataset
|
135
130
|
>>> check_dataset("path/to/coco8.zip", task="detect") # detect dataset
|
@@ -137,6 +132,10 @@ def check_dataset(path: str, task: str) -> None:
|
|
137
132
|
>>> check_dataset("path/to/coco8-pose.zip", task="pose") # pose dataset
|
138
133
|
>>> check_dataset("path/to/dota8.zip", task="obb") # OBB dataset
|
139
134
|
>>> check_dataset("path/to/imagenet10.zip", task="classify") # classification dataset
|
135
|
+
|
136
|
+
Note:
|
137
|
+
Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
|
138
|
+
i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
|
140
139
|
"""
|
141
140
|
HUBDatasetStats(path=path, task=task).get_json()
|
142
141
|
LOGGER.info(f"Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.")
|
ultralytics/hub/auth.py
CHANGED
@@ -18,14 +18,14 @@ class Auth:
|
|
18
18
|
3. Prompting the user to enter an API key.
|
19
19
|
|
20
20
|
Attributes:
|
21
|
-
id_token (str
|
22
|
-
api_key (str
|
21
|
+
id_token (str | bool): Token used for identity verification, initialized as False.
|
22
|
+
api_key (str | bool): API key for authentication, initialized as False.
|
23
23
|
model_key (bool): Placeholder for model key, initialized as False.
|
24
24
|
"""
|
25
25
|
|
26
26
|
id_token = api_key = model_key = False
|
27
27
|
|
28
|
-
def __init__(self, api_key="", verbose=False):
|
28
|
+
def __init__(self, api_key: str = "", verbose: bool = False):
|
29
29
|
"""
|
30
30
|
Initialize Auth class and authenticate user.
|
31
31
|
|
@@ -70,12 +70,8 @@ class Auth:
|
|
70
70
|
elif verbose:
|
71
71
|
LOGGER.info(f"{PREFIX}Get API key from {API_KEY_URL} and then run 'yolo login API_KEY'")
|
72
72
|
|
73
|
-
def request_api_key(self, max_attempts=3):
|
74
|
-
"""
|
75
|
-
Prompt the user to input their API key.
|
76
|
-
|
77
|
-
Returns the model ID.
|
78
|
-
"""
|
73
|
+
def request_api_key(self, max_attempts: int = 3) -> bool:
|
74
|
+
"""Prompt the user to input their API key."""
|
79
75
|
import getpass
|
80
76
|
|
81
77
|
for attempts in range(max_attempts):
|
@@ -107,8 +103,9 @@ class Auth:
|
|
107
103
|
|
108
104
|
def auth_with_cookies(self) -> bool:
|
109
105
|
"""
|
110
|
-
Attempt to fetch authentication via cookies and set id_token.
|
111
|
-
|
106
|
+
Attempt to fetch authentication via cookies and set id_token.
|
107
|
+
|
108
|
+
User must be logged in to HUB and running in a supported browser.
|
112
109
|
|
113
110
|
Returns:
|
114
111
|
(bool): True if authentication is successful, False otherwise.
|
@@ -131,7 +128,7 @@ class Auth:
|
|
131
128
|
Get the authentication header for making API requests.
|
132
129
|
|
133
130
|
Returns:
|
134
|
-
(dict): The authentication header if id_token or API key is set, None otherwise.
|
131
|
+
(dict | None): The authentication header if id_token or API key is set, None otherwise.
|
135
132
|
"""
|
136
133
|
if self.id_token:
|
137
134
|
return {"authorization": f"Bearer {self.id_token}"}
|
ultralytics/hub/session.py
CHANGED
@@ -20,13 +20,25 @@ class HUBTrainingSession:
|
|
20
20
|
"""
|
21
21
|
HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
|
22
22
|
|
23
|
+
This class encapsulates the functionality for interacting with Ultralytics HUB during model training, including
|
24
|
+
model creation, metrics tracking, and checkpoint uploading.
|
25
|
+
|
23
26
|
Attributes:
|
24
27
|
model_id (str): Identifier for the YOLO model being trained.
|
25
28
|
model_url (str): URL for the model in Ultralytics HUB.
|
26
|
-
rate_limits (
|
27
|
-
timers (
|
28
|
-
metrics_queue (
|
29
|
-
|
29
|
+
rate_limits (Dict): Rate limits for different API calls (in seconds).
|
30
|
+
timers (Dict): Timers for rate limiting.
|
31
|
+
metrics_queue (Dict): Queue for the model's metrics.
|
32
|
+
metrics_upload_failed_queue (Dict): Queue for metrics that failed to upload.
|
33
|
+
model (Dict): Model data fetched from Ultralytics HUB.
|
34
|
+
model_file (str): Path to the model file.
|
35
|
+
train_args (Dict): Arguments for training the model.
|
36
|
+
client (HUBClient): Client for interacting with Ultralytics HUB.
|
37
|
+
filename (str): Filename of the model.
|
38
|
+
|
39
|
+
Examples:
|
40
|
+
>>> session = HUBTrainingSession("https://hub.ultralytics.com/models/example-model")
|
41
|
+
>>> session.upload_metrics()
|
30
42
|
"""
|
31
43
|
|
32
44
|
def __init__(self, identifier):
|
@@ -78,7 +90,16 @@ class HUBTrainingSession:
|
|
78
90
|
|
79
91
|
@classmethod
|
80
92
|
def create_session(cls, identifier, args=None):
|
81
|
-
"""
|
93
|
+
"""
|
94
|
+
Create an authenticated HUBTrainingSession or return None.
|
95
|
+
|
96
|
+
Args:
|
97
|
+
identifier (str): Model identifier used to initialize the HUB training session.
|
98
|
+
args (Dict, optional): Arguments for creating a new model if identifier is not a HUB model URL.
|
99
|
+
|
100
|
+
Returns:
|
101
|
+
(HUBTrainingSession | None): An authenticated session or None if creation fails.
|
102
|
+
"""
|
82
103
|
try:
|
83
104
|
session = cls(identifier)
|
84
105
|
if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # not a HUB model URL
|
@@ -90,7 +111,15 @@ class HUBTrainingSession:
|
|
90
111
|
return None
|
91
112
|
|
92
113
|
def load_model(self, model_id):
|
93
|
-
"""
|
114
|
+
"""
|
115
|
+
Load an existing model from Ultralytics HUB using the provided model identifier.
|
116
|
+
|
117
|
+
Args:
|
118
|
+
model_id (str): The identifier of the model to load.
|
119
|
+
|
120
|
+
Raises:
|
121
|
+
ValueError: If the specified HUB model does not exist.
|
122
|
+
"""
|
94
123
|
self.model = self.client.model(model_id)
|
95
124
|
if not self.model.data: # then model does not exist
|
96
125
|
raise ValueError(emojis("❌ The specified HUB model does not exist")) # TODO: improve error handling
|
@@ -108,7 +137,15 @@ class HUBTrainingSession:
|
|
108
137
|
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
109
138
|
|
110
139
|
def create_model(self, model_args):
|
111
|
-
"""
|
140
|
+
"""
|
141
|
+
Initialize a HUB training session with the specified model arguments.
|
142
|
+
|
143
|
+
Args:
|
144
|
+
model_args (Dict): Arguments for creating the model, including batch size, epochs, image size, etc.
|
145
|
+
|
146
|
+
Returns:
|
147
|
+
(None): If the model could not be created.
|
148
|
+
"""
|
112
149
|
payload = {
|
113
150
|
"config": {
|
114
151
|
"batchSize": model_args.get("batch", -1),
|
@@ -146,7 +183,7 @@ class HUBTrainingSession:
|
|
146
183
|
@staticmethod
|
147
184
|
def _parse_identifier(identifier):
|
148
185
|
"""
|
149
|
-
|
186
|
+
Parse the given identifier to determine the type and extract relevant components.
|
150
187
|
|
151
188
|
The method supports different identifier formats:
|
152
189
|
- A HUB model URL https://hub.ultralytics.com/models/MODEL
|
@@ -176,7 +213,7 @@ class HUBTrainingSession:
|
|
176
213
|
|
177
214
|
def _set_train_args(self):
|
178
215
|
"""
|
179
|
-
|
216
|
+
Initialize training arguments and create a model entry on the Ultralytics HUB.
|
180
217
|
|
181
218
|
This method sets up training arguments based on the model's state and updates them with any additional
|
182
219
|
arguments provided. It handles different states of the model, such as whether it's resumable, pretrained,
|
@@ -218,10 +255,26 @@ class HUBTrainingSession:
|
|
218
255
|
*args,
|
219
256
|
**kwargs,
|
220
257
|
):
|
221
|
-
"""
|
258
|
+
"""
|
259
|
+
Attempt to execute `request_func` with retries, timeout handling, optional threading, and progress tracking.
|
260
|
+
|
261
|
+
Args:
|
262
|
+
request_func (callable): The function to execute.
|
263
|
+
retry (int): Number of retry attempts.
|
264
|
+
timeout (int): Maximum time to wait for the request to complete.
|
265
|
+
thread (bool): Whether to run the request in a separate thread.
|
266
|
+
verbose (bool): Whether to log detailed messages.
|
267
|
+
progress_total (int, optional): Total size for progress tracking.
|
268
|
+
stream_response (bool, optional): Whether to stream the response.
|
269
|
+
*args (Any): Additional positional arguments for request_func.
|
270
|
+
**kwargs (Any): Additional keyword arguments for request_func.
|
271
|
+
|
272
|
+
Returns:
|
273
|
+
(requests.Response | None): The response object if thread=False, otherwise None.
|
274
|
+
"""
|
222
275
|
|
223
276
|
def retry_request():
|
224
|
-
"""
|
277
|
+
"""Attempt to call `request_func` with retries, timeout, and optional threading."""
|
225
278
|
t0 = time.time() # Record the start time for the timeout
|
226
279
|
response = None
|
227
280
|
for i in range(retry + 1):
|
@@ -274,7 +327,15 @@ class HUBTrainingSession:
|
|
274
327
|
|
275
328
|
@staticmethod
|
276
329
|
def _should_retry(status_code):
|
277
|
-
"""
|
330
|
+
"""
|
331
|
+
Determine if a request should be retried based on the HTTP status code.
|
332
|
+
|
333
|
+
Args:
|
334
|
+
status_code (int): The HTTP status code from the response.
|
335
|
+
|
336
|
+
Returns:
|
337
|
+
(bool): True if the request should be retried, False otherwise.
|
338
|
+
"""
|
278
339
|
retry_codes = {
|
279
340
|
HTTPStatus.REQUEST_TIMEOUT,
|
280
341
|
HTTPStatus.BAD_GATEWAY,
|
@@ -287,9 +348,9 @@ class HUBTrainingSession:
|
|
287
348
|
Generate a retry message based on the response status code.
|
288
349
|
|
289
350
|
Args:
|
290
|
-
response: The HTTP response object.
|
291
|
-
retry: The number of retry attempts allowed.
|
292
|
-
timeout: The maximum timeout duration.
|
351
|
+
response (requests.Response): The HTTP response object.
|
352
|
+
retry (int): The number of retry attempts allowed.
|
353
|
+
timeout (int): The maximum timeout duration.
|
293
354
|
|
294
355
|
Returns:
|
295
356
|
(str): The retry message.
|
@@ -367,9 +428,6 @@ class HUBTrainingSession:
|
|
367
428
|
Args:
|
368
429
|
content_length (int): The total size of the content to be downloaded in bytes.
|
369
430
|
response (requests.Response): The response object from the file download request.
|
370
|
-
|
371
|
-
Returns:
|
372
|
-
None
|
373
431
|
"""
|
374
432
|
with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:
|
375
433
|
for data in response.iter_content(chunk_size=1024):
|
@@ -382,9 +440,6 @@ class HUBTrainingSession:
|
|
382
440
|
|
383
441
|
Args:
|
384
442
|
response (requests.Response): The response object from the file download request.
|
385
|
-
|
386
|
-
Returns:
|
387
|
-
None
|
388
443
|
"""
|
389
444
|
for _ in response.iter_content(chunk_size=1024):
|
390
445
|
pass # Do nothing with data chunks
|
ultralytics/hub/utils.py
CHANGED
@@ -43,7 +43,7 @@ def request_with_credentials(url: str) -> any:
|
|
43
43
|
url (str): The URL to make the request to.
|
44
44
|
|
45
45
|
Returns:
|
46
|
-
(
|
46
|
+
(Any): The response data from the AJAX request.
|
47
47
|
|
48
48
|
Raises:
|
49
49
|
OSError: If the function is not run in a Google Colab environment.
|
@@ -83,14 +83,14 @@ def requests_with_progress(method, url, **kwargs):
|
|
83
83
|
Args:
|
84
84
|
method (str): The HTTP method to use (e.g. 'GET', 'POST').
|
85
85
|
url (str): The URL to send the request to.
|
86
|
-
**kwargs (
|
86
|
+
**kwargs (Any): Additional keyword arguments to pass to the underlying `requests.request` function.
|
87
87
|
|
88
88
|
Returns:
|
89
89
|
(requests.Response): The response object from the HTTP request.
|
90
90
|
|
91
|
-
|
91
|
+
Notes:
|
92
92
|
- If 'progress' is set to True, the progress bar will display the download progress for responses with a known
|
93
|
-
|
93
|
+
content length.
|
94
94
|
- If 'progress' is a number then progress bar will display assuming content length = progress.
|
95
95
|
"""
|
96
96
|
progress = kwargs.pop("progress", False)
|
@@ -110,18 +110,18 @@ def requests_with_progress(method, url, **kwargs):
|
|
110
110
|
|
111
111
|
def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbose=True, progress=False, **kwargs):
|
112
112
|
"""
|
113
|
-
|
113
|
+
Make an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
|
114
114
|
|
115
115
|
Args:
|
116
116
|
method (str): The HTTP method to use for the request. Choices are 'post' and 'get'.
|
117
117
|
url (str): The URL to make the request to.
|
118
|
-
retry (int, optional): Number of retries to attempt before giving up.
|
119
|
-
timeout (int, optional): Timeout in seconds after which the function will give up retrying.
|
120
|
-
thread (bool, optional): Whether to execute the request in a separate daemon thread.
|
121
|
-
code (int, optional): An identifier for the request, used for logging purposes.
|
122
|
-
verbose (bool, optional): A flag to determine whether to print out to console or not.
|
123
|
-
progress (bool, optional): Whether to show a progress bar during the request.
|
124
|
-
**kwargs (
|
118
|
+
retry (int, optional): Number of retries to attempt before giving up.
|
119
|
+
timeout (int, optional): Timeout in seconds after which the function will give up retrying.
|
120
|
+
thread (bool, optional): Whether to execute the request in a separate daemon thread.
|
121
|
+
code (int, optional): An identifier for the request, used for logging purposes.
|
122
|
+
verbose (bool, optional): A flag to determine whether to print out to console or not.
|
123
|
+
progress (bool, optional): Whether to show a progress bar during the request.
|
124
|
+
**kwargs (Any): Keyword arguments to be passed to the requests function specified in method.
|
125
125
|
|
126
126
|
Returns:
|
127
127
|
(requests.Response): The HTTP response object. If the request is executed in a separate thread, returns None.
|
@@ -169,20 +169,22 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos
|
|
169
169
|
|
170
170
|
class Events:
|
171
171
|
"""
|
172
|
-
A class for collecting anonymous event analytics.
|
173
|
-
|
172
|
+
A class for collecting anonymous event analytics.
|
173
|
+
|
174
|
+
Event analytics are enabled when sync=True in settings and disabled when sync=False. Run 'yolo settings' to see and
|
175
|
+
update settings.
|
174
176
|
|
175
177
|
Attributes:
|
176
178
|
url (str): The URL to send anonymous events.
|
177
179
|
rate_limit (float): The rate limit in seconds for sending events.
|
178
|
-
metadata (
|
180
|
+
metadata (Dict): A dictionary containing metadata about the environment.
|
179
181
|
enabled (bool): A flag to enable or disable Events based on certain conditions.
|
180
182
|
"""
|
181
183
|
|
182
184
|
url = "https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw"
|
183
185
|
|
184
186
|
def __init__(self):
|
185
|
-
"""
|
187
|
+
"""Initialize the Events object with default values for events, rate_limit, and metadata."""
|
186
188
|
self.events = [] # events list
|
187
189
|
self.rate_limit = 30.0 # rate limit (seconds)
|
188
190
|
self.t = 0.0 # rate limit timer (seconds)
|
@@ -205,7 +207,7 @@ class Events:
|
|
205
207
|
|
206
208
|
def __call__(self, cfg):
|
207
209
|
"""
|
208
|
-
|
210
|
+
Attempt to add a new event to the events list and send events if the rate limit is reached.
|
209
211
|
|
210
212
|
Args:
|
211
213
|
cfg (IterableSimpleNamespace): The configuration object containing mode and task information.
|