ultralytics 8.0.237__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (137) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  4. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  5. ultralytics/cfg/datasets/dota8.yaml +34 -0
  6. ultralytics/data/__init__.py +9 -2
  7. ultralytics/data/annotator.py +4 -4
  8. ultralytics/data/augment.py +186 -169
  9. ultralytics/data/base.py +54 -48
  10. ultralytics/data/build.py +34 -23
  11. ultralytics/data/converter.py +242 -70
  12. ultralytics/data/dataset.py +117 -95
  13. ultralytics/data/explorer/__init__.py +5 -0
  14. ultralytics/data/explorer/explorer.py +170 -97
  15. ultralytics/data/explorer/gui/__init__.py +1 -0
  16. ultralytics/data/explorer/gui/dash.py +146 -76
  17. ultralytics/data/explorer/utils.py +87 -25
  18. ultralytics/data/loaders.py +75 -62
  19. ultralytics/data/split_dota.py +44 -36
  20. ultralytics/data/utils.py +160 -142
  21. ultralytics/engine/exporter.py +348 -292
  22. ultralytics/engine/model.py +102 -66
  23. ultralytics/engine/predictor.py +74 -55
  24. ultralytics/engine/results.py +63 -40
  25. ultralytics/engine/trainer.py +192 -144
  26. ultralytics/engine/tuner.py +66 -59
  27. ultralytics/engine/validator.py +31 -26
  28. ultralytics/hub/__init__.py +54 -31
  29. ultralytics/hub/auth.py +28 -25
  30. ultralytics/hub/session.py +282 -133
  31. ultralytics/hub/utils.py +64 -42
  32. ultralytics/models/__init__.py +1 -1
  33. ultralytics/models/fastsam/__init__.py +1 -1
  34. ultralytics/models/fastsam/model.py +6 -6
  35. ultralytics/models/fastsam/predict.py +3 -2
  36. ultralytics/models/fastsam/prompt.py +55 -48
  37. ultralytics/models/fastsam/val.py +1 -1
  38. ultralytics/models/nas/__init__.py +1 -1
  39. ultralytics/models/nas/model.py +9 -8
  40. ultralytics/models/nas/predict.py +8 -6
  41. ultralytics/models/nas/val.py +11 -9
  42. ultralytics/models/rtdetr/__init__.py +1 -1
  43. ultralytics/models/rtdetr/model.py +11 -9
  44. ultralytics/models/rtdetr/train.py +18 -16
  45. ultralytics/models/rtdetr/val.py +25 -19
  46. ultralytics/models/sam/__init__.py +1 -1
  47. ultralytics/models/sam/amg.py +13 -14
  48. ultralytics/models/sam/build.py +44 -42
  49. ultralytics/models/sam/model.py +6 -6
  50. ultralytics/models/sam/modules/decoders.py +6 -4
  51. ultralytics/models/sam/modules/encoders.py +37 -35
  52. ultralytics/models/sam/modules/sam.py +5 -4
  53. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  54. ultralytics/models/sam/modules/transformer.py +3 -2
  55. ultralytics/models/sam/predict.py +39 -27
  56. ultralytics/models/utils/loss.py +99 -95
  57. ultralytics/models/utils/ops.py +34 -31
  58. ultralytics/models/yolo/__init__.py +1 -1
  59. ultralytics/models/yolo/classify/__init__.py +1 -1
  60. ultralytics/models/yolo/classify/predict.py +8 -6
  61. ultralytics/models/yolo/classify/train.py +37 -31
  62. ultralytics/models/yolo/classify/val.py +26 -24
  63. ultralytics/models/yolo/detect/__init__.py +1 -1
  64. ultralytics/models/yolo/detect/predict.py +8 -6
  65. ultralytics/models/yolo/detect/train.py +47 -37
  66. ultralytics/models/yolo/detect/val.py +100 -82
  67. ultralytics/models/yolo/model.py +31 -25
  68. ultralytics/models/yolo/obb/__init__.py +1 -1
  69. ultralytics/models/yolo/obb/predict.py +13 -12
  70. ultralytics/models/yolo/obb/train.py +3 -3
  71. ultralytics/models/yolo/obb/val.py +80 -58
  72. ultralytics/models/yolo/pose/__init__.py +1 -1
  73. ultralytics/models/yolo/pose/predict.py +17 -12
  74. ultralytics/models/yolo/pose/train.py +28 -25
  75. ultralytics/models/yolo/pose/val.py +91 -64
  76. ultralytics/models/yolo/segment/__init__.py +1 -1
  77. ultralytics/models/yolo/segment/predict.py +10 -8
  78. ultralytics/models/yolo/segment/train.py +16 -15
  79. ultralytics/models/yolo/segment/val.py +90 -68
  80. ultralytics/nn/__init__.py +26 -6
  81. ultralytics/nn/autobackend.py +144 -112
  82. ultralytics/nn/modules/__init__.py +96 -13
  83. ultralytics/nn/modules/block.py +28 -7
  84. ultralytics/nn/modules/conv.py +41 -23
  85. ultralytics/nn/modules/head.py +67 -59
  86. ultralytics/nn/modules/transformer.py +49 -32
  87. ultralytics/nn/modules/utils.py +20 -15
  88. ultralytics/nn/tasks.py +215 -141
  89. ultralytics/solutions/ai_gym.py +59 -47
  90. ultralytics/solutions/distance_calculation.py +22 -15
  91. ultralytics/solutions/heatmap.py +76 -54
  92. ultralytics/solutions/object_counter.py +46 -39
  93. ultralytics/solutions/speed_estimation.py +13 -16
  94. ultralytics/trackers/__init__.py +1 -1
  95. ultralytics/trackers/basetrack.py +1 -0
  96. ultralytics/trackers/bot_sort.py +2 -1
  97. ultralytics/trackers/byte_tracker.py +10 -7
  98. ultralytics/trackers/track.py +7 -7
  99. ultralytics/trackers/utils/gmc.py +25 -25
  100. ultralytics/trackers/utils/kalman_filter.py +85 -42
  101. ultralytics/trackers/utils/matching.py +8 -7
  102. ultralytics/utils/__init__.py +173 -151
  103. ultralytics/utils/autobatch.py +10 -10
  104. ultralytics/utils/benchmarks.py +76 -86
  105. ultralytics/utils/callbacks/__init__.py +1 -1
  106. ultralytics/utils/callbacks/base.py +29 -29
  107. ultralytics/utils/callbacks/clearml.py +51 -43
  108. ultralytics/utils/callbacks/comet.py +81 -66
  109. ultralytics/utils/callbacks/dvc.py +33 -26
  110. ultralytics/utils/callbacks/hub.py +44 -26
  111. ultralytics/utils/callbacks/mlflow.py +31 -24
  112. ultralytics/utils/callbacks/neptune.py +35 -25
  113. ultralytics/utils/callbacks/raytune.py +9 -4
  114. ultralytics/utils/callbacks/tensorboard.py +16 -11
  115. ultralytics/utils/callbacks/wb.py +39 -33
  116. ultralytics/utils/checks.py +189 -141
  117. ultralytics/utils/dist.py +15 -12
  118. ultralytics/utils/downloads.py +112 -96
  119. ultralytics/utils/errors.py +1 -1
  120. ultralytics/utils/files.py +11 -11
  121. ultralytics/utils/instance.py +22 -22
  122. ultralytics/utils/loss.py +117 -67
  123. ultralytics/utils/metrics.py +224 -158
  124. ultralytics/utils/ops.py +39 -29
  125. ultralytics/utils/patches.py +3 -3
  126. ultralytics/utils/plotting.py +217 -120
  127. ultralytics/utils/tal.py +19 -13
  128. ultralytics/utils/torch_utils.py +138 -109
  129. ultralytics/utils/triton.py +12 -10
  130. ultralytics/utils/tuner.py +49 -47
  131. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/METADATA +5 -4
  132. ultralytics-8.0.239.dist-info/RECORD +188 -0
  133. ultralytics-8.0.237.dist-info/RECORD +0 -187
  134. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  135. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  136. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  137. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
