ultralytics 8.3.53__py3-none-any.whl → 8.3.55__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. tests/__init__.py +0 -1
  2. tests/conftest.py +2 -2
  3. tests/test_cli.py +2 -1
  4. tests/test_python.py +2 -2
  5. tests/test_solutions.py +11 -9
  6. ultralytics/__init__.py +1 -1
  7. ultralytics/cfg/__init__.py +57 -56
  8. ultralytics/cfg/datasets/coco-pose.yaml +4 -4
  9. ultralytics/cfg/datasets/lvis.yaml +1 -1
  10. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  11. ultralytics/cfg/solutions/default.yaml +1 -1
  12. ultralytics/data/augment.py +6 -3
  13. ultralytics/data/dataset.py +2 -2
  14. ultralytics/engine/exporter.py +11 -11
  15. ultralytics/engine/model.py +22 -24
  16. ultralytics/engine/validator.py +1 -1
  17. ultralytics/models/sam/modules/tiny_encoder.py +2 -1
  18. ultralytics/models/sam/predict.py +1 -1
  19. ultralytics/nn/autobackend.py +7 -10
  20. ultralytics/solutions/__init__.py +2 -2
  21. ultralytics/solutions/analytics.py +1 -1
  22. ultralytics/solutions/distance_calculation.py +2 -0
  23. ultralytics/solutions/heatmap.py +1 -0
  24. ultralytics/solutions/parking_management.py +25 -14
  25. ultralytics/solutions/region_counter.py +4 -0
  26. ultralytics/solutions/security_alarm.py +9 -6
  27. ultralytics/solutions/solutions.py +8 -0
  28. ultralytics/solutions/streamlit_inference.py +180 -133
  29. ultralytics/utils/benchmarks.py +2 -1
  30. ultralytics/utils/downloads.py +1 -1
  31. ultralytics/utils/instance.py +1 -1
  32. ultralytics/utils/metrics.py +3 -4
  33. ultralytics/utils/plotting.py +2 -1
  34. {ultralytics-8.3.53.dist-info → ultralytics-8.3.55.dist-info}/METADATA +2 -2
  35. {ultralytics-8.3.53.dist-info → ultralytics-8.3.55.dist-info}/RECORD +39 -38
  36. {ultralytics-8.3.53.dist-info → ultralytics-8.3.55.dist-info}/LICENSE +0 -0
  37. {ultralytics-8.3.53.dist-info → ultralytics-8.3.55.dist-info}/WHEEL +0 -0
  38. {ultralytics-8.3.53.dist-info → ultralytics-8.3.55.dist-info}/entry_points.txt +0 -0
  39. {ultralytics-8.3.53.dist-info → ultralytics-8.3.55.dist-info}/top_level.txt +0 -0
@@ -1377,7 +1377,7 @@ class SAM2VideoPredictor(SAM2Predictor):
1377
1377
  if "maskmem_pos_enc" not in model_constants:
1378
1378
  assert isinstance(out_maskmem_pos_enc, list)
1379
1379
  # only take the slice for one object, since it's same across objects
1380
- maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
1380
+ maskmem_pos_enc = [x[:1].clone() for x in out_maskmem_pos_enc]
1381
1381
  model_constants["maskmem_pos_enc"] = maskmem_pos_enc
1382
1382
  else:
1383
1383
  maskmem_pos_enc = model_constants["maskmem_pos_enc"]
@@ -192,14 +192,14 @@ class AutoBackend(nn.Module):
192
192
  check_requirements("numpy==1.23.5")
193
193
  import onnxruntime
194
194
 
195
- providers = onnxruntime.get_available_providers()
196
- if not cuda and "CUDAExecutionProvider" in providers:
197
- providers.remove("CUDAExecutionProvider")
198
- elif cuda and "CUDAExecutionProvider" not in providers:
199
- LOGGER.warning("WARNING ⚠️ Failed to start ONNX Runtime session with CUDA. Falling back to CPU...")
195
+ providers = ["CPUExecutionProvider"]
196
+ if cuda and "CUDAExecutionProvider" in onnxruntime.get_available_providers():
197
+ providers.insert(0, "CUDAExecutionProvider")
198
+ elif cuda: # Only log warning if CUDA was requested but unavailable
199
+ LOGGER.warning("WARNING ⚠️ Failed to start ONNX Runtime with CUDA. Using CPU...")
200
200
  device = torch.device("cpu")
