dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/engine/validator.py
CHANGED
|
@@ -29,19 +29,19 @@ from pathlib import Path
|
|
|
29
29
|
|
|
30
30
|
import numpy as np
|
|
31
31
|
import torch
|
|
32
|
+
import torch.distributed as dist
|
|
32
33
|
|
|
33
34
|
from ultralytics.cfg import get_cfg, get_save_dir
|
|
34
35
|
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
|
35
36
|
from ultralytics.nn.autobackend import AutoBackend
|
|
36
|
-
from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
|
|
37
|
+
from ultralytics.utils import LOGGER, RANK, TQDM, callbacks, colorstr, emojis
|
|
37
38
|
from ultralytics.utils.checks import check_imgsz
|
|
38
39
|
from ultralytics.utils.ops import Profile
|
|
39
|
-
from ultralytics.utils.torch_utils import
|
|
40
|
+
from ultralytics.utils.torch_utils import attempt_compile, select_device, smart_inference_mode, unwrap_model
|
|
40
41
|
|
|
41
42
|
|
|
42
43
|
class BaseValidator:
|
|
43
|
-
"""
|
|
44
|
-
A base class for creating validators.
|
|
44
|
+
"""A base class for creating validators.
|
|
45
45
|
|
|
46
46
|
This class provides the foundation for validation processes, including model evaluation, metric computation, and
|
|
47
47
|
result visualization.
|
|
@@ -49,7 +49,6 @@ class BaseValidator:
|
|
|
49
49
|
Attributes:
|
|
50
50
|
args (SimpleNamespace): Configuration for the validator.
|
|
51
51
|
dataloader (DataLoader): Dataloader to use for validation.
|
|
52
|
-
pbar (tqdm): Progress bar to update during validation.
|
|
53
52
|
model (nn.Module): Model to validate.
|
|
54
53
|
data (dict): Data dictionary containing dataset information.
|
|
55
54
|
device (torch.device): Device to use for validation.
|
|
@@ -62,11 +61,13 @@ class BaseValidator:
|
|
|
62
61
|
nc (int): Number of classes.
|
|
63
62
|
iouv (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
|
|
64
63
|
jdict (list): List to store JSON validation results.
|
|
65
|
-
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
|
|
66
|
-
|
|
64
|
+
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective batch
|
|
65
|
+
processing times in milliseconds.
|
|
67
66
|
save_dir (Path): Directory to save results.
|
|
68
67
|
plots (dict): Dictionary to store plots for visualization.
|
|
69
68
|
callbacks (dict): Dictionary to store various callback functions.
|
|
69
|
+
stride (int): Model stride for padding calculations.
|
|
70
|
+
loss (torch.Tensor): Accumulated loss during training validation.
|
|
70
71
|
|
|
71
72
|
Methods:
|
|
72
73
|
__call__: Execute validation process, running inference on dataloader and computing performance metrics.
|
|
@@ -81,30 +82,28 @@ class BaseValidator:
|
|
|
81
82
|
update_metrics: Update metrics based on predictions and batch.
|
|
82
83
|
finalize_metrics: Finalize and return all metrics.
|
|
83
84
|
get_stats: Return statistics about the model's performance.
|
|
84
|
-
check_stats: Check statistics.
|
|
85
85
|
print_results: Print the results of the model's predictions.
|
|
86
86
|
get_desc: Get description of the YOLO model.
|
|
87
|
-
on_plot: Register plots
|
|
87
|
+
on_plot: Register plots for visualization.
|
|
88
88
|
plot_val_samples: Plot validation samples during training.
|
|
89
89
|
plot_predictions: Plot YOLO model predictions on batch images.
|
|
90
90
|
pred_to_json: Convert predictions to JSON format.
|
|
91
91
|
eval_json: Evaluate and return JSON format of prediction statistics.
|
|
92
92
|
"""
|
|
93
93
|
|
|
94
|
-
def __init__(self, dataloader=None, save_dir=None,
|
|
95
|
-
"""
|
|
96
|
-
Initialize a BaseValidator instance.
|
|
94
|
+
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
|
|
95
|
+
"""Initialize a BaseValidator instance.
|
|
97
96
|
|
|
98
97
|
Args:
|
|
99
98
|
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
|
100
99
|
save_dir (Path, optional): Directory to save results.
|
|
101
|
-
pbar (tqdm.tqdm, optional): Progress bar for displaying progress.
|
|
102
100
|
args (SimpleNamespace, optional): Configuration for the validator.
|
|
103
101
|
_callbacks (dict, optional): Dictionary to store various callback functions.
|
|
104
102
|
"""
|
|
103
|
+
import torchvision # noqa (import here so torchvision import time not recorded in postprocess time)
|
|
104
|
+
|
|
105
105
|
self.args = get_cfg(overrides=args)
|
|
106
106
|
self.dataloader = dataloader
|
|
107
|
-
self.pbar = pbar
|
|
108
107
|
self.stride = None
|
|
109
108
|
self.data = None
|
|
110
109
|
self.device = None
|
|
@@ -122,7 +121,7 @@ class BaseValidator:
|
|
|
122
121
|
self.save_dir = save_dir or get_save_dir(self.args)
|
|
123
122
|
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
|
124
123
|
if self.args.conf is None:
|
|
125
|
-
self.args.conf = 0.001 #
|
|
124
|
+
self.args.conf = 0.01 if self.args.task == "obb" else 0.001 # reduce OBB val memory usage
|
|
126
125
|
self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
|
|
127
126
|
|
|
128
127
|
self.plots = {}
|
|
@@ -130,15 +129,14 @@ class BaseValidator:
|
|
|
130
129
|
|
|
131
130
|
@smart_inference_mode()
|
|
132
131
|
def __call__(self, trainer=None, model=None):
|
|
133
|
-
"""
|
|
134
|
-
Execute validation process, running inference on dataloader and computing performance metrics.
|
|
132
|
+
"""Execute validation process, running inference on dataloader and computing performance metrics.
|
|
135
133
|
|
|
136
134
|
Args:
|
|
137
135
|
trainer (object, optional): Trainer object that contains the model to validate.
|
|
138
136
|
model (nn.Module, optional): Model to validate if not using a trainer.
|
|
139
137
|
|
|
140
138
|
Returns:
|
|
141
|
-
|
|
139
|
+
(dict): Dictionary containing validation statistics.
|
|
142
140
|
"""
|
|
143
141
|
self.training = trainer is not None
|
|
144
142
|
augment = self.args.augment and (not self.training)
|
|
@@ -148,8 +146,9 @@ class BaseValidator:
|
|
|
148
146
|
# Force FP16 val during training
|
|
149
147
|
self.args.half = self.device.type != "cpu" and trainer.amp
|
|
150
148
|
model = trainer.ema.ema or trainer.model
|
|
149
|
+
if trainer.args.compile and hasattr(model, "_orig_mod"):
|
|
150
|
+
model = model._orig_mod # validate non-compiled original model to avoid issues
|
|
151
151
|
model = model.half() if self.args.half else model.float()
|
|
152
|
-
# self.model = model
|
|
153
152
|
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
|
154
153
|
self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
|
|
155
154
|
model.eval()
|
|
@@ -158,24 +157,21 @@ class BaseValidator:
|
|
|
158
157
|
LOGGER.warning("validating an untrained model YAML will result in 0 mAP.")
|
|
159
158
|
callbacks.add_integration_callbacks(self)
|
|
160
159
|
model = AutoBackend(
|
|
161
|
-
|
|
162
|
-
device=select_device(self.args.device,
|
|
160
|
+
model=model or self.args.model,
|
|
161
|
+
device=select_device(self.args.device) if RANK == -1 else torch.device("cuda", RANK),
|
|
163
162
|
dnn=self.args.dnn,
|
|
164
163
|
data=self.args.data,
|
|
165
164
|
fp16=self.args.half,
|
|
166
165
|
)
|
|
167
|
-
# self.model = model
|
|
168
166
|
self.device = model.device # update device
|
|
169
167
|
self.args.half = model.fp16 # update half
|
|
170
|
-
stride, pt, jit
|
|
168
|
+
stride, pt, jit = model.stride, model.pt, model.jit
|
|
171
169
|
imgsz = check_imgsz(self.args.imgsz, stride=stride)
|
|
172
|
-
if
|
|
173
|
-
self.args.batch = model.batch_size
|
|
174
|
-
elif not (pt or jit or getattr(model, "dynamic", False)):
|
|
170
|
+
if not (pt or jit or getattr(model, "dynamic", False)):
|
|
175
171
|
self.args.batch = model.metadata.get("batch", 1) # export.py models default to batch-size 1
|
|
176
172
|
LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})")
|
|
177
173
|
|
|
178
|
-
if str(self.args.data).
|
|
174
|
+
if str(self.args.data).rsplit(".", 1)[-1] in {"yaml", "yml"}:
|
|
179
175
|
self.data = check_det_dataset(self.args.data)
|
|
180
176
|
elif self.args.task == "classify":
|
|
181
177
|
self.data = check_cls_dataset(self.args.data, split=self.args.split)
|
|
@@ -184,12 +180,14 @@ class BaseValidator:
|
|
|
184
180
|
|
|
185
181
|
if self.device.type in {"cpu", "mps"}:
|
|
186
182
|
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
|
187
|
-
if not (pt or getattr(model, "dynamic", False)):
|
|
183
|
+
if not (pt or (getattr(model, "dynamic", False) and not model.imx)):
|
|
188
184
|
self.args.rect = False
|
|
189
185
|
self.stride = model.stride # used in get_dataloader() for padding
|
|
190
186
|
self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
|
|
191
187
|
|
|
192
188
|
model.eval()
|
|
189
|
+
if self.args.compile:
|
|
190
|
+
model = attempt_compile(model, device=self.device)
|
|
193
191
|
model.warmup(imgsz=(1 if pt else self.args.batch, self.data["channels"], imgsz, imgsz)) # warmup
|
|
194
192
|
|
|
195
193
|
self.run_callbacks("on_val_start")
|
|
@@ -200,7 +198,7 @@ class BaseValidator:
|
|
|
200
198
|
Profile(device=self.device),
|
|
201
199
|
)
|
|
202
200
|
bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
|
|
203
|
-
self.init_metrics(
|
|
201
|
+
self.init_metrics(unwrap_model(model))
|
|
204
202
|
self.jdict = [] # empty before each val
|
|
205
203
|
for batch_i, batch in enumerate(bar):
|
|
206
204
|
self.run_callbacks("on_val_batch_start")
|
|
@@ -223,22 +221,34 @@ class BaseValidator:
|
|
|
223
221
|
preds = self.postprocess(preds)
|
|
224
222
|
|
|
225
223
|
self.update_metrics(preds, batch)
|
|
226
|
-
if self.args.plots and batch_i < 3:
|
|
224
|
+
if self.args.plots and batch_i < 3 and RANK in {-1, 0}:
|
|
227
225
|
self.plot_val_samples(batch, batch_i)
|
|
228
226
|
self.plot_predictions(batch, preds, batch_i)
|
|
229
227
|
|
|
230
228
|
self.run_callbacks("on_val_batch_end")
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
self.
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
229
|
+
|
|
230
|
+
stats = {}
|
|
231
|
+
self.gather_stats()
|
|
232
|
+
if RANK in {-1, 0}:
|
|
233
|
+
stats = self.get_stats()
|
|
234
|
+
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
|
|
235
|
+
self.finalize_metrics()
|
|
236
|
+
self.print_results()
|
|
237
|
+
self.run_callbacks("on_val_end")
|
|
238
|
+
|
|
237
239
|
if self.training:
|
|
238
240
|
model.float()
|
|
239
|
-
|
|
241
|
+
# Reduce loss across all GPUs
|
|
242
|
+
loss = self.loss.clone().detach()
|
|
243
|
+
if trainer.world_size > 1:
|
|
244
|
+
dist.reduce(loss, dst=0, op=dist.ReduceOp.AVG)
|
|
245
|
+
if RANK > 0:
|
|
246
|
+
return
|
|
247
|
+
results = {**stats, **trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val")}
|
|
240
248
|
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
|
241
249
|
else:
|
|
250
|
+
if RANK > 0:
|
|
251
|
+
return stats
|
|
242
252
|
LOGGER.info(
|
|
243
253
|
"Speed: {:.1f}ms preprocess, {:.1f}ms inference, {:.1f}ms loss, {:.1f}ms postprocess per image".format(
|
|
244
254
|
*tuple(self.speed.values())
|
|
@@ -256,14 +266,13 @@ class BaseValidator:
|
|
|
256
266
|
def match_predictions(
|
|
257
267
|
self, pred_classes: torch.Tensor, true_classes: torch.Tensor, iou: torch.Tensor, use_scipy: bool = False
|
|
258
268
|
) -> torch.Tensor:
|
|
259
|
-
"""
|
|
260
|
-
Match predictions to ground truth objects using IoU.
|
|
269
|
+
"""Match predictions to ground truth objects using IoU.
|
|
261
270
|
|
|
262
271
|
Args:
|
|
263
272
|
pred_classes (torch.Tensor): Predicted class indices of shape (N,).
|
|
264
273
|
true_classes (torch.Tensor): Target class indices of shape (M,).
|
|
265
274
|
iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground truth.
|
|
266
|
-
use_scipy (bool): Whether to use scipy for matching (more precise).
|
|
275
|
+
use_scipy (bool, optional): Whether to use scipy for matching (more precise).
|
|
267
276
|
|
|
268
277
|
Returns:
|
|
269
278
|
(torch.Tensor): Correct tensor of shape (N, 10) for 10 IoU thresholds.
|
|
@@ -292,7 +301,6 @@ class BaseValidator:
|
|
|
292
301
|
if matches.shape[0] > 1:
|
|
293
302
|
matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
|
|
294
303
|
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
|
295
|
-
# matches = matches[matches[:, 2].argsort()[::-1]]
|
|
296
304
|
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
|
297
305
|
correct[matches[:, 1].astype(int), i] = True
|
|
298
306
|
return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
|
|
@@ -330,7 +338,7 @@ class BaseValidator:
|
|
|
330
338
|
"""Update metrics based on predictions and batch."""
|
|
331
339
|
pass
|
|
332
340
|
|
|
333
|
-
def finalize_metrics(self
|
|
341
|
+
def finalize_metrics(self):
|
|
334
342
|
"""Finalize and return all metrics."""
|
|
335
343
|
pass
|
|
336
344
|
|
|
@@ -338,8 +346,8 @@ class BaseValidator:
|
|
|
338
346
|
"""Return statistics about the model's performance."""
|
|
339
347
|
return {}
|
|
340
348
|
|
|
341
|
-
def
|
|
342
|
-
"""
|
|
349
|
+
def gather_stats(self):
|
|
350
|
+
"""Gather statistics from all the GPUs during DDP training to GPU 0."""
|
|
343
351
|
pass
|
|
344
352
|
|
|
345
353
|
def print_results(self):
|
|
@@ -356,10 +364,9 @@ class BaseValidator:
|
|
|
356
364
|
return []
|
|
357
365
|
|
|
358
366
|
def on_plot(self, name, data=None):
|
|
359
|
-
"""Register plots
|
|
367
|
+
"""Register plots for visualization."""
|
|
360
368
|
self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
|
|
361
369
|
|
|
362
|
-
# TODO: may need to put these following functions into callback
|
|
363
370
|
def plot_val_samples(self, batch, ni):
|
|
364
371
|
"""Plot validation samples during training."""
|
|
365
372
|
pass
|
ultralytics/hub/__init__.py
CHANGED
|
@@ -1,31 +1,29 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from ultralytics.data.utils import HUBDatasetStats
|
|
6
6
|
from ultralytics.hub.auth import Auth
|
|
7
7
|
from ultralytics.hub.session import HUBTrainingSession
|
|
8
|
-
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
|
|
8
|
+
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
|
|
9
9
|
from ultralytics.utils import LOGGER, SETTINGS, checks
|
|
10
10
|
|
|
11
11
|
__all__ = (
|
|
12
|
-
"PREFIX",
|
|
13
12
|
"HUB_WEB_ROOT",
|
|
13
|
+
"PREFIX",
|
|
14
14
|
"HUBTrainingSession",
|
|
15
|
-
"
|
|
16
|
-
"logout",
|
|
17
|
-
"reset_model",
|
|
15
|
+
"check_dataset",
|
|
18
16
|
"export_fmts_hub",
|
|
19
17
|
"export_model",
|
|
20
18
|
"get_export",
|
|
21
|
-
"
|
|
22
|
-
"
|
|
19
|
+
"login",
|
|
20
|
+
"logout",
|
|
21
|
+
"reset_model",
|
|
23
22
|
)
|
|
24
23
|
|
|
25
24
|
|
|
26
|
-
def login(api_key: str = None, save: bool = True) -> bool:
|
|
27
|
-
"""
|
|
28
|
-
Log in to the Ultralytics HUB API using the provided API key.
|
|
25
|
+
def login(api_key: str | None = None, save: bool = True) -> bool:
|
|
26
|
+
"""Log in to the Ultralytics HUB API using the provided API key.
|
|
29
27
|
|
|
30
28
|
The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY
|
|
31
29
|
environment variable if successfully authenticated.
|
|
@@ -68,19 +66,15 @@ def login(api_key: str = None, save: bool = True) -> bool:
|
|
|
68
66
|
|
|
69
67
|
|
|
70
68
|
def logout():
|
|
71
|
-
"""
|
|
72
|
-
Log out of Ultralytics HUB by removing the API key from the settings file. To log in again, use 'yolo login'.
|
|
73
|
-
|
|
74
|
-
Examples:
|
|
75
|
-
>>> from ultralytics import hub
|
|
76
|
-
>>> hub.logout()
|
|
77
|
-
"""
|
|
69
|
+
"""Log out of Ultralytics HUB by removing the API key from the settings file."""
|
|
78
70
|
SETTINGS["api_key"] = ""
|
|
79
71
|
LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo login'.")
|
|
80
72
|
|
|
81
73
|
|
|
82
74
|
def reset_model(model_id: str = ""):
|
|
83
75
|
"""Reset a trained model to an untrained state."""
|
|
76
|
+
import requests # scoped as slow import
|
|
77
|
+
|
|
84
78
|
r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key})
|
|
85
79
|
if r.status_code == 200:
|
|
86
80
|
LOGGER.info(f"{PREFIX}Model reset successfully")
|
|
@@ -89,15 +83,14 @@ def reset_model(model_id: str = ""):
|
|
|
89
83
|
|
|
90
84
|
|
|
91
85
|
def export_fmts_hub():
|
|
92
|
-
"""
|
|
86
|
+
"""Return a list of HUB-supported export formats."""
|
|
93
87
|
from ultralytics.engine.exporter import export_formats
|
|
94
88
|
|
|
95
|
-
return list(export_formats()["Argument"][1:])
|
|
89
|
+
return [*list(export_formats()["Argument"][1:]), "ultralytics_tflite", "ultralytics_coreml"]
|
|
96
90
|
|
|
97
91
|
|
|
98
92
|
def export_model(model_id: str = "", format: str = "torchscript"):
|
|
99
|
-
"""
|
|
100
|
-
Export a model to a specified format for deployment via the Ultralytics HUB API.
|
|
93
|
+
"""Export a model to a specified format for deployment via the Ultralytics HUB API.
|
|
101
94
|
|
|
102
95
|
Args:
|
|
103
96
|
model_id (str): The ID of the model to export. An empty string will use the default model.
|
|
@@ -111,6 +104,8 @@ def export_model(model_id: str = "", format: str = "torchscript"):
|
|
|
111
104
|
>>> from ultralytics import hub
|
|
112
105
|
>>> hub.export_model(model_id="your_model_id", format="torchscript")
|
|
113
106
|
"""
|
|
107
|
+
import requests # scoped as slow import
|
|
108
|
+
|
|
114
109
|
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
|
115
110
|
r = requests.post(
|
|
116
111
|
f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key}
|
|
@@ -120,20 +115,24 @@ def export_model(model_id: str = "", format: str = "torchscript"):
|
|
|
120
115
|
|
|
121
116
|
|
|
122
117
|
def get_export(model_id: str = "", format: str = "torchscript"):
|
|
123
|
-
"""
|
|
124
|
-
Retrieve an exported model in the specified format from Ultralytics HUB using the model ID.
|
|
118
|
+
"""Retrieve an exported model in the specified format from Ultralytics HUB using the model ID.
|
|
125
119
|
|
|
126
120
|
Args:
|
|
127
121
|
model_id (str): The ID of the model to retrieve from Ultralytics HUB.
|
|
128
122
|
format (str): The export format to retrieve. Must be one of the supported formats returned by export_fmts_hub().
|
|
129
123
|
|
|
124
|
+
Returns:
|
|
125
|
+
(dict): JSON response containing the exported model information.
|
|
126
|
+
|
|
130
127
|
Raises:
|
|
131
128
|
AssertionError: If the specified format is not supported or if the API request fails.
|
|
132
129
|
|
|
133
130
|
Examples:
|
|
134
131
|
>>> from ultralytics import hub
|
|
135
|
-
>>> hub.get_export(model_id="your_model_id", format="torchscript")
|
|
132
|
+
>>> result = hub.get_export(model_id="your_model_id", format="torchscript")
|
|
136
133
|
"""
|
|
134
|
+
import requests # scoped as slow import
|
|
135
|
+
|
|
137
136
|
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
|
138
137
|
r = requests.post(
|
|
139
138
|
f"{HUB_API_ROOT}/get-export",
|
|
@@ -145,8 +144,7 @@ def get_export(model_id: str = "", format: str = "torchscript"):
|
|
|
145
144
|
|
|
146
145
|
|
|
147
146
|
def check_dataset(path: str, task: str) -> None:
|
|
148
|
-
"""
|
|
149
|
-
Check HUB dataset Zip file for errors before upload.
|
|
147
|
+
"""Check HUB dataset Zip file for errors before upload.
|
|
150
148
|
|
|
151
149
|
Args:
|
|
152
150
|
path (str): Path to data.zip (with data.yaml inside data.zip).
|
|
@@ -160,7 +158,7 @@ def check_dataset(path: str, task: str) -> None:
|
|
|
160
158
|
>>> check_dataset("path/to/dota8.zip", task="obb") # OBB dataset
|
|
161
159
|
>>> check_dataset("path/to/imagenet10.zip", task="classify") # classification dataset
|
|
162
160
|
|
|
163
|
-
|
|
161
|
+
Notes:
|
|
164
162
|
Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
|
|
165
163
|
i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
|
|
166
164
|
"""
|
ultralytics/hub/auth.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
-
import requests
|
|
4
|
-
|
|
5
3
|
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, request_with_credentials
|
|
6
4
|
from ultralytics.utils import IS_COLAB, LOGGER, SETTINGS, emojis
|
|
7
5
|
|
|
@@ -9,8 +7,7 @@ API_KEY_URL = f"{HUB_WEB_ROOT}/settings?tab=api+keys"
|
|
|
9
7
|
|
|
10
8
|
|
|
11
9
|
class Auth:
|
|
12
|
-
"""
|
|
13
|
-
Manages authentication processes including API key handling, cookie-based authentication, and header generation.
|
|
10
|
+
"""Manages authentication processes including API key handling, cookie-based authentication, and header generation.
|
|
14
11
|
|
|
15
12
|
The class supports different methods of authentication:
|
|
16
13
|
1. Directly using an API key.
|
|
@@ -21,13 +18,25 @@ class Auth:
|
|
|
21
18
|
id_token (str | bool): Token used for identity verification, initialized as False.
|
|
22
19
|
api_key (str | bool): API key for authentication, initialized as False.
|
|
23
20
|
model_key (bool): Placeholder for model key, initialized as False.
|
|
21
|
+
|
|
22
|
+
Methods:
|
|
23
|
+
authenticate: Attempt to authenticate with the server using either id_token or API key.
|
|
24
|
+
auth_with_cookies: Attempt to fetch authentication via cookies and set id_token.
|
|
25
|
+
get_auth_header: Get the authentication header for making API requests.
|
|
26
|
+
request_api_key: Prompt the user to input their API key.
|
|
27
|
+
|
|
28
|
+
Examples:
|
|
29
|
+
Initialize Auth with an API key
|
|
30
|
+
>>> auth = Auth(api_key="your_api_key_here")
|
|
31
|
+
|
|
32
|
+
Initialize Auth without API key (will prompt for input)
|
|
33
|
+
>>> auth = Auth()
|
|
24
34
|
"""
|
|
25
35
|
|
|
26
36
|
id_token = api_key = model_key = False
|
|
27
37
|
|
|
28
38
|
def __init__(self, api_key: str = "", verbose: bool = False):
|
|
29
|
-
"""
|
|
30
|
-
Initialize Auth class and authenticate user.
|
|
39
|
+
"""Initialize Auth class and authenticate user.
|
|
31
40
|
|
|
32
41
|
Handles API key validation, Google Colab authentication, and new key requests. Updates SETTINGS upon successful
|
|
33
42
|
authentication.
|
|
@@ -37,7 +46,7 @@ class Auth:
|
|
|
37
46
|
verbose (bool): Enable verbose logging.
|
|
38
47
|
"""
|
|
39
48
|
# Split the input API key in case it contains a combined key_model and keep only the API key part
|
|
40
|
-
api_key = api_key.split("_")[0]
|
|
49
|
+
api_key = api_key.split("_", 1)[0]
|
|
41
50
|
|
|
42
51
|
# Set API key attribute as value passed or SETTINGS API key if none passed
|
|
43
52
|
self.api_key = api_key or SETTINGS.get("api_key", "")
|
|
@@ -71,24 +80,32 @@ class Auth:
|
|
|
71
80
|
LOGGER.info(f"{PREFIX}Get API key from {API_KEY_URL} and then run 'yolo login API_KEY'")
|
|
72
81
|
|
|
73
82
|
def request_api_key(self, max_attempts: int = 3) -> bool:
|
|
74
|
-
"""Prompt the user to input their API key.
|
|
83
|
+
"""Prompt the user to input their API key.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
max_attempts (int): Maximum number of authentication attempts.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
(bool): True if authentication is successful, False otherwise.
|
|
90
|
+
"""
|
|
75
91
|
import getpass
|
|
76
92
|
|
|
77
93
|
for attempts in range(max_attempts):
|
|
78
94
|
LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}")
|
|
79
95
|
input_key = getpass.getpass(f"Enter API key from {API_KEY_URL} ")
|
|
80
|
-
self.api_key = input_key.split("_")[0] # remove model id if present
|
|
96
|
+
self.api_key = input_key.split("_", 1)[0] # remove model id if present
|
|
81
97
|
if self.authenticate():
|
|
82
98
|
return True
|
|
83
99
|
raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
|
|
84
100
|
|
|
85
101
|
def authenticate(self) -> bool:
|
|
86
|
-
"""
|
|
87
|
-
Attempt to authenticate with the server using either id_token or API key.
|
|
102
|
+
"""Attempt to authenticate with the server using either id_token or API key.
|
|
88
103
|
|
|
89
104
|
Returns:
|
|
90
105
|
(bool): True if authentication is successful, False otherwise.
|
|
91
106
|
"""
|
|
107
|
+
import requests # scoped as slow import
|
|
108
|
+
|
|
92
109
|
try:
|
|
93
110
|
if header := self.get_auth_header():
|
|
94
111
|
r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header)
|
|
@@ -102,8 +119,7 @@ class Auth:
|
|
|
102
119
|
return False
|
|
103
120
|
|
|
104
121
|
def auth_with_cookies(self) -> bool:
|
|
105
|
-
"""
|
|
106
|
-
Attempt to fetch authentication via cookies and set id_token.
|
|
122
|
+
"""Attempt to fetch authentication via cookies and set id_token.
|
|
107
123
|
|
|
108
124
|
User must be logged in to HUB and running in a supported browser.
|
|
109
125
|
|
|
@@ -124,8 +140,7 @@ class Auth:
|
|
|
124
140
|
return False
|
|
125
141
|
|
|
126
142
|
def get_auth_header(self):
|
|
127
|
-
"""
|
|
128
|
-
Get the authentication header for making API requests.
|
|
143
|
+
"""Get the authentication header for making API requests.
|
|
129
144
|
|
|
130
145
|
Returns:
|
|
131
146
|
(dict | None): The authentication header if id_token or API key is set, None otherwise.
|
|
@@ -134,4 +149,3 @@ class Auth:
|
|
|
134
149
|
return {"authorization": f"Bearer {self.id_token}"}
|
|
135
150
|
elif self.api_key:
|
|
136
151
|
return {"x-api-key": self.api_key}
|
|
137
|
-
# else returns None
|
|
@@ -1,22 +1,20 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
import concurrent.futures
|
|
4
6
|
import statistics
|
|
5
7
|
import time
|
|
6
|
-
from typing import List, Optional, Tuple
|
|
7
|
-
|
|
8
|
-
import requests
|
|
9
8
|
|
|
10
9
|
|
|
11
10
|
class GCPRegions:
|
|
12
|
-
"""
|
|
13
|
-
A class for managing and analyzing Google Cloud Platform (GCP) regions.
|
|
11
|
+
"""A class for managing and analyzing Google Cloud Platform (GCP) regions.
|
|
14
12
|
|
|
15
|
-
This class provides functionality to initialize, categorize, and analyze GCP regions based on their
|
|
16
|
-
|
|
13
|
+
This class provides functionality to initialize, categorize, and analyze GCP regions based on their geographical
|
|
14
|
+
location, tier classification, and network latency.
|
|
17
15
|
|
|
18
16
|
Attributes:
|
|
19
|
-
regions (
|
|
17
|
+
regions (dict[str, tuple[int, str, str]]): A dictionary of GCP regions with their tier, city, and country.
|
|
20
18
|
|
|
21
19
|
Methods:
|
|
22
20
|
tier1: Returns a list of tier 1 GCP regions.
|
|
@@ -31,7 +29,7 @@ class GCPRegions:
|
|
|
31
29
|
"""
|
|
32
30
|
|
|
33
31
|
def __init__(self):
|
|
34
|
-
"""
|
|
32
|
+
"""Initialize the GCPRegions class with predefined Google Cloud Platform regions and their details."""
|
|
35
33
|
self.regions = {
|
|
36
34
|
"asia-east1": (1, "Taiwan", "China"),
|
|
37
35
|
"asia-east2": (2, "Hong Kong", "China"),
|
|
@@ -73,41 +71,42 @@ class GCPRegions:
|
|
|
73
71
|
"us-west4": (2, "Las Vegas", "United States"),
|
|
74
72
|
}
|
|
75
73
|
|
|
76
|
-
def tier1(self) ->
|
|
77
|
-
"""
|
|
74
|
+
def tier1(self) -> list[str]:
|
|
75
|
+
"""Return a list of GCP regions classified as tier 1 based on predefined criteria."""
|
|
78
76
|
return [region for region, info in self.regions.items() if info[0] == 1]
|
|
79
77
|
|
|
80
|
-
def tier2(self) ->
|
|
81
|
-
"""
|
|
78
|
+
def tier2(self) -> list[str]:
|
|
79
|
+
"""Return a list of GCP regions classified as tier 2 based on predefined criteria."""
|
|
82
80
|
return [region for region, info in self.regions.items() if info[0] == 2]
|
|
83
81
|
|
|
84
82
|
@staticmethod
|
|
85
|
-
def _ping_region(region: str, attempts: int = 1) ->
|
|
86
|
-
"""
|
|
87
|
-
Ping a specified GCP region and measure network latency statistics.
|
|
83
|
+
def _ping_region(region: str, attempts: int = 1) -> tuple[str, float, float, float, float]:
|
|
84
|
+
"""Ping a specified GCP region and measure network latency statistics.
|
|
88
85
|
|
|
89
86
|
Args:
|
|
90
|
-
|
|
91
|
-
|
|
87
|
+
region (str): The GCP region identifier to ping (e.g., 'us-central1').
|
|
88
|
+
attempts (int, optional): Number of ping attempts to make for calculating statistics.
|
|
92
89
|
|
|
93
90
|
Returns:
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
91
|
+
region (str): The GCP region identifier that was pinged.
|
|
92
|
+
mean_latency (float): Mean latency in milliseconds, or infinity if all pings failed.
|
|
93
|
+
std_dev (float): Standard deviation of latencies in milliseconds, or infinity if all pings failed.
|
|
94
|
+
min_latency (float): Minimum latency in milliseconds, or infinity if all pings failed.
|
|
95
|
+
max_latency (float): Maximum latency in milliseconds, or infinity if all pings failed.
|
|
99
96
|
|
|
100
97
|
Examples:
|
|
101
|
-
|
|
102
|
-
|
|
98
|
+
>>> region, mean, std, min_lat, max_lat = GCPRegions._ping_region("us-central1", attempts=3)
|
|
99
|
+
>>> print(f"Region {region} has mean latency: {mean:.2f}ms")
|
|
103
100
|
"""
|
|
101
|
+
import requests # scoped as slow import
|
|
102
|
+
|
|
104
103
|
url = f"https://{region}-docker.pkg.dev"
|
|
105
104
|
latencies = []
|
|
106
105
|
for _ in range(attempts):
|
|
107
106
|
try:
|
|
108
107
|
start_time = time.time()
|
|
109
108
|
_ = requests.head(url, timeout=5)
|
|
110
|
-
latency = (time.time() - start_time) * 1000 #
|
|
109
|
+
latency = (time.time() - start_time) * 1000 # Convert latency to milliseconds
|
|
111
110
|
if latency != float("inf"):
|
|
112
111
|
latencies.append(latency)
|
|
113
112
|
except requests.RequestException:
|
|
@@ -122,21 +121,20 @@ class GCPRegions:
|
|
|
122
121
|
self,
|
|
123
122
|
top: int = 1,
|
|
124
123
|
verbose: bool = False,
|
|
125
|
-
tier:
|
|
124
|
+
tier: int | None = None,
|
|
126
125
|
attempts: int = 1,
|
|
127
|
-
) ->
|
|
128
|
-
"""
|
|
129
|
-
Determines the GCP regions with the lowest latency based on ping tests.
|
|
126
|
+
) -> list[tuple[str, float, float, float, float]]:
|
|
127
|
+
"""Determine the GCP regions with the lowest latency based on ping tests.
|
|
130
128
|
|
|
131
129
|
Args:
|
|
132
|
-
top (int): Number of top regions to return.
|
|
133
|
-
verbose (bool): If True, prints detailed latency information for all tested regions.
|
|
134
|
-
tier (int | None): Filter regions by tier (1 or 2). If None, all regions are tested.
|
|
135
|
-
attempts (int): Number of ping attempts per region.
|
|
130
|
+
top (int, optional): Number of top regions to return.
|
|
131
|
+
verbose (bool, optional): If True, prints detailed latency information for all tested regions.
|
|
132
|
+
tier (int | None, optional): Filter regions by tier (1 or 2). If None, all regions are tested.
|
|
133
|
+
attempts (int, optional): Number of ping attempts per region.
|
|
136
134
|
|
|
137
135
|
Returns:
|
|
138
|
-
(
|
|
139
|
-
|
|
136
|
+
(list[tuple[str, float, float, float, float]]): List of tuples containing region information and latency
|
|
137
|
+
statistics. Each tuple contains (region, mean_latency, std_dev, min_latency, max_latency).
|
|
140
138
|
|
|
141
139
|
Examples:
|
|
142
140
|
>>> regions = GCPRegions()
|