@@ -1,26 +1,24 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- import signal
4
- import sys
3
+ import threading
4
+ import time
5
+ from http import HTTPStatus
5
6
  from pathlib import Path
6
- from time import sleep
7
7
 
8
8
  import requests
9
+ from hub_sdk import HUB_WEB_ROOT, HUBClient
9
10
 
10
- from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, smart_request
11
- from ultralytics.utils import LOGGER, __version__, checks, emojis, is_colab, threaded
11
+ from ultralytics.hub.utils import HELP_MSG, PREFIX, TQDM
12
+ from ultralytics.utils import LOGGER, SETTINGS, __version__, checks, emojis, is_colab
12
13
  from ultralytics.utils.errors import HUBModelError
13
14
 
14
- AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
15
+ AGENT_NAME = f"python-{__version__}-colab" if is_colab() else f"python-{__version__}-local"
15
16
 
16
17
 
17
18
  class HUBTrainingSession:
18
19
  """
19
20
  HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
20
21
 
21
- Args:
22
- url (str): Model identifier used to initialize the HUB training session.
23
-
24
22
  Attributes:
25
23
  agent_id (str): Identifier for the instance communicating with the server.
26
24
  model_id (str): Identifier for the YOLO model being trained.
@@ -34,110 +32,271 @@ class HUBTrainingSession:
34
32
  alive (bool): Indicates if the heartbeat loop is active.
35
33
  """
