ultralytics 8.1.29__py3-none-any.whl → 8.3.63__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +122 -0
- tests/test_cuda.py +155 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +216 -0
- tests/test_integrations.py +150 -0
- tests/test_python.py +615 -0
- tests/test_solutions.py +94 -0
- ultralytics/__init__.py +11 -8
- ultralytics/cfg/__init__.py +569 -131
- ultralytics/cfg/datasets/Argoverse.yaml +2 -1
- ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
- ultralytics/cfg/datasets/ImageNet.yaml +2 -1
- ultralytics/cfg/datasets/Objects365.yaml +5 -4
- ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
- ultralytics/cfg/datasets/VOC.yaml +3 -2
- ultralytics/cfg/datasets/VisDrone.yaml +6 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco-pose.yaml +7 -6
- ultralytics/cfg/datasets/coco.yaml +3 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
- ultralytics/cfg/datasets/coco128.yaml +4 -3
- ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco8.yaml +3 -2
- ultralytics/cfg/datasets/crack-seg.yaml +3 -2
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8.yaml +3 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1236 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
- ultralytics/cfg/datasets/package-seg.yaml +5 -4
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
- ultralytics/cfg/datasets/xView.yaml +2 -1
- ultralytics/cfg/default.yaml +14 -11
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5.yaml +5 -2
- ultralytics/cfg/models/v6/yolov6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8.yaml +5 -2
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/solutions/default.yaml +24 -0
- ultralytics/cfg/trackers/botsort.yaml +8 -5
- ultralytics/cfg/trackers/bytetrack.yaml +8 -5
- ultralytics/data/__init__.py +14 -3
- ultralytics/data/annotator.py +37 -15
- ultralytics/data/augment.py +1783 -289
- ultralytics/data/base.py +62 -27
- ultralytics/data/build.py +37 -8
- ultralytics/data/converter.py +196 -36
- ultralytics/data/dataset.py +233 -94
- ultralytics/data/loaders.py +199 -96
- ultralytics/data/split_dota.py +39 -29
- ultralytics/data/utils.py +111 -41
- ultralytics/engine/__init__.py +1 -1
- ultralytics/engine/exporter.py +579 -244
- ultralytics/engine/model.py +604 -252
- ultralytics/engine/predictor.py +22 -11
- ultralytics/engine/results.py +1228 -218
- ultralytics/engine/trainer.py +191 -129
- ultralytics/engine/tuner.py +18 -18
- ultralytics/engine/validator.py +18 -15
- ultralytics/hub/__init__.py +31 -13
- ultralytics/hub/auth.py +11 -7
- ultralytics/hub/google/__init__.py +159 -0
- ultralytics/hub/session.py +128 -94
- ultralytics/hub/utils.py +20 -21
- ultralytics/models/__init__.py +4 -2
- ultralytics/models/fastsam/__init__.py +2 -3
- ultralytics/models/fastsam/model.py +26 -4
- ultralytics/models/fastsam/predict.py +127 -63
- ultralytics/models/fastsam/utils.py +1 -44
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +21 -10
- ultralytics/models/nas/predict.py +3 -6
- ultralytics/models/nas/val.py +4 -4
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +1 -1
- ultralytics/models/rtdetr/predict.py +6 -8
- ultralytics/models/rtdetr/train.py +6 -2
- ultralytics/models/rtdetr/val.py +3 -3
- ultralytics/models/sam/__init__.py +3 -3
- ultralytics/models/sam/amg.py +29 -23
- ultralytics/models/sam/build.py +211 -13
- ultralytics/models/sam/model.py +91 -30
- ultralytics/models/sam/modules/__init__.py +1 -1
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +381 -53
- ultralytics/models/sam/modules/encoders.py +515 -324
- ultralytics/models/sam/modules/memory_attention.py +237 -0
- ultralytics/models/sam/modules/sam.py +969 -21
- ultralytics/models/sam/modules/tiny_encoder.py +425 -154
- ultralytics/models/sam/modules/transformer.py +159 -60
- ultralytics/models/sam/modules/utils.py +293 -0
- ultralytics/models/sam/predict.py +1263 -132
- ultralytics/models/utils/__init__.py +1 -1
- ultralytics/models/utils/loss.py +36 -24
- ultralytics/models/utils/ops.py +3 -7
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +7 -8
- ultralytics/models/yolo/classify/train.py +17 -22
- ultralytics/models/yolo/classify/val.py +8 -4
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +3 -5
- ultralytics/models/yolo/detect/train.py +11 -4
- ultralytics/models/yolo/detect/val.py +90 -52
- ultralytics/models/yolo/model.py +14 -9
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +2 -2
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +41 -23
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +3 -5
- ultralytics/models/yolo/pose/train.py +2 -2
- ultralytics/models/yolo/pose/val.py +51 -17
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +3 -5
- ultralytics/models/yolo/segment/train.py +2 -2
- ultralytics/models/yolo/segment/val.py +60 -19
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +92 -0
- ultralytics/models/yolo/world/train_world.py +109 -0
- ultralytics/nn/__init__.py +1 -1
- ultralytics/nn/autobackend.py +228 -93
- ultralytics/nn/modules/__init__.py +39 -14
- ultralytics/nn/modules/activation.py +21 -0
- ultralytics/nn/modules/block.py +526 -66
- ultralytics/nn/modules/conv.py +24 -7
- ultralytics/nn/modules/head.py +177 -34
- ultralytics/nn/modules/transformer.py +6 -5
- ultralytics/nn/modules/utils.py +1 -2
- ultralytics/nn/tasks.py +226 -82
- ultralytics/solutions/__init__.py +30 -1
- ultralytics/solutions/ai_gym.py +96 -143
- ultralytics/solutions/analytics.py +247 -0
- ultralytics/solutions/distance_calculation.py +78 -135
- ultralytics/solutions/heatmap.py +93 -247
- ultralytics/solutions/object_counter.py +184 -259
- ultralytics/solutions/parking_management.py +246 -0
- ultralytics/solutions/queue_management.py +112 -0
- ultralytics/solutions/region_counter.py +116 -0
- ultralytics/solutions/security_alarm.py +144 -0
- ultralytics/solutions/solutions.py +178 -0
- ultralytics/solutions/speed_estimation.py +86 -174
- ultralytics/solutions/streamlit_inference.py +190 -0
- ultralytics/solutions/trackzone.py +68 -0
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +32 -13
- ultralytics/trackers/bot_sort.py +61 -28
- ultralytics/trackers/byte_tracker.py +83 -51
- ultralytics/trackers/track.py +21 -6
- ultralytics/trackers/utils/__init__.py +1 -1
- ultralytics/trackers/utils/gmc.py +62 -48
- ultralytics/trackers/utils/kalman_filter.py +166 -35
- ultralytics/trackers/utils/matching.py +40 -21
- ultralytics/utils/__init__.py +511 -239
- ultralytics/utils/autobatch.py +40 -22
- ultralytics/utils/benchmarks.py +266 -85
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +1 -3
- ultralytics/utils/callbacks/clearml.py +7 -6
- ultralytics/utils/callbacks/comet.py +39 -17
- ultralytics/utils/callbacks/dvc.py +1 -1
- ultralytics/utils/callbacks/hub.py +16 -16
- ultralytics/utils/callbacks/mlflow.py +28 -24
- ultralytics/utils/callbacks/neptune.py +6 -2
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +18 -18
- ultralytics/utils/callbacks/wb.py +27 -20
- ultralytics/utils/checks.py +172 -100
- ultralytics/utils/dist.py +2 -1
- ultralytics/utils/downloads.py +40 -34
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +72 -38
- ultralytics/utils/instance.py +41 -19
- ultralytics/utils/loss.py +83 -55
- ultralytics/utils/metrics.py +61 -56
- ultralytics/utils/ops.py +94 -89
- ultralytics/utils/patches.py +30 -14
- ultralytics/utils/plotting.py +600 -269
- ultralytics/utils/tal.py +67 -26
- ultralytics/utils/torch_utils.py +305 -112
- ultralytics/utils/triton.py +2 -1
- ultralytics/utils/tuner.py +21 -12
- ultralytics-8.3.63.dist-info/METADATA +370 -0
- ultralytics-8.3.63.dist-info/RECORD +241 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/WHEEL +1 -1
- ultralytics/data/explorer/__init__.py +0 -5
- ultralytics/data/explorer/explorer.py +0 -472
- ultralytics/data/explorer/gui/__init__.py +0 -1
- ultralytics/data/explorer/gui/dash.py +0 -268
- ultralytics/data/explorer/utils.py +0 -166
- ultralytics/models/fastsam/prompt.py +0 -357
- ultralytics-8.1.29.dist-info/METADATA +0 -373
- ultralytics-8.1.29.dist-info/RECORD +0 -197
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,178 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from collections import defaultdict
|
4
|
+
|
5
|
+
import cv2
|
6
|
+
|
7
|
+
from ultralytics import YOLO
|
8
|
+
from ultralytics.utils import ASSETS_URL, DEFAULT_CFG_DICT, DEFAULT_SOL_DICT, LOGGER
|
9
|
+
from ultralytics.utils.checks import check_imshow, check_requirements
|
10
|
+
|
11
|
+
|
12
|
+
class BaseSolution:
|
13
|
+
"""
|
14
|
+
A base class for managing Ultralytics Solutions.
|
15
|
+
|
16
|
+
This class provides core functionality for various Ultralytics Solutions, including model loading, object tracking,
|
17
|
+
and region initialization.
|
18
|
+
|
19
|
+
Attributes:
|
20
|
+
LineString (shapely.geometry.LineString): Class for creating line string geometries.
|
21
|
+
Polygon (shapely.geometry.Polygon): Class for creating polygon geometries.
|
22
|
+
Point (shapely.geometry.Point): Class for creating point geometries.
|
23
|
+
CFG (Dict): Configuration dictionary loaded from a YAML file and updated with kwargs.
|
24
|
+
region (List[Tuple[int, int]]): List of coordinate tuples defining a region of interest.
|
25
|
+
line_width (int): Width of lines used in visualizations.
|
26
|
+
model (ultralytics.YOLO): Loaded YOLO model instance.
|
27
|
+
names (Dict[int, str]): Dictionary mapping class indices to class names.
|
28
|
+
env_check (bool): Flag indicating whether the environment supports image display.
|
29
|
+
track_history (collections.defaultdict): Dictionary to store tracking history for each object.
|
30
|
+
|
31
|
+
Methods:
|
32
|
+
extract_tracks: Apply object tracking and extract tracks from an input image.
|
33
|
+
store_tracking_history: Store object tracking history for a given track ID and bounding box.
|
34
|
+
initialize_region: Initialize the counting region and line segment based on configuration.
|
35
|
+
display_output: Display the results of processing, including showing frames or saving results.
|
36
|
+
|
37
|
+
Examples:
|
38
|
+
>>> solution = BaseSolution(model="yolov8n.pt", region=[(0, 0), (100, 0), (100, 100), (0, 100)])
|
39
|
+
>>> solution.initialize_region()
|
40
|
+
>>> image = cv2.imread("image.jpg")
|
41
|
+
>>> solution.extract_tracks(image)
|
42
|
+
>>> solution.display_output(image)
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(self, IS_CLI=False, **kwargs):
|
46
|
+
"""
|
47
|
+
Initializes the `BaseSolution` class with configuration settings and the YOLO model for Ultralytics solutions.
|
48
|
+
|
49
|
+
IS_CLI (optional): Enables CLI mode if set.
|
50
|
+
"""
|
51
|
+
check_requirements("shapely>=2.0.0")
|
52
|
+
from shapely.geometry import LineString, Point, Polygon
|
53
|
+
from shapely.prepared import prep
|
54
|
+
|
55
|
+
self.LineString = LineString
|
56
|
+
self.Polygon = Polygon
|
57
|
+
self.Point = Point
|
58
|
+
self.prep = prep
|
59
|
+
self.annotator = None # Initialize annotator
|
60
|
+
self.tracks = None
|
61
|
+
self.track_data = None
|
62
|
+
self.boxes = []
|
63
|
+
self.clss = []
|
64
|
+
self.track_ids = []
|
65
|
+
self.track_line = None
|
66
|
+
self.r_s = None
|
67
|
+
|
68
|
+
# Load config and update with args
|
69
|
+
DEFAULT_SOL_DICT.update(kwargs)
|
70
|
+
DEFAULT_CFG_DICT.update(kwargs)
|
71
|
+
self.CFG = {**DEFAULT_SOL_DICT, **DEFAULT_CFG_DICT}
|
72
|
+
LOGGER.info(f"Ultralytics Solutions: ✅ {DEFAULT_SOL_DICT}")
|
73
|
+
|
74
|
+
self.region = self.CFG["region"] # Store region data for other classes usage
|
75
|
+
self.line_width = (
|
76
|
+
self.CFG["line_width"] if self.CFG["line_width"] is not None else 2
|
77
|
+
) # Store line_width for usage
|
78
|
+
|
79
|
+
# Load Model and store classes names
|
80
|
+
if self.CFG["model"] is None:
|
81
|
+
self.CFG["model"] = "yolo11n.pt"
|
82
|
+
self.model = YOLO(self.CFG["model"])
|
83
|
+
self.names = self.model.names
|
84
|
+
|
85
|
+
self.track_add_args = { # Tracker additional arguments for advance configuration
|
86
|
+
k: self.CFG[k] for k in ["verbose", "iou", "conf", "device", "max_det", "half", "tracker"]
|
87
|
+
}
|
88
|
+
|
89
|
+
if IS_CLI and self.CFG["source"] is None:
|
90
|
+
d_s = "solutions_ci_demo.mp4" if "-pose" not in self.CFG["model"] else "solution_ci_pose_demo.mp4"
|
91
|
+
LOGGER.warning(f"⚠️ WARNING: source not provided. using default source {ASSETS_URL}/{d_s}")
|
92
|
+
from ultralytics.utils.downloads import safe_download
|
93
|
+
|
94
|
+
safe_download(f"{ASSETS_URL}/{d_s}") # download source from ultralytics assets
|
95
|
+
self.CFG["source"] = d_s # set default source
|
96
|
+
|
97
|
+
# Initialize environment and region setup
|
98
|
+
self.env_check = check_imshow(warn=True)
|
99
|
+
self.track_history = defaultdict(list)
|
100
|
+
|
101
|
+
def extract_tracks(self, im0):
|
102
|
+
"""
|
103
|
+
Applies object tracking and extracts tracks from an input image or frame.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
im0 (ndarray): The input image or frame.
|
107
|
+
|
108
|
+
Examples:
|
109
|
+
>>> solution = BaseSolution()
|
110
|
+
>>> frame = cv2.imread("path/to/image.jpg")
|
111
|
+
>>> solution.extract_tracks(frame)
|
112
|
+
"""
|
113
|
+
self.tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"], **self.track_add_args)
|
114
|
+
|
115
|
+
# Extract tracks for OBB or object detection
|
116
|
+
self.track_data = self.tracks[0].obb or self.tracks[0].boxes
|
117
|
+
|
118
|
+
if self.track_data and self.track_data.id is not None:
|
119
|
+
self.boxes = self.track_data.xyxy.cpu()
|
120
|
+
self.clss = self.track_data.cls.cpu().tolist()
|
121
|
+
self.track_ids = self.track_data.id.int().cpu().tolist()
|
122
|
+
else:
|
123
|
+
LOGGER.warning("WARNING ⚠️ no tracks found!")
|
124
|
+
self.boxes, self.clss, self.track_ids = [], [], []
|
125
|
+
|
126
|
+
def store_tracking_history(self, track_id, box):
|
127
|
+
"""
|
128
|
+
Stores the tracking history of an object.
|
129
|
+
|
130
|
+
This method updates the tracking history for a given object by appending the center point of its
|
131
|
+
bounding box to the track line. It maintains a maximum of 30 points in the tracking history.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
track_id (int): The unique identifier for the tracked object.
|
135
|
+
box (List[float]): The bounding box coordinates of the object in the format [x1, y1, x2, y2].
|
136
|
+
|
137
|
+
Examples:
|
138
|
+
>>> solution = BaseSolution()
|
139
|
+
>>> solution.store_tracking_history(1, [100, 200, 300, 400])
|
140
|
+
"""
|
141
|
+
# Store tracking history
|
142
|
+
self.track_line = self.track_history[track_id]
|
143
|
+
self.track_line.append(((box[0] + box[2]) / 2, (box[1] + box[3]) / 2))
|
144
|
+
if len(self.track_line) > 30:
|
145
|
+
self.track_line.pop(0)
|
146
|
+
|
147
|
+
def initialize_region(self):
|
148
|
+
"""Initialize the counting region and line segment based on configuration settings."""
|
149
|
+
if self.region is None:
|
150
|
+
self.region = [(20, 400), (1080, 400), (1080, 360), (20, 360)]
|
151
|
+
self.r_s = (
|
152
|
+
self.Polygon(self.region) if len(self.region) >= 3 else self.LineString(self.region)
|
153
|
+
) # region or line
|
154
|
+
|
155
|
+
def display_output(self, im0):
|
156
|
+
"""
|
157
|
+
Display the results of the processing, which could involve showing frames, printing counts, or saving results.
|
158
|
+
|
159
|
+
This method is responsible for visualizing the output of the object detection and tracking process. It displays
|
160
|
+
the processed frame with annotations, and allows for user interaction to close the display.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
im0 (numpy.ndarray): The input image or frame that has been processed and annotated.
|
164
|
+
|
165
|
+
Examples:
|
166
|
+
>>> solution = BaseSolution()
|
167
|
+
>>> frame = cv2.imread("path/to/image.jpg")
|
168
|
+
>>> solution.display_output(frame)
|
169
|
+
|
170
|
+
Notes:
|
171
|
+
- This method will only display output if the 'show' configuration is set to True and the environment
|
172
|
+
supports image display.
|
173
|
+
- The display can be closed by pressing the 'q' key.
|
174
|
+
"""
|
175
|
+
if self.CFG.get("show") and self.env_check:
|
176
|
+
cv2.imshow("Ultralytics Solutions", im0)
|
177
|
+
if cv2.waitKey(1) & 0xFF == ord("q"):
|
178
|
+
return
|
@@ -1,198 +1,110 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
-
from collections import defaultdict
|
4
3
|
from time import time
|
5
4
|
|
6
|
-
import cv2
|
7
5
|
import numpy as np
|
8
6
|
|
9
|
-
from ultralytics.
|
7
|
+
from ultralytics.solutions.solutions import BaseSolution
|
10
8
|
from ultralytics.utils.plotting import Annotator, colors
|
11
9
|
|
12
10
|
|
13
|
-
class SpeedEstimator:
|
14
|
-
"""
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
):
|
11
|
+
class SpeedEstimator(BaseSolution):
|
12
|
+
"""
|
13
|
+
A class to estimate the speed of objects in a real-time video stream based on their tracks.
|
14
|
+
|
15
|
+
This class extends the BaseSolution class and provides functionality for estimating object speeds using
|
16
|
+
tracking data in video streams.
|
17
|
+
|
18
|
+
Attributes:
|
19
|
+
spd (Dict[int, float]): Dictionary storing speed data for tracked objects.
|
20
|
+
trkd_ids (List[int]): List of tracked object IDs that have already been speed-estimated.
|
21
|
+
trk_pt (Dict[int, float]): Dictionary storing previous timestamps for tracked objects.
|
22
|
+
trk_pp (Dict[int, Tuple[float, float]]): Dictionary storing previous positions for tracked objects.
|
23
|
+
annotator (Annotator): Annotator object for drawing on images.
|
24
|
+
region (List[Tuple[int, int]]): List of points defining the speed estimation region.
|
25
|
+
track_line (List[Tuple[float, float]]): List of points representing the object's track.
|
26
|
+
r_s (LineString): LineString object representing the speed estimation region.
|
27
|
+
|
28
|
+
Methods:
|
29
|
+
initialize_region: Initializes the speed estimation region.
|
30
|
+
estimate_speed: Estimates the speed of objects based on tracking data.
|
31
|
+
store_tracking_history: Stores the tracking history for an object.
|
32
|
+
extract_tracks: Extracts tracks from the current frame.
|
33
|
+
display_output: Displays the output with annotations.
|
34
|
+
|
35
|
+
Examples:
|
36
|
+
>>> estimator = SpeedEstimator()
|
37
|
+
>>> frame = cv2.imread("frame.jpg")
|
38
|
+
>>> processed_frame = estimator.estimate_speed(frame)
|
39
|
+
>>> cv2.imshow("Speed Estimation", processed_frame)
|
40
|
+
"""
|
41
|
+
|
42
|
+
def __init__(self, **kwargs):
|
43
|
+
"""Initializes the SpeedEstimator object with speed estimation parameters and data structures."""
|
44
|
+
super().__init__(**kwargs)
|
45
|
+
|
46
|
+
self.initialize_region() # Initialize speed region
|
47
|
+
|
48
|
+
self.spd = {} # set for speed data
|
49
|
+
self.trkd_ids = [] # list for already speed_estimated and tracked ID's
|
50
|
+
self.trk_pt = {} # set for tracks previous time
|
51
|
+
self.trk_pp = {} # set for tracks previous point
|
52
|
+
|
53
|
+
def estimate_speed(self, im0):
|
57
54
|
"""
|
58
|
-
|
55
|
+
Estimates the speed of objects based on tracking data.
|
59
56
|
|
60
57
|
Args:
|
61
|
-
|
62
|
-
names (dict): object detection classes names
|
63
|
-
view_img (bool): Flag indicating frame display
|
64
|
-
line_thickness (int): Line thickness for bounding boxes.
|
65
|
-
region_thickness (int): Speed estimation region thickness
|
66
|
-
spdl_dist_thresh (int): Euclidean distance threshold for speed line
|
67
|
-
"""
|
68
|
-
if reg_pts is None:
|
69
|
-
print("Region points not provided, using default values")
|
70
|
-
else:
|
71
|
-
self.reg_pts = reg_pts
|
72
|
-
self.names = names
|
73
|
-
self.view_img = view_img
|
74
|
-
self.line_thickness = line_thickness
|
75
|
-
self.region_thickness = region_thickness
|
76
|
-
self.spdl_dist_thresh = spdl_dist_thresh
|
77
|
-
|
78
|
-
def extract_tracks(self, tracks):
|
79
|
-
"""
|
80
|
-
Extracts results from the provided data.
|
81
|
-
|
82
|
-
Args:
|
83
|
-
tracks (list): List of tracks obtained from the object tracking process.
|
84
|
-
"""
|
85
|
-
self.boxes = tracks[0].boxes.xyxy.cpu()
|
86
|
-
self.clss = tracks[0].boxes.cls.cpu().tolist()
|
87
|
-
self.trk_ids = tracks[0].boxes.id.int().cpu().tolist()
|
88
|
-
|
89
|
-
def store_track_info(self, track_id, box):
|
90
|
-
"""
|
91
|
-
Store track data.
|
92
|
-
|
93
|
-
Args:
|
94
|
-
track_id (int): object track id.
|
95
|
-
box (list): object bounding box data
|
96
|
-
"""
|
97
|
-
track = self.trk_history[track_id]
|
98
|
-
bbox_center = (float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2))
|
99
|
-
track.append(bbox_center)
|
100
|
-
|
101
|
-
if len(track) > 30:
|
102
|
-
track.pop(0)
|
103
|
-
|
104
|
-
self.trk_pts = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
|
105
|
-
return track
|
58
|
+
im0 (np.ndarray): Input image for processing. Shape is typically (H, W, C) for RGB images.
|
106
59
|
|
107
|
-
|
108
|
-
|
109
|
-
Plot track and bounding box.
|
60
|
+
Returns:
|
61
|
+
(np.ndarray): Processed image with speed estimations and annotations.
|
110
62
|
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
track (list): tracking history for tracks path drawing
|
63
|
+
Examples:
|
64
|
+
>>> estimator = SpeedEstimator()
|
65
|
+
>>> image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
|
66
|
+
>>> processed_image = estimator.estimate_speed(image)
|
116
67
|
"""
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
self.annotator.box_label(box, speed_label, bbox_color)
|
121
|
-
|
122
|
-
cv2.polylines(self.im0, [self.trk_pts], isClosed=False, color=(0, 255, 0), thickness=1)
|
123
|
-
cv2.circle(self.im0, (int(track[-1][0]), int(track[-1][1])), 5, bbox_color, -1)
|
68
|
+
self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator
|
69
|
+
self.extract_tracks(im0) # Extract tracks
|
124
70
|
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
Args:
|
130
|
-
trk_id (int): object track id.
|
131
|
-
track (list): tracking history for tracks path drawing
|
132
|
-
"""
|
133
|
-
|
134
|
-
if not self.reg_pts[0][0] < track[-1][0] < self.reg_pts[1][0]:
|
135
|
-
return
|
136
|
-
if self.reg_pts[1][1] - self.spdl_dist_thresh < track[-1][1] < self.reg_pts[1][1] + self.spdl_dist_thresh:
|
137
|
-
direction = "known"
|
138
|
-
|
139
|
-
elif self.reg_pts[0][1] - self.spdl_dist_thresh < track[-1][1] < self.reg_pts[0][1] + self.spdl_dist_thresh:
|
140
|
-
direction = "known"
|
141
|
-
|
142
|
-
else:
|
143
|
-
direction = "unknown"
|
144
|
-
|
145
|
-
if self.trk_previous_times[trk_id] != 0 and direction != "unknown" and trk_id not in self.trk_idslist:
|
146
|
-
self.trk_idslist.append(trk_id)
|
147
|
-
|
148
|
-
time_difference = time() - self.trk_previous_times[trk_id]
|
149
|
-
if time_difference > 0:
|
150
|
-
dist_difference = np.abs(track[-1][1] - self.trk_previous_points[trk_id][1])
|
151
|
-
speed = dist_difference / time_difference
|
152
|
-
self.dist_data[trk_id] = speed
|
153
|
-
|
154
|
-
self.trk_previous_times[trk_id] = time()
|
155
|
-
self.trk_previous_points[trk_id] = track[-1]
|
156
|
-
|
157
|
-
def estimate_speed(self, im0, tracks, region_color=(255, 0, 0)):
|
158
|
-
"""
|
159
|
-
Calculate object based on tracking data.
|
160
|
-
|
161
|
-
Args:
|
162
|
-
im0 (nd array): Image
|
163
|
-
tracks (list): List of tracks obtained from the object tracking process.
|
164
|
-
region_color (tuple): Color to use when drawing regions.
|
165
|
-
"""
|
166
|
-
self.im0 = im0
|
167
|
-
if tracks[0].boxes.id is None:
|
168
|
-
if self.view_img and self.env_check:
|
169
|
-
self.display_frames()
|
170
|
-
return im0
|
171
|
-
self.extract_tracks(tracks)
|
71
|
+
self.annotator.draw_region(
|
72
|
+
reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2
|
73
|
+
) # Draw region
|
172
74
|
|
173
|
-
|
174
|
-
|
75
|
+
for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss):
|
76
|
+
self.store_tracking_history(track_id, box) # Store track history
|
175
77
|
|
176
|
-
|
177
|
-
|
78
|
+
# Check if track_id is already in self.trk_pp or trk_pt initialize if not
|
79
|
+
if track_id not in self.trk_pt:
|
80
|
+
self.trk_pt[track_id] = 0
|
81
|
+
if track_id not in self.trk_pp:
|
82
|
+
self.trk_pp[track_id] = self.track_line[-1]
|
178
83
|
|
179
|
-
if
|
180
|
-
|
84
|
+
speed_label = f"{int(self.spd[track_id])} km/h" if track_id in self.spd else self.names[int(cls)]
|
85
|
+
self.annotator.box_label(box, label=speed_label, color=colors(track_id, True)) # Draw bounding box
|
181
86
|
|
182
|
-
|
183
|
-
self.
|
87
|
+
# Draw tracks of objects
|
88
|
+
self.annotator.draw_centroid_and_tracks(
|
89
|
+
self.track_line, color=colors(int(track_id), True), track_thickness=self.line_width
|
90
|
+
)
|
184
91
|
|
185
|
-
|
186
|
-
self.
|
92
|
+
# Calculate object speed and direction based on region intersection
|
93
|
+
if self.LineString([self.trk_pp[track_id], self.track_line[-1]]).intersects(self.r_s):
|
94
|
+
direction = "known"
|
95
|
+
else:
|
96
|
+
direction = "unknown"
|
187
97
|
|
188
|
-
|
98
|
+
# Perform speed calculation and tracking updates if direction is valid
|
99
|
+
if direction == "known" and track_id not in self.trkd_ids:
|
100
|
+
self.trkd_ids.append(track_id)
|
101
|
+
time_difference = time() - self.trk_pt[track_id]
|
102
|
+
if time_difference > 0:
|
103
|
+
self.spd[track_id] = np.abs(self.track_line[-1][1] - self.trk_pp[track_id][1]) / time_difference
|
189
104
|
|
190
|
-
|
191
|
-
|
192
|
-
cv2.imshow("Ultralytics Speed Estimation", self.im0)
|
193
|
-
if cv2.waitKey(1) & 0xFF == ord("q"):
|
194
|
-
return
|
105
|
+
self.trk_pt[track_id] = time()
|
106
|
+
self.trk_pp[track_id] = self.track_line[-1]
|
195
107
|
|
108
|
+
self.display_output(im0) # display output with base class function
|
196
109
|
|
197
|
-
|
198
|
-
SpeedEstimator()
|
110
|
+
return im0 # return output image for more usage
|
@@ -0,0 +1,190 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
import io
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
import cv2
|
7
|
+
|
8
|
+
from ultralytics import YOLO
|
9
|
+
from ultralytics.utils import LOGGER
|
10
|
+
from ultralytics.utils.checks import check_requirements
|
11
|
+
from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
|
12
|
+
|
13
|
+
|
14
|
+
class Inference:
|
15
|
+
"""
|
16
|
+
A class to perform object detection, image classification, image segmentation and pose estimation inference using
|
17
|
+
Streamlit and Ultralytics YOLO models. It provides the functionalities such as loading models, configuring settings,
|
18
|
+
uploading video files, and performing real-time inference.
|
19
|
+
|
20
|
+
Attributes:
|
21
|
+
st (module): Streamlit module for UI creation.
|
22
|
+
temp_dict (dict): Temporary dictionary to store the model path.
|
23
|
+
model_path (str): Path to the loaded model.
|
24
|
+
model (YOLO): The YOLO model instance.
|
25
|
+
source (str): Selected video source.
|
26
|
+
enable_trk (str): Enable tracking option.
|
27
|
+
conf (float): Confidence threshold.
|
28
|
+
iou (float): IoU threshold for non-max suppression.
|
29
|
+
vid_file_name (str): Name of the uploaded video file.
|
30
|
+
selected_ind (list): List of selected class indices.
|
31
|
+
|
32
|
+
Methods:
|
33
|
+
web_ui: Sets up the Streamlit web interface with custom HTML elements.
|
34
|
+
sidebar: Configures the Streamlit sidebar for model and inference settings.
|
35
|
+
source_upload: Handles video file uploads through the Streamlit interface.
|
36
|
+
configure: Configures the model and loads selected classes for inference.
|
37
|
+
inference: Performs real-time object detection inference.
|
38
|
+
|
39
|
+
Examples:
|
40
|
+
>>> inf = solutions.Inference(model="path/to/model.pt") # Model is not necessary argument.
|
41
|
+
>>> inf.inference()
|
42
|
+
"""
|
43
|
+
|
44
|
+
def __init__(self, **kwargs: Any):
|
45
|
+
"""
|
46
|
+
Initializes the Inference class, checking Streamlit requirements and setting up the model path.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
**kwargs (Any): Additional keyword arguments for model configuration.
|
50
|
+
"""
|
51
|
+
check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds
|
52
|
+
import streamlit as st
|
53
|
+
|
54
|
+
self.st = st # Reference to the Streamlit class instance
|
55
|
+
self.source = None # Placeholder for video or webcam source details
|
56
|
+
self.enable_trk = False # Flag to toggle object tracking
|
57
|
+
self.conf = 0.25 # Confidence threshold for detection
|
58
|
+
self.iou = 0.45 # Intersection-over-Union (IoU) threshold for non-maximum suppression
|
59
|
+
self.org_frame = None # Container for the original frame to be displayed
|
60
|
+
self.ann_frame = None # Container for the annotated frame to be displayed
|
61
|
+
self.vid_file_name = None # Holds the name of the video file
|
62
|
+
self.selected_ind = [] # List of selected classes for detection or tracking
|
63
|
+
self.model = None # Container for the loaded model instance
|
64
|
+
|
65
|
+
self.temp_dict = {"model": None, **kwargs}
|
66
|
+
self.model_path = None # Store model file name with path
|
67
|
+
if self.temp_dict["model"] is not None:
|
68
|
+
self.model_path = self.temp_dict["model"]
|
69
|
+
|
70
|
+
LOGGER.info(f"Ultralytics Solutions: ✅ {self.temp_dict}")
|
71
|
+
|
72
|
+
def web_ui(self):
|
73
|
+
"""Sets up the Streamlit web interface with custom HTML elements."""
|
74
|
+
menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>""" # Hide main menu style
|
75
|
+
|
76
|
+
# Main title of streamlit application
|
77
|
+
main_title_cfg = """<div><h1 style="color:#FF64DA; text-align:center; font-size:40px; margin-top:-50px;
|
78
|
+
font-family: 'Archivo', sans-serif; margin-bottom:20px;">Ultralytics YOLO Streamlit Application</h1></div>"""
|
79
|
+
|
80
|
+
# Subtitle of streamlit application
|
81
|
+
sub_title_cfg = """<div><h4 style="color:#042AFF; text-align:center; font-family: 'Archivo', sans-serif;
|
82
|
+
margin-top:-15px; margin-bottom:50px;">Experience real-time object detection on your webcam with the power
|
83
|
+
of Ultralytics YOLO! 🚀</h4></div>"""
|
84
|
+
|
85
|
+
# Set html page configuration and append custom HTML
|
86
|
+
self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide")
|
87
|
+
self.st.markdown(menu_style_cfg, unsafe_allow_html=True)
|
88
|
+
self.st.markdown(main_title_cfg, unsafe_allow_html=True)
|
89
|
+
self.st.markdown(sub_title_cfg, unsafe_allow_html=True)
|
90
|
+
|
91
|
+
def sidebar(self):
|
92
|
+
"""Configures the Streamlit sidebar for model and inference settings."""
|
93
|
+
with self.st.sidebar: # Add Ultralytics LOGO
|
94
|
+
logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg"
|
95
|
+
self.st.image(logo, width=250)
|
96
|
+
|
97
|
+
self.st.sidebar.title("User Configuration") # Add elements to vertical setting menu
|
98
|
+
self.source = self.st.sidebar.selectbox(
|
99
|
+
"Video",
|
100
|
+
("webcam", "video"),
|
101
|
+
) # Add source selection dropdown
|
102
|
+
self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) # Enable object tracking
|
103
|
+
self.conf = float(
|
104
|
+
self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01)
|
105
|
+
) # Slider for confidence
|
106
|
+
self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, self.iou, 0.01)) # Slider for NMS threshold
|
107
|
+
|
108
|
+
col1, col2 = self.st.columns(2)
|
109
|
+
self.org_frame = col1.empty()
|
110
|
+
self.ann_frame = col2.empty()
|
111
|
+
|
112
|
+
def source_upload(self):
|
113
|
+
"""Handles video file uploads through the Streamlit interface."""
|
114
|
+
self.vid_file_name = ""
|
115
|
+
if self.source == "video":
|
116
|
+
vid_file = self.st.sidebar.file_uploader("Upload Video File", type=["mp4", "mov", "avi", "mkv"])
|
117
|
+
if vid_file is not None:
|
118
|
+
g = io.BytesIO(vid_file.read()) # BytesIO Object
|
119
|
+
with open("ultralytics.mp4", "wb") as out: # Open temporary file as bytes
|
120
|
+
out.write(g.read()) # Read bytes into file
|
121
|
+
self.vid_file_name = "ultralytics.mp4"
|
122
|
+
elif self.source == "webcam":
|
123
|
+
self.vid_file_name = 0
|
124
|
+
|
125
|
+
def configure(self):
|
126
|
+
"""Configures the model and loads selected classes for inference."""
|
127
|
+
# Add dropdown menu for model selection
|
128
|
+
available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolo11")]
|
129
|
+
if self.model_path: # If user provided the custom model, insert model without suffix as *.pt is added later
|
130
|
+
available_models.insert(0, self.model_path.split(".pt")[0])
|
131
|
+
selected_model = self.st.sidebar.selectbox("Model", available_models)
|
132
|
+
|
133
|
+
with self.st.spinner("Model is downloading..."):
|
134
|
+
self.model = YOLO(f"{selected_model.lower()}.pt") # Load the YOLO model
|
135
|
+
class_names = list(self.model.names.values()) # Convert dictionary to list of class names
|
136
|
+
self.st.success("Model loaded successfully!")
|
137
|
+
|
138
|
+
# Multiselect box with class names and get indices of selected classes
|
139
|
+
selected_classes = self.st.sidebar.multiselect("Classes", class_names, default=class_names[:3])
|
140
|
+
self.selected_ind = [class_names.index(option) for option in selected_classes]
|
141
|
+
|
142
|
+
if not isinstance(self.selected_ind, list): # Ensure selected_options is a list
|
143
|
+
self.selected_ind = list(self.selected_ind)
|
144
|
+
|
145
|
+
def inference(self):
|
146
|
+
"""Performs real-time object detection inference."""
|
147
|
+
self.web_ui() # Initialize the web interface
|
148
|
+
self.sidebar() # Create the sidebar
|
149
|
+
self.source_upload() # Upload the video source
|
150
|
+
self.configure() # Configure the app
|
151
|
+
|
152
|
+
if self.st.sidebar.button("Start"):
|
153
|
+
stop_button = self.st.button("Stop") # Button to stop the inference
|
154
|
+
cap = cv2.VideoCapture(self.vid_file_name) # Capture the video
|
155
|
+
if not cap.isOpened():
|
156
|
+
self.st.error("Could not open webcam.")
|
157
|
+
while cap.isOpened():
|
158
|
+
success, frame = cap.read()
|
159
|
+
if not success:
|
160
|
+
self.st.warning("Failed to read frame from webcam. Please verify the webcam is connected properly.")
|
161
|
+
break
|
162
|
+
|
163
|
+
# Store model predictions
|
164
|
+
if self.enable_trk == "Yes":
|
165
|
+
results = self.model.track(
|
166
|
+
frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True
|
167
|
+
)
|
168
|
+
else:
|
169
|
+
results = self.model(frame, conf=self.conf, iou=self.iou, classes=self.selected_ind)
|
170
|
+
annotated_frame = results[0].plot() # Add annotations on frame
|
171
|
+
|
172
|
+
if stop_button:
|
173
|
+
cap.release() # Release the capture
|
174
|
+
self.st.stop() # Stop streamlit app
|
175
|
+
|
176
|
+
self.org_frame.image(frame, channels="BGR") # Display original frame
|
177
|
+
self.ann_frame.image(annotated_frame, channels="BGR") # Display processed frame
|
178
|
+
|
179
|
+
cap.release() # Release the capture
|
180
|
+
cv2.destroyAllWindows() # Destroy window
|
181
|
+
|
182
|
+
|
183
|
+
if __name__ == "__main__":
|
184
|
+
import sys # Import the sys module for accessing command-line arguments
|
185
|
+
|
186
|
+
# Check if a model name is provided as a command-line argument
|
187
|
+
args = len(sys.argv)
|
188
|
+
model = sys.argv[1] if args > 1 else None # assign first argument as the model name
|
189
|
+
# Create an instance of the Inference class and run inference
|
190
|
+
Inference(model=model).inference()
|