201
201
  cuda = False
202
- LOGGER.info(f"Preferring ONNX Runtime {providers[0]}")
202
+ LOGGER.info(f"Using ONNX Runtime {providers[0]}")
203
203
  if onnx:
204
204
  session = onnxruntime.InferenceSession(w, providers=providers)
205
205
  else:
@@ -429,10 +429,7 @@ class AutoBackend(nn.Module):
429
429
 
430
430
  import MNN
431
431
 
432
- config = {}
433
- config["precision"] = "low"
434
- config["backend"] = "CPU"
435
- config["numThread"] = (os.cpu_count() + 1) // 2
432
+ config = {"precision": "low", "backend": "CPU", "numThread": (os.cpu_count() + 1) // 2}
436
433
  rt = MNN.nn.create_runtime_manager((config,))
437
434
  net = MNN.nn.load_module_from_file(w, [], [], runtime_manager=rt, rearrange=True)
438
435
 
@@ -10,7 +10,7 @@ from .queue_management import QueueManager
10
10
  from .region_counter import RegionCounter
11
11
  from .security_alarm import SecurityAlarm
12
12
  from .speed_estimation import SpeedEstimator
13
- from .streamlit_inference import inference
13
+ from .streamlit_inference import Inference
14
14
  from .trackzone import TrackZone
15
15
 
16
16
  __all__ = (
@@ -23,7 +23,7 @@ __all__ = (
23
23
  "QueueManager",
24
24
  "SpeedEstimator",
25
25
  "Analytics",
26
- "inference",
26
+ "Inference",
27
27
  "RegionCounter",
28
28
  "TrackZone",
29
29
  "SecurityAlarm",
@@ -170,7 +170,7 @@ class Analytics(BaseSolution):
170
170
  for key in count_dict.keys():
171
171
  y_data_dict[key] = np.append(y_data_dict[key], float(count_dict[key]))
172
172
  if len(y_data_dict[key]) < max_length:
173
- y_data_dict[key] = np.pad(y_data_dict[key], (0, max_length - len(y_data_dict[key])), "constant")
173
+ y_data_dict[key] = np.pad(y_data_dict[key], (0, max_length - len(y_data_dict[key])))
174
174
  if len(x_data) > self.max_points:
175
175
  x_data = x_data[1:]
176
176
  for key in count_dict.keys():
@@ -45,6 +45,8 @@ class DistanceCalculation(BaseSolution):
45
45
  self.left_mouse_count = 0
46
46
  self.selected_boxes = {}
47
47
 
48
+ self.centroids = [] # Initialize empty list to store centroids
49
+
48
50
  def mouse_event_for_distance(self, event, x, y, flags, param):
49
51
  """
50
52
  Handles mouse events to select regions in a real-time video stream for distance calculation.
@@ -41,6 +41,7 @@ class Heatmap(ObjectCounter):
41
41
 
42
42
  # store colormap
43
43
  self.colormap = cv2.COLORMAP_PARULA if self.CFG["colormap"] is None else self.CFG["colormap"]
44
+ self.heatmap = None
44
45
 
45
46
  def heatmap_effect(self, box):
46
47
  """
@@ -5,7 +5,9 @@ import json
5
5
  import cv2
6
6
  import numpy as np
7
7
 
8
- from ultralytics.solutions.solutions import LOGGER, BaseSolution, check_requirements
8
+ from ultralytics.solutions.solutions import BaseSolution
9
+ from ultralytics.utils import LOGGER
10
+ from ultralytics.utils.checks import check_requirements
9
11
  from ultralytics.utils.plotting import Annotator
10
12
 
11
13
 
@@ -32,7 +34,6 @@ class ParkingPtsSelection:
32
34
  canvas_max_height (int): Maximum height of the canvas.
33
35
 
34
36
  Methods:
35
- setup_ui: Sets up the Tkinter UI components.
36
37
  initialize_properties: Initializes the necessary properties.
37
38
  upload_image: Uploads an image, resizes it to fit the canvas, and displays it.
38
39
  on_canvas_click: Handles mouse clicks to add points for bounding boxes.
@@ -53,20 +54,22 @@ class ParkingPtsSelection:
53
54
  from tkinter import filedialog, messagebox
54
55
 
55
56
  self.tk, self.filedialog, self.messagebox = tk, filedialog, messagebox
56
- self.setup_ui()
57
- self.initialize_properties()
58
- self.master.mainloop()
59
-
60
- def setup_ui(self):
61
- """Sets up the Tkinter UI components for the parking zone points selection interface."""
62
- self.master = self.tk.Tk()
57
+ self.master = self.tk.Tk() # Reference to the main application window or parent widget
63
58
  self.master.title("Ultralytics Parking Zones Points Selector")
64
59
  self.master.resizable(False, False)
65
60
 
66
- # Canvas for image display
67
- self.canvas = self.tk.Canvas(self.master, bg="white")
61
+ self.canvas = self.tk.Canvas(self.master, bg="white") # Canvas widget for displaying images or graphics
68
62
  self.canvas.pack(side=self.tk.BOTTOM)
69
63
 
64
+ self.image = None # Variable to store the loaded image
65
+ self.canvas_image = None # Reference to the image displayed on the canvas
66
+ self.canvas_max_width = None # Maximum allowed width for the canvas
67
+ self.canvas_max_height = None # Maximum allowed height for the canvas
68
+ self.rg_data = None # Data related to region or annotation management
69
+ self.current_box = None # Stores the currently selected or active bounding box
70
+ self.imgh = None # Height of the current image
71
+ self.imgw = None # Width of the current image
72
+
70
73
  # Button frame with buttons
71
74
  button_frame = self.tk.Frame(self.master)
72
75
  button_frame.pack(side=self.tk.TOP)
@@ -78,6 +81,9 @@ class ParkingPtsSelection:
78
81
  ]:
79
82
  self.tk.Button(button_frame, text=text, command=cmd).pack(side=self.tk.LEFT)
80
83
 
84
+ self.initialize_properties()
85
+ self.master.mainloop()
86
+
81
87
  def initialize_properties(self):
82
88
  """Initialize properties for image, canvas, bounding boxes, and dimensions."""
83
89
  self.image = self.canvas_image = None
@@ -103,7 +109,7 @@ class ParkingPtsSelection:
103
109
  )
104
110
 
105
111
  self.canvas.config(width=canvas_width, height=canvas_height)
106
- self.canvas_image = ImageTk.PhotoImage(self.image.resize((canvas_width, canvas_height), Image.LANCZOS))
112
+ self.canvas_image = ImageTk.PhotoImage(self.image.resize((canvas_width, canvas_height)))
107
113
  self.canvas.create_image(0, 0, anchor=self.tk.NW, image=self.canvas_image)
108
114
  self.canvas.bind("<Button-1>", self.on_canvas_click)
109
115
 
@@ -142,8 +148,13 @@ class ParkingPtsSelection:
142
148
  """Saves the selected parking zone points to a JSON file with scaled coordinates."""
143
149
  scale_w, scale_h = self.imgw / self.canvas.winfo_width(), self.imgh / self.canvas.winfo_height()
144
150
  data = [{"points": [(int(x * scale_w), int(y * scale_h)) for x, y in box]} for box in self.rg_data]
145
- with open("bounding_boxes.json", "w") as f:
146
- json.dump(data, f, indent=4)
151
+
152
+ from io import StringIO # Function level import, as it's only required to store coordinates, not every frame
153
+
154
+ write_buffer = StringIO()
155
+ json.dump(data, write_buffer, indent=4)
156
+ with open("bounding_boxes.json", "w", encoding="utf-8") as f:
157
+ f.write(write_buffer.getvalue())
147
158
  self.messagebox.showinfo("Success", "Bounding boxes saved to bounding_boxes.json")
148
159
 
149
160
 
@@ -1,6 +1,7 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
3
  from ultralytics.solutions.solutions import BaseSolution
4
+ from ultralytics.utils import LOGGER
4
5
  from ultralytics.utils.plotting import Annotator, colors
5
6
 
6
7
 
@@ -81,6 +82,9 @@ class RegionCounter(BaseSolution):
81
82
 
82
83
  # Draw regions and process counts for each defined area
83
84
  for idx, (region_name, reg_pts) in enumerate(regions.items(), start=1):
85
+ if not isinstance(reg_pts, list) or not all(isinstance(pt, tuple) for pt in reg_pts):
86
+ LOGGER.warning(f"Invalid region points for {region_name}: {reg_pts}")
87
+ continue # Skip invalid entries
84
88
  color = colors(idx, True)
85
89
  self.annotator.draw_region(reg_pts=reg_pts, color=color, thickness=self.line_width * 2)
86
90
  self.add_region(region_name, reg_pts, color, self.annotator.get_txt_color())
@@ -1,6 +1,7 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- from ultralytics.solutions.solutions import LOGGER, BaseSolution
3
+ from ultralytics.solutions.solutions import BaseSolution
4
+ from ultralytics.utils import LOGGER
4
5
  from ultralytics.utils.plotting import Annotator, colors
5
6
 
6
7
 
@@ -33,6 +34,9 @@ class SecurityAlarm(BaseSolution):
33
34
  super().__init__(**kwargs)
34
35
  self.email_sent = False
35
36
  self.records = self.CFG["records"]
37
+ self.server = None
38
+ self.to_email = ""
39
+ self.from_email = ""
36
40
 
37
41
  def authenticate(self, from_email, password, to_email):
38
42
  """
@@ -90,7 +94,7 @@ class SecurityAlarm(BaseSolution):
90
94
 
91
95
  # Add the text message body
92
96
  message_body = f"Ultralytics ALERT!!! " f"{records} objects have been detected!!"
93
- message.attach(MIMEText(message_body, "plain"))
97
+ message.attach(MIMEText(message_body))
94
98
 
95
99
  # Attach the image
96
100
  image_attachment = MIMEImage(img_bytes, name="ultralytics.jpg")
@@ -131,10 +135,9 @@ class SecurityAlarm(BaseSolution):
131
135
  self.annotator.box_label(box, label=self.names[cls], color=colors(cls, True))
132
136
 
133
137
  total_det = len(self.clss)
134
- if total_det > self.records: # Only send email If not sent before
135
- if not self.email_sent:
136
- self.send_email(im0, total_det)
137
- self.email_sent = True
138
+ if total_det > self.records and not self.email_sent: # Only send email If not sent before
139
+ self.send_email(im0, total_det)
140
+ self.email_sent = True
138
141
 
139
142
  self.display_output(im0) # display output with base class function
140
143
 
@@ -56,6 +56,14 @@ class BaseSolution:
56
56
  self.Polygon = Polygon
57
57
  self.Point = Point
58
58
  self.prep = prep
59
+ self.annotator = None # Initialize annotator
60
+ self.tracks = None
61
+ self.track_data = None
62
+ self.boxes = []
63
+ self.clss = []
64
+ self.track_ids = []
65
+ self.track_line = None
66
+ self.r_s = None
59
67
 
60
68
  # Load config and update with args
61
69
  DEFAULT_SOL_DICT.update(kwargs)
@@ -1,148 +1,195 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
3
  import io
4
- import time
4
+ from typing import Any
5
5
 
6
6
  import cv2
7
- import torch
8
7
 
8
+ from ultralytics import YOLO
9
+ from ultralytics.utils import LOGGER
9
10
  from ultralytics.utils.checks import check_requirements
10
11
  from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
11
12
 
12
13
 
13
- def inference(model=None):
14
- """Performs real-time object detection on video input using YOLO in a Streamlit web application."""
15
- check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds
16
- import streamlit as st
17
-
18
- from ultralytics import YOLO
19
-
20
- # Hide main menu style
21
- menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>"""
22
-
23
- # Main title of streamlit application
24
- main_title_cfg = """<div><h1 style="color:#FF64DA; text-align:center; font-size:40px;
25
- font-family: 'Archivo', sans-serif; margin-top:-50px;margin-bottom:20px;">
26
- Ultralytics YOLO Streamlit Application
27
- </h1></div>"""
28
-
29
- # Subtitle of streamlit application
30
- sub_title_cfg = """<div><h4 style="color:#042AFF; text-align:center;
31
- font-family: 'Archivo', sans-serif; margin-top:-15px; margin-bottom:50px;">
32
- Experience real-time object detection on your webcam with the power of Ultralytics YOLO! 🚀</h4>
33
- </div>"""
34
-
35
- # Set html page configuration
36
- st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide", initial_sidebar_state="auto")
37
-
38
- # Append the custom HTML
39
- st.markdown(menu_style_cfg, unsafe_allow_html=True)
40
- st.markdown(main_title_cfg, unsafe_allow_html=True)
41
- st.markdown(sub_title_cfg, unsafe_allow_html=True)
42
-
43
- # Add ultralytics logo in sidebar
44
- with st.sidebar:
45
- logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg"
46
- st.image(logo, width=250)
14
+ class Inference:
15
+ """
16
+ A class to perform object detection, image classification, image segmentation and pose estimation inference using
17
+ Streamlit and Ultralytics YOLO models. It provides the functionalities such as loading models, configuring settings,
18
+ uploading video files, and performing real-time inference.
19
+
20
+ Attributes:
21
+ st (module): Streamlit module for UI creation.
22
+ temp_dict (dict): Temporary dictionary to store the model path.
23
+ model_path (str): Path to the loaded model.
24
+ model (YOLO): The YOLO model instance.
25
+ source (str): Selected video source.
26
+ enable_trk (str): Enable tracking option.
27
+ conf (float): Confidence threshold.
28
+ iou (float): IoU threshold for non-max suppression.
29
+ vid_file_name (str): Name of the uploaded video file.
30
+ selected_ind (list): List of selected class indices.
31
+
32
+ Methods:
33
+ web_ui: Sets up the Streamlit web interface with custom HTML elements.
34
+ sidebar: Configures the Streamlit sidebar for model and inference settings.
35
+ source_upload: Handles video file uploads through the Streamlit interface.
36
+ configure: Configures the model and loads selected classes for inference.
37
+ inference: Performs real-time object detection inference.
38
+
39
+ Examples:
40
+ >>> inf = solutions.Inference(model="path/to/model.pt") # Model is not necessary argument.
41
+ >>> inf.inference()
42
+ """
43
+
44
+ def __init__(self, **kwargs: Any):
45
+ """
46
+ Initializes the Inference class, checking Streamlit requirements and setting up the model path.
47
+
48
+ Args:
49
+ **kwargs (Any): Additional keyword arguments for model configuration.
50
+ """
51
+ check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds
52
+ import streamlit as st
53
+
54
+ self.st = st # Reference to the Streamlit class instance
55
+ self.source = None # Placeholder for video or webcam source details
56
+ self.enable_trk = False # Flag to toggle object tracking
57
+ self.conf = 0.25 # Confidence threshold for detection
58
+ self.iou = 0.45 # Intersection-over-Union (IoU) threshold for non-maximum suppression
59
+ self.org_frame = None # Container for the original frame to be displayed
60
+ self.ann_frame = None # Container for the annotated frame to be displayed
61
+ self.vid_file_name = None # Holds the name of the video file
62
+ self.selected_ind = [] # List of selected classes for detection or tracking
63
+ self.model = None # Container for the loaded model instance
64
+
65
+ self.temp_dict = {"model": None} # Temporary dict to store the model path
66
+ self.temp_dict.update(kwargs)
67
+ self.model_path = None # Store model file name with path
68
+ if self.temp_dict["model"] is not None:
69
+ self.model_path = self.temp_dict["model"]
70
+
71
+ LOGGER.info(f"Ultralytics Solutions: ✅ {self.temp_dict}")
72
+
73
+ def web_ui(self):
74
+ """Sets up the Streamlit web interface with custom HTML elements."""
75
+ menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>""" # Hide main menu style
76
+
77
+ # Main title of streamlit application
78
+ main_title_cfg = """<div><h1 style="color:#FF64DA; text-align:center; font-size:40px; margin-top:-50px;
79
+ font-family: 'Archivo', sans-serif; margin-bottom:20px;">Ultralytics YOLO Streamlit Application</h1></div>"""
80
+
81
+ # Subtitle of streamlit application
82
+ sub_title_cfg = """<div><h4 style="color:#042AFF; text-align:center; font-family: 'Archivo', sans-serif;
83
+ margin-top:-15px; margin-bottom:50px;">Experience real-time object detection on your webcam with the power
84
+ of Ultralytics YOLO! 🚀</h4></div>"""
85
+
86
+ # Set html page configuration and append custom HTML
87
+ self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide")
88
+ self.st.markdown(menu_style_cfg, unsafe_allow_html=True)
89
+ self.st.markdown(main_title_cfg, unsafe_allow_html=True)
90
+ self.st.markdown(sub_title_cfg, unsafe_allow_html=True)
91
+
92
+ def sidebar(self):
93
+ """Configures the Streamlit sidebar for model and inference settings."""
94
+ with self.st.sidebar: # Add Ultralytics LOGO
95
+ logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg"
96
+ self.st.image(logo, width=250)
97
+
98
+ self.st.sidebar.title("User Configuration") # Add elements to vertical setting menu
99
+ self.source = self.st.sidebar.selectbox(
100
+ "Video",
101
+ ("webcam", "video"),
102
+ ) # Add source selection dropdown
103
+ self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) # Enable object tracking
104
+ self.conf = float(
105
+ self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01)
106
+ ) # Slider for confidence
107
+ self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, self.iou, 0.01)) # Slider for NMS threshold
108
+
109
+ col1, col2 = self.st.columns(2)
110
+ self.org_frame = col1.empty()
111
+ self.ann_frame = col2.empty()
112
+
113
+ def source_upload(self):
114
+ """Handles video file uploads through the Streamlit interface."""
115
+ self.vid_file_name = ""
116
+ if self.source == "video":
117
+ vid_file = self.st.sidebar.file_uploader("Upload Video File", type=["mp4", "mov", "avi", "mkv"])
118
+ if vid_file is not None:
119
+ g = io.BytesIO(vid_file.read()) # BytesIO Object
120
+ with open("ultralytics.mp4", "wb") as out: # Open temporary file as bytes
121
+ out.write(g.read()) # Read bytes into file
122
+ self.vid_file_name = "ultralytics.mp4"
123
+ elif self.source == "webcam":
124
+ self.vid_file_name = 0
125
+
126
+ def configure(self):
127
+ """Configures the model and loads selected classes for inference."""
128
+ # Add dropdown menu for model selection
129
+ available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolo11")]
130
+ if self.model_path: # If user provided the custom model, insert model without suffix as *.pt is added later
131
+ available_models.insert(0, self.model_path.split(".pt")[0])
132
+ selected_model = self.st.sidebar.selectbox("Model", available_models)
133
+
134
+ with self.st.spinner("Model is downloading..."):
135
+ self.model = YOLO(f"{selected_model.lower()}.pt") # Load the YOLO model
136
+ class_names = list(self.model.names.values()) # Convert dictionary to list of class names
137
+ self.st.success("Model loaded successfully!")
138
+
139
+ # Multiselect box with class names and get indices of selected classes
140
+ selected_classes = self.st.sidebar.multiselect("Classes", class_names, default=class_names[:3])
141
+ self.selected_ind = [class_names.index(option) for option in selected_classes]
142
+
143
+ if not isinstance(self.selected_ind, list): # Ensure selected_options is a list
144
+ self.selected_ind = list(self.selected_ind)
145
+
146
+ def inference(self):
147
+ """Performs real-time object detection inference."""
148
+ self.web_ui() # Initialize the web interface
149
+ self.sidebar() # Create the sidebar
150
+ self.source_upload() # Upload the video source
151
+ self.configure() # Configure the app
152
+
153
+ if self.st.sidebar.button("Start"):
154
+ stop_button = self.st.button("Stop") # Button to stop the inference
155
+ cap = cv2.VideoCapture(self.vid_file_name) # Capture the video
156
+ if not cap.isOpened():
157
+ self.st.error("Could not open webcam.")
158
+ while cap.isOpened():
159
+ success, frame = cap.read()
160
+ if not success:
161
+ self.st.warning("Failed to read frame from webcam. Please verify the webcam is connected properly.")
162
+ break
163
+
164
+ # Store model predictions
165
+ if self.enable_trk == "Yes":
166
+ results = self.model.track(
167
+ frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True
168
+ )
169
+ else:
170
+ results = self.model(frame, conf=self.conf, iou=self.iou, classes=self.selected_ind)
171
+ annotated_frame = results[0].plot() # Add annotations on frame
172
+
173
+ if stop_button:
174
+ cap.release() # Release the capture
175
+ self.st.stop() # Stop streamlit app
176
+
177
+ self.org_frame.image(frame, channels="BGR") # Display original frame
178
+ self.ann_frame.image(annotated_frame, channels="BGR") # Display processed frame
179
+
180
+ cap.release() # Release the capture
181
+ cv2.destroyAllWindows() # Destroy window
47
182
 