36
34
 
37
- def __init__(self, url):
35
+ def __init__(self, identifier):
38
36
  """
39
37
  Initialize the HUBTrainingSession with the provided model identifier.
40
38
 
41
39
  Args:
42
- url (str): Model identifier used to initialize the HUB training session.
43
- It can be a URL string or a model key with specific format.
40
+ identifier (str): Model identifier used to initialize the HUB training session.
41
+ It can be a URL string or a model key with specific format.
44
42
 
45
43
  Raises:
46
44
  ValueError: If the provided model identifier is invalid.
47
45
  ConnectionError: If connecting with global API key is not supported.
48
46
  """
49
-
50
- from ultralytics.hub.auth import Auth
47
+ self.rate_limits = {
48
+ "metrics": 3.0,
49
+ "ckpt": 900.0,
50
+ "heartbeat": 300.0,
51
+ } # rate limits (seconds)
52
+ self.metrics_queue = {} # holds metrics for each epoch until upload
53
+ self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
51
54
 
52
55
  # Parse input
53
- if url.startswith(f'{HUB_WEB_ROOT}/models/'):
54
- url = url.split(f'{HUB_WEB_ROOT}/models/')[-1]
55
- if [len(x) for x in url.split('_')] == [42, 20]:
56
- key, model_id = url.split('_')
57
- elif len(url) == 20:
58
- key, model_id = '', url
56
+ api_key, model_id, self.filename = self._parse_identifier(identifier)
57
+
58
+ # Get credentials
59
+ active_key = api_key or SETTINGS.get("api_key")
60
+ credentials = {"api_key": active_key} if active_key else None # set credentials
61
+
62
+ # Initialize client
63
+ self.client = HUBClient(credentials)
64
+
65
+ if model_id:
66
+ self.load_model(model_id) # load existing model
59
67
  else:
