ultralytics 8.1.29__py3-none-any.whl → 8.3.62__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +122 -0
- tests/test_cuda.py +155 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +216 -0
- tests/test_integrations.py +150 -0
- tests/test_python.py +615 -0
- tests/test_solutions.py +94 -0
- ultralytics/__init__.py +11 -8
- ultralytics/cfg/__init__.py +569 -131
- ultralytics/cfg/datasets/Argoverse.yaml +2 -1
- ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
- ultralytics/cfg/datasets/ImageNet.yaml +2 -1
- ultralytics/cfg/datasets/Objects365.yaml +5 -4
- ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
- ultralytics/cfg/datasets/VOC.yaml +3 -2
- ultralytics/cfg/datasets/VisDrone.yaml +6 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco-pose.yaml +7 -6
- ultralytics/cfg/datasets/coco.yaml +3 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
- ultralytics/cfg/datasets/coco128.yaml +4 -3
- ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco8.yaml +3 -2
- ultralytics/cfg/datasets/crack-seg.yaml +3 -2
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8.yaml +3 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1236 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
- ultralytics/cfg/datasets/package-seg.yaml +5 -4
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
- ultralytics/cfg/datasets/xView.yaml +2 -1
- ultralytics/cfg/default.yaml +14 -11
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5.yaml +5 -2
- ultralytics/cfg/models/v6/yolov6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8.yaml +5 -2
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/solutions/default.yaml +24 -0
- ultralytics/cfg/trackers/botsort.yaml +8 -5
- ultralytics/cfg/trackers/bytetrack.yaml +8 -5
- ultralytics/data/__init__.py +14 -3
- ultralytics/data/annotator.py +37 -15
- ultralytics/data/augment.py +1783 -289
- ultralytics/data/base.py +62 -27
- ultralytics/data/build.py +36 -8
- ultralytics/data/converter.py +196 -36
- ultralytics/data/dataset.py +233 -94
- ultralytics/data/loaders.py +199 -96
- ultralytics/data/split_dota.py +39 -29
- ultralytics/data/utils.py +110 -40
- ultralytics/engine/__init__.py +1 -1
- ultralytics/engine/exporter.py +569 -242
- ultralytics/engine/model.py +604 -252
- ultralytics/engine/predictor.py +22 -11
- ultralytics/engine/results.py +1228 -218
- ultralytics/engine/trainer.py +190 -129
- ultralytics/engine/tuner.py +18 -18
- ultralytics/engine/validator.py +18 -15
- ultralytics/hub/__init__.py +31 -13
- ultralytics/hub/auth.py +11 -7
- ultralytics/hub/google/__init__.py +159 -0
- ultralytics/hub/session.py +128 -94
- ultralytics/hub/utils.py +20 -21
- ultralytics/models/__init__.py +4 -2
- ultralytics/models/fastsam/__init__.py +2 -3
- ultralytics/models/fastsam/model.py +26 -4
- ultralytics/models/fastsam/predict.py +127 -63
- ultralytics/models/fastsam/utils.py +1 -44
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +21 -10
- ultralytics/models/nas/predict.py +3 -6
- ultralytics/models/nas/val.py +4 -4
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +1 -1
- ultralytics/models/rtdetr/predict.py +6 -8
- ultralytics/models/rtdetr/train.py +6 -2
- ultralytics/models/rtdetr/val.py +3 -3
- ultralytics/models/sam/__init__.py +3 -3
- ultralytics/models/sam/amg.py +29 -23
- ultralytics/models/sam/build.py +211 -13
- ultralytics/models/sam/model.py +91 -30
- ultralytics/models/sam/modules/__init__.py +1 -1
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +381 -53
- ultralytics/models/sam/modules/encoders.py +515 -324
- ultralytics/models/sam/modules/memory_attention.py +237 -0
- ultralytics/models/sam/modules/sam.py +969 -21
- ultralytics/models/sam/modules/tiny_encoder.py +425 -154
- ultralytics/models/sam/modules/transformer.py +159 -60
- ultralytics/models/sam/modules/utils.py +293 -0
- ultralytics/models/sam/predict.py +1263 -132
- ultralytics/models/utils/__init__.py +1 -1
- ultralytics/models/utils/loss.py +36 -24
- ultralytics/models/utils/ops.py +3 -7
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +7 -8
- ultralytics/models/yolo/classify/train.py +17 -22
- ultralytics/models/yolo/classify/val.py +8 -4
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +3 -5
- ultralytics/models/yolo/detect/train.py +11 -4
- ultralytics/models/yolo/detect/val.py +90 -52
- ultralytics/models/yolo/model.py +14 -9
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +2 -2
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +41 -23
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +3 -5
- ultralytics/models/yolo/pose/train.py +2 -2
- ultralytics/models/yolo/pose/val.py +51 -17
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +3 -5
- ultralytics/models/yolo/segment/train.py +2 -2
- ultralytics/models/yolo/segment/val.py +60 -19
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +92 -0
- ultralytics/models/yolo/world/train_world.py +109 -0
- ultralytics/nn/__init__.py +1 -1
- ultralytics/nn/autobackend.py +228 -93
- ultralytics/nn/modules/__init__.py +39 -14
- ultralytics/nn/modules/activation.py +21 -0
- ultralytics/nn/modules/block.py +526 -66
- ultralytics/nn/modules/conv.py +24 -7
- ultralytics/nn/modules/head.py +177 -34
- ultralytics/nn/modules/transformer.py +6 -5
- ultralytics/nn/modules/utils.py +1 -2
- ultralytics/nn/tasks.py +225 -77
- ultralytics/solutions/__init__.py +30 -1
- ultralytics/solutions/ai_gym.py +96 -143
- ultralytics/solutions/analytics.py +247 -0
- ultralytics/solutions/distance_calculation.py +78 -135
- ultralytics/solutions/heatmap.py +93 -247
- ultralytics/solutions/object_counter.py +184 -259
- ultralytics/solutions/parking_management.py +246 -0
- ultralytics/solutions/queue_management.py +112 -0
- ultralytics/solutions/region_counter.py +116 -0
- ultralytics/solutions/security_alarm.py +144 -0
- ultralytics/solutions/solutions.py +178 -0
- ultralytics/solutions/speed_estimation.py +86 -174
- ultralytics/solutions/streamlit_inference.py +190 -0
- ultralytics/solutions/trackzone.py +68 -0
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +32 -13
- ultralytics/trackers/bot_sort.py +61 -28
- ultralytics/trackers/byte_tracker.py +83 -51
- ultralytics/trackers/track.py +21 -6
- ultralytics/trackers/utils/__init__.py +1 -1
- ultralytics/trackers/utils/gmc.py +62 -48
- ultralytics/trackers/utils/kalman_filter.py +166 -35
- ultralytics/trackers/utils/matching.py +40 -21
- ultralytics/utils/__init__.py +511 -239
- ultralytics/utils/autobatch.py +40 -22
- ultralytics/utils/benchmarks.py +266 -85
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +1 -3
- ultralytics/utils/callbacks/clearml.py +7 -6
- ultralytics/utils/callbacks/comet.py +39 -17
- ultralytics/utils/callbacks/dvc.py +1 -1
- ultralytics/utils/callbacks/hub.py +16 -16
- ultralytics/utils/callbacks/mlflow.py +28 -24
- ultralytics/utils/callbacks/neptune.py +6 -2
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +18 -18
- ultralytics/utils/callbacks/wb.py +27 -20
- ultralytics/utils/checks.py +160 -100
- ultralytics/utils/dist.py +2 -1
- ultralytics/utils/downloads.py +40 -34
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +72 -38
- ultralytics/utils/instance.py +41 -19
- ultralytics/utils/loss.py +83 -55
- ultralytics/utils/metrics.py +61 -56
- ultralytics/utils/ops.py +94 -89
- ultralytics/utils/patches.py +30 -14
- ultralytics/utils/plotting.py +600 -269
- ultralytics/utils/tal.py +67 -26
- ultralytics/utils/torch_utils.py +302 -102
- ultralytics/utils/triton.py +2 -1
- ultralytics/utils/tuner.py +21 -12
- ultralytics-8.3.62.dist-info/METADATA +370 -0
- ultralytics-8.3.62.dist-info/RECORD +241 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
- ultralytics/data/explorer/__init__.py +0 -5
- ultralytics/data/explorer/explorer.py +0 -472
- ultralytics/data/explorer/gui/__init__.py +0 -1
- ultralytics/data/explorer/gui/dash.py +0 -268
- ultralytics/data/explorer/utils.py +0 -166
- ultralytics/models/fastsam/prompt.py +0 -357
- ultralytics-8.1.29.dist-info/METADATA +0 -373
- ultralytics-8.1.29.dist-info/RECORD +0 -197
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
ultralytics/hub/session.py
CHANGED
@@ -1,17 +1,19 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
import shutil
|
3
4
|
import threading
|
4
5
|
import time
|
5
6
|
from http import HTTPStatus
|
6
7
|
from pathlib import Path
|
8
|
+
from urllib.parse import parse_qs, urlparse
|
7
9
|
|
8
10
|
import requests
|
9
11
|
|
10
|
-
from ultralytics.hub.utils import
|
11
|
-
from ultralytics.utils import LOGGER, SETTINGS, __version__, checks, emojis
|
12
|
+
from ultralytics.hub.utils import HELP_MSG, HUB_WEB_ROOT, PREFIX, TQDM
|
13
|
+
from ultralytics.utils import IS_COLAB, LOGGER, SETTINGS, __version__, checks, emojis
|
12
14
|
from ultralytics.utils.errors import HUBModelError
|
13
15
|
|
14
|
-
AGENT_NAME = f"python-{__version__}-colab" if
|
16
|
+
AGENT_NAME = f"python-{__version__}-colab" if IS_COLAB else f"python-{__version__}-local"
|
15
17
|
|
16
18
|
|
17
19
|
class HUBTrainingSession:
|
@@ -19,16 +21,12 @@ class HUBTrainingSession:
|
|
19
21
|
HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
|
20
22
|
|
21
23
|
Attributes:
|
22
|
-
agent_id (str): Identifier for the instance communicating with the server.
|
23
24
|
model_id (str): Identifier for the YOLO model being trained.
|
24
25
|
model_url (str): URL for the model in Ultralytics HUB.
|
25
|
-
api_url (str): API URL for the model in Ultralytics HUB.
|
26
|
-
auth_header (dict): Authentication header for the Ultralytics HUB API requests.
|
27
26
|
rate_limits (dict): Rate limits for different API calls (in seconds).
|
28
27
|
timers (dict): Timers for rate limiting.
|
29
28
|
metrics_queue (dict): Queue for the model's metrics.
|
30
29
|
model (dict): Model data fetched from Ultralytics HUB.
|
31
|
-
alive (bool): Indicates if the heartbeat loop is active.
|
32
30
|
"""
|
33
31
|
|
34
32
|
def __init__(self, identifier):
|
@@ -46,14 +44,14 @@ class HUBTrainingSession:
|
|
46
44
|
"""
|
47
45
|
from hub_sdk import HUBClient
|
48
46
|
|
49
|
-
self.rate_limits = {
|
50
|
-
"metrics": 3.0,
|
51
|
-
"ckpt": 900.0,
|
52
|
-
"heartbeat": 300.0,
|
53
|
-
} # rate limits (seconds)
|
47
|
+
self.rate_limits = {"metrics": 3, "ckpt": 900, "heartbeat": 300} # rate limits (seconds)
|
54
48
|
self.metrics_queue = {} # holds metrics for each epoch until upload
|
55
49
|
self.metrics_upload_failed_queue = {} # holds metrics for each epoch if upload failed
|
56
50
|
self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
|
51
|
+
self.model = None
|
52
|
+
self.model_url = None
|
53
|
+
self.model_file = None
|
54
|
+
self.train_args = None
|
57
55
|
|
58
56
|
# Parse input
|
59
57
|
api_key, model_id, self.filename = self._parse_identifier(identifier)
|
@@ -65,10 +63,31 @@ class HUBTrainingSession:
|
|
65
63
|
# Initialize client
|
66
64
|
self.client = HUBClient(credentials)
|
67
65
|
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
66
|
+
# Load models
|
67
|
+
try:
|
68
|
+
if model_id:
|
69
|
+
self.load_model(model_id) # load existing model
|
70
|
+
else:
|
71
|
+
self.model = self.client.model() # load empty model
|
72
|
+
except Exception:
|
73
|
+
if identifier.startswith(f"{HUB_WEB_ROOT}/models/") and not self.client.authenticated:
|
74
|
+
LOGGER.warning(
|
75
|
+
f"{PREFIX}WARNING ⚠️ Please log in using 'yolo login API_KEY'. "
|
76
|
+
"You can find your API Key at: https://hub.ultralytics.com/settings?tab=api+keys."
|
77
|
+
)
|
78
|
+
|
79
|
+
@classmethod
|
80
|
+
def create_session(cls, identifier, args=None):
|
81
|
+
"""Class method to create an authenticated HUBTrainingSession or return None."""
|
82
|
+
try:
|
83
|
+
session = cls(identifier)
|
84
|
+
if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # not a HUB model URL
|
85
|
+
session.create_model(args)
|
86
|
+
assert session.model.id, "HUB model not loaded correctly"
|
87
|
+
return session
|
88
|
+
# PermissionError and ModuleNotFoundError indicate hub-sdk not installed
|
89
|
+
except (PermissionError, ModuleNotFoundError, AssertionError):
|
90
|
+
return None
|
72
91
|
|
73
92
|
def load_model(self, model_id):
|
74
93
|
"""Loads an existing model from Ultralytics HUB using the provided model identifier."""
|
@@ -77,10 +96,14 @@ class HUBTrainingSession:
|
|
77
96
|
raise ValueError(emojis("❌ The specified HUB model does not exist")) # TODO: improve error handling
|
78
97
|
|
79
98
|
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
|
99
|
+
if self.model.is_trained():
|
100
|
+
print(emojis(f"Loading trained HUB model {self.model_url} 🚀"))
|
101
|
+
url = self.model.get_weights_url("best") # download URL with auth
|
102
|
+
self.model_file = checks.check_file(url, download_dir=Path(SETTINGS["weights_dir"]) / "hub" / self.model.id)
|
103
|
+
return
|
80
104
|
|
105
|
+
# Set training args and start heartbeats for HUB to monitor agent
|
81
106
|
self._set_train_args()
|
82
|
-
|
83
|
-
# Start heartbeats for HUB to monitor agent
|
84
107
|
self.model.start_heartbeat(self.rate_limits["heartbeat"])
|
85
108
|
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
86
109
|
|
@@ -92,14 +115,12 @@ class HUBTrainingSession:
|
|
92
115
|
"epochs": model_args.get("epochs", 300),
|
93
116
|
"imageSize": model_args.get("imgsz", 640),
|
94
117
|
"patience": model_args.get("patience", 100),
|
95
|
-
"device": model_args.get("device", ""),
|
96
|
-
"cache": model_args.get("cache", "ram"),
|
118
|
+
"device": str(model_args.get("device", "")), # convert None to string
|
119
|
+
"cache": str(model_args.get("cache", "ram")), # convert True, False, None to string
|
97
120
|
},
|
98
121
|
"dataset": {"name": model_args.get("data")},
|
99
122
|
"lineage": {
|
100
|
-
"architecture": {
|
101
|
-
"name": self.filename.replace(".pt", "").replace(".yaml", ""),
|
102
|
-
},
|
123
|
+
"architecture": {"name": self.filename.replace(".pt", "").replace(".yaml", "")},
|
103
124
|
"parent": {},
|
104
125
|
},
|
105
126
|
"meta": {"name": self.filename},
|
@@ -113,7 +134,7 @@ class HUBTrainingSession:
|
|
113
134
|
# Model could not be created
|
114
135
|
# TODO: improve error handling
|
115
136
|
if not self.model.id:
|
116
|
-
return
|
137
|
+
return None
|
117
138
|
|
118
139
|
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
|
119
140
|
|
@@ -122,14 +143,14 @@ class HUBTrainingSession:
|
|
122
143
|
|
123
144
|
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
124
145
|
|
125
|
-
|
146
|
+
@staticmethod
|
147
|
+
def _parse_identifier(identifier):
|
126
148
|
"""
|
127
149
|
Parses the given identifier to determine the type of identifier and extract relevant components.
|
128
150
|
|
129
151
|
The method supports different identifier formats:
|
130
|
-
- A HUB URL
|
131
|
-
-
|
132
|
-
- An identifier that is solely a model ID of a fixed length
|
152
|
+
- A HUB model URL https://hub.ultralytics.com/models/MODEL
|
153
|
+
- A HUB model URL with API Key https://hub.ultralytics.com/models/MODEL?api_key=APIKEY
|
133
154
|
- A local filename that ends with '.pt' or '.yaml'
|
134
155
|
|
135
156
|
Args:
|
@@ -141,67 +162,46 @@ class HUBTrainingSession:
|
|
141
162
|
Raises:
|
142
163
|
HUBModelError: If the identifier format is not recognized.
|
143
164
|
"""
|
144
|
-
|
145
|
-
# Initialize variables
|
146
165
|
api_key, model_id, filename = None, None, None
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
model_id =
|
166
|
+
if Path(identifier).suffix in {".pt", ".yaml"}:
|
167
|
+
filename = identifier
|
168
|
+
elif identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
|
169
|
+
parsed_url = urlparse(identifier)
|
170
|
+
model_id = Path(parsed_url.path).stem # handle possible final backslash robustly
|
171
|
+
query_params = parse_qs(parsed_url.query) # dictionary, i.e. {"api_key": ["API_KEY_HERE"]}
|
172
|
+
api_key = query_params.get("api_key", [None])[0]
|
152
173
|
else:
|
153
|
-
|
154
|
-
parts = identifier.split("_")
|
155
|
-
|
156
|
-
# Check if identifier is in the format of API key and model ID
|
157
|
-
if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
|
158
|
-
api_key, model_id = parts
|
159
|
-
# Check if identifier is a single model ID
|
160
|
-
elif len(parts) == 1 and len(parts[0]) == 20:
|
161
|
-
model_id = parts[0]
|
162
|
-
# Check if identifier is a local filename
|
163
|
-
elif identifier.endswith(".pt") or identifier.endswith(".yaml"):
|
164
|
-
filename = identifier
|
165
|
-
else:
|
166
|
-
raise HUBModelError(
|
167
|
-
f"model='{identifier}' could not be parsed. Check format is correct. "
|
168
|
-
f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file."
|
169
|
-
)
|
170
|
-
|
174
|
+
raise HUBModelError(f"model='{identifier} invalid, correct format is {HUB_WEB_ROOT}/models/MODEL_ID")
|
171
175
|
return api_key, model_id, filename
|
172
176
|
|
173
|
-
def _set_train_args(self
|
174
|
-
"""
|
175
|
-
|
176
|
-
# Model is already trained
|
177
|
-
raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀"))
|
177
|
+
def _set_train_args(self):
|
178
|
+
"""
|
179
|
+
Initializes training arguments and creates a model entry on the Ultralytics HUB.
|
178
180
|
|
181
|
+
This method sets up training arguments based on the model's state and updates them with any additional
|
182
|
+
arguments provided. It handles different states of the model, such as whether it's resumable, pretrained,
|
183
|
+
or requires specific file setup.
|
184
|
+
|
185
|
+
Raises:
|
186
|
+
ValueError: If the model is already trained, if required dataset information is missing, or if there are
|
187
|
+
issues with the provided training arguments.
|
188
|
+
"""
|
179
189
|
if self.model.is_resumable():
|
180
190
|
# Model has saved weights
|
181
191
|
self.train_args = {"data": self.model.get_dataset_url(), "resume": True}
|
182
192
|
self.model_file = self.model.get_weights_url("last")
|
183
193
|
else:
|
184
194
|
# Model has no saved weights
|
185
|
-
|
186
|
-
|
187
|
-
return {
|
188
|
-
"batch": config["batchSize"],
|
189
|
-
"epochs": config["epochs"],
|
190
|
-
"imgsz": config["imageSize"],
|
191
|
-
"patience": config["patience"],
|
192
|
-
"device": config["device"],
|
193
|
-
"cache": config["cache"],
|
194
|
-
"data": self.model.get_dataset_url(),
|
195
|
-
}
|
196
|
-
|
197
|
-
self.train_args = get_train_args(self.model.data.get("config"))
|
195
|
+
self.train_args = self.model.data.get("train_args") # new response
|
196
|
+
|
198
197
|
# Set the model file as either a *.pt or *.yaml file
|
199
198
|
self.model_file = (
|
200
199
|
self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture()
|
201
200
|
)
|
202
201
|
|
203
|
-
if not self.train_args
|
204
|
-
|
202
|
+
if "data" not in self.train_args:
|
203
|
+
# RF bug - datasets are sometimes not exported
|
204
|
+
raise ValueError("Dataset may still be processing. Please wait a minute and try again.")
|
205
205
|
|
206
206
|
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
|
207
207
|
self.model_id = self.model.id
|
@@ -214,12 +214,16 @@ class HUBTrainingSession:
|
|
214
214
|
thread=True,
|
215
215
|
verbose=True,
|
216
216
|
progress_total=None,
|
217
|
+
stream_response=None,
|
217
218
|
*args,
|
218
219
|
**kwargs,
|
219
220
|
):
|
221
|
+
"""Attempts to execute `request_func` with retries, timeout handling, optional threading, and progress."""
|
222
|
+
|
220
223
|
def retry_request():
|
221
224
|
"""Attempts to call `request_func` with retries, timeout, and optional threading."""
|
222
225
|
t0 = time.time() # Record the start time for the timeout
|
226
|
+
response = None
|
223
227
|
for i in range(retry + 1):
|
224
228
|
if (time.time() - t0) > timeout:
|
225
229
|
LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}")
|
@@ -233,6 +237,8 @@ class HUBTrainingSession:
|
|
233
237
|
|
234
238
|
if progress_total:
|
235
239
|
self._show_upload_progress(progress_total, response)
|
240
|
+
elif stream_response:
|
241
|
+
self._iterate_content(response)
|
236
242
|
|
237
243
|
if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
|
238
244
|
# if request related to metrics upload
|
@@ -255,7 +261,7 @@ class HUBTrainingSession:
|
|
255
261
|
|
256
262
|
# if request related to metrics upload and exceed retries
|
257
263
|
if response is None and kwargs.get("metrics"):
|
258
|
-
self.metrics_upload_failed_queue.update(kwargs.get("metrics"
|
264
|
+
self.metrics_upload_failed_queue.update(kwargs.get("metrics"))
|
259
265
|
|
260
266
|
return response
|
261
267
|
|
@@ -266,7 +272,8 @@ class HUBTrainingSession:
|
|
266
272
|
# If running in the main thread, call retry_request directly
|
267
273
|
return retry_request()
|
268
274
|
|
269
|
-
|
275
|
+
@staticmethod
|
276
|
+
def _should_retry(status_code):
|
270
277
|
"""Determines if a request should be retried based on the HTTP status code."""
|
271
278
|
retry_codes = {
|
272
279
|
HTTPStatus.REQUEST_TIMEOUT,
|
@@ -323,24 +330,37 @@ class HUBTrainingSession:
|
|
323
330
|
map (float): Mean average precision of the model.
|
324
331
|
final (bool): Indicates if the model is the final model after training.
|
325
332
|
"""
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
333
|
+
weights = Path(weights)
|
334
|
+
if not weights.is_file():
|
335
|
+
last = weights.with_name(f"last{weights.suffix}")
|
336
|
+
if final and last.is_file():
|
337
|
+
LOGGER.warning(
|
338
|
+
f"{PREFIX} WARNING ⚠️ Model 'best.pt' not found, copying 'last.pt' to 'best.pt' and uploading. "
|
339
|
+
"This often happens when resuming training in transient environments like Google Colab. "
|
340
|
+
"For more reliable training, consider using Ultralytics HUB Cloud. "
|
341
|
+
"Learn more at https://docs.ultralytics.com/hub/cloud-training."
|
342
|
+
)
|
343
|
+
shutil.copy(last, weights) # copy last.pt to best.pt
|
344
|
+
else:
|
345
|
+
LOGGER.warning(f"{PREFIX} WARNING ⚠️ Model upload issue. Missing model {weights}.")
|
346
|
+
return
|
347
|
+
|
348
|
+
self.request_queue(
|
349
|
+
self.model.upload_model,
|
350
|
+
epoch=epoch,
|
351
|
+
weights=str(weights),
|
352
|
+
is_best=is_best,
|
353
|
+
map=map,
|
354
|
+
final=final,
|
355
|
+
retry=10,
|
356
|
+
timeout=3600,
|
357
|
+
thread=not final,
|
358
|
+
progress_total=weights.stat().st_size if final else None, # only show progress if final
|
359
|
+
stream_response=True,
|
360
|
+
)
|
361
|
+
|
362
|
+
@staticmethod
|
363
|
+
def _show_upload_progress(content_length: int, response: requests.Response) -> None:
|
344
364
|
"""
|
345
365
|
Display a progress bar to track the upload progress of a file download.
|
346
366
|
|
@@ -354,3 +374,17 @@ class HUBTrainingSession:
|
|
354
374
|
with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:
|
355
375
|
for data in response.iter_content(chunk_size=1024):
|
356
376
|
pbar.update(len(data))
|
377
|
+
|
378
|
+
@staticmethod
|
379
|
+
def _iterate_content(response: requests.Response) -> None:
|
380
|
+
"""
|
381
|
+
Process the streamed HTTP response data.
|
382
|
+
|
383
|
+
Args:
|
384
|
+
response (requests.Response): The response object from the file download request.
|
385
|
+
|
386
|
+
Returns:
|
387
|
+
None
|
388
|
+
"""
|
389
|
+
for _ in response.iter_content(chunk_size=1024):
|
390
|
+
pass # Do nothing with data chunks
|
ultralytics/hub/utils.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import os
|
4
4
|
import platform
|
5
5
|
import random
|
6
|
-
import sys
|
7
6
|
import threading
|
8
7
|
import time
|
9
8
|
from pathlib import Path
|
@@ -11,7 +10,11 @@ from pathlib import Path
|
|
11
10
|
import requests
|
12
11
|
|
13
12
|
from ultralytics.utils import (
|
13
|
+
ARGV,
|
14
14
|
ENVIRONMENT,
|
15
|
+
IS_COLAB,
|
16
|
+
IS_GIT_DIR,
|
17
|
+
IS_PIP_PACKAGE,
|
15
18
|
LOGGER,
|
16
19
|
ONLINE,
|
17
20
|
RANK,
|
@@ -22,9 +25,6 @@ from ultralytics.utils import (
|
|
22
25
|
__version__,
|
23
26
|
colorstr,
|
24
27
|
get_git_origin_url,
|
25
|
-
is_colab,
|
26
|
-
is_git_dir,
|
27
|
-
is_pip_package,
|
28
28
|
)
|
29
29
|
from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES
|
30
30
|
|
@@ -48,30 +48,29 @@ def request_with_credentials(url: str) -> any:
|
|
48
48
|
Raises:
|
49
49
|
OSError: If the function is not run in a Google Colab environment.
|
50
50
|
"""
|
51
|
-
if not
|
51
|
+
if not IS_COLAB:
|
52
52
|
raise OSError("request_with_credentials() must run in a Colab environment")
|
53
53
|
from google.colab import output # noqa
|
54
54
|
from IPython import display # noqa
|
55
55
|
|
56
56
|
display.display(
|
57
57
|
display.Javascript(
|
58
|
-
"""
|
59
|
-
window._hub_tmp = new Promise((resolve, reject) => {
|
58
|
+
f"""
|
59
|
+
window._hub_tmp = new Promise((resolve, reject) => {{
|
60
60
|
const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000)
|
61
|
-
fetch("
|
61
|
+
fetch("{url}", {{
|
62
62
|
method: 'POST',
|
63
63
|
credentials: 'include'
|
64
|
-
})
|
64
|
+
}})
|
65
65
|
.then((response) => resolve(response.json()))
|
66
|
-
.then((json) => {
|
66
|
+
.then((json) => {{
|
67
67
|
clearTimeout(timeout);
|
68
|
-
}).catch((err) => {
|
68
|
+
}}).catch((err) => {{
|
69
69
|
clearTimeout(timeout);
|
70
70
|
reject(err);
|
71
|
-
});
|
72
|
-
});
|
71
|
+
}});
|
72
|
+
}});
|
73
73
|
"""
|
74
|
-
% url
|
75
74
|
)
|
76
75
|
)
|
77
76
|
return output.eval_js("_hub_tmp")
|
@@ -171,7 +170,7 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos
|
|
171
170
|
class Events:
|
172
171
|
"""
|
173
172
|
A class for collecting anonymous event analytics. Event analytics are enabled when sync=True in settings and
|
174
|
-
disabled when sync=False. Run 'yolo settings' to see and update settings
|
173
|
+
disabled when sync=False. Run 'yolo settings' to see and update settings.
|
175
174
|
|
176
175
|
Attributes:
|
177
176
|
url (str): The URL to send anonymous events.
|
@@ -185,11 +184,11 @@ class Events:
|
|
185
184
|
def __init__(self):
|
186
185
|
"""Initializes the Events object with default values for events, rate_limit, and metadata."""
|
187
186
|
self.events = [] # events list
|
188
|
-
self.rate_limit =
|
187
|
+
self.rate_limit = 30.0 # rate limit (seconds)
|
189
188
|
self.t = 0.0 # rate limit timer (seconds)
|
190
189
|
self.metadata = {
|
191
|
-
"cli": Path(
|
192
|
-
"install": "git" if
|
190
|
+
"cli": Path(ARGV[0]).name == "yolo",
|
191
|
+
"install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other",
|
193
192
|
"python": ".".join(platform.python_version_tuple()[:2]), # i.e. 3.10
|
194
193
|
"version": __version__,
|
195
194
|
"env": ENVIRONMENT,
|
@@ -198,10 +197,10 @@ class Events:
|
|
198
197
|
}
|
199
198
|
self.enabled = (
|
200
199
|
SETTINGS["sync"]
|
201
|
-
and RANK in
|
200
|
+
and RANK in {-1, 0}
|
202
201
|
and not TESTS_RUNNING
|
203
202
|
and ONLINE
|
204
|
-
and (
|
203
|
+
and (IS_PIP_PACKAGE or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git")
|
205
204
|
)
|
206
205
|
|
207
206
|
def __call__(self, cfg):
|
ultralytics/models/__init__.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from .fastsam import FastSAM
|
4
|
+
from .nas import NAS
|
3
5
|
from .rtdetr import RTDETR
|
4
6
|
from .sam import SAM
|
5
7
|
from .yolo import YOLO, YOLOWorld
|
6
8
|
|
7
|
-
__all__ = "YOLO", "RTDETR", "SAM", "YOLOWorld" # allow simpler import
|
9
|
+
__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld" # allow simpler import
|
@@ -1,8 +1,7 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from .model import FastSAM
|
4
4
|
from .predict import FastSAMPredictor
|
5
|
-
from .prompt import FastSAMPrompt
|
6
5
|
from .val import FastSAMValidator
|
7
6
|
|
8
|
-
__all__ = "FastSAMPredictor", "FastSAM", "
|
7
|
+
__all__ = "FastSAMPredictor", "FastSAM", "FastSAMValidator"
|
@@ -1,8 +1,9 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from pathlib import Path
|
4
4
|
|
5
5
|
from ultralytics.engine.model import Model
|
6
|
+
|
6
7
|
from .predict import FastSAMPredictor
|
7
8
|
from .val import FastSAMValidator
|
8
9
|
|
@@ -15,8 +16,8 @@ class FastSAM(Model):
|
|
15
16
|
```python
|
16
17
|
from ultralytics import FastSAM
|
17
18
|
|
18
|
-
model = FastSAM(
|
19
|
-
results = model.predict(
|
19
|
+
model = FastSAM("last.pt")
|
20
|
+
results = model.predict("ultralytics/assets/bus.jpg")
|
20
21
|
```
|
21
22
|
"""
|
22
23
|
|
@@ -24,9 +25,30 @@ class FastSAM(Model):
|
|
24
25
|
"""Call the __init__ method of the parent class (YOLO) with the updated default model."""
|
25
26
|
if str(model) == "FastSAM.pt":
|
26
27
|
model = "FastSAM-x.pt"
|
27
|
-
assert Path(model).suffix not in
|
28
|
+
assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models."
|
28
29
|
super().__init__(model=model, task="segment")
|
29
30
|
|
31
|
+
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, texts=None, **kwargs):
|
32
|
+
"""
|
33
|
+
Perform segmentation prediction on image or video source.
|
34
|
+
|
35
|
+
Supports prompted segmentation with bounding boxes, points, labels, and texts.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
source (str | PIL.Image | numpy.ndarray): Input source.
|
39
|
+
stream (bool): Enable real-time streaming.
|
40
|
+
bboxes (list): Bounding box coordinates for prompted segmentation.
|
41
|
+
points (list): Points for prompted segmentation.
|
42
|
+
labels (list): Labels for prompted segmentation.
|
43
|
+
texts (list): Texts for prompted segmentation.
|
44
|
+
**kwargs (Any): Additional keyword arguments.
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
(list): Model predictions.
|
48
|
+
"""
|
49
|
+
prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts)
|
50
|
+
return super().predict(source, stream, prompts=prompts, **kwargs)
|
51
|
+
|
30
52
|
@property
|
31
53
|
def task_map(self):
|
32
54
|
"""Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
|