48
- # Add elements to vertical setting menu
49
- st.sidebar.title("User Configuration")
50
183
 
51
- # Add video source selection dropdown
52
- source = st.sidebar.selectbox(
53
- "Video",
54
- ("webcam", "video"),
55
- )
56
-
57
- vid_file_name = ""
58
- if source == "video":
59
- vid_file = st.sidebar.file_uploader("Upload Video File", type=["mp4", "mov", "avi", "mkv"])
60
- if vid_file is not None:
61
- g = io.BytesIO(vid_file.read()) # BytesIO Object
62
- vid_location = "ultralytics.mp4"
63
- with open(vid_location, "wb") as out: # Open temporary file as bytes
64
- out.write(g.read()) # Read bytes into file
65
- vid_file_name = "ultralytics.mp4"
66
- elif source == "webcam":
67
- vid_file_name = 0
68
-
69
- # Add dropdown menu for model selection
70
- available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolo11")]
71
- if model:
72
- available_models.insert(0, model.split(".pt")[0]) # insert model without suffix as *.pt is added later
73
-
74
- selected_model = st.sidebar.selectbox("Model", available_models)
75
- with st.spinner("Model is downloading..."):
76
- model = YOLO(f"{selected_model.lower()}.pt") # Load the YOLO model
77
- class_names = list(model.names.values()) # Convert dictionary to list of class names
78
- st.success("Model loaded successfully!")
79
-
80
- # Multiselect box with class names and get indices of selected classes
81
- selected_classes = st.sidebar.multiselect("Classes", class_names, default=class_names[:3])
82
- selected_ind = [class_names.index(option) for option in selected_classes]
83
-
84
- if not isinstance(selected_ind, list): # Ensure selected_options is a list
85
- selected_ind = list(selected_ind)
86
-
87
- enable_trk = st.sidebar.radio("Enable Tracking", ("Yes", "No"))
88
- conf = float(st.sidebar.slider("Confidence Threshold", 0.0, 1.0, 0.25, 0.01))
89
- iou = float(st.sidebar.slider("IoU Threshold", 0.0, 1.0, 0.45, 0.01))
90
-
91
- col1, col2 = st.columns(2)
92
- org_frame = col1.empty()
93
- ann_frame = col2.empty()
94
-
95
- fps_display = st.sidebar.empty() # Placeholder for FPS display
96
-
97
- if st.sidebar.button("Start"):
98
- videocapture = cv2.VideoCapture(vid_file_name) # Capture the video
99
-
100
- if not videocapture.isOpened():
101
- st.error("Could not open webcam.")
102
-
103
- stop_button = st.button("Stop") # Button to stop the inference
104
-
105
- while videocapture.isOpened():
106
- success, frame = videocapture.read()
107
- if not success:
108
- st.warning("Failed to read frame from webcam. Please make sure the webcam is connected properly.")
109
- break
110
-
111
- prev_time = time.time() # Store initial time for FPS calculation
112
-
113
- # Store model predictions
114
- if enable_trk == "Yes":
115
- results = model.track(frame, conf=conf, iou=iou, classes=selected_ind, persist=True)
116
- else:
117
- results = model(frame, conf=conf, iou=iou, classes=selected_ind)
118
- annotated_frame = results[0].plot() # Add annotations on frame
119
-
120
- # Calculate model FPS
121
- curr_time = time.time()
122
- fps = 1 / (curr_time - prev_time)
123
-
124
- # display frame
125
- org_frame.image(frame, channels="BGR")
126
- ann_frame.image(annotated_frame, channels="BGR")
127
-
128
- if stop_button:
129
- videocapture.release() # Release the capture
130
- torch.cuda.empty_cache() # Clear CUDA memory
131
- st.stop() # Stop streamlit app
132
-
133
- # Display FPS in sidebar
134
- fps_display.metric("FPS", f"{fps:.2f}")
135
-
136
- # Release the capture
137
- videocapture.release()
138
-
139
- # Clear CUDA memory
140
- torch.cuda.empty_cache()
184
+ if __name__ == "__main__":
185
+ import sys # Import the sys module for accessing command-line arguments
141
186
 