60
- raise HUBModelError(f"model='{url}' not found. Check format is correct, i.e. "
61
- f"model='{HUB_WEB_ROOT}/models/MODEL_ID' and try again.")
62
-
63
- # Authorize
64
- auth = Auth(key)
65
- self.agent_id = None # identifies which instance is communicating with server
66
- self.model_id = model_id
67
- self.model_url = f'{HUB_WEB_ROOT}/models/{model_id}'
68
- self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
69
- self.auth_header = auth.get_auth_header()
70
- self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
71
- self.timers = {} # rate limit timers (seconds)
72
- self.metrics_queue = {} # metrics queue
73
- self.model = self._get_model()
74
- self.alive = True
75
- self._start_heartbeat() # start heartbeats
76
- self._register_signal_handlers()
77
- LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
78
-
79
- def _register_signal_handlers(self):
80
- """Register signal handlers for SIGTERM and SIGINT signals to gracefully handle termination."""
81
- signal.signal(signal.SIGTERM, self._handle_signal)
82
- signal.signal(signal.SIGINT, self._handle_signal)
83
-
84
- def _handle_signal(self, signum, frame):
68
+ self.model = self.client.model() # load empty model
69
+
70
+ def load_model(self, model_id):
71
+ # Initialize model
72
+ self.model = self.client.model(model_id)
73
+ self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
74
+
75
+ self._set_train_args()
76
+
77
+ # Start heartbeats for HUB to monitor agent
78
+ self.model.start_heartbeat(self.rate_limits["heartbeat"])
79
+ LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
80
+
81
+ def create_model(self, model_args):
82
+ # Initialize model
83
+ payload = {
84
+ "config": {
85
+ "batchSize": model_args.get("batch", -1),
86
+ "epochs": model_args.get("epochs", 300),
87
+ "imageSize": model_args.get("imgsz", 640),
88
+ "patience": model_args.get("patience", 100),
89
+ "device": model_args.get("device", ""),
90
+ "cache": model_args.get("cache", "ram"),
91
+ },
92
+ "dataset": {"name": model_args.get("data")},
93
+ "lineage": {
94
+ "architecture": {
95
+ "name": self.filename.replace(".pt", "").replace(".yaml", ""),
96
+ },
97
+ "parent": {},
98
+ },
99
+ "meta": {"name": self.filename},
100
+ }
101
+
102
+ if self.filename.endswith(".pt"):
103
+ payload["lineage"]["parent"]["name"] = self.filename
104
+
105
+ self.model.create_model(payload)
106
+
107
+ # Model could not be created
108
+ # TODO: improve error handling
109
+ if not self.model.id:
110
+ return
111
+
112
+ self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
113
+
114
+ # Start heartbeats for HUB to monitor agent
115
+ self.model.start_heartbeat(self.rate_limits["heartbeat"])
116
+
117
+ LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
118
+
119
+ def _parse_identifier(self, identifier):
120
+ """
121
+ Parses the given identifier to determine the type of identifier and extract relevant components.
122
+
123
+ The method supports different identifier formats:
124
+ - A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/'
125
+ - An identifier containing an API key and a model ID separated by an underscore
126
+ - An identifier that is solely a model ID of a fixed length
127
+ - A local filename that ends with '.pt' or '.yaml'
128
+
129
+ Args:
130
+ identifier (str): The identifier string to be parsed.
131
+
132
+ Returns:
133
+ (tuple): A tuple containing the API key, model ID, and filename as applicable.
134
+
135
+ Raises:
136
+ HUBModelError: If the identifier format is not recognized.
85
137
  """
86
- Handle kill signals and prevent heartbeats from being sent on Colab after termination.
87
138
 
