dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
- dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -9
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +13 -10
- tests/test_engine.py +9 -9
- tests/test_exports.py +65 -13
- tests/test_integrations.py +13 -13
- tests/test_python.py +125 -69
- tests/test_solutions.py +161 -152
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +86 -92
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +4 -2
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
- ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
- ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
- ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
- ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
- ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
- ultralytics/cfg/models/26/yolo26.yaml +52 -0
- ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
- ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +5 -6
- ultralytics/data/augment.py +300 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +108 -87
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +36 -45
- ultralytics/engine/exporter.py +351 -263
- ultralytics/engine/model.py +186 -225
- ultralytics/engine/predictor.py +45 -54
- ultralytics/engine/results.py +198 -325
- ultralytics/engine/trainer.py +165 -106
- ultralytics/engine/tuner.py +41 -43
- ultralytics/engine/validator.py +55 -38
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +18 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +10 -13
- ultralytics/models/yolo/classify/train.py +12 -33
- ultralytics/models/yolo/classify/val.py +30 -29
- ultralytics/models/yolo/detect/predict.py +9 -12
- ultralytics/models/yolo/detect/train.py +17 -23
- ultralytics/models/yolo/detect/val.py +77 -59
- ultralytics/models/yolo/model.py +43 -60
- ultralytics/models/yolo/obb/predict.py +7 -16
- ultralytics/models/yolo/obb/train.py +14 -17
- ultralytics/models/yolo/obb/val.py +40 -37
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +13 -16
- ultralytics/models/yolo/pose/val.py +39 -58
- ultralytics/models/yolo/segment/predict.py +17 -21
- ultralytics/models/yolo/segment/train.py +7 -10
- ultralytics/models/yolo/segment/val.py +95 -47
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +36 -44
- ultralytics/models/yolo/yoloe/train_seg.py +11 -11
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +159 -85
- ultralytics/nn/modules/__init__.py +68 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +260 -224
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +831 -299
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +180 -195
- ultralytics/nn/text_model.py +45 -69
- ultralytics/optim/__init__.py +5 -0
- ultralytics/optim/muon.py +338 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +13 -19
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +6 -7
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +8 -14
- ultralytics/solutions/instance_segmentation.py +6 -9
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +34 -32
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +77 -76
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +21 -37
- ultralytics/trackers/track.py +4 -7
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +124 -124
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +57 -71
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +423 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +25 -31
- ultralytics/utils/callbacks/wb.py +16 -14
- ultralytics/utils/checks.py +127 -85
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +9 -12
- ultralytics/utils/downloads.py +25 -33
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +246 -0
- ultralytics/utils/export/imx.py +117 -63
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +26 -30
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +601 -215
- ultralytics/utils/metrics.py +128 -156
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +117 -166
- ultralytics/utils/patches.py +75 -21
- ultralytics/utils/plotting.py +75 -80
- ultralytics/utils/tal.py +125 -59
- ultralytics/utils/torch_utils.py +53 -79
- ultralytics/utils/tqdm.py +24 -21
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +19 -10
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
ultralytics/engine/tuner.py
CHANGED
|
@@ -8,9 +8,9 @@ that yield the best model performance. This is particularly crucial in deep lear
|
|
|
8
8
|
where small changes in hyperparameters can lead to significant differences in model accuracy and efficiency.
|
|
9
9
|
|
|
10
10
|
Examples:
|
|
11
|
-
Tune hyperparameters for
|
|
11
|
+
Tune hyperparameters for YOLO26n on COCO8 at imgsz=640 and epochs=10 for 300 tuning iterations.
|
|
12
12
|
>>> from ultralytics import YOLO
|
|
13
|
-
>>> model = YOLO("
|
|
13
|
+
>>> model = YOLO("yolo26n.pt")
|
|
14
14
|
>>> model.tune(data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False)
|
|
15
15
|
"""
|
|
16
16
|
|
|
@@ -34,12 +34,11 @@ from ultralytics.utils.plotting import plot_tune_results
|
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
class Tuner:
|
|
37
|
-
"""
|
|
38
|
-
A class for hyperparameter tuning of YOLO models.
|
|
37
|
+
"""A class for hyperparameter tuning of YOLO models.
|
|
39
38
|
|
|
40
39
|
The class evolves YOLO model hyperparameters over a given number of iterations by mutating them according to the
|
|
41
|
-
search space and retraining the model to evaluate their performance. Supports both local CSV storage and
|
|
42
|
-
|
|
40
|
+
search space and retraining the model to evaluate their performance. Supports both local CSV storage and distributed
|
|
41
|
+
MongoDB Atlas coordination for multi-machine hyperparameter optimization.
|
|
43
42
|
|
|
44
43
|
Attributes:
|
|
45
44
|
space (dict[str, tuple]): Hyperparameter search space containing bounds and scaling factors for mutation.
|
|
@@ -56,9 +55,9 @@ class Tuner:
|
|
|
56
55
|
__call__: Execute the hyperparameter evolution across multiple iterations.
|
|
57
56
|
|
|
58
57
|
Examples:
|
|
59
|
-
Tune hyperparameters for
|
|
58
|
+
Tune hyperparameters for YOLO26n on COCO8 at imgsz=640 and epochs=10 for 300 tuning iterations.
|
|
60
59
|
>>> from ultralytics import YOLO
|
|
61
|
-
>>> model = YOLO("
|
|
60
|
+
>>> model = YOLO("yolo26n.pt")
|
|
62
61
|
>>> model.tune(
|
|
63
62
|
>>> data="coco8.yaml",
|
|
64
63
|
>>> epochs=10,
|
|
@@ -83,8 +82,7 @@ class Tuner:
|
|
|
83
82
|
"""
|
|
84
83
|
|
|
85
84
|
def __init__(self, args=DEFAULT_CFG, _callbacks: list | None = None):
|
|
86
|
-
"""
|
|
87
|
-
Initialize the Tuner with configurations.
|
|
85
|
+
"""Initialize the Tuner with configurations.
|
|
88
86
|
|
|
89
87
|
Args:
|
|
90
88
|
args (dict): Configuration for hyperparameter evolution.
|
|
@@ -92,15 +90,15 @@ class Tuner:
|
|
|
92
90
|
"""
|
|
93
91
|
self.space = args.pop("space", None) or { # key: (min, max, gain(optional))
|
|
94
92
|
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
|
|
95
|
-
"lr0": (1e-5, 1e-
|
|
96
|
-
"lrf": (0.
|
|
93
|
+
"lr0": (1e-5, 1e-2), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
|
|
94
|
+
"lrf": (0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
|
|
97
95
|
"momentum": (0.7, 0.98, 0.3), # SGD momentum/Adam beta1
|
|
98
96
|
"weight_decay": (0.0, 0.001), # optimizer weight decay 5e-4
|
|
99
97
|
"warmup_epochs": (0.0, 5.0), # warmup epochs (fractions ok)
|
|
100
98
|
"warmup_momentum": (0.0, 0.95), # warmup initial momentum
|
|
101
99
|
"box": (1.0, 20.0), # box loss gain
|
|
102
100
|
"cls": (0.1, 4.0), # cls loss gain (scale with pixels)
|
|
103
|
-
"dfl": (0.4,
|
|
101
|
+
"dfl": (0.4, 12.0), # dfl loss gain
|
|
104
102
|
"hsv_h": (0.0, 0.1), # image HSV-Hue augmentation (fraction)
|
|
105
103
|
"hsv_s": (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
|
|
106
104
|
"hsv_v": (0.0, 0.9), # image HSV-Value augmentation (fraction)
|
|
@@ -142,8 +140,7 @@ class Tuner:
|
|
|
142
140
|
)
|
|
143
141
|
|
|
144
142
|
def _connect(self, uri: str = "mongodb+srv://username:password@cluster.mongodb.net/", max_retries: int = 3):
|
|
145
|
-
"""
|
|
146
|
-
Create MongoDB client with exponential backoff retry on connection failures.
|
|
143
|
+
"""Create MongoDB client with exponential backoff retry on connection failures.
|
|
147
144
|
|
|
148
145
|
Args:
|
|
149
146
|
uri (str): MongoDB connection string with credentials and cluster information.
|
|
@@ -183,12 +180,10 @@ class Tuner:
|
|
|
183
180
|
time.sleep(wait_time)
|
|
184
181
|
|
|
185
182
|
def _init_mongodb(self, mongodb_uri="", mongodb_db="", mongodb_collection=""):
|
|
186
|
-
"""
|
|
187
|
-
Initialize MongoDB connection for distributed tuning.
|
|
183
|
+
"""Initialize MongoDB connection for distributed tuning.
|
|
188
184
|
|
|
189
|
-
Connects to MongoDB Atlas for distributed hyperparameter optimization across multiple machines.
|
|
190
|
-
|
|
191
|
-
from all workers for evolution.
|
|
185
|
+
Connects to MongoDB Atlas for distributed hyperparameter optimization across multiple machines. Each worker
|
|
186
|
+
saves results to a shared collection and reads the latest best hyperparameters from all workers for evolution.
|
|
192
187
|
|
|
193
188
|
Args:
|
|
194
189
|
mongodb_uri (str): MongoDB connection string, e.g. 'mongodb+srv://username:password@cluster.mongodb.net/'.
|
|
@@ -206,8 +201,7 @@ class Tuner:
|
|
|
206
201
|
LOGGER.info(f"{self.prefix}Using MongoDB Atlas for distributed tuning")
|
|
207
202
|
|
|
208
203
|
def _get_mongodb_results(self, n: int = 5) -> list:
|
|
209
|
-
"""
|
|
210
|
-
Get top N results from MongoDB sorted by fitness.
|
|
204
|
+
"""Get top N results from MongoDB sorted by fitness.
|
|
211
205
|
|
|
212
206
|
Args:
|
|
213
207
|
n (int): Number of top results to retrieve.
|
|
@@ -221,8 +215,7 @@ class Tuner:
|
|
|
221
215
|
return []
|
|
222
216
|
|
|
223
217
|
def _save_to_mongodb(self, fitness: float, hyperparameters: dict[str, float], metrics: dict, iteration: int):
|
|
224
|
-
"""
|
|
225
|
-
Save results to MongoDB with proper type conversion.
|
|
218
|
+
"""Save results to MongoDB with proper type conversion.
|
|
226
219
|
|
|
227
220
|
Args:
|
|
228
221
|
fitness (float): Fitness score achieved with these hyperparameters.
|
|
@@ -233,7 +226,7 @@ class Tuner:
|
|
|
233
226
|
try:
|
|
234
227
|
self.collection.insert_one(
|
|
235
228
|
{
|
|
236
|
-
"fitness":
|
|
229
|
+
"fitness": fitness,
|
|
237
230
|
"hyperparameters": {k: (v.item() if hasattr(v, "item") else v) for k, v in hyperparameters.items()},
|
|
238
231
|
"metrics": metrics,
|
|
239
232
|
"timestamp": datetime.now(),
|
|
@@ -244,8 +237,7 @@ class Tuner:
|
|
|
244
237
|
LOGGER.warning(f"{self.prefix}MongoDB save failed: {e}")
|
|
245
238
|
|
|
246
239
|
def _sync_mongodb_to_csv(self):
|
|
247
|
-
"""
|
|
248
|
-
Sync MongoDB results to CSV for plotting compatibility.
|
|
240
|
+
"""Sync MongoDB results to CSV for plotting compatibility.
|
|
249
241
|
|
|
250
242
|
Downloads all results from MongoDB and writes them to the local CSV file in chronological order. This enables
|
|
251
243
|
the existing plotting functions to work seamlessly with distributed MongoDB data.
|
|
@@ -257,19 +249,20 @@ class Tuner:
|
|
|
257
249
|
return
|
|
258
250
|
|
|
259
251
|
# Write to CSV
|
|
260
|
-
headers = ",".join(["fitness"
|
|
252
|
+
headers = ",".join(["fitness", *list(self.space.keys())]) + "\n"
|
|
261
253
|
with open(self.tune_csv, "w", encoding="utf-8") as f:
|
|
262
254
|
f.write(headers)
|
|
263
255
|
for result in all_results:
|
|
264
256
|
fitness = result["fitness"]
|
|
265
|
-
hyp_values = [result["hyperparameters"]
|
|
266
|
-
log_row = [round(fitness, 5)
|
|
257
|
+
hyp_values = [result["hyperparameters"].get(k, self.args.get(k)) for k in self.space.keys()]
|
|
258
|
+
log_row = [round(fitness, 5), *hyp_values]
|
|
267
259
|
f.write(",".join(map(str, log_row)) + "\n")
|
|
268
260
|
|
|
269
261
|
except Exception as e:
|
|
270
262
|
LOGGER.warning(f"{self.prefix}MongoDB to CSV sync failed: {e}")
|
|
271
263
|
|
|
272
|
-
|
|
264
|
+
@staticmethod
|
|
265
|
+
def _crossover(x: np.ndarray, alpha: float = 0.2, k: int = 9) -> np.ndarray:
|
|
273
266
|
"""BLX-α crossover from up to top-k parents (x[:,0]=fitness, rest=genes)."""
|
|
274
267
|
k = min(k, len(x))
|
|
275
268
|
# fitness weights (shifted to >0); fallback to uniform if degenerate
|
|
@@ -280,6 +273,8 @@ class Tuner:
|
|
|
280
273
|
parents_mat = np.stack([x[i][1:] for i in idxs], 0) # (k, ng) strip fitness
|
|
281
274
|
lo, hi = parents_mat.min(0), parents_mat.max(0)
|
|
282
275
|
span = hi - lo
|
|
276
|
+
# given a small value when span is zero to avoid no mutation
|
|
277
|
+
span = np.where(span == 0, np.random.uniform(0.01, 0.1, span.shape), span)
|
|
283
278
|
return np.random.uniform(lo - alpha * span, hi + alpha * span)
|
|
284
279
|
|
|
285
280
|
def _mutate(
|
|
@@ -288,11 +283,9 @@ class Tuner:
|
|
|
288
283
|
mutation: float = 0.5,
|
|
289
284
|
sigma: float = 0.2,
|
|
290
285
|
) -> dict[str, float]:
|
|
291
|
-
"""
|
|
292
|
-
Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
|
|
286
|
+
"""Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
|
|
293
287
|
|
|
294
288
|
Args:
|
|
295
|
-
parent (str): Parent selection method (kept for API compatibility, unused in BLX mode).
|
|
296
289
|
n (int): Number of top parents to consider.
|
|
297
290
|
mutation (float): Probability of a parameter mutation in any given iteration.
|
|
298
291
|
sigma (float): Standard deviation for Gaussian random number generator.
|
|
@@ -304,10 +297,14 @@ class Tuner:
|
|
|
304
297
|
|
|
305
298
|
# Try MongoDB first if available
|
|
306
299
|
if self.mongodb:
|
|
307
|
-
results
|
|
308
|
-
if results:
|
|
300
|
+
if results := self._get_mongodb_results(n):
|
|
309
301
|
# MongoDB already sorted by fitness DESC, so results[0] is best
|
|
310
|
-
x = np.array(
|
|
302
|
+
x = np.array(
|
|
303
|
+
[
|
|
304
|
+
[r["fitness"]] + [r["hyperparameters"].get(k, self.args.get(k)) for k in self.space.keys()]
|
|
305
|
+
for r in results
|
|
306
|
+
]
|
|
307
|
+
)
|
|
311
308
|
elif self.collection.name in self.collection.database.list_collection_names(): # Tuner started elsewhere
|
|
312
309
|
x = np.array([[0.0] + [getattr(self.args, k) for k in self.space.keys()]])
|
|
313
310
|
|
|
@@ -344,13 +341,14 @@ class Tuner:
|
|
|
344
341
|
|
|
345
342
|
# Update types
|
|
346
343
|
if "close_mosaic" in hyp:
|
|
347
|
-
hyp["close_mosaic"] =
|
|
344
|
+
hyp["close_mosaic"] = round(hyp["close_mosaic"])
|
|
345
|
+
if "epochs" in hyp:
|
|
346
|
+
hyp["epochs"] = round(hyp["epochs"])
|
|
348
347
|
|
|
349
348
|
return hyp
|
|
350
349
|
|
|
351
|
-
def __call__(self,
|
|
352
|
-
"""
|
|
353
|
-
Execute the hyperparameter evolution process when the Tuner instance is called.
|
|
350
|
+
def __call__(self, iterations: int = 10, cleanup: bool = True):
|
|
351
|
+
"""Execute the hyperparameter evolution process when the Tuner instance is called.
|
|
354
352
|
|
|
355
353
|
This method iterates through the specified number of iterations, performing the following steps:
|
|
356
354
|
1. Sync MongoDB results to CSV (if using distributed mode)
|
|
@@ -360,7 +358,6 @@ class Tuner:
|
|
|
360
358
|
5. Track the best performing configuration across all iterations
|
|
361
359
|
|
|
362
360
|
Args:
|
|
363
|
-
model (Model | None, optional): A pre-initialized YOLO model to be used for training.
|
|
364
361
|
iterations (int): The number of generations to run the evolution for.
|
|
365
362
|
cleanup (bool): Whether to delete iteration weights to reduce storage space during tuning.
|
|
366
363
|
"""
|
|
@@ -389,6 +386,7 @@ class Tuner:
|
|
|
389
386
|
metrics = {}
|
|
390
387
|
train_args = {**vars(self.args), **mutated_hyp}
|
|
391
388
|
save_dir = get_save_dir(get_cfg(train_args))
|
|
389
|
+
train_args["save_dir"] = str(save_dir) # pass save_dir to subprocess to ensure same path is used
|
|
392
390
|
weights_dir = save_dir / "weights"
|
|
393
391
|
try:
|
|
394
392
|
# Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang)
|
|
@@ -421,7 +419,7 @@ class Tuner:
|
|
|
421
419
|
else:
|
|
422
420
|
# Save to CSV only if no MongoDB
|
|
423
421
|
log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
|
|
424
|
-
headers = "" if self.tune_csv.exists() else (",".join(["fitness"
|
|
422
|
+
headers = "" if self.tune_csv.exists() else (",".join(["fitness", *list(self.space.keys())]) + "\n")
|
|
425
423
|
with open(self.tune_csv, "a", encoding="utf-8") as f:
|
|
426
424
|
f.write(headers + ",".join(map(str, log_row)) + "\n")
|
|
427
425
|
|
ultralytics/engine/validator.py
CHANGED
|
@@ -3,24 +3,24 @@
|
|
|
3
3
|
Check a model's accuracy on a test or val split of a dataset.
|
|
4
4
|
|
|
5
5
|
Usage:
|
|
6
|
-
$ yolo mode=val model=
|
|
6
|
+
$ yolo mode=val model=yolo26n.pt data=coco8.yaml imgsz=640
|
|
7
7
|
|
|
8
8
|
Usage - formats:
|
|
9
|
-
$ yolo mode=val model=
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
9
|
+
$ yolo mode=val model=yolo26n.pt # PyTorch
|
|
10
|
+
yolo26n.torchscript # TorchScript
|
|
11
|
+
yolo26n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
|
12
|
+
yolo26n_openvino_model # OpenVINO
|
|
13
|
+
yolo26n.engine # TensorRT
|
|
14
|
+
yolo26n.mlpackage # CoreML (macOS-only)
|
|
15
|
+
yolo26n_saved_model # TensorFlow SavedModel
|
|
16
|
+
yolo26n.pb # TensorFlow GraphDef
|
|
17
|
+
yolo26n.tflite # TensorFlow Lite
|
|
18
|
+
yolo26n_edgetpu.tflite # TensorFlow Edge TPU
|
|
19
|
+
yolo26n_paddle_model # PaddlePaddle
|
|
20
|
+
yolo26n.mnn # MNN
|
|
21
|
+
yolo26n_ncnn_model # NCNN
|
|
22
|
+
yolo26n_imx_model # Sony IMX
|
|
23
|
+
yolo26n_rknn_model # Rockchip RKNN
|
|
24
24
|
"""
|
|
25
25
|
|
|
26
26
|
import json
|
|
@@ -29,26 +29,26 @@ from pathlib import Path
|
|
|
29
29
|
|
|
30
30
|
import numpy as np
|
|
31
31
|
import torch
|
|
32
|
+
import torch.distributed as dist
|
|
32
33
|
|
|
33
34
|
from ultralytics.cfg import get_cfg, get_save_dir
|
|
34
35
|
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
|
35
36
|
from ultralytics.nn.autobackend import AutoBackend
|
|
36
|
-
from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
|
|
37
|
+
from ultralytics.utils import LOGGER, RANK, TQDM, callbacks, colorstr, emojis
|
|
37
38
|
from ultralytics.utils.checks import check_imgsz
|
|
38
39
|
from ultralytics.utils.ops import Profile
|
|
39
40
|
from ultralytics.utils.torch_utils import attempt_compile, select_device, smart_inference_mode, unwrap_model
|
|
40
41
|
|
|
41
42
|
|
|
42
43
|
class BaseValidator:
|
|
43
|
-
"""
|
|
44
|
-
A base class for creating validators.
|
|
44
|
+
"""A base class for creating validators.
|
|
45
45
|
|
|
46
46
|
This class provides the foundation for validation processes, including model evaluation, metric computation, and
|
|
47
47
|
result visualization.
|
|
48
48
|
|
|
49
49
|
Attributes:
|
|
50
50
|
args (SimpleNamespace): Configuration for the validator.
|
|
51
|
-
dataloader (DataLoader):
|
|
51
|
+
dataloader (DataLoader): DataLoader to use for validation.
|
|
52
52
|
model (nn.Module): Model to validate.
|
|
53
53
|
data (dict): Data dictionary containing dataset information.
|
|
54
54
|
device (torch.device): Device to use for validation.
|
|
@@ -61,8 +61,8 @@ class BaseValidator:
|
|
|
61
61
|
nc (int): Number of classes.
|
|
62
62
|
iouv (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
|
|
63
63
|
jdict (list): List to store JSON validation results.
|
|
64
|
-
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
|
|
65
|
-
|
|
64
|
+
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective batch
|
|
65
|
+
processing times in milliseconds.
|
|
66
66
|
save_dir (Path): Directory to save results.
|
|
67
67
|
plots (dict): Dictionary to store plots for visualization.
|
|
68
68
|
callbacks (dict): Dictionary to store various callback functions.
|
|
@@ -92,11 +92,10 @@ class BaseValidator:
|
|
|
92
92
|
"""
|
|
93
93
|
|
|
94
94
|
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
|
|
95
|
-
"""
|
|
96
|
-
Initialize a BaseValidator instance.
|
|
95
|
+
"""Initialize a BaseValidator instance.
|
|
97
96
|
|
|
98
97
|
Args:
|
|
99
|
-
dataloader (torch.utils.data.DataLoader, optional):
|
|
98
|
+
dataloader (torch.utils.data.DataLoader, optional): DataLoader to be used for validation.
|
|
100
99
|
save_dir (Path, optional): Directory to save results.
|
|
101
100
|
args (SimpleNamespace, optional): Configuration for the validator.
|
|
102
101
|
_callbacks (dict, optional): Dictionary to store various callback functions.
|
|
@@ -130,8 +129,7 @@ class BaseValidator:
|
|
|
130
129
|
|
|
131
130
|
@smart_inference_mode()
|
|
132
131
|
def __call__(self, trainer=None, model=None):
|
|
133
|
-
"""
|
|
134
|
-
Execute validation process, running inference on dataloader and computing performance metrics.
|
|
132
|
+
"""Execute validation process, running inference on dataloader and computing performance metrics.
|
|
135
133
|
|
|
136
134
|
Args:
|
|
137
135
|
trainer (object, optional): Trainer object that contains the model to validate.
|
|
@@ -160,7 +158,7 @@ class BaseValidator:
|
|
|
160
158
|
callbacks.add_integration_callbacks(self)
|
|
161
159
|
model = AutoBackend(
|
|
162
160
|
model=model or self.args.model,
|
|
163
|
-
device=select_device(self.args.device),
|
|
161
|
+
device=select_device(self.args.device) if RANK == -1 else torch.device("cuda", RANK),
|
|
164
162
|
dnn=self.args.dnn,
|
|
165
163
|
data=self.args.data,
|
|
166
164
|
fp16=self.args.half,
|
|
@@ -223,21 +221,34 @@ class BaseValidator:
|
|
|
223
221
|
preds = self.postprocess(preds)
|
|
224
222
|
|
|
225
223
|
self.update_metrics(preds, batch)
|
|
226
|
-
if self.args.plots and batch_i < 3:
|
|
224
|
+
if self.args.plots and batch_i < 3 and RANK in {-1, 0}:
|
|
227
225
|
self.plot_val_samples(batch, batch_i)
|
|
228
226
|
self.plot_predictions(batch, preds, batch_i)
|
|
229
227
|
|
|
230
228
|
self.run_callbacks("on_val_batch_end")
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
self.
|
|
234
|
-
|
|
235
|
-
|
|
229
|
+
|
|
230
|
+
stats = {}
|
|
231
|
+
self.gather_stats()
|
|
232
|
+
if RANK in {-1, 0}:
|
|
233
|
+
stats = self.get_stats()
|
|
234
|
+
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
|
|
235
|
+
self.finalize_metrics()
|
|
236
|
+
self.print_results()
|
|
237
|
+
self.run_callbacks("on_val_end")
|
|
238
|
+
|
|
236
239
|
if self.training:
|
|
237
240
|
model.float()
|
|
238
|
-
|
|
241
|
+
# Reduce loss across all GPUs
|
|
242
|
+
loss = self.loss.clone().detach()
|
|
243
|
+
if trainer.world_size > 1:
|
|
244
|
+
dist.reduce(loss, dst=0, op=dist.ReduceOp.AVG)
|
|
245
|
+
if RANK > 0:
|
|
246
|
+
return
|
|
247
|
+
results = {**stats, **trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val")}
|
|
239
248
|
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
|
240
249
|
else:
|
|
250
|
+
if RANK > 0:
|
|
251
|
+
return stats
|
|
241
252
|
LOGGER.info(
|
|
242
253
|
"Speed: {:.1f}ms preprocess, {:.1f}ms inference, {:.1f}ms loss, {:.1f}ms postprocess per image".format(
|
|
243
254
|
*tuple(self.speed.values())
|
|
@@ -255,8 +266,7 @@ class BaseValidator:
|
|
|
255
266
|
def match_predictions(
|
|
256
267
|
self, pred_classes: torch.Tensor, true_classes: torch.Tensor, iou: torch.Tensor, use_scipy: bool = False
|
|
257
268
|
) -> torch.Tensor:
|
|
258
|
-
"""
|
|
259
|
-
Match predictions to ground truth objects using IoU.
|
|
269
|
+
"""Match predictions to ground truth objects using IoU.
|
|
260
270
|
|
|
261
271
|
Args:
|
|
262
272
|
pred_classes (torch.Tensor): Predicted class indices of shape (N,).
|
|
@@ -336,6 +346,10 @@ class BaseValidator:
|
|
|
336
346
|
"""Return statistics about the model's performance."""
|
|
337
347
|
return {}
|
|
338
348
|
|
|
349
|
+
def gather_stats(self):
|
|
350
|
+
"""Gather statistics from all the GPUs during DDP training to GPU 0."""
|
|
351
|
+
pass
|
|
352
|
+
|
|
339
353
|
def print_results(self):
|
|
340
354
|
"""Print the results of the model's predictions."""
|
|
341
355
|
pass
|
|
@@ -350,7 +364,10 @@ class BaseValidator:
|
|
|
350
364
|
return []
|
|
351
365
|
|
|
352
366
|
def on_plot(self, name, data=None):
|
|
353
|
-
"""Register plots for visualization."""
|
|
367
|
+
"""Register plots for visualization, deduplicating by type."""
|
|
368
|
+
plot_type = data.get("type") if data else None
|
|
369
|
+
if plot_type and any((v.get("data") or {}).get("type") == plot_type for v in self.plots.values()):
|
|
370
|
+
return # Skip duplicate plot types
|
|
354
371
|
self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
|
|
355
372
|
|
|
356
373
|
def plot_val_samples(self, batch, ni):
|
ultralytics/hub/__init__.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
from ultralytics.data.utils import HUBDatasetStats
|
|
4
6
|
from ultralytics.hub.auth import Auth
|
|
5
7
|
from ultralytics.hub.session import HUBTrainingSession
|
|
@@ -7,29 +9,28 @@ from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
|
|
|
7
9
|
from ultralytics.utils import LOGGER, SETTINGS, checks
|
|
8
10
|
|
|
9
11
|
__all__ = (
|
|
10
|
-
"PREFIX",
|
|
11
12
|
"HUB_WEB_ROOT",
|
|
13
|
+
"PREFIX",
|
|
12
14
|
"HUBTrainingSession",
|
|
13
|
-
"
|
|
14
|
-
"logout",
|
|
15
|
-
"reset_model",
|
|
15
|
+
"check_dataset",
|
|
16
16
|
"export_fmts_hub",
|
|
17
17
|
"export_model",
|
|
18
18
|
"get_export",
|
|
19
|
-
"
|
|
19
|
+
"login",
|
|
20
|
+
"logout",
|
|
21
|
+
"reset_model",
|
|
20
22
|
)
|
|
21
23
|
|
|
22
24
|
|
|
23
|
-
def login(api_key: str = None, save: bool = True) -> bool:
|
|
24
|
-
"""
|
|
25
|
-
Log in to the Ultralytics HUB API using the provided API key.
|
|
25
|
+
def login(api_key: str | None = None, save: bool = True) -> bool:
|
|
26
|
+
"""Log in to the Ultralytics HUB API using the provided API key.
|
|
26
27
|
|
|
27
28
|
The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY
|
|
28
29
|
environment variable if successfully authenticated.
|
|
29
30
|
|
|
30
31
|
Args:
|
|
31
|
-
api_key (str, optional): API key to use for authentication. If not provided, it will be retrieved from
|
|
32
|
-
|
|
32
|
+
api_key (str, optional): API key to use for authentication. If not provided, it will be retrieved from SETTINGS
|
|
33
|
+
or HUB_API_KEY environment variable.
|
|
33
34
|
save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful.
|
|
34
35
|
|
|
35
36
|
Returns:
|
|
@@ -85,12 +86,11 @@ def export_fmts_hub():
|
|
|
85
86
|
"""Return a list of HUB-supported export formats."""
|
|
86
87
|
from ultralytics.engine.exporter import export_formats
|
|
87
88
|
|
|
88
|
-
return list(export_formats()["Argument"][1:])
|
|
89
|
+
return [*list(export_formats()["Argument"][1:]), "ultralytics_tflite", "ultralytics_coreml"]
|
|
89
90
|
|
|
90
91
|
|
|
91
92
|
def export_model(model_id: str = "", format: str = "torchscript"):
|
|
92
|
-
"""
|
|
93
|
-
Export a model to a specified format for deployment via the Ultralytics HUB API.
|
|
93
|
+
"""Export a model to a specified format for deployment via the Ultralytics HUB API.
|
|
94
94
|
|
|
95
95
|
Args:
|
|
96
96
|
model_id (str): The ID of the model to export. An empty string will use the default model.
|
|
@@ -115,13 +115,11 @@ def export_model(model_id: str = "", format: str = "torchscript"):
|
|
|
115
115
|
|
|
116
116
|
|
|
117
117
|
def get_export(model_id: str = "", format: str = "torchscript"):
|
|
118
|
-
"""
|
|
119
|
-
Retrieve an exported model in the specified format from Ultralytics HUB using the model ID.
|
|
118
|
+
"""Retrieve an exported model in the specified format from Ultralytics HUB using the model ID.
|
|
120
119
|
|
|
121
120
|
Args:
|
|
122
121
|
model_id (str): The ID of the model to retrieve from Ultralytics HUB.
|
|
123
|
-
format (str): The export format to retrieve. Must be one of the supported formats returned by
|
|
124
|
-
export_fmts_hub().
|
|
122
|
+
format (str): The export format to retrieve. Must be one of the supported formats returned by export_fmts_hub().
|
|
125
123
|
|
|
126
124
|
Returns:
|
|
127
125
|
(dict): JSON response containing the exported model information.
|
|
@@ -146,8 +144,7 @@ def get_export(model_id: str = "", format: str = "torchscript"):
|
|
|
146
144
|
|
|
147
145
|
|
|
148
146
|
def check_dataset(path: str, task: str) -> None:
|
|
149
|
-
"""
|
|
150
|
-
Check HUB dataset Zip file for errors before upload.
|
|
147
|
+
"""Check HUB dataset Zip file for errors before upload.
|
|
151
148
|
|
|
152
149
|
Args:
|
|
153
150
|
path (str): Path to data.zip (with data.yaml inside data.zip).
|
ultralytics/hub/auth.py
CHANGED
|
@@ -7,8 +7,7 @@ API_KEY_URL = f"{HUB_WEB_ROOT}/settings?tab=api+keys"
|
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class Auth:
|
|
10
|
-
"""
|
|
11
|
-
Manages authentication processes including API key handling, cookie-based authentication, and header generation.
|
|
10
|
+
"""Manages authentication processes including API key handling, cookie-based authentication, and header generation.
|
|
12
11
|
|
|
13
12
|
The class supports different methods of authentication:
|
|
14
13
|
1. Directly using an API key.
|
|
@@ -37,8 +36,7 @@ class Auth:
|
|
|
37
36
|
id_token = api_key = model_key = False
|
|
38
37
|
|
|
39
38
|
def __init__(self, api_key: str = "", verbose: bool = False):
|
|
40
|
-
"""
|
|
41
|
-
Initialize Auth class and authenticate user.
|
|
39
|
+
"""Initialize Auth class and authenticate user.
|
|
42
40
|
|
|
43
41
|
Handles API key validation, Google Colab authentication, and new key requests. Updates SETTINGS upon successful
|
|
44
42
|
authentication.
|
|
@@ -82,8 +80,7 @@ class Auth:
|
|
|
82
80
|
LOGGER.info(f"{PREFIX}Get API key from {API_KEY_URL} and then run 'yolo login API_KEY'")
|
|
83
81
|
|
|
84
82
|
def request_api_key(self, max_attempts: int = 3) -> bool:
|
|
85
|
-
"""
|
|
86
|
-
Prompt the user to input their API key.
|
|
83
|
+
"""Prompt the user to input their API key.
|
|
87
84
|
|
|
88
85
|
Args:
|
|
89
86
|
max_attempts (int): Maximum number of authentication attempts.
|
|
@@ -102,8 +99,7 @@ class Auth:
|
|
|
102
99
|
raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
|
|
103
100
|
|
|
104
101
|
def authenticate(self) -> bool:
|
|
105
|
-
"""
|
|
106
|
-
Attempt to authenticate with the server using either id_token or API key.
|
|
102
|
+
"""Attempt to authenticate with the server using either id_token or API key.
|
|
107
103
|
|
|
108
104
|
Returns:
|
|
109
105
|
(bool): True if authentication is successful, False otherwise.
|
|
@@ -123,8 +119,7 @@ class Auth:
|
|
|
123
119
|
return False
|
|
124
120
|
|
|
125
121
|
def auth_with_cookies(self) -> bool:
|
|
126
|
-
"""
|
|
127
|
-
Attempt to fetch authentication via cookies and set id_token.
|
|
122
|
+
"""Attempt to fetch authentication via cookies and set id_token.
|
|
128
123
|
|
|
129
124
|
User must be logged in to HUB and running in a supported browser.
|
|
130
125
|
|
|
@@ -145,8 +140,7 @@ class Auth:
|
|
|
145
140
|
return False
|
|
146
141
|
|
|
147
142
|
def get_auth_header(self):
|
|
148
|
-
"""
|
|
149
|
-
Get the authentication header for making API requests.
|
|
143
|
+
"""Get the authentication header for making API requests.
|
|
150
144
|
|
|
151
145
|
Returns:
|
|
152
146
|
(dict | None): The authentication header if id_token or API key is set, None otherwise.
|
|
@@ -8,11 +8,10 @@ import time
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class GCPRegions:
|
|
11
|
-
"""
|
|
12
|
-
A class for managing and analyzing Google Cloud Platform (GCP) regions.
|
|
11
|
+
"""A class for managing and analyzing Google Cloud Platform (GCP) regions.
|
|
13
12
|
|
|
14
|
-
This class provides functionality to initialize, categorize, and analyze GCP regions based on their
|
|
15
|
-
|
|
13
|
+
This class provides functionality to initialize, categorize, and analyze GCP regions based on their geographical
|
|
14
|
+
location, tier classification, and network latency.
|
|
16
15
|
|
|
17
16
|
Attributes:
|
|
18
17
|
regions (dict[str, tuple[int, str, str]]): A dictionary of GCP regions with their tier, city, and country.
|
|
@@ -82,8 +81,7 @@ class GCPRegions:
|
|
|
82
81
|
|
|
83
82
|
@staticmethod
|
|
84
83
|
def _ping_region(region: str, attempts: int = 1) -> tuple[str, float, float, float, float]:
|
|
85
|
-
"""
|
|
86
|
-
Ping a specified GCP region and measure network latency statistics.
|
|
84
|
+
"""Ping a specified GCP region and measure network latency statistics.
|
|
87
85
|
|
|
88
86
|
Args:
|
|
89
87
|
region (str): The GCP region identifier to ping (e.g., 'us-central1').
|
|
@@ -126,8 +124,7 @@ class GCPRegions:
|
|
|
126
124
|
tier: int | None = None,
|
|
127
125
|
attempts: int = 1,
|
|
128
126
|
) -> list[tuple[str, float, float, float, float]]:
|
|
129
|
-
"""
|
|
130
|
-
Determine the GCP regions with the lowest latency based on ping tests.
|
|
127
|
+
"""Determine the GCP regions with the lowest latency based on ping tests.
|
|
131
128
|
|
|
132
129
|
Args:
|
|
133
130
|
top (int, optional): Number of top regions to return.
|
|
@@ -136,8 +133,8 @@ class GCPRegions:
|
|
|
136
133
|
attempts (int, optional): Number of ping attempts per region.
|
|
137
134
|
|
|
138
135
|
Returns:
|
|
139
|
-
(list[tuple[str, float, float, float, float]]): List of tuples containing region information and
|
|
140
|
-
|
|
136
|
+
(list[tuple[str, float, float, float, float]]): List of tuples containing region information and latency
|
|
137
|
+
statistics. Each tuple contains (region, mean_latency, std_dev, min_latency, max_latency).
|
|
141
138
|
|
|
142
139
|
Examples:
|
|
143
140
|
>>> regions = GCPRegions()
|