142
- # Destroy window
143
- cv2.destroyAllWindows()
187
+ model = None # Initialize the model variable as None
144
188
 
189
+ # Check if a model name is provided as a command-line argument
190
+ args = len(sys.argv)
191
+ if args > 1:
192
+ model = sys.argv[1] # Assign the first argument as the model name
145
193
 
146
- # Main function call
147
- if __name__ == "__main__":
148
- inference()
194
+ # Create an instance of the Inference class and run inference
195
+ Inference(model=model).inference()
@@ -440,7 +440,8 @@ class ProfileModels:
440
440
  print(f"Profiling: {sorted(files)}")
441
441
  return [Path(file) for file in sorted(files)]
442
442
 
443
- def get_onnx_model_info(self, onnx_file: str):
443
+ @staticmethod
444
+ def get_onnx_model_info(onnx_file: str):
444
445
  """Extracts metadata from an ONNX model file including parameters, GFLOPs, and input shape."""
445
446
  return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops)
446
447
 
@@ -138,7 +138,7 @@ def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=Fals
138
138
  If a path is not provided, the function will use the parent directory of the zipfile as the default path.
139
139
 
140
140
  Args:
141
- file (str): The path to the zipfile to be extracted.
141
+ file (str | Path): The path to the zipfile to be extracted.
142
142
  path (str, optional): The path to extract the zipfile to. Defaults to None.