88
- This method does not use frame, it is included as it is passed by signal.
139
+ # Initialize variables
140
+ api_key, model_id, filename = None, None, None
141
+
142
+ # Check if identifier is a HUB URL
143
+ if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
144
+ # Extract the model_id after the HUB_WEB_ROOT URL
145
+ model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
146
+ else:
147
+ # Split the identifier based on underscores only if it's not a HUB URL
148
+ parts = identifier.split("_")
149
+
150
+ # Check if identifier is in the format of API key and model ID
151
+ if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
152
+ api_key, model_id = parts
153
+ # Check if identifier is a single model ID
154
+ elif len(parts) == 1 and len(parts[0]) == 20:
155
+ model_id = parts[0]
156
+ # Check if identifier is a local filename
157
+ elif identifier.endswith(".pt") or identifier.endswith(".yaml"):
158
+ filename = identifier
159
+ else:
160
+ raise HUBModelError(
161
+ f"model='{identifier}' could not be parsed. Check format is correct. "
162
+ f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file."
163
+ )
164
+
165
+ return api_key, model_id, filename
166
+
167
+ def _set_train_args(self, **kwargs):
168
+ if self.model.is_trained():
169
+ # Model is already trained
170
+ raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀"))
171
+
172
+ if self.model.is_resumable():
173
+ # Model has saved weights
174
+ self.train_args = {"data": self.model.get_dataset_url(), "resume": True}
175
+ self.model_file = self.model.get_weights_url("last")
176
+ else:
177
+ # Model has no saved weights
178
+ def get_train_args(config):
179
+ return {
180
+ "batch": config["batchSize"],
181
+ "epochs": config["epochs"],
182
+ "imgsz": config["imageSize"],
183
+ "patience": config["patience"],
184
+ "device": config["device"],
185
+ "cache": config["cache"],
186
+ "data": self.model.get_dataset_url(),
187
+ }
188
+
189
+ self.train_args = get_train_args(self.model.data.get("config"))
190
+ # Set the model file as either a *.pt or *.yaml file
191
+ self.model_file = (
192
+ self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture()
193
+ )
194
+
195
+ if not self.train_args.get("data"):
196
+ raise ValueError("Dataset may still be processing. Please wait a minute and try again.") # RF fix
197
+
198
+ self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
199
+ self.model_id = self.model.id
200
+
201
+ def request_queue(
202
+ self,
203
+ request_func,
204
+ retry=3,
205
+ timeout=30,
206
+ thread=True,
207
+ verbose=True,
208
+ progress_total=None,
209
+ *args,
210
+ **kwargs,
211
+ ):
212
+ def retry_request():
213
+ t0 = time.time() # Record the start time for the timeout
214
+ for i in range(retry + 1):
215
+ if (time.time() - t0) > timeout:
216
+ LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}")
217
+ break # Timeout reached, exit loop
218
+
219
+ response = request_func(*args, **kwargs)
220
+ if progress_total:
221
+ self._show_upload_progress(progress_total, response)
222
+
223
+ if response is None:
224
+ LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}")
225
+ time.sleep(2**i) # Exponential backoff before retrying
226
+ continue # Skip further processing and retry
227
+
228
+ if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
229
+ return response # Success, no need to retry
230
+
231
+ if i == 0:
232
+ # Initial attempt, check status code and provide messages
233
+ message = self._get_failure_message(response, retry, timeout)
234
+
235
+ if verbose:
236
+ LOGGER.warning(f"{PREFIX}{message} {HELP_MSG} ({response.status_code})")
237
+
238
+ if not self._should_retry(response.status_code):
239
+ LOGGER.warning(f"{PREFIX}Request failed. {HELP_MSG} ({response.status_code}")
240
+ break # Not an error that should be retried, exit loop
241
+
242
+ time.sleep(2**i) # Exponential backoff for retries
243
+
244
+ return response
245
+
246
+ if thread:
247
+ # Start a new thread to run the retry_request function
248
+ threading.Thread(target=retry_request, daemon=True).start()
249
+ else:
250
+ # If running in the main thread, call retry_request directly
251
+ return retry_request()
252
+
253
+ def _should_retry(self, status_code):
254
+ # Status codes that trigger retries
255
+ retry_codes = {
256
+ HTTPStatus.REQUEST_TIMEOUT,
257
+ HTTPStatus.BAD_GATEWAY,
258
+ HTTPStatus.GATEWAY_TIMEOUT,
259
+ }
260
+ return True if status_code in retry_codes else False
261
+
262
+ def _get_failure_message(self, response: requests.Response, retry: int, timeout: int):
89
263
  """
90
- if self.alive is True:
91
- LOGGER.info(f'{PREFIX}Kill signal received! ❌')
92
- self._stop_heartbeat()
93
- sys.exit(signum)
264
+ Generate a retry message based on the response status code.
94
265
 
95
- def _stop_heartbeat(self):
96
- """Terminate the heartbeat loop."""
97
- self.alive = False
266
+ Args:
267
+ response: The HTTP response object.
268
+ retry: The number of retry attempts allowed.
269
+ timeout: The maximum timeout duration.
270
+
271
+ Returns:
272
+ str: The retry message.
273
+ """
274
+ if self._should_retry(response.status_code):
275
+ return f"Retrying {retry}x for {timeout}s." if retry else ""
276
+ elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit
277
+ headers = response.headers
278
+ return (
279
+ f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
280
+ f"Please retry after {headers['Retry-After']}s."
281
+ )
282
+ else:
283
+ try:
284
+ return response.json().get("message", "No JSON message.")
285
+ except AttributeError:
286
+ return "Unable to read JSON."
98
287
 
99
288
  def upload_metrics(self):
100
289
  """Upload model metrics to Ultralytics HUB."""
