ultralytics 8.2.76__py3-none-any.whl → 8.2.77__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +1 -1
- ultralytics/engine/results.py +19 -5
- ultralytics/engine/trainer.py +3 -1
- ultralytics/models/yolo/detect/train.py +1 -1
- ultralytics/trackers/basetrack.py +31 -12
- ultralytics/trackers/bot_sort.py +58 -24
- ultralytics/trackers/byte_tracker.py +75 -42
- ultralytics/trackers/track.py +17 -2
- ultralytics/trackers/utils/gmc.py +52 -38
- ultralytics/trackers/utils/kalman_filter.py +162 -31
- ultralytics/trackers/utils/matching.py +38 -14
- ultralytics/utils/__init__.py +1 -1
- ultralytics/utils/files.py +69 -34
- ultralytics/utils/plotting.py +11 -3
- {ultralytics-8.2.76.dist-info → ultralytics-8.2.77.dist-info}/METADATA +2 -2
- {ultralytics-8.2.76.dist-info → ultralytics-8.2.77.dist-info}/RECORD +21 -21
- {ultralytics-8.2.76.dist-info → ultralytics-8.2.77.dist-info}/WHEEL +1 -1
- {ultralytics-8.2.76.dist-info → ultralytics-8.2.77.dist-info}/LICENSE +0 -0
- {ultralytics-8.2.76.dist-info → ultralytics-8.2.77.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.2.76.dist-info → ultralytics-8.2.77.dist-info}/top_level.txt +0 -0
|
@@ -25,7 +25,7 @@ class STrack(BaseTrack):
|
|
|
25
25
|
is_activated (bool): Boolean flag indicating if the track has been activated.
|
|
26
26
|
score (float): Confidence score of the track.
|
|
27
27
|
tracklet_len (int): Length of the tracklet.
|
|
28
|
-
cls (
|
|
28
|
+
cls (Any): Class label for the object.
|
|
29
29
|
idx (int): Index or identifier for the object.
|
|
30
30
|
frame_id (int): Current frame ID.
|
|
31
31
|
start_frame (int): Frame where the object was first detected.
|
|
@@ -39,12 +39,31 @@ class STrack(BaseTrack):
|
|
|
39
39
|
update(new_track, frame_id): Update the state of a matched track.
|
|
40
40
|
convert_coords(tlwh): Convert bounding box to x-y-aspect-height format.
|
|
41
41
|
tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format.
|
|
42
|
+
|
|
43
|
+
Examples:
|
|
44
|
+
Initialize and activate a new track
|
|
45
|
+
>>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls='person')
|
|
46
|
+
>>> track.activate(kalman_filter=KalmanFilterXYAH(), frame_id=1)
|
|
42
47
|
"""
|
|
43
48
|
|
|
44
49
|
shared_kalman = KalmanFilterXYAH()
|
|
45
50
|
|
|
46
51
|
def __init__(self, xywh, score, cls):
|
|
47
|
-
"""
|
|
52
|
+
"""
|
|
53
|
+
Initialize a new STrack instance.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
xywh (List[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where
|
|
57
|
+
(x, y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id.
|
|
58
|
+
score (float): Confidence score of the detection.
|
|
59
|
+
cls (Any): Class label for the detected object.
|
|
60
|
+
|
|
61
|
+
Examples:
|
|
62
|
+
>>> xywh = [100.0, 150.0, 50.0, 75.0, 1]
|
|
63
|
+
>>> score = 0.9
|
|
64
|
+
>>> cls = 'person'
|
|
65
|
+
>>> track = STrack(xywh, score, cls)
|
|
66
|
+
"""
|
|
48
67
|
super().__init__()
|
|
49
68
|
# xywh+idx or xywha+idx
|
|
50
69
|
assert len(xywh) in {5, 6}, f"expected 5 or 6 values but got {len(xywh)}"
|
|
@@ -60,7 +79,7 @@ class STrack(BaseTrack):
|
|
|
60
79
|
self.angle = xywh[4] if len(xywh) == 6 else None
|
|
61
80
|
|
|
62
81
|
def predict(self):
|
|
63
|
-
"""Predicts mean and covariance using Kalman filter."""
|
|
82
|
+
"""Predicts the next state (mean and covariance) of the object using the Kalman filter."""
|
|
64
83
|
mean_state = self.mean.copy()
|
|
65
84
|
if self.state != TrackState.Tracked:
|
|
66
85
|
mean_state[7] = 0
|
|
@@ -68,7 +87,7 @@ class STrack(BaseTrack):
|
|
|
68
87
|
|
|
69
88
|
@staticmethod
|
|
70
89
|
def multi_predict(stracks):
|
|
71
|
-
"""Perform multi-object predictive tracking using Kalman filter for
|
|
90
|
+
"""Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances."""
|
|
72
91
|
if len(stracks) <= 0:
|
|
73
92
|
return
|
|
74
93
|
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
|
@@ -83,7 +102,7 @@ class STrack(BaseTrack):
|
|
|
83
102
|
|
|
84
103
|
@staticmethod
|
|
85
104
|
def multi_gmc(stracks, H=np.eye(2, 3)):
|
|
86
|
-
"""Update state tracks positions and covariances using a homography matrix."""
|
|
105
|
+
"""Update state tracks positions and covariances using a homography matrix for multiple tracks."""
|
|
87
106
|
if len(stracks) > 0:
|
|
88
107
|
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
|
89
108
|
multi_covariance = np.asarray([st.covariance for st in stracks])
|
|
@@ -101,7 +120,7 @@ class STrack(BaseTrack):
|
|
|
101
120
|
stracks[i].covariance = cov
|
|
102
121
|
|
|
103
122
|
def activate(self, kalman_filter, frame_id):
|
|
104
|
-
"""
|
|
123
|
+
"""Activate a new tracklet using the provided Kalman filter and initialize its state and covariance."""
|
|
105
124
|
self.kalman_filter = kalman_filter
|
|
106
125
|
self.track_id = self.next_id()
|
|
107
126
|
self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh))
|
|
@@ -114,7 +133,7 @@ class STrack(BaseTrack):
|
|
|
114
133
|
self.start_frame = frame_id
|
|
115
134
|
|
|
116
135
|
def re_activate(self, new_track, frame_id, new_id=False):
|
|
117
|
-
"""Reactivates a previously lost track
|
|
136
|
+
"""Reactivates a previously lost track using new detection data and updates its state and attributes."""
|
|
118
137
|
self.mean, self.covariance = self.kalman_filter.update(
|
|
119
138
|
self.mean, self.covariance, self.convert_coords(new_track.tlwh)
|
|
120
139
|
)
|
|
@@ -136,6 +155,12 @@ class STrack(BaseTrack):
|
|
|
136
155
|
Args:
|
|
137
156
|
new_track (STrack): The new track containing updated information.
|
|
138
157
|
frame_id (int): The ID of the current frame.
|
|
158
|
+
|
|
159
|
+
Examples:
|
|
160
|
+
Update the state of a track with new detection information
|
|
161
|
+
>>> track = STrack([100, 200, 50, 80, 0.9, 1])
|
|
162
|
+
>>> new_track = STrack([105, 205, 55, 85, 0.95, 1])
|
|
163
|
+
>>> track.update(new_track, 2)
|
|
139
164
|
"""
|
|
140
165
|
self.frame_id = frame_id
|
|
141
166
|
self.tracklet_len += 1
|
|
@@ -158,7 +183,7 @@ class STrack(BaseTrack):
|
|
|
158
183
|
|
|
159
184
|
@property
|
|
160
185
|
def tlwh(self):
|
|
161
|
-
"""
|
|
186
|
+
"""Returns the bounding box in top-left-width-height format from the current state estimate."""
|
|
162
187
|
if self.mean is None:
|
|
163
188
|
return self._tlwh.copy()
|
|
164
189
|
ret = self.mean[:4].copy()
|
|
@@ -168,16 +193,14 @@ class STrack(BaseTrack):
|
|
|
168
193
|
|
|
169
194
|
@property
|
|
170
195
|
def xyxy(self):
|
|
171
|
-
"""
|
|
196
|
+
"""Converts bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format."""
|
|
172
197
|
ret = self.tlwh.copy()
|
|
173
198
|
ret[2:] += ret[:2]
|
|
174
199
|
return ret
|
|
175
200
|
|
|
176
201
|
@staticmethod
|
|
177
202
|
def tlwh_to_xyah(tlwh):
|
|
178
|
-
"""Convert bounding box
|
|
179
|
-
height.
|
|
180
|
-
"""
|
|
203
|
+
"""Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format."""
|
|
181
204
|
ret = np.asarray(tlwh).copy()
|
|
182
205
|
ret[:2] += ret[2:] / 2
|
|
183
206
|
ret[2] /= ret[3]
|
|
@@ -185,14 +208,14 @@ class STrack(BaseTrack):
|
|
|
185
208
|
|
|
186
209
|
@property
|
|
187
210
|
def xywh(self):
|
|
188
|
-
"""
|
|
211
|
+
"""Returns the current position of the bounding box in (center x, center y, width, height) format."""
|
|
189
212
|
ret = np.asarray(self.tlwh).copy()
|
|
190
213
|
ret[:2] += ret[2:] / 2
|
|
191
214
|
return ret
|
|
192
215
|
|
|
193
216
|
@property
|
|
194
217
|
def xywha(self):
|
|
195
|
-
"""
|
|
218
|
+
"""Returns position in (center x, center y, width, height, angle) format, warning if angle is missing."""
|
|
196
219
|
if self.angle is None:
|
|
197
220
|
LOGGER.warning("WARNING ⚠️ `angle` attr not found, returning `xywh` instead.")
|
|
198
221
|
return self.xywh
|
|
@@ -200,12 +223,12 @@ class STrack(BaseTrack):
|
|
|
200
223
|
|
|
201
224
|
@property
|
|
202
225
|
def result(self):
|
|
203
|
-
"""
|
|
226
|
+
"""Returns the current tracking results in the appropriate bounding box format."""
|
|
204
227
|
coords = self.xyxy if self.angle is None else self.xywha
|
|
205
228
|
return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
|
|
206
229
|
|
|
207
230
|
def __repr__(self):
|
|
208
|
-
"""
|
|
231
|
+
"""Returns a string representation of the STrack object including start frame, end frame, and track ID."""
|
|
209
232
|
return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
|
|
210
233
|
|
|
211
234
|
|
|
@@ -213,18 +236,18 @@ class BYTETracker:
|
|
|
213
236
|
"""
|
|
214
237
|
BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
|
|
215
238
|
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
239
|
+
Responsible for initializing, updating, and managing the tracks for detected objects in a video sequence.
|
|
240
|
+
It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for predicting
|
|
241
|
+
the new object locations, and performs data association.
|
|
219
242
|
|
|
220
243
|
Attributes:
|
|
221
|
-
tracked_stracks (
|
|
222
|
-
lost_stracks (
|
|
223
|
-
removed_stracks (
|
|
244
|
+
tracked_stracks (List[STrack]): List of successfully activated tracks.
|
|
245
|
+
lost_stracks (List[STrack]): List of lost tracks.
|
|
246
|
+
removed_stracks (List[STrack]): List of removed tracks.
|
|
224
247
|
frame_id (int): The current frame ID.
|
|
225
|
-
args (
|
|
248
|
+
args (Namespace): Command-line arguments.
|
|
226
249
|
max_time_lost (int): The maximum frames for a track to be considered as 'lost'.
|
|
227
|
-
kalman_filter (
|
|
250
|
+
kalman_filter (KalmanFilterXYAH): Kalman Filter object.
|
|
228
251
|
|
|
229
252
|
Methods:
|
|
230
253
|
update(results, img=None): Updates object tracker with new detections.
|
|
@@ -236,10 +259,27 @@ class BYTETracker:
|
|
|
236
259
|
joint_stracks(tlista, tlistb): Combines two lists of stracks.
|
|
237
260
|
sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list.
|
|
238
261
|
remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IoU.
|
|
262
|
+
|
|
263
|
+
Examples:
|
|
264
|
+
Initialize BYTETracker and update with detection results
|
|
265
|
+
>>> tracker = BYTETracker(args, frame_rate=30)
|
|
266
|
+
>>> results = yolo_model.detect(image)
|
|
267
|
+
>>> tracked_objects = tracker.update(results)
|
|
239
268
|
"""
|
|
240
269
|
|
|
241
270
|
def __init__(self, args, frame_rate=30):
|
|
242
|
-
"""
|
|
271
|
+
"""
|
|
272
|
+
Initialize a BYTETracker instance for object tracking.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
args (Namespace): Command-line arguments containing tracking parameters.
|
|
276
|
+
frame_rate (int): Frame rate of the video sequence.
|
|
277
|
+
|
|
278
|
+
Examples:
|
|
279
|
+
Initialize BYTETracker with command-line arguments and a frame rate of 30
|
|
280
|
+
>>> args = Namespace(track_buffer=30)
|
|
281
|
+
>>> tracker = BYTETracker(args, frame_rate=30)
|
|
282
|
+
"""
|
|
243
283
|
self.tracked_stracks = [] # type: list[STrack]
|
|
244
284
|
self.lost_stracks = [] # type: list[STrack]
|
|
245
285
|
self.removed_stracks = [] # type: list[STrack]
|
|
@@ -251,7 +291,7 @@ class BYTETracker:
|
|
|
251
291
|
self.reset_id()
|
|
252
292
|
|
|
253
293
|
def update(self, results, img=None):
|
|
254
|
-
"""Updates
|
|
294
|
+
"""Updates the tracker with new detections and returns the current list of tracked objects."""
|
|
255
295
|
self.frame_id += 1
|
|
256
296
|
activated_stracks = []
|
|
257
297
|
refind_stracks = []
|
|
@@ -365,31 +405,31 @@ class BYTETracker:
|
|
|
365
405
|
return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
|
|
366
406
|
|
|
367
407
|
def get_kalmanfilter(self):
|
|
368
|
-
"""Returns a Kalman filter object for tracking bounding boxes."""
|
|
408
|
+
"""Returns a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH."""
|
|
369
409
|
return KalmanFilterXYAH()
|
|
370
410
|
|
|
371
411
|
def init_track(self, dets, scores, cls, img=None):
|
|
372
|
-
"""
|
|
412
|
+
"""Initializes object tracking with given detections, scores, and class labels using the STrack algorithm."""
|
|
373
413
|
return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
|
|
374
414
|
|
|
375
415
|
def get_dists(self, tracks, detections):
|
|
376
|
-
"""Calculates the distance between tracks and detections using IoU and fuses scores."""
|
|
416
|
+
"""Calculates the distance between tracks and detections using IoU and optionally fuses scores."""
|
|
377
417
|
dists = matching.iou_distance(tracks, detections)
|
|
378
418
|
if self.args.fuse_score:
|
|
379
419
|
dists = matching.fuse_score(dists, detections)
|
|
380
420
|
return dists
|
|
381
421
|
|
|
382
422
|
def multi_predict(self, tracks):
|
|
383
|
-
"""
|
|
423
|
+
"""Predict the next states for multiple tracks using Kalman filter."""
|
|
384
424
|
STrack.multi_predict(tracks)
|
|
385
425
|
|
|
386
426
|
@staticmethod
|
|
387
427
|
def reset_id():
|
|
388
|
-
"""Resets the ID counter
|
|
428
|
+
"""Resets the ID counter for STrack instances to ensure unique track IDs across tracking sessions."""
|
|
389
429
|
STrack.reset_id()
|
|
390
430
|
|
|
391
431
|
def reset(self):
|
|
392
|
-
"""
|
|
432
|
+
"""Resets the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter."""
|
|
393
433
|
self.tracked_stracks = [] # type: list[STrack]
|
|
394
434
|
self.lost_stracks = [] # type: list[STrack]
|
|
395
435
|
self.removed_stracks = [] # type: list[STrack]
|
|
@@ -399,7 +439,7 @@ class BYTETracker:
|
|
|
399
439
|
|
|
400
440
|
@staticmethod
|
|
401
441
|
def joint_stracks(tlista, tlistb):
|
|
402
|
-
"""
|
|
442
|
+
"""Combines two lists of STrack objects into a single list, ensuring no duplicates based on track IDs."""
|
|
403
443
|
exists = {}
|
|
404
444
|
res = []
|
|
405
445
|
for t in tlista:
|
|
@@ -414,20 +454,13 @@ class BYTETracker:
|
|
|
414
454
|
|
|
415
455
|
@staticmethod
|
|
416
456
|
def sub_stracks(tlista, tlistb):
|
|
417
|
-
"""
|
|
418
|
-
stracks = {t.track_id: t for t in tlista}
|
|
419
|
-
for t in tlistb:
|
|
420
|
-
tid = t.track_id
|
|
421
|
-
if stracks.get(tid, 0):
|
|
422
|
-
del stracks[tid]
|
|
423
|
-
return list(stracks.values())
|
|
424
|
-
"""
|
|
457
|
+
"""Filters out the stracks present in the second list from the first list."""
|
|
425
458
|
track_ids_b = {t.track_id for t in tlistb}
|
|
426
459
|
return [t for t in tlista if t.track_id not in track_ids_b]
|
|
427
460
|
|
|
428
461
|
@staticmethod
|
|
429
462
|
def remove_duplicate_stracks(stracksa, stracksb):
|
|
430
|
-
"""
|
|
463
|
+
"""Removes duplicate stracks from two lists based on Intersection over Union (IoU) distance."""
|
|
431
464
|
pdist = matching.iou_distance(stracksa, stracksb)
|
|
432
465
|
pairs = np.where(pdist < 0.15)
|
|
433
466
|
dupa, dupb = [], []
|
ultralytics/trackers/track.py
CHANGED
|
@@ -21,10 +21,15 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
|
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
23
|
predictor (object): The predictor object to initialize trackers for.
|
|
24
|
-
persist (bool
|
|
24
|
+
persist (bool): Whether to persist the trackers if they already exist.
|
|
25
25
|
|
|
26
26
|
Raises:
|
|
27
27
|
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
|
|
28
|
+
|
|
29
|
+
Examples:
|
|
30
|
+
Initialize trackers for a predictor object:
|
|
31
|
+
>>> predictor = SomePredictorClass()
|
|
32
|
+
>>> on_predict_start(predictor, persist=True)
|
|
28
33
|
"""
|
|
29
34
|
if hasattr(predictor, "trackers") and persist:
|
|
30
35
|
return
|
|
@@ -51,7 +56,12 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
|
|
|
51
56
|
|
|
52
57
|
Args:
|
|
53
58
|
predictor (object): The predictor object containing the predictions.
|
|
54
|
-
persist (bool
|
|
59
|
+
persist (bool): Whether to persist the trackers if they already exist.
|
|
60
|
+
|
|
61
|
+
Examples:
|
|
62
|
+
Postprocess predictions and update with tracking
|
|
63
|
+
>>> predictor = YourPredictorClass()
|
|
64
|
+
>>> on_predict_postprocess_end(predictor, persist=True)
|
|
55
65
|
"""
|
|
56
66
|
path, im0s = predictor.batch[:2]
|
|
57
67
|
|
|
@@ -84,6 +94,11 @@ def register_tracker(model: object, persist: bool) -> None:
|
|
|
84
94
|
Args:
|
|
85
95
|
model (object): The model object to register tracking callbacks for.
|
|
86
96
|
persist (bool): Whether to persist the trackers if they already exist.
|
|
97
|
+
|
|
98
|
+
Examples:
|
|
99
|
+
Register tracking callbacks to a YOLO model
|
|
100
|
+
>>> model = YOLOModel()
|
|
101
|
+
>>> register_tracker(model, persist=True)
|
|
87
102
|
"""
|
|
88
103
|
model.add_callback("on_predict_start", partial(on_predict_start, persist=persist))
|
|
89
104
|
model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist))
|
|
@@ -19,27 +19,39 @@ class GMC:
|
|
|
19
19
|
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
|
|
20
20
|
downscale (int): Factor by which to downscale the frames for processing.
|
|
21
21
|
prevFrame (np.ndarray): Stores the previous frame for tracking.
|
|
22
|
-
prevKeyPoints (
|
|
22
|
+
prevKeyPoints (List): Stores the keypoints from the previous frame.
|
|
23
23
|
prevDescriptors (np.ndarray): Stores the descriptors from the previous frame.
|
|
24
24
|
initializedFirstFrame (bool): Flag to indicate if the first frame has been processed.
|
|
25
25
|
|
|
26
26
|
Methods:
|
|
27
|
-
__init__
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
27
|
+
__init__: Initializes a GMC object with the specified method and downscale factor.
|
|
28
|
+
apply: Applies the chosen method to a raw frame and optionally uses provided detections.
|
|
29
|
+
applyEcc: Applies the ECC algorithm to a raw frame.
|
|
30
|
+
applyFeatures: Applies feature-based methods like ORB or SIFT to a raw frame.
|
|
31
|
+
applySparseOptFlow: Applies the Sparse Optical Flow method to a raw frame.
|
|
32
|
+
reset_params: Resets the internal parameters of the GMC object.
|
|
33
|
+
|
|
34
|
+
Examples:
|
|
35
|
+
Create a GMC object and apply it to a frame
|
|
36
|
+
>>> gmc = GMC(method='sparseOptFlow', downscale=2)
|
|
37
|
+
>>> frame = np.array([[1, 2, 3], [4, 5, 6]])
|
|
38
|
+
>>> processed_frame = gmc.apply(frame)
|
|
39
|
+
>>> print(processed_frame)
|
|
40
|
+
array([[1, 2, 3],
|
|
41
|
+
[4, 5, 6]])
|
|
34
42
|
"""
|
|
35
43
|
|
|
36
44
|
def __init__(self, method: str = "sparseOptFlow", downscale: int = 2) -> None:
|
|
37
45
|
"""
|
|
38
|
-
Initialize a
|
|
46
|
+
Initialize a Generalized Motion Compensation (GMC) object with tracking method and downscale factor.
|
|
39
47
|
|
|
40
48
|
Args:
|
|
41
49
|
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
|
|
42
50
|
downscale (int): Downscale factor for processing frames.
|
|
51
|
+
|
|
52
|
+
Examples:
|
|
53
|
+
Initialize a GMC object with the 'sparseOptFlow' method and a downscale factor of 2
|
|
54
|
+
>>> gmc = GMC(method='sparseOptFlow', downscale=2)
|
|
43
55
|
"""
|
|
44
56
|
super().__init__()
|
|
45
57
|
|
|
@@ -79,20 +91,21 @@ class GMC:
|
|
|
79
91
|
|
|
80
92
|
def apply(self, raw_frame: np.array, detections: list = None) -> np.array:
|
|
81
93
|
"""
|
|
82
|
-
Apply object detection on a raw frame using specified method.
|
|
94
|
+
Apply object detection on a raw frame using the specified method.
|
|
83
95
|
|
|
84
96
|
Args:
|
|
85
|
-
raw_frame (np.ndarray): The raw frame to be processed.
|
|
86
|
-
detections (
|
|
97
|
+
raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
|
|
98
|
+
detections (List | None): List of detections to be used in the processing.
|
|
87
99
|
|
|
88
100
|
Returns:
|
|
89
|
-
(np.ndarray): Processed frame.
|
|
101
|
+
(np.ndarray): Processed frame with applied object detection.
|
|
90
102
|
|
|
91
103
|
Examples:
|
|
92
|
-
>>> gmc = GMC()
|
|
93
|
-
>>>
|
|
94
|
-
|
|
95
|
-
|
|
104
|
+
>>> gmc = GMC(method='sparseOptFlow')
|
|
105
|
+
>>> raw_frame = np.random.rand(480, 640, 3)
|
|
106
|
+
>>> processed_frame = gmc.apply(raw_frame)
|
|
107
|
+
>>> print(processed_frame.shape)
|
|
108
|
+
(480, 640, 3)
|
|
96
109
|
"""
|
|
97
110
|
if self.method in {"orb", "sift"}:
|
|
98
111
|
return self.applyFeatures(raw_frame, detections)
|
|
@@ -105,19 +118,20 @@ class GMC:
|
|
|
105
118
|
|
|
106
119
|
def applyEcc(self, raw_frame: np.array) -> np.array:
|
|
107
120
|
"""
|
|
108
|
-
Apply ECC algorithm to a raw frame.
|
|
121
|
+
Apply the ECC (Enhanced Correlation Coefficient) algorithm to a raw frame for motion compensation.
|
|
109
122
|
|
|
110
123
|
Args:
|
|
111
|
-
raw_frame (np.ndarray): The raw frame to be processed.
|
|
124
|
+
raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
|
|
112
125
|
|
|
113
126
|
Returns:
|
|
114
|
-
(np.ndarray):
|
|
127
|
+
(np.ndarray): The processed frame with the applied ECC transformation.
|
|
115
128
|
|
|
116
129
|
Examples:
|
|
117
|
-
>>> gmc = GMC()
|
|
118
|
-
>>> gmc.applyEcc(np.array([[1, 2, 3], [4, 5, 6]]))
|
|
119
|
-
|
|
120
|
-
|
|
130
|
+
>>> gmc = GMC(method='ecc')
|
|
131
|
+
>>> processed_frame = gmc.applyEcc(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]))
|
|
132
|
+
>>> print(processed_frame)
|
|
133
|
+
[[1. 0. 0.]
|
|
134
|
+
[0. 1. 0.]]
|
|
121
135
|
"""
|
|
122
136
|
height, width, _ = raw_frame.shape
|
|
123
137
|
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
|
@@ -127,8 +141,6 @@ class GMC:
|
|
|
127
141
|
if self.downscale > 1.0:
|
|
128
142
|
frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
|
129
143
|
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
|
130
|
-
width = width // self.downscale
|
|
131
|
-
height = height // self.downscale
|
|
132
144
|
|
|
133
145
|
# Handle first frame
|
|
134
146
|
if not self.initializedFirstFrame:
|
|
@@ -154,17 +166,18 @@ class GMC:
|
|
|
154
166
|
Apply feature-based methods like ORB or SIFT to a raw frame.
|
|
155
167
|
|
|
156
168
|
Args:
|
|
157
|
-
raw_frame (np.ndarray): The raw frame to be processed.
|
|
158
|
-
detections (
|
|
169
|
+
raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
|
|
170
|
+
detections (List | None): List of detections to be used in the processing.
|
|
159
171
|
|
|
160
172
|
Returns:
|
|
161
173
|
(np.ndarray): Processed frame.
|
|
162
174
|
|
|
163
175
|
Examples:
|
|
164
|
-
>>> gmc = GMC()
|
|
165
|
-
>>>
|
|
166
|
-
|
|
167
|
-
|
|
176
|
+
>>> gmc = GMC(method='orb')
|
|
177
|
+
>>> raw_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
|
|
178
|
+
>>> processed_frame = gmc.applyFeatures(raw_frame)
|
|
179
|
+
>>> print(processed_frame.shape)
|
|
180
|
+
(2, 3)
|
|
168
181
|
"""
|
|
169
182
|
height, width, _ = raw_frame.shape
|
|
170
183
|
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
|
@@ -296,16 +309,17 @@ class GMC:
|
|
|
296
309
|
Apply Sparse Optical Flow method to a raw frame.
|
|
297
310
|
|
|
298
311
|
Args:
|
|
299
|
-
raw_frame (np.ndarray): The raw frame to be processed.
|
|
312
|
+
raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
|
|
300
313
|
|
|
301
314
|
Returns:
|
|
302
|
-
(np.ndarray): Processed frame.
|
|
315
|
+
(np.ndarray): Processed frame with shape (2, 3).
|
|
303
316
|
|
|
304
317
|
Examples:
|
|
305
318
|
>>> gmc = GMC()
|
|
306
|
-
>>> gmc.applySparseOptFlow(np.array([[1, 2, 3], [4, 5, 6]]))
|
|
307
|
-
|
|
308
|
-
|
|
319
|
+
>>> result = gmc.applySparseOptFlow(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]))
|
|
320
|
+
>>> print(result)
|
|
321
|
+
[[1. 0. 0.]
|
|
322
|
+
[0. 1. 0.]]
|
|
309
323
|
"""
|
|
310
324
|
height, width, _ = raw_frame.shape
|
|
311
325
|
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
|
@@ -356,7 +370,7 @@ class GMC:
|
|
|
356
370
|
return H
|
|
357
371
|
|
|
358
372
|
def reset_params(self) -> None:
|
|
359
|
-
"""Reset parameters."""
|
|
373
|
+
"""Reset the internal parameters including previous frame, keypoints, and descriptors."""
|
|
360
374
|
self.prevFrame = None
|
|
361
375
|
self.prevKeyPoints = None
|
|
362
376
|
self.prevDescriptors = None
|