ultralytics 8.3.142__py3-none-any.whl → 8.3.144__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +7 -24
- tests/test_cli.py +1 -1
- tests/test_cuda.py +7 -2
- tests/test_engine.py +7 -8
- tests/test_exports.py +16 -16
- tests/test_integrations.py +1 -1
- tests/test_solutions.py +12 -12
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +16 -13
- ultralytics/data/annotator.py +6 -5
- ultralytics/data/augment.py +127 -126
- ultralytics/data/base.py +54 -51
- ultralytics/data/build.py +47 -23
- ultralytics/data/converter.py +47 -43
- ultralytics/data/dataset.py +51 -50
- ultralytics/data/loaders.py +77 -44
- ultralytics/data/split.py +22 -9
- ultralytics/data/split_dota.py +63 -39
- ultralytics/data/utils.py +59 -39
- ultralytics/engine/exporter.py +79 -27
- ultralytics/engine/model.py +39 -39
- ultralytics/engine/predictor.py +37 -28
- ultralytics/engine/results.py +187 -157
- ultralytics/engine/trainer.py +36 -19
- ultralytics/engine/tuner.py +12 -9
- ultralytics/engine/validator.py +7 -9
- ultralytics/hub/__init__.py +11 -13
- ultralytics/hub/auth.py +22 -2
- ultralytics/hub/google/__init__.py +19 -19
- ultralytics/hub/session.py +37 -51
- ultralytics/hub/utils.py +19 -5
- ultralytics/models/fastsam/model.py +30 -12
- ultralytics/models/fastsam/predict.py +5 -6
- ultralytics/models/fastsam/utils.py +3 -3
- ultralytics/models/fastsam/val.py +10 -6
- ultralytics/models/nas/model.py +9 -5
- ultralytics/models/nas/predict.py +6 -6
- ultralytics/models/nas/val.py +3 -3
- ultralytics/models/rtdetr/model.py +7 -6
- ultralytics/models/rtdetr/predict.py +14 -7
- ultralytics/models/rtdetr/train.py +10 -4
- ultralytics/models/rtdetr/val.py +36 -9
- ultralytics/models/sam/amg.py +30 -12
- ultralytics/models/sam/build.py +22 -22
- ultralytics/models/sam/model.py +10 -9
- ultralytics/models/sam/modules/blocks.py +76 -80
- ultralytics/models/sam/modules/decoders.py +6 -8
- ultralytics/models/sam/modules/encoders.py +23 -26
- ultralytics/models/sam/modules/memory_attention.py +13 -1
- ultralytics/models/sam/modules/sam.py +57 -26
- ultralytics/models/sam/modules/tiny_encoder.py +232 -237
- ultralytics/models/sam/modules/transformer.py +13 -13
- ultralytics/models/sam/modules/utils.py +11 -19
- ultralytics/models/sam/predict.py +114 -101
- ultralytics/models/utils/loss.py +98 -77
- ultralytics/models/utils/ops.py +116 -67
- ultralytics/models/yolo/classify/predict.py +5 -5
- ultralytics/models/yolo/classify/train.py +32 -28
- ultralytics/models/yolo/classify/val.py +7 -8
- ultralytics/models/yolo/detect/predict.py +1 -0
- ultralytics/models/yolo/detect/train.py +15 -14
- ultralytics/models/yolo/detect/val.py +37 -36
- ultralytics/models/yolo/model.py +106 -23
- ultralytics/models/yolo/obb/predict.py +3 -4
- ultralytics/models/yolo/obb/train.py +14 -6
- ultralytics/models/yolo/obb/val.py +29 -23
- ultralytics/models/yolo/pose/predict.py +9 -8
- ultralytics/models/yolo/pose/train.py +24 -16
- ultralytics/models/yolo/pose/val.py +44 -26
- ultralytics/models/yolo/segment/predict.py +5 -5
- ultralytics/models/yolo/segment/train.py +11 -7
- ultralytics/models/yolo/segment/val.py +2 -2
- ultralytics/models/yolo/world/train.py +33 -23
- ultralytics/models/yolo/world/train_world.py +11 -3
- ultralytics/models/yolo/yoloe/predict.py +11 -11
- ultralytics/models/yolo/yoloe/train.py +73 -21
- ultralytics/models/yolo/yoloe/train_seg.py +10 -7
- ultralytics/models/yolo/yoloe/val.py +42 -18
- ultralytics/nn/autobackend.py +59 -15
- ultralytics/nn/modules/__init__.py +4 -4
- ultralytics/nn/modules/activation.py +4 -1
- ultralytics/nn/modules/block.py +178 -111
- ultralytics/nn/modules/conv.py +6 -5
- ultralytics/nn/modules/head.py +469 -121
- ultralytics/nn/modules/transformer.py +147 -58
- ultralytics/nn/tasks.py +227 -20
- ultralytics/nn/text_model.py +30 -33
- ultralytics/solutions/ai_gym.py +1 -1
- ultralytics/solutions/analytics.py +7 -4
- ultralytics/solutions/config.py +10 -10
- ultralytics/solutions/distance_calculation.py +11 -10
- ultralytics/solutions/heatmap.py +1 -1
- ultralytics/solutions/instance_segmentation.py +6 -3
- ultralytics/solutions/object_blurrer.py +3 -3
- ultralytics/solutions/object_counter.py +16 -8
- ultralytics/solutions/object_cropper.py +12 -5
- ultralytics/solutions/parking_management.py +29 -28
- ultralytics/solutions/queue_management.py +6 -6
- ultralytics/solutions/region_counter.py +10 -3
- ultralytics/solutions/security_alarm.py +3 -3
- ultralytics/solutions/similarity_search.py +85 -24
- ultralytics/solutions/solutions.py +215 -85
- ultralytics/solutions/speed_estimation.py +28 -22
- ultralytics/solutions/streamlit_inference.py +17 -12
- ultralytics/solutions/trackzone.py +4 -4
- ultralytics/trackers/basetrack.py +16 -23
- ultralytics/trackers/bot_sort.py +30 -20
- ultralytics/trackers/byte_tracker.py +70 -64
- ultralytics/trackers/track.py +4 -8
- ultralytics/trackers/utils/gmc.py +31 -58
- ultralytics/trackers/utils/kalman_filter.py +37 -37
- ultralytics/trackers/utils/matching.py +1 -1
- ultralytics/utils/__init__.py +105 -89
- ultralytics/utils/autobatch.py +16 -3
- ultralytics/utils/autodevice.py +54 -24
- ultralytics/utils/benchmarks.py +42 -28
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +9 -9
- ultralytics/utils/callbacks/comet.py +67 -25
- ultralytics/utils/callbacks/dvc.py +7 -10
- ultralytics/utils/callbacks/mlflow.py +2 -5
- ultralytics/utils/callbacks/neptune.py +7 -13
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +5 -6
- ultralytics/utils/callbacks/wb.py +14 -14
- ultralytics/utils/checks.py +14 -13
- ultralytics/utils/dist.py +5 -5
- ultralytics/utils/downloads.py +94 -67
- ultralytics/utils/errors.py +5 -5
- ultralytics/utils/export.py +61 -47
- ultralytics/utils/files.py +23 -22
- ultralytics/utils/instance.py +48 -52
- ultralytics/utils/loss.py +78 -40
- ultralytics/utils/metrics.py +186 -130
- ultralytics/utils/ops.py +186 -190
- ultralytics/utils/patches.py +15 -17
- ultralytics/utils/plotting.py +71 -27
- ultralytics/utils/tal.py +21 -15
- ultralytics/utils/torch_utils.py +53 -50
- ultralytics/utils/triton.py +5 -4
- ultralytics/utils/tuner.py +5 -5
- {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/METADATA +1 -1
- ultralytics-8.3.144.dist-info/RECORD +272 -0
- ultralytics-8.3.142.dist-info/RECORD +0 -272
- {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.142.dist-info → ultralytics-8.3.144.dist-info}/top_level.txt +0 -0
ultralytics/engine/trainer.py
CHANGED
@@ -60,6 +60,9 @@ class BaseTrainer:
|
|
60
60
|
"""
|
61
61
|
A base class for creating trainers.
|
62
62
|
|
63
|
+
This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing,
|
64
|
+
and various training utilities. It supports both single-GPU and multi-GPU distributed training.
|
65
|
+
|
63
66
|
Attributes:
|
64
67
|
args (SimpleNamespace): Configuration for the trainer.
|
65
68
|
validator (BaseValidator): Validator instance.
|
@@ -89,6 +92,19 @@ class BaseTrainer:
|
|
89
92
|
csv (Path): Path to results CSV file.
|
90
93
|
metrics (dict): Dictionary of metrics.
|
91
94
|
plots (dict): Dictionary of plots.
|
95
|
+
|
96
|
+
Methods:
|
97
|
+
train: Execute the training process.
|
98
|
+
validate: Run validation on the test set.
|
99
|
+
save_model: Save model training checkpoints.
|
100
|
+
get_dataset: Get train and validation datasets.
|
101
|
+
setup_model: Load, create, or download model.
|
102
|
+
build_optimizer: Construct an optimizer for the model.
|
103
|
+
|
104
|
+
Examples:
|
105
|
+
Initialize a trainer and start training
|
106
|
+
>>> trainer = BaseTrainer(cfg="config.yaml")
|
107
|
+
>>> trainer.train()
|
92
108
|
"""
|
93
109
|
|
94
110
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
@@ -96,14 +112,14 @@ class BaseTrainer:
|
|
96
112
|
Initialize the BaseTrainer class.
|
97
113
|
|
98
114
|
Args:
|
99
|
-
cfg (str, optional): Path to a configuration file.
|
100
|
-
overrides (dict, optional): Configuration overrides.
|
101
|
-
_callbacks (list, optional): List of callback functions.
|
115
|
+
cfg (str, optional): Path to a configuration file.
|
116
|
+
overrides (dict, optional): Configuration overrides.
|
117
|
+
_callbacks (list, optional): List of callback functions.
|
102
118
|
"""
|
103
119
|
self.args = get_cfg(cfg, overrides)
|
104
120
|
self.check_resume(overrides)
|
105
121
|
self.device = select_device(self.args.device, self.args.batch)
|
106
|
-
#
|
122
|
+
# Update "-1" devices so post-training val does not repeat search
|
107
123
|
self.args.device = os.getenv("CUDA_VISIBLE_DEVICES") if "cuda" in str(self.device) else str(self.device)
|
108
124
|
self.validator = None
|
109
125
|
self.metrics = None
|
@@ -626,7 +642,7 @@ class BaseTrainer:
|
|
626
642
|
self.ema.update(self.model)
|
627
643
|
|
628
644
|
def preprocess_batch(self, batch):
|
629
|
-
"""
|
645
|
+
"""Allow custom preprocessing model inputs and ground truths depending on task type."""
|
630
646
|
return batch
|
631
647
|
|
632
648
|
def validate(self):
|
@@ -634,7 +650,8 @@ class BaseTrainer:
|
|
634
650
|
Run validation on test set using self.validator.
|
635
651
|
|
636
652
|
Returns:
|
637
|
-
(
|
653
|
+
metrics (dict): Dictionary of validation metrics.
|
654
|
+
fitness (float): Fitness score for the validation.
|
638
655
|
"""
|
639
656
|
metrics = self.validator(self)
|
640
657
|
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
@@ -647,11 +664,11 @@ class BaseTrainer:
|
|
647
664
|
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
648
665
|
|
649
666
|
def get_validator(self):
|
650
|
-
"""
|
667
|
+
"""Return a NotImplementedError when the get_validator function is called."""
|
651
668
|
raise NotImplementedError("get_validator function not implemented in trainer")
|
652
669
|
|
653
670
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
654
|
-
"""
|
671
|
+
"""Return dataloader derived from torch.data.Dataloader."""
|
655
672
|
raise NotImplementedError("get_dataloader function not implemented in trainer")
|
656
673
|
|
657
674
|
def build_dataset(self, img_path, mode="train", batch=None):
|
@@ -660,7 +677,7 @@ class BaseTrainer:
|
|
660
677
|
|
661
678
|
def label_loss_items(self, loss_items=None, prefix="train"):
|
662
679
|
"""
|
663
|
-
|
680
|
+
Return a loss dict with labelled training loss items tensor.
|
664
681
|
|
665
682
|
Note:
|
666
683
|
This is not needed for classification but necessary for segmentation & detection
|
@@ -672,20 +689,20 @@ class BaseTrainer:
|
|
672
689
|
self.model.names = self.data["names"]
|
673
690
|
|
674
691
|
def build_targets(self, preds, targets):
|
675
|
-
"""
|
692
|
+
"""Build target tensors for training YOLO model."""
|
676
693
|
pass
|
677
694
|
|
678
695
|
def progress_string(self):
|
679
|
-
"""
|
696
|
+
"""Return a string describing training progress."""
|
680
697
|
return ""
|
681
698
|
|
682
699
|
# TODO: may need to put these following functions into callback
|
683
700
|
def plot_training_samples(self, batch, ni):
|
684
|
-
"""
|
701
|
+
"""Plot training samples during YOLO training."""
|
685
702
|
pass
|
686
703
|
|
687
704
|
def plot_training_labels(self):
|
688
|
-
"""
|
705
|
+
"""Plot training labels for YOLO model."""
|
689
706
|
pass
|
690
707
|
|
691
708
|
def save_metrics(self, metrics):
|
@@ -702,7 +719,7 @@ class BaseTrainer:
|
|
702
719
|
pass
|
703
720
|
|
704
721
|
def on_plot(self, name, data=None):
|
705
|
-
"""
|
722
|
+
"""Register plots (e.g. to be consumed in callbacks)."""
|
706
723
|
path = Path(name)
|
707
724
|
self.plots[path] = {"data": data, "timestamp": time.time()}
|
708
725
|
|
@@ -796,12 +813,12 @@ class BaseTrainer:
|
|
796
813
|
Args:
|
797
814
|
model (torch.nn.Module): The model for which to build an optimizer.
|
798
815
|
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
|
799
|
-
based on the number of iterations.
|
800
|
-
lr (float, optional): The learning rate for the optimizer.
|
801
|
-
momentum (float, optional): The momentum factor for the optimizer.
|
802
|
-
decay (float, optional): The weight decay for the optimizer.
|
816
|
+
based on the number of iterations.
|
817
|
+
lr (float, optional): The learning rate for the optimizer.
|
818
|
+
momentum (float, optional): The momentum factor for the optimizer.
|
819
|
+
decay (float, optional): The weight decay for the optimizer.
|
803
820
|
iterations (float, optional): The number of iterations, which determines the optimizer if
|
804
|
-
name is 'auto'.
|
821
|
+
name is 'auto'.
|
805
822
|
|
806
823
|
Returns:
|
807
824
|
(torch.optim.Optimizer): The constructed optimizer.
|
ultralytics/engine/tuner.py
CHANGED
@@ -18,6 +18,7 @@ import random
|
|
18
18
|
import shutil
|
19
19
|
import subprocess
|
20
20
|
import time
|
21
|
+
from typing import Dict, List, Optional
|
21
22
|
|
22
23
|
import numpy as np
|
23
24
|
import torch
|
@@ -35,7 +36,7 @@ class Tuner:
|
|
35
36
|
search space and retraining the model to evaluate their performance.
|
36
37
|
|
37
38
|
Attributes:
|
38
|
-
space (
|
39
|
+
space (Dict[str, tuple]): Hyperparameter search space containing bounds and scaling factors for mutation.
|
39
40
|
tune_dir (Path): Directory where evolution logs and results will be saved.
|
40
41
|
tune_csv (Path): Path to the CSV file where evolution logs are saved.
|
41
42
|
args (dict): Configuration arguments for the tuning process.
|
@@ -43,8 +44,8 @@ class Tuner:
|
|
43
44
|
prefix (str): Prefix string for logging messages.
|
44
45
|
|
45
46
|
Methods:
|
46
|
-
_mutate:
|
47
|
-
__call__:
|
47
|
+
_mutate: Mutate hyperparameters based on bounds and scaling factors.
|
48
|
+
__call__: Execute the hyperparameter evolution across multiple iterations.
|
48
49
|
|
49
50
|
Examples:
|
50
51
|
Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
|
@@ -58,13 +59,13 @@ class Tuner:
|
|
58
59
|
>>> model.tune(space={key1: val1, key2: val2}) # custom search space dictionary
|
59
60
|
"""
|
60
61
|
|
61
|
-
def __init__(self, args=DEFAULT_CFG, _callbacks=None):
|
62
|
+
def __init__(self, args=DEFAULT_CFG, _callbacks: Optional[List] = None):
|
62
63
|
"""
|
63
64
|
Initialize the Tuner with configurations.
|
64
65
|
|
65
66
|
Args:
|
66
67
|
args (dict): Configuration for hyperparameter evolution.
|
67
|
-
_callbacks (
|
68
|
+
_callbacks (List, optional): Callback functions to be executed during tuning.
|
68
69
|
"""
|
69
70
|
self.space = args.pop("space", None) or { # key: (min, max, gain(optional))
|
70
71
|
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
|
@@ -106,7 +107,9 @@ class Tuner:
|
|
106
107
|
f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning"
|
107
108
|
)
|
108
109
|
|
109
|
-
def _mutate(
|
110
|
+
def _mutate(
|
111
|
+
self, parent: str = "single", n: int = 5, mutation: float = 0.8, sigma: float = 0.2
|
112
|
+
) -> Dict[str, float]:
|
110
113
|
"""
|
111
114
|
Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
|
112
115
|
|
@@ -117,7 +120,7 @@ class Tuner:
|
|
117
120
|
sigma (float): Standard deviation for Gaussian random number generator.
|
118
121
|
|
119
122
|
Returns:
|
120
|
-
(
|
123
|
+
(Dict[str, float]): A dictionary containing mutated hyperparameters.
|
121
124
|
"""
|
122
125
|
if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate
|
123
126
|
# Select parent(s)
|
@@ -152,14 +155,14 @@ class Tuner:
|
|
152
155
|
|
153
156
|
return hyp
|
154
157
|
|
155
|
-
def __call__(self, model=None, iterations=10, cleanup=True):
|
158
|
+
def __call__(self, model=None, iterations: int = 10, cleanup: bool = True):
|
156
159
|
"""
|
157
160
|
Execute the hyperparameter evolution process when the Tuner instance is called.
|
158
161
|
|
159
162
|
This method iterates through the number of iterations, performing the following steps in each iteration:
|
160
163
|
|
161
164
|
1. Load the existing hyperparameters or initialize new ones.
|
162
|
-
2. Mutate the hyperparameters using the `
|
165
|
+
2. Mutate the hyperparameters using the `_mutate` method.
|
163
166
|
3. Train a YOLO model with the mutated hyperparameters.
|
164
167
|
4. Log the fitness score and mutated hyperparameters to a CSV file.
|
165
168
|
|
ultralytics/engine/validator.py
CHANGED
@@ -67,6 +67,8 @@ class BaseValidator:
|
|
67
67
|
save_dir (Path): Directory to save results.
|
68
68
|
plots (dict): Dictionary to store plots for visualization.
|
69
69
|
callbacks (dict): Dictionary to store various callback functions.
|
70
|
+
stride (int): Model stride for padding calculations.
|
71
|
+
loss (torch.Tensor): Accumulated loss during training validation.
|
70
72
|
|
71
73
|
Methods:
|
72
74
|
__call__: Execute validation process, running inference on dataloader and computing performance metrics.
|
@@ -84,7 +86,7 @@ class BaseValidator:
|
|
84
86
|
check_stats: Check statistics.
|
85
87
|
print_results: Print the results of the model's predictions.
|
86
88
|
get_desc: Get description of the YOLO model.
|
87
|
-
on_plot: Register plots
|
89
|
+
on_plot: Register plots for visualization.
|
88
90
|
plot_val_samples: Plot validation samples during training.
|
89
91
|
plot_predictions: Plot YOLO model predictions on batch images.
|
90
92
|
pred_to_json: Convert predictions to JSON format.
|
@@ -138,7 +140,7 @@ class BaseValidator:
|
|
138
140
|
model (nn.Module, optional): Model to validate if not using a trainer.
|
139
141
|
|
140
142
|
Returns:
|
141
|
-
|
143
|
+
(dict): Dictionary containing validation statistics.
|
142
144
|
"""
|
143
145
|
self.training = trainer is not None
|
144
146
|
augment = self.args.augment and (not self.training)
|
@@ -149,7 +151,6 @@ class BaseValidator:
|
|
149
151
|
self.args.half = self.device.type != "cpu" and trainer.amp
|
150
152
|
model = trainer.ema.ema or trainer.model
|
151
153
|
model = model.half() if self.args.half else model.float()
|
152
|
-
# self.model = model
|
153
154
|
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
154
155
|
self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
|
155
156
|
model.eval()
|
@@ -164,7 +165,6 @@ class BaseValidator:
|
|
164
165
|
data=self.args.data,
|
165
166
|
fp16=self.args.half,
|
166
167
|
)
|
167
|
-
# self.model = model
|
168
168
|
self.device = model.device # update device
|
169
169
|
self.args.half = model.fp16 # update half
|
170
170
|
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
|
@@ -184,7 +184,7 @@ class BaseValidator:
|
|
184
184
|
|
185
185
|
if self.device.type in {"cpu", "mps"}:
|
186
186
|
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
187
|
-
if not (pt or getattr(model, "dynamic", False)):
|
187
|
+
if not (pt or (getattr(model, "dynamic", False) and not model.imx)):
|
188
188
|
self.args.rect = False
|
189
189
|
self.stride = model.stride # used in get_dataloader() for padding
|
190
190
|
self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
|
@@ -263,7 +263,7 @@ class BaseValidator:
|
|
263
263
|
pred_classes (torch.Tensor): Predicted class indices of shape (N,).
|
264
264
|
true_classes (torch.Tensor): Target class indices of shape (M,).
|
265
265
|
iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground truth.
|
266
|
-
use_scipy (bool): Whether to use scipy for matching (more precise).
|
266
|
+
use_scipy (bool, optional): Whether to use scipy for matching (more precise).
|
267
267
|
|
268
268
|
Returns:
|
269
269
|
(torch.Tensor): Correct tensor of shape (N, 10) for 10 IoU thresholds.
|
@@ -292,7 +292,6 @@ class BaseValidator:
|
|
292
292
|
if matches.shape[0] > 1:
|
293
293
|
matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
|
294
294
|
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
295
|
-
# matches = matches[matches[:, 2].argsort()[::-1]]
|
296
295
|
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
297
296
|
correct[matches[:, 1].astype(int), i] = True
|
298
297
|
return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
|
@@ -356,10 +355,9 @@ class BaseValidator:
|
|
356
355
|
return []
|
357
356
|
|
358
357
|
def on_plot(self, name, data=None):
|
359
|
-
"""Register plots
|
358
|
+
"""Register plots for visualization."""
|
360
359
|
self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
|
361
360
|
|
362
|
-
# TODO: may need to put these following functions into callback
|
363
361
|
def plot_val_samples(self, batch, ni):
|
364
362
|
"""Plot validation samples during training."""
|
365
363
|
pass
|
ultralytics/hub/__init__.py
CHANGED
@@ -31,8 +31,8 @@ def login(api_key: str = None, save: bool = True) -> bool:
|
|
31
31
|
environment variable if successfully authenticated.
|
32
32
|
|
33
33
|
Args:
|
34
|
-
api_key (str, optional): API key to use for authentication. If not provided, it will be retrieved from
|
35
|
-
or HUB_API_KEY environment variable.
|
34
|
+
api_key (str, optional): API key to use for authentication. If not provided, it will be retrieved from
|
35
|
+
SETTINGS or HUB_API_KEY environment variable.
|
36
36
|
save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful.
|
37
37
|
|
38
38
|
Returns:
|
@@ -68,13 +68,7 @@ def login(api_key: str = None, save: bool = True) -> bool:
|
|
68
68
|
|
69
69
|
|
70
70
|
def logout():
|
71
|
-
"""
|
72
|
-
Log out of Ultralytics HUB by removing the API key from the settings file. To log in again, use 'yolo login'.
|
73
|
-
|
74
|
-
Examples:
|
75
|
-
>>> from ultralytics import hub
|
76
|
-
>>> hub.logout()
|
77
|
-
"""
|
71
|
+
"""Log out of Ultralytics HUB by removing the API key from the settings file."""
|
78
72
|
SETTINGS["api_key"] = ""
|
79
73
|
LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo login'.")
|
80
74
|
|
@@ -89,7 +83,7 @@ def reset_model(model_id: str = ""):
|
|
89
83
|
|
90
84
|
|
91
85
|
def export_fmts_hub():
|
92
|
-
"""
|
86
|
+
"""Return a list of HUB-supported export formats."""
|
93
87
|
from ultralytics.engine.exporter import export_formats
|
94
88
|
|
95
89
|
return list(export_formats()["Argument"][1:]) + ["ultralytics_tflite", "ultralytics_coreml"]
|
@@ -125,14 +119,18 @@ def get_export(model_id: str = "", format: str = "torchscript"):
|
|
125
119
|
|
126
120
|
Args:
|
127
121
|
model_id (str): The ID of the model to retrieve from Ultralytics HUB.
|
128
|
-
format (str): The export format to retrieve. Must be one of the supported formats returned by
|
122
|
+
format (str): The export format to retrieve. Must be one of the supported formats returned by
|
123
|
+
export_fmts_hub().
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
(dict): JSON response containing the exported model information.
|
129
127
|
|
130
128
|
Raises:
|
131
129
|
AssertionError: If the specified format is not supported or if the API request fails.
|
132
130
|
|
133
131
|
Examples:
|
134
132
|
>>> from ultralytics import hub
|
135
|
-
>>> hub.get_export(model_id="your_model_id", format="torchscript")
|
133
|
+
>>> result = hub.get_export(model_id="your_model_id", format="torchscript")
|
136
134
|
"""
|
137
135
|
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
138
136
|
r = requests.post(
|
@@ -160,7 +158,7 @@ def check_dataset(path: str, task: str) -> None:
|
|
160
158
|
>>> check_dataset("path/to/dota8.zip", task="obb") # OBB dataset
|
161
159
|
>>> check_dataset("path/to/imagenet10.zip", task="classify") # classification dataset
|
162
160
|
|
163
|
-
|
161
|
+
Notes:
|
164
162
|
Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
|
165
163
|
i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
|
166
164
|
"""
|
ultralytics/hub/auth.py
CHANGED
@@ -21,6 +21,19 @@ class Auth:
|
|
21
21
|
id_token (str | bool): Token used for identity verification, initialized as False.
|
22
22
|
api_key (str | bool): API key for authentication, initialized as False.
|
23
23
|
model_key (bool): Placeholder for model key, initialized as False.
|
24
|
+
|
25
|
+
Methods:
|
26
|
+
authenticate: Attempt to authenticate with the server using either id_token or API key.
|
27
|
+
auth_with_cookies: Attempt to fetch authentication via cookies and set id_token.
|
28
|
+
get_auth_header: Get the authentication header for making API requests.
|
29
|
+
request_api_key: Prompt the user to input their API key.
|
30
|
+
|
31
|
+
Examples:
|
32
|
+
Initialize Auth with an API key
|
33
|
+
>>> auth = Auth(api_key="your_api_key_here")
|
34
|
+
|
35
|
+
Initialize Auth without API key (will prompt for input)
|
36
|
+
>>> auth = Auth()
|
24
37
|
"""
|
25
38
|
|
26
39
|
id_token = api_key = model_key = False
|
@@ -71,7 +84,15 @@ class Auth:
|
|
71
84
|
LOGGER.info(f"{PREFIX}Get API key from {API_KEY_URL} and then run 'yolo login API_KEY'")
|
72
85
|
|
73
86
|
def request_api_key(self, max_attempts: int = 3) -> bool:
|
74
|
-
"""
|
87
|
+
"""
|
88
|
+
Prompt the user to input their API key.
|
89
|
+
|
90
|
+
Args:
|
91
|
+
max_attempts (int): Maximum number of authentication attempts.
|
92
|
+
|
93
|
+
Returns:
|
94
|
+
(bool): True if authentication is successful, False otherwise.
|
95
|
+
"""
|
75
96
|
import getpass
|
76
97
|
|
77
98
|
for attempts in range(max_attempts):
|
@@ -134,4 +155,3 @@ class Auth:
|
|
134
155
|
return {"authorization": f"Bearer {self.id_token}"}
|
135
156
|
elif self.api_key:
|
136
157
|
return {"x-api-key": self.api_key}
|
137
|
-
# else returns None
|
@@ -31,7 +31,7 @@ class GCPRegions:
|
|
31
31
|
"""
|
32
32
|
|
33
33
|
def __init__(self):
|
34
|
-
"""
|
34
|
+
"""Initialize the GCPRegions class with predefined Google Cloud Platform regions and their details."""
|
35
35
|
self.regions = {
|
36
36
|
"asia-east1": (1, "Taiwan", "China"),
|
37
37
|
"asia-east2": (2, "Hong Kong", "China"),
|
@@ -74,11 +74,11 @@ class GCPRegions:
|
|
74
74
|
}
|
75
75
|
|
76
76
|
def tier1(self) -> List[str]:
|
77
|
-
"""
|
77
|
+
"""Return a list of GCP regions classified as tier 1 based on predefined criteria."""
|
78
78
|
return [region for region, info in self.regions.items() if info[0] == 1]
|
79
79
|
|
80
80
|
def tier2(self) -> List[str]:
|
81
|
-
"""
|
81
|
+
"""Return a list of GCP regions classified as tier 2 based on predefined criteria."""
|
82
82
|
return [region for region, info in self.regions.items() if info[0] == 2]
|
83
83
|
|
84
84
|
@staticmethod
|
@@ -87,19 +87,19 @@ class GCPRegions:
|
|
87
87
|
Ping a specified GCP region and measure network latency statistics.
|
88
88
|
|
89
89
|
Args:
|
90
|
-
|
91
|
-
|
90
|
+
region (str): The GCP region identifier to ping (e.g., 'us-central1').
|
91
|
+
attempts (int, optional): Number of ping attempts to make for calculating statistics.
|
92
92
|
|
93
93
|
Returns:
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
94
|
+
region (str): The GCP region identifier that was pinged.
|
95
|
+
mean_latency (float): Mean latency in milliseconds, or infinity if all pings failed.
|
96
|
+
std_dev (float): Standard deviation of latencies in milliseconds, or infinity if all pings failed.
|
97
|
+
min_latency (float): Minimum latency in milliseconds, or infinity if all pings failed.
|
98
|
+
max_latency (float): Maximum latency in milliseconds, or infinity if all pings failed.
|
99
99
|
|
100
100
|
Examples:
|
101
|
-
|
102
|
-
|
101
|
+
>>> region, mean, std, min_lat, max_lat = GCPRegions._ping_region("us-central1", attempts=3)
|
102
|
+
>>> print(f"Region {region} has mean latency: {mean:.2f}ms")
|
103
103
|
"""
|
104
104
|
url = f"https://{region}-docker.pkg.dev"
|
105
105
|
latencies = []
|
@@ -107,7 +107,7 @@ class GCPRegions:
|
|
107
107
|
try:
|
108
108
|
start_time = time.time()
|
109
109
|
_ = requests.head(url, timeout=5)
|
110
|
-
latency = (time.time() - start_time) * 1000 #
|
110
|
+
latency = (time.time() - start_time) * 1000 # Convert latency to milliseconds
|
111
111
|
if latency != float("inf"):
|
112
112
|
latencies.append(latency)
|
113
113
|
except requests.RequestException:
|
@@ -126,17 +126,17 @@ class GCPRegions:
|
|
126
126
|
attempts: int = 1,
|
127
127
|
) -> List[Tuple[str, float, float, float, float]]:
|
128
128
|
"""
|
129
|
-
|
129
|
+
Determine the GCP regions with the lowest latency based on ping tests.
|
130
130
|
|
131
131
|
Args:
|
132
|
-
top (int): Number of top regions to return.
|
133
|
-
verbose (bool): If True, prints detailed latency information for all tested regions.
|
134
|
-
tier (int | None): Filter regions by tier (1 or 2). If None, all regions are tested.
|
135
|
-
attempts (int): Number of ping attempts per region.
|
132
|
+
top (int, optional): Number of top regions to return.
|
133
|
+
verbose (bool, optional): If True, prints detailed latency information for all tested regions.
|
134
|
+
tier (int | None, optional): Filter regions by tier (1 or 2). If None, all regions are tested.
|
135
|
+
attempts (int, optional): Number of ping attempts per region.
|
136
136
|
|
137
137
|
Returns:
|
138
138
|
(List[Tuple[str, float, float, float, float]]): List of tuples containing region information and
|
139
|
-
|
139
|
+
latency statistics. Each tuple contains (region, mean_latency, std_dev, min_latency, max_latency).
|
140
140
|
|
141
141
|
Examples:
|
142
142
|
>>> regions = GCPRegions()
|