101
- payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'}
102
- smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2)
103
-
104
- def _get_model(self):
105
- """Fetch and return model data from Ultralytics HUB."""
106
- api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
107
-
108
- try:
109
- response = smart_request('get', api_url, headers=self.auth_header, thread=False, code=0)
110
- data = response.json().get('data', None)
111
-
112
- if data.get('status', None) == 'trained':
113
- raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀'))
114
-
115
- if not data.get('data', None):
116
- raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
117
- self.model_id = data['id']
118
-
119
- if data['status'] == 'new': # new model to start training
120
- self.train_args = {
121
- 'batch': data['batch_size'], # note HUB argument is slightly different
122
- 'epochs': data['epochs'],
123
- 'imgsz': data['imgsz'],
124
- 'patience': data['patience'],
125
- 'device': data['device'],
126
- 'cache': data['cache'],
127
- 'data': data['data']}
128
- self.model_file = data.get('cfg') or data.get('weights') # cfg for pretrained=False
129
- self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
130
- elif data['status'] == 'training': # existing model to resume training
131
- self.train_args = {'data': data['data'], 'resume': True}
132
- self.model_file = data['resume']
133
-
134
- return data
135
- except requests.exceptions.ConnectionError as e:
136
- raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e
137
- except Exception:
138
- raise
139
-
140
- def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
290
+ return self.request_queue(self.model.upload_metrics, metrics=self.metrics_queue.copy(), thread=True)
291
+
292
+ def upload_model(
293
+ self,
294
+ epoch: int,
295
+ weights: str,
296
+ is_best: bool = False,
297
+ map: float = 0.0,
298
+ final: bool = False,
299
+ ) -> None:
141
300
  """
142
301
  Upload a model checkpoint to Ultralytics HUB.
143
302
 
@@ -149,43 +308,33 @@ class HUBTrainingSession:
149
308
  final (bool): Indicates if the model is the final model after training.
150
309
  """
151
310
  if Path(weights).is_file():
152
- with open(weights, 'rb') as f:
153
- file = f.read()
311
+ progress_total = Path(weights).stat().st_size if final else None # Only show progress if final
312
+ self.request_queue(
313
+ self.model.upload_model,
314
+ epoch=epoch,
315
+ weights=weights,
316
+ is_best=is_best,
317
+ map=map,
318
+ final=final,
319
+ retry=10,
320
+ timeout=3600,
321
+ thread=not final,
322
+ progress_total=progress_total,
323
+ )
154
324
  else:
155
- LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.')
156
- file = None
157
- url = f'{self.api_url}/upload'
158
- # url = 'http://httpbin.org/post' # for debug
159
- data = {'epoch': epoch}
160
- if final:
161
- data.update({'type': 'final', 'map': map})
162
- filesize = Path(weights).stat().st_size
163
- smart_request('post',
164
- url,
165
- data=data,
166
- files={'best.pt': file},
167
- headers=self.auth_header,
168
- retry=10,
169
- timeout=3600,
170
- thread=False,
171
- progress=filesize,
172
- code=4)
173
- else:
174
- data.update({'type': 'epoch', 'isBest': bool(is_best)})
175
- smart_request('post', url, data=data, files={'last.pt': file}, headers=self.auth_header, code=3)
176
-
177
- @threaded
178
- def _start_heartbeat(self):
179
- """Begin a threaded heartbeat loop to report the agent's status to Ultralytics HUB."""
180
- while self.alive:
181
- r = smart_request('post',
182
- f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
183
- json={
184
- 'agent': AGENT_NAME,
185
- 'agentId': self.agent_id},
186
- headers=self.auth_header,
187
- retry=0,
188
- code=5,
189
- thread=False) # already in a thread
190
- self.agent_id = r.json().get('data', {}).get('agentId', None)
191
- sleep(self.rate_limits['heartbeat'])
325
+ LOGGER.warning(f"{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.")
326
+
327
+ def _show_upload_progress(self, content_length: int, response: requests.Response) -> None:
328
+ """
329
+ Display a progress bar to track the upload progress of a file download.
330
+
331
+ Args:
332
+ content_length (int): The total size of the content to be downloaded in bytes.
333
+ response (requests.Response): The response object from the file download request.
334
+
335
+ Returns:
336
+ (None)
337
+ """
338
+ with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar:
339
+ for data in response.iter_content(chunk_size=1024):
340
+ pbar.update(len(data))
ultralytics/hub/utils.py CHANGED
@@ -1,6 +1,5 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- import os
4
3
  import platform
