dgenerate-ultralytics-headless 8.3.189__py3-none-any.whl → 8.3.191__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/METADATA +1 -1
- {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/RECORD +111 -109
- tests/test_cuda.py +6 -5
- tests/test_exports.py +1 -6
- tests/test_python.py +1 -4
- tests/test_solutions.py +1 -1
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +16 -14
- ultralytics/cfg/datasets/VisDrone.yaml +4 -4
- ultralytics/data/annotator.py +6 -6
- ultralytics/data/augment.py +53 -51
- ultralytics/data/base.py +15 -13
- ultralytics/data/build.py +7 -4
- ultralytics/data/converter.py +9 -10
- ultralytics/data/dataset.py +24 -22
- ultralytics/data/loaders.py +13 -11
- ultralytics/data/split.py +4 -3
- ultralytics/data/split_dota.py +14 -12
- ultralytics/data/utils.py +31 -25
- ultralytics/engine/exporter.py +7 -4
- ultralytics/engine/model.py +16 -14
- ultralytics/engine/predictor.py +9 -7
- ultralytics/engine/results.py +59 -57
- ultralytics/engine/trainer.py +7 -0
- ultralytics/engine/tuner.py +4 -3
- ultralytics/engine/validator.py +3 -1
- ultralytics/hub/__init__.py +6 -2
- ultralytics/hub/auth.py +2 -2
- ultralytics/hub/google/__init__.py +9 -8
- ultralytics/hub/session.py +11 -11
- ultralytics/hub/utils.py +8 -9
- ultralytics/models/fastsam/model.py +8 -6
- ultralytics/models/nas/model.py +5 -3
- ultralytics/models/rtdetr/train.py +4 -3
- ultralytics/models/rtdetr/val.py +6 -4
- ultralytics/models/sam/amg.py +13 -10
- ultralytics/models/sam/model.py +3 -2
- ultralytics/models/sam/modules/blocks.py +21 -21
- ultralytics/models/sam/modules/decoders.py +11 -11
- ultralytics/models/sam/modules/encoders.py +25 -25
- ultralytics/models/sam/modules/memory_attention.py +9 -8
- ultralytics/models/sam/modules/sam.py +8 -10
- ultralytics/models/sam/modules/tiny_encoder.py +21 -20
- ultralytics/models/sam/modules/transformer.py +6 -5
- ultralytics/models/sam/modules/utils.py +7 -5
- ultralytics/models/sam/predict.py +32 -31
- ultralytics/models/utils/loss.py +29 -27
- ultralytics/models/utils/ops.py +10 -8
- ultralytics/models/yolo/classify/train.py +7 -5
- ultralytics/models/yolo/classify/val.py +10 -8
- ultralytics/models/yolo/detect/predict.py +3 -3
- ultralytics/models/yolo/detect/train.py +8 -6
- ultralytics/models/yolo/detect/val.py +23 -21
- ultralytics/models/yolo/model.py +14 -14
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +13 -10
- ultralytics/models/yolo/pose/train.py +7 -5
- ultralytics/models/yolo/pose/val.py +11 -9
- ultralytics/models/yolo/segment/train.py +4 -5
- ultralytics/models/yolo/segment/val.py +12 -10
- ultralytics/models/yolo/world/train.py +9 -7
- ultralytics/models/yolo/yoloe/train.py +7 -6
- ultralytics/models/yolo/yoloe/val.py +10 -8
- ultralytics/nn/autobackend.py +40 -52
- ultralytics/nn/modules/__init__.py +3 -3
- ultralytics/nn/modules/block.py +12 -12
- ultralytics/nn/modules/conv.py +4 -3
- ultralytics/nn/modules/head.py +46 -38
- ultralytics/nn/modules/transformer.py +22 -21
- ultralytics/nn/tasks.py +2 -2
- ultralytics/nn/text_model.py +6 -5
- ultralytics/solutions/analytics.py +7 -5
- ultralytics/solutions/config.py +12 -10
- ultralytics/solutions/distance_calculation.py +3 -3
- ultralytics/solutions/heatmap.py +4 -2
- ultralytics/solutions/object_counter.py +5 -3
- ultralytics/solutions/parking_management.py +4 -2
- ultralytics/solutions/region_counter.py +7 -5
- ultralytics/solutions/similarity_search.py +5 -3
- ultralytics/solutions/solutions.py +38 -36
- ultralytics/solutions/streamlit_inference.py +8 -7
- ultralytics/trackers/bot_sort.py +11 -9
- ultralytics/trackers/byte_tracker.py +17 -15
- ultralytics/trackers/utils/gmc.py +4 -3
- ultralytics/utils/__init__.py +27 -77
- ultralytics/utils/autobatch.py +3 -2
- ultralytics/utils/autodevice.py +10 -10
- ultralytics/utils/benchmarks.py +11 -10
- ultralytics/utils/callbacks/comet.py +9 -9
- ultralytics/utils/callbacks/platform.py +2 -1
- ultralytics/utils/checks.py +20 -29
- ultralytics/utils/downloads.py +2 -2
- ultralytics/utils/export.py +12 -11
- ultralytics/utils/files.py +8 -7
- ultralytics/utils/git.py +139 -0
- ultralytics/utils/instance.py +8 -7
- ultralytics/utils/logger.py +7 -6
- ultralytics/utils/loss.py +15 -13
- ultralytics/utils/metrics.py +62 -62
- ultralytics/utils/nms.py +346 -0
- ultralytics/utils/ops.py +83 -251
- ultralytics/utils/patches.py +6 -4
- ultralytics/utils/plotting.py +18 -16
- ultralytics/utils/tal.py +1 -1
- ultralytics/utils/torch_utils.py +4 -2
- ultralytics/utils/tqdm.py +47 -33
- ultralytics/utils/triton.py +3 -2
- {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,8 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
-
from
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from typing import Any
|
4
6
|
|
5
7
|
import numpy as np
|
6
8
|
|
@@ -53,10 +55,10 @@ class RegionCounter(BaseSolution):
|
|
53
55
|
def add_region(
|
54
56
|
self,
|
55
57
|
name: str,
|
56
|
-
polygon_points:
|
57
|
-
region_color:
|
58
|
-
text_color:
|
59
|
-
) ->
|
58
|
+
polygon_points: list[tuple],
|
59
|
+
region_color: tuple[int, int, int],
|
60
|
+
text_color: tuple[int, int, int],
|
61
|
+
) -> dict[str, Any]:
|
60
62
|
"""
|
61
63
|
Add a new region to the counting list based on the provided template with specific attributes.
|
62
64
|
|
@@ -1,8 +1,10 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
import os
|
4
6
|
from pathlib import Path
|
5
|
-
from typing import Any
|
7
|
+
from typing import Any
|
6
8
|
|
7
9
|
import numpy as np
|
8
10
|
from PIL import Image
|
@@ -126,7 +128,7 @@ class VisualAISearch:
|
|
126
128
|
|
127
129
|
LOGGER.info(f"Indexed {len(self.image_paths)} images.")
|
128
130
|
|
129
|
-
def search(self, query: str, k: int = 30, similarity_thresh: float = 0.1) ->
|
131
|
+
def search(self, query: str, k: int = 30, similarity_thresh: float = 0.1) -> list[str]:
|
130
132
|
"""
|
131
133
|
Return top-k semantically similar images to the given query.
|
132
134
|
|
@@ -158,7 +160,7 @@ class VisualAISearch:
|
|
158
160
|
|
159
161
|
return [r[0] for r in results]
|
160
162
|
|
161
|
-
def __call__(self, query: str) ->
|
163
|
+
def __call__(self, query: str) -> list[str]:
|
162
164
|
"""Direct call interface for the search function."""
|
163
165
|
return self.search(query)
|
164
166
|
|
@@ -1,9 +1,11 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
import math
|
4
6
|
from collections import Counter, defaultdict
|
5
7
|
from functools import lru_cache
|
6
|
-
from typing import Any
|
8
|
+
from typing import Any
|
7
9
|
|
8
10
|
import cv2
|
9
11
|
import numpy as np
|
@@ -135,7 +137,7 @@ class BaseSolution:
|
|
135
137
|
ops.Profile(device=self.device), # solution
|
136
138
|
)
|
137
139
|
|
138
|
-
def adjust_box_label(self, cls: int, conf: float, track_id:
|
140
|
+
def adjust_box_label(self, cls: int, conf: float, track_id: int | None = None) -> str | None:
|
139
141
|
"""
|
140
142
|
Generate a formatted label for a bounding box.
|
141
143
|
|
@@ -302,8 +304,8 @@ class SolutionAnnotator(Annotator):
|
|
302
304
|
def __init__(
|
303
305
|
self,
|
304
306
|
im: np.ndarray,
|
305
|
-
line_width:
|
306
|
-
font_size:
|
307
|
+
line_width: int | None = None,
|
308
|
+
font_size: int | None = None,
|
307
309
|
font: str = "Arial.ttf",
|
308
310
|
pil: bool = False,
|
309
311
|
example: str = "abc",
|
@@ -323,8 +325,8 @@ class SolutionAnnotator(Annotator):
|
|
323
325
|
|
324
326
|
def draw_region(
|
325
327
|
self,
|
326
|
-
reg_pts:
|
327
|
-
color:
|
328
|
+
reg_pts: list[tuple[int, int]] | None = None,
|
329
|
+
color: tuple[int, int, int] = (0, 255, 0),
|
328
330
|
thickness: int = 5,
|
329
331
|
):
|
330
332
|
"""
|
@@ -344,9 +346,9 @@ class SolutionAnnotator(Annotator):
|
|
344
346
|
def queue_counts_display(
|
345
347
|
self,
|
346
348
|
label: str,
|
347
|
-
points:
|
348
|
-
region_color:
|
349
|
-
txt_color:
|
349
|
+
points: list[tuple[int, int]] | None = None,
|
350
|
+
region_color: tuple[int, int, int] = (255, 255, 255),
|
351
|
+
txt_color: tuple[int, int, int] = (0, 0, 0),
|
350
352
|
):
|
351
353
|
"""
|
352
354
|
Display queue counts on an image centered at the points with customizable font size and colors.
|
@@ -390,9 +392,9 @@ class SolutionAnnotator(Annotator):
|
|
390
392
|
def display_analytics(
|
391
393
|
self,
|
392
394
|
im0: np.ndarray,
|
393
|
-
text:
|
394
|
-
txt_color:
|
395
|
-
bg_color:
|
395
|
+
text: dict[str, Any],
|
396
|
+
txt_color: tuple[int, int, int],
|
397
|
+
bg_color: tuple[int, int, int],
|
396
398
|
margin: int,
|
397
399
|
):
|
398
400
|
"""
|
@@ -425,7 +427,7 @@ class SolutionAnnotator(Annotator):
|
|
425
427
|
|
426
428
|
@staticmethod
|
427
429
|
@lru_cache(maxsize=256)
|
428
|
-
def estimate_pose_angle(a:
|
430
|
+
def estimate_pose_angle(a: list[float], b: list[float], c: list[float]) -> float:
|
429
431
|
"""
|
430
432
|
Calculate the angle between three points for workout monitoring.
|
431
433
|
|
@@ -443,8 +445,8 @@ class SolutionAnnotator(Annotator):
|
|
443
445
|
|
444
446
|
def draw_specific_kpts(
|
445
447
|
self,
|
446
|
-
keypoints:
|
447
|
-
indices:
|
448
|
+
keypoints: list[list[float]],
|
449
|
+
indices: list[int] | None = None,
|
448
450
|
radius: int = 2,
|
449
451
|
conf_thresh: float = 0.25,
|
450
452
|
) -> np.ndarray:
|
@@ -480,9 +482,9 @@ class SolutionAnnotator(Annotator):
|
|
480
482
|
def plot_workout_information(
|
481
483
|
self,
|
482
484
|
display_text: str,
|
483
|
-
position:
|
484
|
-
color:
|
485
|
-
txt_color:
|
485
|
+
position: tuple[int, int],
|
486
|
+
color: tuple[int, int, int] = (104, 31, 17),
|
487
|
+
txt_color: tuple[int, int, int] = (255, 255, 255),
|
486
488
|
) -> int:
|
487
489
|
"""
|
488
490
|
Draw workout text with a background on the image.
|
@@ -516,9 +518,9 @@ class SolutionAnnotator(Annotator):
|
|
516
518
|
angle_text: str,
|
517
519
|
count_text: str,
|
518
520
|
stage_text: str,
|
519
|
-
center_kpt:
|
520
|
-
color:
|
521
|
-
txt_color:
|
521
|
+
center_kpt: list[int],
|
522
|
+
color: tuple[int, int, int] = (104, 31, 17),
|
523
|
+
txt_color: tuple[int, int, int] = (255, 255, 255),
|
522
524
|
):
|
523
525
|
"""
|
524
526
|
Plot the pose angle, count value, and step stage for workout monitoring.
|
@@ -548,9 +550,9 @@ class SolutionAnnotator(Annotator):
|
|
548
550
|
def plot_distance_and_line(
|
549
551
|
self,
|
550
552
|
pixels_distance: float,
|
551
|
-
centroids:
|
552
|
-
line_color:
|
553
|
-
centroid_color:
|
553
|
+
centroids: list[tuple[int, int]],
|
554
|
+
line_color: tuple[int, int, int] = (104, 31, 17),
|
555
|
+
centroid_color: tuple[int, int, int] = (255, 0, 255),
|
554
556
|
):
|
555
557
|
"""
|
556
558
|
Plot the distance and line between two centroids on the frame.
|
@@ -589,8 +591,8 @@ class SolutionAnnotator(Annotator):
|
|
589
591
|
self,
|
590
592
|
im0: np.ndarray,
|
591
593
|
text: str,
|
592
|
-
txt_color:
|
593
|
-
bg_color:
|
594
|
+
txt_color: tuple[int, int, int],
|
595
|
+
bg_color: tuple[int, int, int],
|
594
596
|
x_center: float,
|
595
597
|
y_center: float,
|
596
598
|
margin: int,
|
@@ -638,9 +640,9 @@ class SolutionAnnotator(Annotator):
|
|
638
640
|
self,
|
639
641
|
line_x: int = 0,
|
640
642
|
line_y: int = 0,
|
641
|
-
label:
|
642
|
-
color:
|
643
|
-
txt_color:
|
643
|
+
label: str | None = None,
|
644
|
+
color: tuple[int, int, int] = (221, 0, 186),
|
645
|
+
txt_color: tuple[int, int, int] = (255, 255, 255),
|
644
646
|
):
|
645
647
|
"""
|
646
648
|
Draw a sweep annotation line and an optional label.
|
@@ -677,10 +679,10 @@ class SolutionAnnotator(Annotator):
|
|
677
679
|
|
678
680
|
def visioneye(
|
679
681
|
self,
|
680
|
-
box:
|
681
|
-
center_point:
|
682
|
-
color:
|
683
|
-
pin_color:
|
682
|
+
box: list[float],
|
683
|
+
center_point: tuple[int, int],
|
684
|
+
color: tuple[int, int, int] = (235, 219, 11),
|
685
|
+
pin_color: tuple[int, int, int] = (255, 0, 255),
|
684
686
|
):
|
685
687
|
"""
|
686
688
|
Perform pinpoint human-vision eye mapping and plotting.
|
@@ -698,10 +700,10 @@ class SolutionAnnotator(Annotator):
|
|
698
700
|
|
699
701
|
def adaptive_label(
|
700
702
|
self,
|
701
|
-
box:
|
703
|
+
box: tuple[float, float, float, float],
|
702
704
|
label: str = "",
|
703
|
-
color:
|
704
|
-
txt_color:
|
705
|
+
color: tuple[int, int, int] = (128, 128, 128),
|
706
|
+
txt_color: tuple[int, int, int] = (255, 255, 255),
|
705
707
|
shape: str = "rect",
|
706
708
|
margin: int = 5,
|
707
709
|
):
|
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
import io
|
4
4
|
import os
|
5
|
-
from typing import Any
|
5
|
+
from typing import Any
|
6
6
|
|
7
7
|
import cv2
|
8
8
|
import torch
|
@@ -72,7 +72,7 @@ class Inference:
|
|
72
72
|
self.org_frame = None # Container for the original frame display
|
73
73
|
self.ann_frame = None # Container for the annotated frame display
|
74
74
|
self.vid_file_name = None # Video file name or webcam index
|
75
|
-
self.selected_ind:
|
75
|
+
self.selected_ind: list[int] = [] # List of selected class indices for detection
|
76
76
|
self.model = None # YOLO model instance
|
77
77
|
|
78
78
|
self.temp_dict = {"model": None, **kwargs}
|
@@ -91,8 +91,8 @@ class Inference:
|
|
91
91
|
font-family: 'Archivo', sans-serif; margin-bottom:20px;">Ultralytics YOLO Streamlit Application</h1></div>"""
|
92
92
|
|
93
93
|
# Subtitle of streamlit application
|
94
|
-
sub_title_cfg = """<div><h5 style="color:#042AFF; text-align:center; font-family: 'Archivo', sans-serif;
|
95
|
-
margin-top:-15px; margin-bottom:50px;">Experience real-time object detection on your webcam, videos, and images
|
94
|
+
sub_title_cfg = """<div><h5 style="color:#042AFF; text-align:center; font-family: 'Archivo', sans-serif;
|
95
|
+
margin-top:-15px; margin-bottom:50px;">Experience real-time object detection on your webcam, videos, and images
|
96
96
|
with the power of Ultralytics YOLO! 🚀</h5></div>"""
|
97
97
|
|
98
98
|
# Set html page configuration and append custom HTML
|
@@ -141,8 +141,9 @@ class Inference:
|
|
141
141
|
elif self.source == "image":
|
142
142
|
import tempfile # scope import
|
143
143
|
|
144
|
-
imgfiles
|
145
|
-
|
144
|
+
if imgfiles := self.st.sidebar.file_uploader(
|
145
|
+
"Upload Image Files", type=IMG_FORMATS, accept_multiple_files=True
|
146
|
+
):
|
146
147
|
for imgfile in imgfiles: # Save each uploaded image to a temporary file
|
147
148
|
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{imgfile.name.split('.')[-1]}") as tf:
|
148
149
|
tf.write(imgfile.read())
|
@@ -185,7 +186,7 @@ class Inference:
|
|
185
186
|
|
186
187
|
def image_inference(self) -> None:
|
187
188
|
"""Perform inference on uploaded images."""
|
188
|
-
for
|
189
|
+
for img_info in self.img_file_names:
|
189
190
|
img_path = img_info["path"]
|
190
191
|
image = cv2.imread(img_path) # Load and display the original image
|
191
192
|
if image is not None:
|
ultralytics/trackers/bot_sort.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
from collections import deque
|
4
|
-
from typing import Any
|
6
|
+
from typing import Any
|
5
7
|
|
6
8
|
import numpy as np
|
7
9
|
import torch
|
@@ -53,7 +55,7 @@ class BOTrack(STrack):
|
|
53
55
|
shared_kalman = KalmanFilterXYWH()
|
54
56
|
|
55
57
|
def __init__(
|
56
|
-
self, xywh: np.ndarray, score: float, cls: int, feat:
|
58
|
+
self, xywh: np.ndarray, score: float, cls: int, feat: np.ndarray | None = None, feat_history: int = 50
|
57
59
|
):
|
58
60
|
"""
|
59
61
|
Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features.
|
@@ -102,13 +104,13 @@ class BOTrack(STrack):
|
|
102
104
|
|
103
105
|
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
104
106
|
|
105
|
-
def re_activate(self, new_track:
|
107
|
+
def re_activate(self, new_track: BOTrack, frame_id: int, new_id: bool = False) -> None:
|
106
108
|
"""Reactivate a track with updated features and optionally assign a new ID."""
|
107
109
|
if new_track.curr_feat is not None:
|
108
110
|
self.update_features(new_track.curr_feat)
|
109
111
|
super().re_activate(new_track, frame_id, new_id)
|
110
112
|
|
111
|
-
def update(self, new_track:
|
113
|
+
def update(self, new_track: BOTrack, frame_id: int) -> None:
|
112
114
|
"""Update the track with new detection information and the current frame ID."""
|
113
115
|
if new_track.curr_feat is not None:
|
114
116
|
self.update_features(new_track.curr_feat)
|
@@ -124,7 +126,7 @@ class BOTrack(STrack):
|
|
124
126
|
return ret
|
125
127
|
|
126
128
|
@staticmethod
|
127
|
-
def multi_predict(stracks:
|
129
|
+
def multi_predict(stracks: list[BOTrack]) -> None:
|
128
130
|
"""Predict the mean and covariance for multiple object tracks using a shared Kalman filter."""
|
129
131
|
if len(stracks) <= 0:
|
130
132
|
return
|
@@ -210,7 +212,7 @@ class BOTSORT(BYTETracker):
|
|
210
212
|
"""Return an instance of KalmanFilterXYWH for predicting and updating object states in the tracking process."""
|
211
213
|
return KalmanFilterXYWH()
|
212
214
|
|
213
|
-
def init_track(self, results, img:
|
215
|
+
def init_track(self, results, img: np.ndarray | None = None) -> list[BOTrack]:
|
214
216
|
"""Initialize object tracks using detection bounding boxes, scores, class labels, and optional ReID features."""
|
215
217
|
if len(results) == 0:
|
216
218
|
return []
|
@@ -222,7 +224,7 @@ class BOTSORT(BYTETracker):
|
|
222
224
|
else:
|
223
225
|
return [BOTrack(xywh, s, c) for (xywh, s, c) in zip(bboxes, results.conf, results.cls)]
|
224
226
|
|
225
|
-
def get_dists(self, tracks:
|
227
|
+
def get_dists(self, tracks: list[BOTrack], detections: list[BOTrack]) -> np.ndarray:
|
226
228
|
"""Calculate distances between tracks and detections using IoU and optionally ReID embeddings."""
|
227
229
|
dists = matching.iou_distance(tracks, detections)
|
228
230
|
dists_mask = dists > (1 - self.proximity_thresh)
|
@@ -237,7 +239,7 @@ class BOTSORT(BYTETracker):
|
|
237
239
|
dists = np.minimum(dists, emb_dists)
|
238
240
|
return dists
|
239
241
|
|
240
|
-
def multi_predict(self, tracks:
|
242
|
+
def multi_predict(self, tracks: list[BOTrack]) -> None:
|
241
243
|
"""Predict the mean and covariance of multiple object tracks using a shared Kalman filter."""
|
242
244
|
BOTrack.multi_predict(tracks)
|
243
245
|
|
@@ -262,7 +264,7 @@ class ReID:
|
|
262
264
|
self.model = YOLO(model)
|
263
265
|
self.model(embed=[len(self.model.model.model) - 2 if ".pt" in model else -1], verbose=False, save=False) # init
|
264
266
|
|
265
|
-
def __call__(self, img: np.ndarray, dets: np.ndarray) ->
|
267
|
+
def __call__(self, img: np.ndarray, dets: np.ndarray) -> list[np.ndarray]:
|
266
268
|
"""Extract embeddings for detected objects."""
|
267
269
|
feats = self.model.predictor(
|
268
270
|
[save_one_box(det, img, save=False) for det in xywh2xyxy(torch.from_numpy(dets[:, :4]))]
|
@@ -1,6 +1,8 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
-
from
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from typing import Any
|
4
6
|
|
5
7
|
import numpy as np
|
6
8
|
|
@@ -51,7 +53,7 @@ class STrack(BaseTrack):
|
|
51
53
|
|
52
54
|
shared_kalman = KalmanFilterXYAH()
|
53
55
|
|
54
|
-
def __init__(self, xywh:
|
56
|
+
def __init__(self, xywh: list[float], score: float, cls: Any):
|
55
57
|
"""
|
56
58
|
Initialize a new STrack instance.
|
57
59
|
|
@@ -89,7 +91,7 @@ class STrack(BaseTrack):
|
|
89
91
|
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
90
92
|
|
91
93
|
@staticmethod
|
92
|
-
def multi_predict(stracks:
|
94
|
+
def multi_predict(stracks: list[STrack]):
|
93
95
|
"""Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances."""
|
94
96
|
if len(stracks) <= 0:
|
95
97
|
return
|
@@ -104,9 +106,9 @@ class STrack(BaseTrack):
|
|
104
106
|
stracks[i].covariance = cov
|
105
107
|
|
106
108
|
@staticmethod
|
107
|
-
def multi_gmc(stracks:
|
109
|
+
def multi_gmc(stracks: list[STrack], H: np.ndarray = np.eye(2, 3)):
|
108
110
|
"""Update state tracks positions and covariances using a homography matrix for multiple tracks."""
|
109
|
-
if
|
111
|
+
if stracks:
|
110
112
|
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
111
113
|
multi_covariance = np.asarray([st.covariance for st in stracks])
|
112
114
|
|
@@ -135,7 +137,7 @@ class STrack(BaseTrack):
|
|
135
137
|
self.frame_id = frame_id
|
136
138
|
self.start_frame = frame_id
|
137
139
|
|
138
|
-
def re_activate(self, new_track:
|
140
|
+
def re_activate(self, new_track: STrack, frame_id: int, new_id: bool = False):
|
139
141
|
"""Reactivate a previously lost track using new detection data and update its state and attributes."""
|
140
142
|
self.mean, self.covariance = self.kalman_filter.update(
|
141
143
|
self.mean, self.covariance, self.convert_coords(new_track.tlwh)
|
@@ -151,7 +153,7 @@ class STrack(BaseTrack):
|
|
151
153
|
self.angle = new_track.angle
|
152
154
|
self.idx = new_track.idx
|
153
155
|
|
154
|
-
def update(self, new_track:
|
156
|
+
def update(self, new_track: STrack, frame_id: int):
|
155
157
|
"""
|
156
158
|
Update the state of a matched track.
|
157
159
|
|
@@ -225,7 +227,7 @@ class STrack(BaseTrack):
|
|
225
227
|
return np.concatenate([self.xywh, self.angle[None]])
|
226
228
|
|
227
229
|
@property
|
228
|
-
def result(self) ->
|
230
|
+
def result(self) -> list[float]:
|
229
231
|
"""Get the current tracking results in the appropriate bounding box format."""
|
230
232
|
coords = self.xyxy if self.angle is None else self.xywha
|
231
233
|
return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
|
@@ -294,7 +296,7 @@ class BYTETracker:
|
|
294
296
|
self.kalman_filter = self.get_kalmanfilter()
|
295
297
|
self.reset_id()
|
296
298
|
|
297
|
-
def update(self, results, img:
|
299
|
+
def update(self, results, img: np.ndarray | None = None, feats: np.ndarray | None = None) -> np.ndarray:
|
298
300
|
"""Update the tracker with new detections and return the current list of tracked objects."""
|
299
301
|
self.frame_id += 1
|
300
302
|
activated_stracks = []
|
@@ -411,7 +413,7 @@ class BYTETracker:
|
|
411
413
|
"""Return a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH."""
|
412
414
|
return KalmanFilterXYAH()
|
413
415
|
|
414
|
-
def init_track(self, results, img:
|
416
|
+
def init_track(self, results, img: np.ndarray | None = None) -> list[STrack]:
|
415
417
|
"""Initialize object tracking with given detections, scores, and class labels using the STrack algorithm."""
|
416
418
|
if len(results) == 0:
|
417
419
|
return []
|
@@ -419,14 +421,14 @@ class BYTETracker:
|
|
419
421
|
bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)
|
420
422
|
return [STrack(xywh, s, c) for (xywh, s, c) in zip(bboxes, results.conf, results.cls)]
|
421
423
|
|
422
|
-
def get_dists(self, tracks:
|
424
|
+
def get_dists(self, tracks: list[STrack], detections: list[STrack]) -> np.ndarray:
|
423
425
|
"""Calculate the distance between tracks and detections using IoU and optionally fuse scores."""
|
424
426
|
dists = matching.iou_distance(tracks, detections)
|
425
427
|
if self.args.fuse_score:
|
426
428
|
dists = matching.fuse_score(dists, detections)
|
427
429
|
return dists
|
428
430
|
|
429
|
-
def multi_predict(self, tracks:
|
431
|
+
def multi_predict(self, tracks: list[STrack]):
|
430
432
|
"""Predict the next states for multiple tracks using Kalman filter."""
|
431
433
|
STrack.multi_predict(tracks)
|
432
434
|
|
@@ -445,7 +447,7 @@ class BYTETracker:
|
|
445
447
|
self.reset_id()
|
446
448
|
|
447
449
|
@staticmethod
|
448
|
-
def joint_stracks(tlista:
|
450
|
+
def joint_stracks(tlista: list[STrack], tlistb: list[STrack]) -> list[STrack]:
|
449
451
|
"""Combine two lists of STrack objects into a single list, ensuring no duplicates based on track IDs."""
|
450
452
|
exists = {}
|
451
453
|
res = []
|
@@ -460,13 +462,13 @@ class BYTETracker:
|
|
460
462
|
return res
|
461
463
|
|
462
464
|
@staticmethod
|
463
|
-
def sub_stracks(tlista:
|
465
|
+
def sub_stracks(tlista: list[STrack], tlistb: list[STrack]) -> list[STrack]:
|
464
466
|
"""Filter out the stracks present in the second list from the first list."""
|
465
467
|
track_ids_b = {t.track_id for t in tlistb}
|
466
468
|
return [t for t in tlista if t.track_id not in track_ids_b]
|
467
469
|
|
468
470
|
@staticmethod
|
469
|
-
def remove_duplicate_stracks(stracksa:
|
471
|
+
def remove_duplicate_stracks(stracksa: list[STrack], stracksb: list[STrack]) -> tuple[list[STrack], list[STrack]]:
|
470
472
|
"""Remove duplicate stracks from two lists based on Intersection over Union (IoU) distance."""
|
471
473
|
pdist = matching.iou_distance(stracksa, stracksb)
|
472
474
|
pairs = np.where(pdist < 0.15)
|
@@ -1,7 +1,8 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
import copy
|
4
|
-
from typing import List, Optional
|
5
6
|
|
6
7
|
import cv2
|
7
8
|
import numpy as np
|
@@ -89,7 +90,7 @@ class GMC:
|
|
89
90
|
self.prevDescriptors = None
|
90
91
|
self.initializedFirstFrame = False
|
91
92
|
|
92
|
-
def apply(self, raw_frame: np.ndarray, detections:
|
93
|
+
def apply(self, raw_frame: np.ndarray, detections: list | None = None) -> np.ndarray:
|
93
94
|
"""
|
94
95
|
Apply object detection on a raw frame using the specified method.
|
95
96
|
|
@@ -156,7 +157,7 @@ class GMC:
|
|
156
157
|
|
157
158
|
return H
|
158
159
|
|
159
|
-
def apply_features(self, raw_frame: np.ndarray, detections:
|
160
|
+
def apply_features(self, raw_frame: np.ndarray, detections: list | None = None) -> np.ndarray:
|
160
161
|
"""
|
161
162
|
Apply feature-based methods like ORB or SIFT to a raw frame.
|
162
163
|
|