143
143
  exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX').
144
144
  exist_ok (bool, optional): Whether to overwrite existing contents if they exist. Defaults to False.
@@ -28,7 +28,7 @@ to_4tuple = _ntuple(4)
28
28
  # `ltwh` means left top and width, height(COCO format)
29
29
  _formats = ["xyxy", "xywh", "ltwh"]
30
30
 
31
- __all__ = ("Bboxes",) # tuple or list
31
+ __all__ = ("Bboxes", "Instances") # tuple or list
32
32
 
33
33
 
34
34
  class Bboxes:
@@ -372,10 +372,9 @@ class ConfusionMatrix:
372
372
  else:
373
373
  self.matrix[self.nc, gc] += 1 # true background
374
374
 
375
- if n:
376
- for i, dc in enumerate(detection_classes):
377
- if not any(m1 == i):
378
- self.matrix[dc, self.nc] += 1 # predicted background
375
+ for i, dc in enumerate(detection_classes):
376
+ if not any(m1 == i):
377
+ self.matrix[dc, self.nc] += 1 # predicted background
379
378
 
380
379
  def matrix(self):
381
380
  """Returns the confusion matrix."""
@@ -545,7 +545,8 @@ class Annotator:
545
545
  """Save the annotated image to 'filename'."""
546
546
  cv2.imwrite(filename, np.asarray(self.im))
547
547
 
548
- def get_bbox_dimension(self, bbox=None):
548
+ @staticmethod
549
+ def get_bbox_dimension(bbox=None):
549
550
  """
550
551
  Calculate the area of a bounding box.
551
552