5
4
  import random
6
5
  import sys
@@ -10,14 +9,26 @@ from pathlib import Path
10
9
 
11
10
  import requests
12
11
 
13
- from ultralytics.utils import (ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING, TQDM, TryExcept, __version__,
14
- colorstr, get_git_origin_url, is_colab, is_git_dir, is_pip_package)
12
+ from ultralytics.utils import (
13
+ ENVIRONMENT,
14
+ LOGGER,
15
+ ONLINE,
16
+ RANK,
17
+ SETTINGS,
18
+ TESTS_RUNNING,
19
+ TQDM,
20
+ TryExcept,
21
+ __version__,
22
+ colorstr,
23
+ get_git_origin_url,
24
+ is_colab,
25
+ is_git_dir,
26
+ is_pip_package,
27
+ )
15
28
  from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES
16
29
 
17
- PREFIX = colorstr('Ultralytics HUB: ')
18
- HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
19
- HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.com')
20
- HUB_WEB_ROOT = os.environ.get('ULTRALYTICS_HUB_WEB', 'https://hub.ultralytics.com')
30
+ PREFIX = colorstr("Ultralytics HUB: ")
31
+ HELP_MSG = "If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance."
21
32
 
22
33
 
23
34
  def request_with_credentials(url: str) -> any:
@@ -34,11 +45,13 @@ def request_with_credentials(url: str) -> any:
34
45
  OSError: If the function is not run in a Google Colab environment.
35
46
  """
36
47
  if not is_colab():
37
- raise OSError('request_with_credentials() must run in a Colab environment')
48
+ raise OSError("request_with_credentials() must run in a Colab environment")
38
49
  from google.colab import output # noqa
39
50
  from IPython import display # noqa
51
+
40
52
  display.display(
41
- display.Javascript("""
53
+ display.Javascript(
54
+ """
42
55
  window._hub_tmp = new Promise((resolve, reject) => {
43
56
  const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000)
44
57
  fetch("%s", {
@@ -53,8 +66,11 @@ def request_with_credentials(url: str) -> any:
53
66
  reject(err);
54
67
  });
55
68
  });
56
- """ % url))
57
- return output.eval_js('_hub_tmp')
69
+ """
70
+ % url
71
+ )
72
+ )
73
+ return output.eval_js("_hub_tmp")
58
74
 
59
75
 
60
76
  def requests_with_progress(method, url, **kwargs):
@@ -74,13 +90,13 @@ def requests_with_progress(method, url, **kwargs):
74
90
  content length.
75
91
  - If 'progress' is a number then progress bar will display assuming content length = progress.
76
92
  """
77
- progress = kwargs.pop('progress', False)
93
+ progress = kwargs.pop("progress", False)
78
94
  if not progress:
79
95
  return requests.request(method, url, **kwargs)
80
96
  response = requests.request(method, url, stream=True, **kwargs)
81
- total = int(response.headers.get('content-length', 0) if isinstance(progress, bool) else progress) # total size
97
+ total = int(response.headers.get("content-length", 0) if isinstance(progress, bool) else progress) # total size
82
98
  try:
83
- pbar = TQDM(total=total, unit='B', unit_scale=True, unit_divisor=1024)
99
+ pbar = TQDM(total=total, unit="B", unit_scale=True, unit_divisor=1024)
84
100
  for data in response.iter_content(chunk_size=1024):
85
101
  pbar.update(len(data))
86
102
  pbar.close()
@@ -121,25 +137,27 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos
121
137
  if r.status_code < 300: # return codes in the 2xx range are generally considered "good" or "successful"
122
138
  break
123
139
  try:
124
- m = r.json().get('message', 'No JSON message.')
140
+ m = r.json().get("message", "No JSON message.")
125
141
  except AttributeError:
126
- m = 'Unable to read JSON.'
142
+ m = "Unable to read JSON."
127
143
  if i == 0:
128
144
  if r.status_code in retry_codes:
129
- m += f' Retrying {retry}x for {timeout}s.' if retry else ''
145
+ m += f" Retrying {retry}x for {timeout}s." if retry else ""
130
146
  elif r.status_code == 429: # rate limit
