ultralytics 8.0.194__py3-none-any.whl → 8.0.196__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +5 -6
- ultralytics/data/augment.py +234 -29
- ultralytics/data/base.py +2 -1
- ultralytics/data/build.py +9 -3
- ultralytics/data/converter.py +5 -2
- ultralytics/data/dataset.py +16 -2
- ultralytics/data/loaders.py +111 -7
- ultralytics/data/utils.py +3 -3
- ultralytics/engine/exporter.py +1 -3
- ultralytics/engine/model.py +16 -9
- ultralytics/engine/predictor.py +10 -6
- ultralytics/engine/results.py +18 -8
- ultralytics/engine/trainer.py +19 -31
- ultralytics/engine/tuner.py +20 -20
- ultralytics/engine/validator.py +3 -4
- ultralytics/hub/__init__.py +2 -2
- ultralytics/hub/auth.py +18 -3
- ultralytics/hub/session.py +1 -0
- ultralytics/hub/utils.py +1 -3
- ultralytics/models/fastsam/model.py +2 -1
- ultralytics/models/fastsam/predict.py +10 -7
- ultralytics/models/fastsam/prompt.py +15 -1
- ultralytics/models/nas/model.py +3 -1
- ultralytics/models/rtdetr/model.py +4 -6
- ultralytics/models/rtdetr/predict.py +2 -1
- ultralytics/models/rtdetr/train.py +2 -1
- ultralytics/models/rtdetr/val.py +1 -0
- ultralytics/models/sam/amg.py +12 -6
- ultralytics/models/sam/model.py +5 -6
- ultralytics/models/sam/modules/decoders.py +5 -1
- ultralytics/models/sam/modules/encoders.py +15 -12
- ultralytics/models/sam/modules/tiny_encoder.py +38 -2
- ultralytics/models/sam/modules/transformer.py +2 -4
- ultralytics/models/sam/predict.py +8 -4
- ultralytics/models/utils/loss.py +35 -8
- ultralytics/models/utils/ops.py +14 -18
- ultralytics/models/yolo/classify/predict.py +1 -0
- ultralytics/models/yolo/classify/train.py +4 -2
- ultralytics/models/yolo/classify/val.py +1 -0
- ultralytics/models/yolo/detect/train.py +4 -3
- ultralytics/models/yolo/model.py +2 -4
- ultralytics/models/yolo/pose/predict.py +1 -0
- ultralytics/models/yolo/segment/predict.py +2 -0
- ultralytics/models/yolo/segment/val.py +1 -1
- ultralytics/nn/autobackend.py +54 -43
- ultralytics/nn/modules/__init__.py +13 -9
- ultralytics/nn/modules/block.py +11 -5
- ultralytics/nn/modules/conv.py +16 -7
- ultralytics/nn/modules/head.py +6 -3
- ultralytics/nn/modules/transformer.py +47 -15
- ultralytics/nn/modules/utils.py +6 -4
- ultralytics/nn/tasks.py +61 -21
- ultralytics/trackers/bot_sort.py +53 -6
- ultralytics/trackers/byte_tracker.py +71 -15
- ultralytics/trackers/track.py +0 -1
- ultralytics/trackers/utils/gmc.py +23 -0
- ultralytics/trackers/utils/kalman_filter.py +6 -6
- ultralytics/utils/__init__.py +32 -19
- ultralytics/utils/autobatch.py +1 -3
- ultralytics/utils/benchmarks.py +14 -1
- ultralytics/utils/callbacks/base.py +1 -3
- ultralytics/utils/callbacks/comet.py +11 -3
- ultralytics/utils/callbacks/dvc.py +9 -0
- ultralytics/utils/callbacks/neptune.py +5 -6
- ultralytics/utils/callbacks/wb.py +1 -0
- ultralytics/utils/checks.py +13 -9
- ultralytics/utils/dist.py +2 -1
- ultralytics/utils/downloads.py +7 -3
- ultralytics/utils/files.py +3 -3
- ultralytics/utils/instance.py +12 -3
- ultralytics/utils/loss.py +97 -22
- ultralytics/utils/metrics.py +35 -34
- ultralytics/utils/ops.py +10 -9
- ultralytics/utils/patches.py +9 -7
- ultralytics/utils/plotting.py +4 -3
- ultralytics/utils/torch_utils.py +8 -6
- ultralytics/utils/triton.py +87 -0
- {ultralytics-8.0.194.dist-info → ultralytics-8.0.196.dist-info}/METADATA +1 -1
- {ultralytics-8.0.194.dist-info → ultralytics-8.0.196.dist-info}/RECORD +84 -83
- {ultralytics-8.0.194.dist-info → ultralytics-8.0.196.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.194.dist-info → ultralytics-8.0.196.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.194.dist-info → ultralytics-8.0.196.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.194.dist-info → ultralytics-8.0.196.dist-info}/top_level.txt +0 -0
|
@@ -8,10 +8,43 @@ from .utils.kalman_filter import KalmanFilterXYAH
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class STrack(BaseTrack):
|
|
11
|
+
"""
|
|
12
|
+
Single object tracking representation that uses Kalman filtering for state estimation.
|
|
13
|
+
|
|
14
|
+
This class is responsible for storing all the information regarding individual tracklets and performs state updates
|
|
15
|
+
and predictions based on Kalman filter.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
shared_kalman (KalmanFilterXYAH): Shared Kalman filter that is used across all STrack instances for prediction.
|
|
19
|
+
_tlwh (np.ndarray): Private attribute to store top-left corner coordinates and width and height of bounding box.
|
|
20
|
+
kalman_filter (KalmanFilterXYAH): Instance of Kalman filter used for this particular object track.
|
|
21
|
+
mean (np.ndarray): Mean state estimate vector.
|
|
22
|
+
covariance (np.ndarray): Covariance of state estimate.
|
|
23
|
+
is_activated (bool): Boolean flag indicating if the track has been activated.
|
|
24
|
+
score (float): Confidence score of the track.
|
|
25
|
+
tracklet_len (int): Length of the tracklet.
|
|
26
|
+
cls (any): Class label for the object.
|
|
27
|
+
idx (int): Index or identifier for the object.
|
|
28
|
+
frame_id (int): Current frame ID.
|
|
29
|
+
start_frame (int): Frame where the object was first detected.
|
|
30
|
+
|
|
31
|
+
Methods:
|
|
32
|
+
predict(): Predict the next state of the object using Kalman filter.
|
|
33
|
+
multi_predict(stracks): Predict the next states for multiple tracks.
|
|
34
|
+
multi_gmc(stracks, H): Update multiple track states using a homography matrix.
|
|
35
|
+
activate(kalman_filter, frame_id): Activate a new tracklet.
|
|
36
|
+
re_activate(new_track, frame_id, new_id): Reactivate a previously lost tracklet.
|
|
37
|
+
update(new_track, frame_id): Update the state of a matched track.
|
|
38
|
+
convert_coords(tlwh): Convert bounding box to x-y-angle-height format.
|
|
39
|
+
tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format.
|
|
40
|
+
tlbr_to_tlwh(tlbr): Convert tlbr bounding box to tlwh format.
|
|
41
|
+
tlwh_to_tlbr(tlwh): Convert tlwh bounding box to tlbr format.
|
|
42
|
+
"""
|
|
43
|
+
|
|
11
44
|
shared_kalman = KalmanFilterXYAH()
|
|
12
45
|
|
|
13
46
|
def __init__(self, tlwh, score, cls):
|
|
14
|
-
"""
|
|
47
|
+
"""Initialize new STrack instance."""
|
|
15
48
|
self._tlwh = np.asarray(self.tlbr_to_tlwh(tlwh[:-1]), dtype=np.float32)
|
|
16
49
|
self.kalman_filter = None
|
|
17
50
|
self.mean, self.covariance = None, None
|
|
@@ -92,10 +125,11 @@ class STrack(BaseTrack):
|
|
|
92
125
|
|
|
93
126
|
def update(self, new_track, frame_id):
|
|
94
127
|
"""
|
|
95
|
-
Update a matched track
|
|
96
|
-
|
|
97
|
-
:
|
|
98
|
-
|
|
128
|
+
Update the state of a matched track.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
new_track (STrack): The new track containing updated information.
|
|
132
|
+
frame_id (int): The ID of the current frame.
|
|
99
133
|
"""
|
|
100
134
|
self.frame_id = frame_id
|
|
101
135
|
self.tracklet_len += 1
|
|
@@ -116,9 +150,7 @@ class STrack(BaseTrack):
|
|
|
116
150
|
|
|
117
151
|
@property
|
|
118
152
|
def tlwh(self):
|
|
119
|
-
"""Get current position in bounding box format
|
|
120
|
-
width, height)`.
|
|
121
|
-
"""
|
|
153
|
+
"""Get current position in bounding box format (top left x, top left y, width, height)."""
|
|
122
154
|
if self.mean is None:
|
|
123
155
|
return self._tlwh.copy()
|
|
124
156
|
ret = self.mean[:4].copy()
|
|
@@ -128,17 +160,15 @@ class STrack(BaseTrack):
|
|
|
128
160
|
|
|
129
161
|
@property
|
|
130
162
|
def tlbr(self):
|
|
131
|
-
"""Convert bounding box to format
|
|
132
|
-
`(top left, bottom right)`.
|
|
133
|
-
"""
|
|
163
|
+
"""Convert bounding box to format (min x, min y, max x, max y), i.e., (top left, bottom right)."""
|
|
134
164
|
ret = self.tlwh.copy()
|
|
135
165
|
ret[2:] += ret[:2]
|
|
136
166
|
return ret
|
|
137
167
|
|
|
138
168
|
@staticmethod
|
|
139
169
|
def tlwh_to_xyah(tlwh):
|
|
140
|
-
"""Convert bounding box to format
|
|
141
|
-
height
|
|
170
|
+
"""Convert bounding box to format (center x, center y, aspect ratio, height), where the aspect ratio is width /
|
|
171
|
+
height.
|
|
142
172
|
"""
|
|
143
173
|
ret = np.asarray(tlwh).copy()
|
|
144
174
|
ret[:2] += ret[2:] / 2
|
|
@@ -165,6 +195,33 @@ class STrack(BaseTrack):
|
|
|
165
195
|
|
|
166
196
|
|
|
167
197
|
class BYTETracker:
|
|
198
|
+
"""
|
|
199
|
+
BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
|
|
200
|
+
|
|
201
|
+
The class is responsible for initializing, updating, and managing the tracks for detected objects in a video
|
|
202
|
+
sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for
|
|
203
|
+
predicting the new object locations, and performs data association.
|
|
204
|
+
|
|
205
|
+
Attributes:
|
|
206
|
+
tracked_stracks (list[STrack]): List of successfully activated tracks.
|
|
207
|
+
lost_stracks (list[STrack]): List of lost tracks.
|
|
208
|
+
removed_stracks (list[STrack]): List of removed tracks.
|
|
209
|
+
frame_id (int): The current frame ID.
|
|
210
|
+
args (namespace): Command-line arguments.
|
|
211
|
+
max_time_lost (int): The maximum frames for a track to be considered as 'lost'.
|
|
212
|
+
kalman_filter (object): Kalman Filter object.
|
|
213
|
+
|
|
214
|
+
Methods:
|
|
215
|
+
update(results, img=None): Updates object tracker with new detections.
|
|
216
|
+
get_kalmanfilter(): Returns a Kalman filter object for tracking bounding boxes.
|
|
217
|
+
init_track(dets, scores, cls, img=None): Initialize object tracking with detections.
|
|
218
|
+
get_dists(tracks, detections): Calculates the distance between tracks and detections.
|
|
219
|
+
multi_predict(tracks): Predicts the location of tracks.
|
|
220
|
+
reset_id(): Resets the ID counter of STrack.
|
|
221
|
+
joint_stracks(tlista, tlistb): Combines two lists of stracks.
|
|
222
|
+
sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list.
|
|
223
|
+
remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IOU.
|
|
224
|
+
"""
|
|
168
225
|
|
|
169
226
|
def __init__(self, args, frame_rate=30):
|
|
170
227
|
"""Initialize a YOLOv8 object to track objects with given arguments and frame rate."""
|
|
@@ -234,8 +291,7 @@ class BYTETracker:
|
|
|
234
291
|
else:
|
|
235
292
|
track.re_activate(det, self.frame_id, new_id=False)
|
|
236
293
|
refind_stracks.append(track)
|
|
237
|
-
# Step 3: Second association, with low score detection boxes
|
|
238
|
-
# association the untrack to the low score detections
|
|
294
|
+
# Step 3: Second association, with low score detection boxes association the untrack to the low score detections
|
|
239
295
|
detections_second = self.init_track(dets_second, scores_second, cls_second, img)
|
|
240
296
|
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
|
|
241
297
|
# TODO
|
ultralytics/trackers/track.py
CHANGED
|
@@ -60,7 +60,6 @@ def register_tracker(model, persist):
|
|
|
60
60
|
Args:
|
|
61
61
|
model (object): The model object to register tracking callbacks for.
|
|
62
62
|
persist (bool): Whether to persist the trackers if they already exist.
|
|
63
|
-
|
|
64
63
|
"""
|
|
65
64
|
model.add_callback('on_predict_start', partial(on_predict_start, persist=persist))
|
|
66
65
|
model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end)
|
|
@@ -9,6 +9,29 @@ from ultralytics.utils import LOGGER
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class GMC:
|
|
12
|
+
"""
|
|
13
|
+
Generalized Motion Compensation (GMC) class for tracking and object detection in video frames.
|
|
14
|
+
|
|
15
|
+
This class provides methods for tracking and detecting objects based on several tracking algorithms including ORB,
|
|
16
|
+
SIFT, ECC, and Sparse Optical Flow. It also supports downscaling of frames for computational efficiency.
|
|
17
|
+
|
|
18
|
+
Attributes:
|
|
19
|
+
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
|
|
20
|
+
downscale (int): Factor by which to downscale the frames for processing.
|
|
21
|
+
prevFrame (np.array): Stores the previous frame for tracking.
|
|
22
|
+
prevKeyPoints (list): Stores the keypoints from the previous frame.
|
|
23
|
+
prevDescriptors (np.array): Stores the descriptors from the previous frame.
|
|
24
|
+
initializedFirstFrame (bool): Flag to indicate if the first frame has been processed.
|
|
25
|
+
|
|
26
|
+
Methods:
|
|
27
|
+
__init__(self, method='sparseOptFlow', downscale=2): Initializes a GMC object with the specified method
|
|
28
|
+
and downscale factor.
|
|
29
|
+
apply(self, raw_frame, detections=None): Applies the chosen method to a raw frame and optionally uses
|
|
30
|
+
provided detections.
|
|
31
|
+
applyEcc(self, raw_frame, detections=None): Applies the ECC algorithm to a raw frame.
|
|
32
|
+
applyFeatures(self, raw_frame, detections=None): Applies feature-based methods like ORB or SIFT to a raw frame.
|
|
33
|
+
applySparseOptFlow(self, raw_frame, detections=None): Applies the Sparse Optical Flow method to a raw frame.
|
|
34
|
+
"""
|
|
12
35
|
|
|
13
36
|
def __init__(self, method='sparseOptFlow', downscale=2):
|
|
14
37
|
"""Initialize a video tracker with specified parameters."""
|
|
@@ -8,8 +8,8 @@ class KalmanFilterXYAH:
|
|
|
8
8
|
"""
|
|
9
9
|
For bytetrack. A simple Kalman filter for tracking bounding boxes in image space.
|
|
10
10
|
|
|
11
|
-
The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y),
|
|
12
|
-
|
|
11
|
+
The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect
|
|
12
|
+
ratio a, height h, and their respective velocities.
|
|
13
13
|
|
|
14
14
|
Object motion follows a constant velocity model. The bounding box location (x, y, a, h) is taken as direct
|
|
15
15
|
observation of the state space (linear observation model).
|
|
@@ -182,8 +182,8 @@ class KalmanFilterXYAH:
|
|
|
182
182
|
def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'):
|
|
183
183
|
"""
|
|
184
184
|
Compute gating distance between state distribution and measurements. A suitable distance threshold can be
|
|
185
|
-
obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of
|
|
186
|
-
|
|
185
|
+
obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of freedom,
|
|
186
|
+
otherwise 2.
|
|
187
187
|
|
|
188
188
|
Parameters
|
|
189
189
|
----------
|
|
@@ -223,8 +223,8 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|
|
223
223
|
"""
|
|
224
224
|
For BoT-SORT. A simple Kalman filter for tracking bounding boxes in image space.
|
|
225
225
|
|
|
226
|
-
The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y),
|
|
227
|
-
|
|
226
|
+
The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y), width
|
|
227
|
+
w, height h, and their respective velocities.
|
|
228
228
|
|
|
229
229
|
Object motion follows a constant velocity model. The bounding box location (x, y, w, h) is taken as direct
|
|
230
230
|
observation of the state space (linear observation model).
|
ultralytics/utils/__init__.py
CHANGED
|
@@ -117,6 +117,7 @@ class TQDM(tqdm_original):
|
|
|
117
117
|
"""
|
|
118
118
|
|
|
119
119
|
def __init__(self, *args, **kwargs):
|
|
120
|
+
"""Initialize custom Ultralytics tqdm class with different default arguments."""
|
|
120
121
|
# Set new default values (these can still be overridden when calling TQDM)
|
|
121
122
|
kwargs['disable'] = not VERBOSE or kwargs.get('disable', False) # logical 'and' with default value if passed
|
|
122
123
|
kwargs.setdefault('bar_format', TQDM_BAR_FORMAT) # override default value if passed
|
|
@@ -124,8 +125,7 @@ class TQDM(tqdm_original):
|
|
|
124
125
|
|
|
125
126
|
|
|
126
127
|
class SimpleClass:
|
|
127
|
-
"""
|
|
128
|
-
Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute
|
|
128
|
+
"""Ultralytics SimpleClass is a base class providing helpful string representation, error reporting, and attribute
|
|
129
129
|
access methods for easier debugging and usage.
|
|
130
130
|
"""
|
|
131
131
|
|
|
@@ -154,8 +154,7 @@ class SimpleClass:
|
|
|
154
154
|
|
|
155
155
|
|
|
156
156
|
class IterableSimpleNamespace(SimpleNamespace):
|
|
157
|
-
"""
|
|
158
|
-
Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and
|
|
157
|
+
"""Ultralytics IterableSimpleNamespace is an extension class of SimpleNamespace that adds iterable functionality and
|
|
159
158
|
enables usage with dict() and for loops.
|
|
160
159
|
"""
|
|
161
160
|
|
|
@@ -256,8 +255,8 @@ class EmojiFilter(logging.Filter):
|
|
|
256
255
|
"""
|
|
257
256
|
A custom logging filter class for removing emojis in log messages.
|
|
258
257
|
|
|
259
|
-
This filter is particularly useful for ensuring compatibility with Windows terminals
|
|
260
|
-
|
|
258
|
+
This filter is particularly useful for ensuring compatibility with Windows terminals that may not support the
|
|
259
|
+
display of emojis in log messages.
|
|
261
260
|
"""
|
|
262
261
|
|
|
263
262
|
def filter(self, record):
|
|
@@ -275,9 +274,9 @@ if WINDOWS: # emoji-safe logging
|
|
|
275
274
|
|
|
276
275
|
class ThreadingLocked:
|
|
277
276
|
"""
|
|
278
|
-
A decorator class for ensuring thread-safe execution of a function or method.
|
|
279
|
-
|
|
280
|
-
|
|
277
|
+
A decorator class for ensuring thread-safe execution of a function or method. This class can be used as a decorator
|
|
278
|
+
to make sure that if the decorated function is called from multiple threads, only one thread at a time will be able
|
|
279
|
+
to execute the function.
|
|
281
280
|
|
|
282
281
|
Attributes:
|
|
283
282
|
lock (threading.Lock): A lock object used to manage access to the decorated function.
|
|
@@ -294,13 +293,16 @@ class ThreadingLocked:
|
|
|
294
293
|
"""
|
|
295
294
|
|
|
296
295
|
def __init__(self):
|
|
296
|
+
"""Initializes the decorator class for thread-safe execution of a function or method."""
|
|
297
297
|
self.lock = threading.Lock()
|
|
298
298
|
|
|
299
299
|
def __call__(self, f):
|
|
300
|
+
"""Run thread-safe execution of function or method."""
|
|
300
301
|
from functools import wraps
|
|
301
302
|
|
|
302
303
|
@wraps(f)
|
|
303
304
|
def decorated(*args, **kwargs):
|
|
305
|
+
"""Applies thread-safety to the decorated function or method."""
|
|
304
306
|
with self.lock:
|
|
305
307
|
return f(*args, **kwargs)
|
|
306
308
|
|
|
@@ -424,8 +426,7 @@ def is_kaggle():
|
|
|
424
426
|
|
|
425
427
|
def is_jupyter():
|
|
426
428
|
"""
|
|
427
|
-
Check if the current script is running inside a Jupyter Notebook.
|
|
428
|
-
Verified on Colab, Jupyterlab, Kaggle, Paperspace.
|
|
429
|
+
Check if the current script is running inside a Jupyter Notebook. Verified on Colab, Jupyterlab, Kaggle, Paperspace.
|
|
429
430
|
|
|
430
431
|
Returns:
|
|
431
432
|
(bool): True if running inside a Jupyter Notebook, False otherwise.
|
|
@@ -529,8 +530,8 @@ def is_github_actions_ci() -> bool:
|
|
|
529
530
|
|
|
530
531
|
def is_git_dir():
|
|
531
532
|
"""
|
|
532
|
-
Determines whether the current file is part of a git repository.
|
|
533
|
-
|
|
533
|
+
Determines whether the current file is part of a git repository. If the current file is not part of a git
|
|
534
|
+
repository, returns None.
|
|
534
535
|
|
|
535
536
|
Returns:
|
|
536
537
|
(bool): True if current file is part of a git repository.
|
|
@@ -540,8 +541,8 @@ def is_git_dir():
|
|
|
540
541
|
|
|
541
542
|
def get_git_dir():
|
|
542
543
|
"""
|
|
543
|
-
Determines whether the current file is part of a git repository and if so, returns the repository root directory.
|
|
544
|
-
|
|
544
|
+
Determines whether the current file is part of a git repository and if so, returns the repository root directory. If
|
|
545
|
+
the current file is not part of a git repository, returns None.
|
|
545
546
|
|
|
546
547
|
Returns:
|
|
547
548
|
(Path | None): Git root directory if found or None if not found.
|
|
@@ -578,7 +579,8 @@ def get_git_branch():
|
|
|
578
579
|
|
|
579
580
|
|
|
580
581
|
def get_default_args(func):
|
|
581
|
-
"""
|
|
582
|
+
"""
|
|
583
|
+
Returns a dictionary of default arguments for a function.
|
|
582
584
|
|
|
583
585
|
Args:
|
|
584
586
|
func (callable): The function to inspect.
|
|
@@ -705,12 +707,16 @@ def remove_colorstr(input_string):
|
|
|
705
707
|
>>> remove_colorstr(colorstr('blue', 'bold', 'hello world'))
|
|
706
708
|
>>> 'hello world'
|
|
707
709
|
"""
|
|
708
|
-
ansi_escape = re.compile(r'\x1B(?:[@-Z
|
|
710
|
+
ansi_escape = re.compile(r'\x1B(?:[@-Z\\\-_]|\[[0-9]*[ -/]*[@-~])')
|
|
709
711
|
return ansi_escape.sub('', input_string)
|
|
710
712
|
|
|
711
713
|
|
|
712
714
|
class TryExcept(contextlib.ContextDecorator):
|
|
713
|
-
"""
|
|
715
|
+
"""
|
|
716
|
+
YOLOv8 TryExcept class.
|
|
717
|
+
|
|
718
|
+
Use as @TryExcept() decorator or 'with TryExcept():' context manager.
|
|
719
|
+
"""
|
|
714
720
|
|
|
715
721
|
def __init__(self, msg='', verbose=True):
|
|
716
722
|
"""Initialize TryExcept class with optional message and verbosity settings."""
|
|
@@ -729,7 +735,11 @@ class TryExcept(contextlib.ContextDecorator):
|
|
|
729
735
|
|
|
730
736
|
|
|
731
737
|
def threaded(func):
|
|
732
|
-
"""
|
|
738
|
+
"""
|
|
739
|
+
Multi-threads a target function and returns thread.
|
|
740
|
+
|
|
741
|
+
Use as @threaded decorator.
|
|
742
|
+
"""
|
|
733
743
|
|
|
734
744
|
def wrapper(*args, **kwargs):
|
|
735
745
|
"""Multi-threads a given function and returns the thread."""
|
|
@@ -824,6 +834,9 @@ class SettingsManager(dict):
|
|
|
824
834
|
"""
|
|
825
835
|
|
|
826
836
|
def __init__(self, file=SETTINGS_YAML, version='0.0.4'):
|
|
837
|
+
"""Initialize the SettingsManager with default settings, load and validate current settings from the YAML
|
|
838
|
+
file.
|
|
839
|
+
"""
|
|
827
840
|
import copy
|
|
828
841
|
import hashlib
|
|
829
842
|
|
ultralytics/utils/autobatch.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
|
-
"""
|
|
3
|
-
Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch.
|
|
4
|
-
"""
|
|
2
|
+
"""Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch."""
|
|
5
3
|
|
|
6
4
|
from copy import deepcopy
|
|
7
5
|
|
ultralytics/utils/benchmarks.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
2
|
"""
|
|
3
|
-
Benchmark a YOLO model formats for speed and accuracy
|
|
3
|
+
Benchmark a YOLO model formats for speed and accuracy.
|
|
4
4
|
|
|
5
5
|
Usage:
|
|
6
6
|
from ultralytics.utils.benchmarks import ProfileModels, benchmark
|
|
@@ -194,6 +194,7 @@ class ProfileModels:
|
|
|
194
194
|
self.device = device or torch.device(0 if torch.cuda.is_available() else 'cpu')
|
|
195
195
|
|
|
196
196
|
def profile(self):
|
|
197
|
+
"""Logs the benchmarking results of a model, checks metrics against floor and returns the results."""
|
|
197
198
|
files = self.get_files()
|
|
198
199
|
|
|
199
200
|
if not files:
|
|
@@ -235,6 +236,7 @@ class ProfileModels:
|
|
|
235
236
|
return output
|
|
236
237
|
|
|
237
238
|
def get_files(self):
|
|
239
|
+
"""Returns a list of paths for all relevant model files given by the user."""
|
|
238
240
|
files = []
|
|
239
241
|
for path in self.paths:
|
|
240
242
|
path = Path(path)
|
|
@@ -250,10 +252,14 @@ class ProfileModels:
|
|
|
250
252
|
return [Path(file) for file in sorted(files)]
|
|
251
253
|
|
|
252
254
|
def get_onnx_model_info(self, onnx_file: str):
|
|
255
|
+
"""Retrieves the information including number of layers, parameters, gradients and FLOPs for an ONNX model
|
|
256
|
+
file.
|
|
257
|
+
"""
|
|
253
258
|
# return (num_layers, num_params, num_gradients, num_flops)
|
|
254
259
|
return 0.0, 0.0, 0.0, 0.0
|
|
255
260
|
|
|
256
261
|
def iterative_sigma_clipping(self, data, sigma=2, max_iters=3):
|
|
262
|
+
"""Applies an iterative sigma clipping algorithm to the given data times number of iterations."""
|
|
257
263
|
data = np.array(data)
|
|
258
264
|
for _ in range(max_iters):
|
|
259
265
|
mean, std = np.mean(data), np.std(data)
|
|
@@ -264,6 +270,7 @@ class ProfileModels:
|
|
|
264
270
|
return data
|
|
265
271
|
|
|
266
272
|
def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-3):
|
|
273
|
+
"""Profiles the TensorRT model, measuring average run time and standard deviation among runs."""
|
|
267
274
|
if not self.trt or not Path(engine_file).is_file():
|
|
268
275
|
return 0.0, 0.0
|
|
269
276
|
|
|
@@ -292,6 +299,9 @@ class ProfileModels:
|
|
|
292
299
|
return np.mean(run_times), np.std(run_times)
|
|
293
300
|
|
|
294
301
|
def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3):
|
|
302
|
+
"""Profiles an ONNX model by executing it multiple times and returns the mean and standard deviation of run
|
|
303
|
+
times.
|
|
304
|
+
"""
|
|
295
305
|
check_requirements('onnxruntime')
|
|
296
306
|
import onnxruntime as ort
|
|
297
307
|
|
|
@@ -344,10 +354,12 @@ class ProfileModels:
|
|
|
344
354
|
return np.mean(run_times), np.std(run_times)
|
|
345
355
|
|
|
346
356
|
def generate_table_row(self, model_name, t_onnx, t_engine, model_info):
|
|
357
|
+
"""Generates a formatted string for a table row that includes model performance and metric details."""
|
|
347
358
|
layers, params, gradients, flops = model_info
|
|
348
359
|
return f'| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |'
|
|
349
360
|
|
|
350
361
|
def generate_results_dict(self, model_name, t_onnx, t_engine, model_info):
|
|
362
|
+
"""Generates a dictionary of model details including name, parameters, GFLOPS and speed metrics."""
|
|
351
363
|
layers, params, gradients, flops = model_info
|
|
352
364
|
return {
|
|
353
365
|
'model/name': model_name,
|
|
@@ -357,6 +369,7 @@ class ProfileModels:
|
|
|
357
369
|
'model/speed_TensorRT(ms)': round(t_engine[0], 3)}
|
|
358
370
|
|
|
359
371
|
def print_table(self, table_rows):
|
|
372
|
+
"""Formats and prints a comparison table for different models with given statistics and performance data."""
|
|
360
373
|
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'GPU'
|
|
361
374
|
header = f'| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>{gpu} TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |'
|
|
362
375
|
separator = '|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|'
|
|
@@ -26,31 +26,38 @@ except (ImportError, AssertionError):
|
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def _get_comet_mode():
|
|
29
|
+
"""Returns the mode of comet set in the environment variables, defaults to 'online' if not set."""
|
|
29
30
|
return os.getenv('COMET_MODE', 'online')
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
def _get_comet_model_name():
|
|
34
|
+
"""Returns the model name for Comet from the environment variable 'COMET_MODEL_NAME' or defaults to 'YOLOv8'."""
|
|
33
35
|
return os.getenv('COMET_MODEL_NAME', 'YOLOv8')
|
|
34
36
|
|
|
35
37
|
|
|
36
38
|
def _get_eval_batch_logging_interval():
|
|
39
|
+
"""Get the evaluation batch logging interval from environment variable or use default value 1."""
|
|
37
40
|
return int(os.getenv('COMET_EVAL_BATCH_LOGGING_INTERVAL', 1))
|
|
38
41
|
|
|
39
42
|
|
|
40
43
|
def _get_max_image_predictions_to_log():
|
|
44
|
+
"""Get the maximum number of image predictions to log from the environment variables."""
|
|
41
45
|
return int(os.getenv('COMET_MAX_IMAGE_PREDICTIONS', 100))
|
|
42
46
|
|
|
43
47
|
|
|
44
48
|
def _scale_confidence_score(score):
|
|
49
|
+
"""Scales the given confidence score by a factor specified in an environment variable."""
|
|
45
50
|
scale = float(os.getenv('COMET_MAX_CONFIDENCE_SCORE', 100.0))
|
|
46
51
|
return score * scale
|
|
47
52
|
|
|
48
53
|
|
|
49
54
|
def _should_log_confusion_matrix():
|
|
55
|
+
"""Determines if the confusion matrix should be logged based on the environment variable settings."""
|
|
50
56
|
return os.getenv('COMET_EVAL_LOG_CONFUSION_MATRIX', 'false').lower() == 'true'
|
|
51
57
|
|
|
52
58
|
|
|
53
59
|
def _should_log_image_predictions():
|
|
60
|
+
"""Determines whether to log image predictions based on a specified environment variable."""
|
|
54
61
|
return os.getenv('COMET_EVAL_LOG_IMAGE_PREDICTIONS', 'true').lower() == 'true'
|
|
55
62
|
|
|
56
63
|
|
|
@@ -104,9 +111,10 @@ def _fetch_trainer_metadata(trainer):
|
|
|
104
111
|
|
|
105
112
|
|
|
106
113
|
def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad):
|
|
107
|
-
"""
|
|
108
|
-
are normalized based on this resized shape.
|
|
109
|
-
|
|
114
|
+
"""
|
|
115
|
+
YOLOv8 resizes images during training and the label values are normalized based on this resized shape.
|
|
116
|
+
|
|
117
|
+
This function rescales the bounding box labels to the original image shape.
|
|
110
118
|
"""
|
|
111
119
|
|
|
112
120
|
resized_image_height, resized_image_width = resized_image_shape
|
|
@@ -25,6 +25,7 @@ except (ImportError, AssertionError, TypeError):
|
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
def _log_images(path, prefix=''):
|
|
28
|
+
"""Logs images at specified path with an optional prefix using DVCLive."""
|
|
28
29
|
if live:
|
|
29
30
|
name = path.name
|
|
30
31
|
|
|
@@ -38,6 +39,7 @@ def _log_images(path, prefix=''):
|
|
|
38
39
|
|
|
39
40
|
|
|
40
41
|
def _log_plots(plots, prefix=''):
|
|
42
|
+
"""Logs plot images for training progress if they have not been previously processed."""
|
|
41
43
|
for name, params in plots.items():
|
|
42
44
|
timestamp = params['timestamp']
|
|
43
45
|
if _processed_plots.get(name) != timestamp:
|
|
@@ -46,6 +48,7 @@ def _log_plots(plots, prefix=''):
|
|
|
46
48
|
|
|
47
49
|
|
|
48
50
|
def _log_confusion_matrix(validator):
|
|
51
|
+
"""Logs the confusion matrix for the given validator using DVCLive."""
|
|
49
52
|
targets = []
|
|
50
53
|
preds = []
|
|
51
54
|
matrix = validator.confusion_matrix.matrix
|
|
@@ -62,6 +65,7 @@ def _log_confusion_matrix(validator):
|
|
|
62
65
|
|
|
63
66
|
|
|
64
67
|
def on_pretrain_routine_start(trainer):
|
|
68
|
+
"""Initializes DVCLive logger for training metadata during pre-training routine."""
|
|
65
69
|
try:
|
|
66
70
|
global live
|
|
67
71
|
live = dvclive.Live(save_dvc_exp=True, cache_images=True)
|
|
@@ -71,20 +75,24 @@ def on_pretrain_routine_start(trainer):
|
|
|
71
75
|
|
|
72
76
|
|
|
73
77
|
def on_pretrain_routine_end(trainer):
|
|
78
|
+
"""Logs plots related to the training process at the end of the pretraining routine."""
|
|
74
79
|
_log_plots(trainer.plots, 'train')
|
|
75
80
|
|
|
76
81
|
|
|
77
82
|
def on_train_start(trainer):
|
|
83
|
+
"""Logs the training parameters if DVCLive logging is active."""
|
|
78
84
|
if live:
|
|
79
85
|
live.log_params(trainer.args)
|
|
80
86
|
|
|
81
87
|
|
|
82
88
|
def on_train_epoch_start(trainer):
|
|
89
|
+
"""Sets the global variable _training_epoch value to True at the start of training each epoch."""
|
|
83
90
|
global _training_epoch
|
|
84
91
|
_training_epoch = True
|
|
85
92
|
|
|
86
93
|
|
|
87
94
|
def on_fit_epoch_end(trainer):
|
|
95
|
+
"""Logs training metrics and model info, and advances to next step on the end of each fit epoch."""
|
|
88
96
|
global _training_epoch
|
|
89
97
|
if live and _training_epoch:
|
|
90
98
|
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
|
|
@@ -104,6 +112,7 @@ def on_fit_epoch_end(trainer):
|
|
|
104
112
|
|
|
105
113
|
|
|
106
114
|
def on_train_end(trainer):
|
|
115
|
+
"""Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active."""
|
|
107
116
|
if live:
|
|
108
117
|
# At the end log the best metrics. It runs validator on the best model internally.
|
|
109
118
|
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics, **trainer.lr}
|
|
@@ -31,14 +31,13 @@ def _log_images(imgs_dict, group=''):
|
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
def _log_plot(title, plot_path):
|
|
34
|
-
"""Log plots to the NeptuneAI experiment logger."""
|
|
35
34
|
"""
|
|
36
|
-
|
|
35
|
+
Log plots to the NeptuneAI experiment logger.
|
|
37
36
|
|
|
38
|
-
|
|
39
|
-
title (str) Title of the plot
|
|
40
|
-
plot_path (PosixPath
|
|
41
|
-
|
|
37
|
+
Args:
|
|
38
|
+
title (str): Title of the plot.
|
|
39
|
+
plot_path (PosixPath | str): Path to the saved image file.
|
|
40
|
+
"""
|
|
42
41
|
import matplotlib.image as mpimg
|
|
43
42
|
import matplotlib.pyplot as plt
|
|
44
43
|
|
|
@@ -17,6 +17,7 @@ except (ImportError, AssertionError):
|
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
def _log_plots(plots, step):
|
|
20
|
+
"""Logs plots from the input dictionary if they haven't been logged already at the specified step."""
|
|
20
21
|
for name, params in plots.items():
|
|
21
22
|
timestamp = params['timestamp']
|
|
22
23
|
if _processed_plots.get(name) != timestamp:
|
ultralytics/utils/checks.py
CHANGED
|
@@ -64,8 +64,8 @@ def parse_requirements(file_path=ROOT.parent / 'requirements.txt', package=''):
|
|
|
64
64
|
|
|
65
65
|
def parse_version(version='0.0.0') -> tuple:
|
|
66
66
|
"""
|
|
67
|
-
Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version.
|
|
68
|
-
|
|
67
|
+
Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This
|
|
68
|
+
function replaces deprecated 'pkg_resources.parse_version(v)'.
|
|
69
69
|
|
|
70
70
|
Args:
|
|
71
71
|
version (str): Version string, i.e. '2.0.1+cpu'
|
|
@@ -372,8 +372,10 @@ def check_torchvision():
|
|
|
372
372
|
Checks the installed versions of PyTorch and Torchvision to ensure they're compatible.
|
|
373
373
|
|
|
374
374
|
This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
|
|
375
|
-
to the provided compatibility table based on
|
|
376
|
-
|
|
375
|
+
to the provided compatibility table based on:
|
|
376
|
+
https://github.com/pytorch/vision#installation.
|
|
377
|
+
|
|
378
|
+
The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
|
|
377
379
|
Torchvision versions.
|
|
378
380
|
"""
|
|
379
381
|
|
|
@@ -527,9 +529,9 @@ def collect_system_info():
|
|
|
527
529
|
|
|
528
530
|
def check_amp(model):
|
|
529
531
|
"""
|
|
530
|
-
This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model.
|
|
531
|
-
|
|
532
|
-
|
|
532
|
+
This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks
|
|
533
|
+
fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will
|
|
534
|
+
be disabled during training.
|
|
533
535
|
|
|
534
536
|
Args:
|
|
535
537
|
model (nn.Module): A YOLOv8 model instance.
|
|
@@ -606,7 +608,8 @@ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
|
|
|
606
608
|
|
|
607
609
|
|
|
608
610
|
def cuda_device_count() -> int:
|
|
609
|
-
"""
|
|
611
|
+
"""
|
|
612
|
+
Get the number of NVIDIA GPUs available in the environment.
|
|
610
613
|
|
|
611
614
|
Returns:
|
|
612
615
|
(int): The number of NVIDIA GPUs available.
|
|
@@ -626,7 +629,8 @@ def cuda_device_count() -> int:
|
|
|
626
629
|
|
|
627
630
|
|
|
628
631
|
def cuda_is_available() -> bool:
|
|
629
|
-
"""
|
|
632
|
+
"""
|
|
633
|
+
Check if CUDA is available in the environment.
|
|
630
634
|
|
|
631
635
|
Returns:
|
|
632
636
|
(bool): True if one or more NVIDIA GPUs are available, False otherwise.
|