131
147
  h = r.headers # response headers
132
- m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \
148
+ m = (
149
+ f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). "
133
150
  f"Please retry after {h['Retry-After']}s."
151
+ )
134
152
  if verbose:
135
- LOGGER.warning(f'{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})')
153
+ LOGGER.warning(f"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})")
136
154
  if r.status_code not in retry_codes:
137
155
  return r
138
- time.sleep(2 ** i) # exponential standoff
156
+ time.sleep(2**i) # exponential standoff
139
157
  return r
140
158
 
141
159
  args = method, url
142
- kwargs['progress'] = progress
160
+ kwargs["progress"] = progress
143
161
  if thread:
144
162
  threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
145
163
  else:
@@ -158,7 +176,7 @@ class Events:
158
176
  enabled (bool): A flag to enable or disable Events based on certain conditions.
159
177
  """
160
178
 
161
- url = 'https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw'
179
+ url = "https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw"
162
180
 
163
181
  def __init__(self):
164
182
  """Initializes the Events object with default values for events, rate_limit, and metadata."""
@@ -166,19 +184,21 @@ class Events:
166
184
  self.rate_limit = 60.0 # rate limit (seconds)
167
185
  self.t = 0.0 # rate limit timer (seconds)
168
186
  self.metadata = {
169
- 'cli': Path(sys.argv[0]).name == 'yolo',
170
- 'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
171
- 'python': '.'.join(platform.python_version_tuple()[:2]), # i.e. 3.10
172
- 'version': __version__,
173
- 'env': ENVIRONMENT,
174
- 'session_id': round(random.random() * 1E15),
175
- 'engagement_time_msec': 1000}
176
- self.enabled = \
177
- SETTINGS['sync'] and \
178
- RANK in (-1, 0) and \
179
- not TESTS_RUNNING and \
180
- ONLINE and \
181
- (is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git')
187
+ "cli": Path(sys.argv[0]).name == "yolo",
188
+ "install": "git" if is_git_dir() else "pip" if is_pip_package() else "other",
189
+ "python": ".".join(platform.python_version_tuple()[:2]), # i.e. 3.10
190
+ "version": __version__,
191
+ "env": ENVIRONMENT,
192
+ "session_id": round(random.random() * 1e15),
193
+ "engagement_time_msec": 1000,
194
+ }
195
+ self.enabled = (
196
+ SETTINGS["sync"]
197
+ and RANK in (-1, 0)
198
+ and not TESTS_RUNNING
199
+ and ONLINE
200
+ and (is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git")
201
+ )
182
202
 
183
203
  def __call__(self, cfg):
184
204
  """
@@ -194,11 +214,13 @@ class Events:
194
214
  # Attempt to add to events
195
215
  if len(self.events) < 25: # Events list limited to 25 events (drop any events past this)
196
216
  params = {
197
- **self.metadata, 'task': cfg.task,
198
- 'model': cfg.model if cfg.model in GITHUB_ASSETS_NAMES else 'custom'}
199
- if cfg.mode == 'export':
200
- params['format'] = cfg.format
201
- self.events.append({'name': cfg.mode, 'params': params})
217
+ **self.metadata,
218
+ "task": cfg.task,
219
+ "model": cfg.model if cfg.model in GITHUB_ASSETS_NAMES else "custom",
220
+ }
221
+ if cfg.mode == "export":
222
+ params["format"] = cfg.format
223
+ self.events.append({"name": cfg.mode, "params": params})
202
224
 
203
225
  # Check rate limit
204
226
  t = time.time()
@@ -207,10 +229,10 @@ class Events:
207
229
  return
208
230
 
209
231
  # Time is over rate limiter, send now
210
- data = {'client_id': SETTINGS['uuid'], 'events': self.events} # SHA-256 anonymized UUID hash and events list
232
+ data = {"client_id": SETTINGS["uuid"], "events": self.events} # SHA-256 anonymized UUID hash and events list
211
233
 
212
234
  # POST equivalent to requests.post(self.url, json=data)
213
- smart_request('post', self.url, json=data, retry=0, verbose=False)
235
+ smart_request("post", self.url, json=data, retry=0, verbose=False)
214
236
 
215
237
  # Reset events and rate limit timer
216
238
  self.events = []