ultralytics 8.3.15__py3-none-any.whl → 8.3.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tests/test_solutions.py CHANGED
@@ -17,10 +17,15 @@ def test_major_solutions():
17
17
  cap = cv2.VideoCapture("solutions_ci_demo.mp4")
18
18
  assert cap.isOpened(), "Error reading video file"
19
19
  region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)]
20
- counter = solutions.ObjectCounter(region=region_points, model="yolo11n.pt", show=False)
21
- heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False)
22
- speed = solutions.SpeedEstimator(region=region_points, model="yolo11n.pt", show=False)
23
- queue = solutions.QueueManager(region=region_points, model="yolo11n.pt", show=False)
20
+ counter = solutions.ObjectCounter(region=region_points, model="yolo11n.pt", show=False) # Test object counter
21
+ heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False) # Test heatmaps
22
+ speed = solutions.SpeedEstimator(region=region_points, model="yolo11n.pt", show=False) # Test queue manager
23
+ queue = solutions.QueueManager(region=region_points, model="yolo11n.pt", show=False) # Test speed estimation
24
+ line_analytics = solutions.Analytics(analytics_type="line", model="yolo11n.pt", show=False) # line analytics
25
+ pie_analytics = solutions.Analytics(analytics_type="pie", model="yolo11n.pt", show=False) # line analytics
26
+ bar_analytics = solutions.Analytics(analytics_type="bar", model="yolo11n.pt", show=False) # line analytics
27
+ area_analytics = solutions.Analytics(analytics_type="area", model="yolo11n.pt", show=False) # line analytics
28
+ frame_count = 0 # Required for analytics
24
29
  while cap.isOpened():
25
30
  success, im0 = cap.read()
26
31
  if not success:
@@ -30,24 +35,23 @@ def test_major_solutions():
30
35
  _ = heatmap.generate_heatmap(original_im0.copy())
31
36
  _ = speed.estimate_speed(original_im0.copy())
32
37
  _ = queue.process_queue(original_im0.copy())
38
+ _ = line_analytics.process_data(original_im0.copy(), frame_count)
39
+ _ = pie_analytics.process_data(original_im0.copy(), frame_count)
40
+ _ = bar_analytics.process_data(original_im0.copy(), frame_count)
41
+ _ = area_analytics.process_data(original_im0.copy(), frame_count)
33
42
  cap.release()
34
- cv2.destroyAllWindows()
35
-
36
43
 
37
- @pytest.mark.slow
38
- def test_aigym():
39
- """Test the workouts monitoring solution."""
44
+ # Test workouts monitoring
40
45
  safe_download(url=WORKOUTS_SOLUTION_DEMO)
41
- cap = cv2.VideoCapture("solution_ci_pose_demo.mp4")
42
- assert cap.isOpened(), "Error reading video file"
43
- gym = solutions.AIGym(line_width=2, kpts=[5, 11, 13])
44
- while cap.isOpened():
45
- success, im0 = cap.read()
46
+ cap1 = cv2.VideoCapture("solution_ci_pose_demo.mp4")
47
+ assert cap1.isOpened(), "Error reading video file"
48
+ gym = solutions.AIGym(line_width=2, kpts=[5, 11, 13], show=False)
49
+ while cap1.isOpened():
50
+ success, im0 = cap1.read()
46
51
  if not success:
47
52
  break
48
53
  _ = gym.monitor(im0)
49
- cap.release()
50
- cv2.destroyAllWindows()
54
+ cap1.release()
51
55
 
52
56
 
53
57
  @pytest.mark.slow
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- __version__ = "8.3.15"
3
+ __version__ = "8.3.16"
4
4
 
5
5
  import os
6
6
 
@@ -15,3 +15,4 @@ down_angle: 90 # Workouts down_angle for counts, 90 is default value. You can ch
15
15
  kpts: [6, 8, 10] # Keypoints for workouts monitoring, i.e. If you want to consider keypoints for pushups that have mostly values of [6, 8, 10].
16
16
  colormap: # Colormap for heatmap, Only OPENCV supported colormaps can be used. By default COLORMAP_PARULA will be used for visualization.
17
17
  analytics_type: "line" # Analytics type i.e "line", "pie", "bar" or "area" charts. By default, "line" analytics will be used for processing.
18
+ json_file: # parking system regions file path.
@@ -13,9 +13,6 @@ from tqdm import tqdm
13
13
  from ultralytics.data.utils import exif_size, img2label_paths
14
14
  from ultralytics.utils.checks import check_requirements
15
15
 
16
- check_requirements("shapely")
17
- from shapely.geometry import Polygon
18
-
19
16
 
20
17
  def bbox_iof(polygon1, bbox2, eps=1e-6):
21
18
  """
@@ -33,6 +30,9 @@ def bbox_iof(polygon1, bbox2, eps=1e-6):
33
30
  Polygon format: [x1, y1, x2, y2, x3, y3, x4, y4].
34
31
  Bounding box format: [x_min, y_min, x_max, y_max].
35
32
  """
33
+ check_requirements("shapely")
34
+ from shapely.geometry import Polygon
35
+
36
36
  polygon1 = polygon1.reshape(-1, 4, 2)
37
37
  lt_point = np.min(polygon1, axis=-2) # left-top
38
38
  rb_point = np.max(polygon1, axis=-2) # right-bottom
@@ -1,16 +1,40 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- from ultralytics.solutions.solutions import BaseSolution # Import a parent class
3
+ from ultralytics.solutions.solutions import BaseSolution
4
4
  from ultralytics.utils.plotting import Annotator
5
5
 
6
6
 
7
7
  class AIGym(BaseSolution):
8
- """A class to manage the gym steps of people in a real-time video stream based on their poses."""
8
+ """
9
+ A class to manage gym steps of people in a real-time video stream based on their poses.
10
+
11
+ This class extends BaseSolution to monitor workouts using YOLO pose estimation models. It tracks and counts
12
+ repetitions of exercises based on predefined angle thresholds for up and down positions.
13
+
14
+ Attributes:
15
+ count (List[int]): Repetition counts for each detected person.
16
+ angle (List[float]): Current angle of the tracked body part for each person.
17
+ stage (List[str]): Current exercise stage ('up', 'down', or '-') for each person.
18
+ initial_stage (str | None): Initial stage of the exercise.
19
+ up_angle (float): Angle threshold for considering the 'up' position of an exercise.
20
+ down_angle (float): Angle threshold for considering the 'down' position of an exercise.
21
+ kpts (List[int]): Indices of keypoints used for angle calculation.
22
+ lw (int): Line width for drawing annotations.
23
+ annotator (Annotator): Object for drawing annotations on the image.
24
+
25
+ Methods:
26
+ monitor: Processes a frame to detect poses, calculate angles, and count repetitions.
27
+
28
+ Examples:
29
+ >>> gym = AIGym(model="yolov8n-pose.pt")
30
+ >>> image = cv2.imread("gym_scene.jpg")
31
+ >>> processed_image = gym.monitor(image)
32
+ >>> cv2.imshow("Processed Image", processed_image)
33
+ >>> cv2.waitKey(0)
34
+ """
9
35
 
10
36
  def __init__(self, **kwargs):
11
- """Initialization function for AiGYM class, a child class of BaseSolution class, can be used for workouts
12
- monitoring.
13
- """
37
+ """Initializes AIGym for workout monitoring using pose estimation and predefined angles."""
14
38
  # Check if the model name ends with '-pose'
15
39
  if "model" in kwargs and "-pose" not in kwargs["model"]:
16
40
  kwargs["model"] = "yolo11n-pose.pt"
@@ -31,12 +55,22 @@ class AIGym(BaseSolution):
31
55
 
32
56
  def monitor(self, im0):
33
57
  """
34
- Monitor the workouts using Ultralytics YOLO Pose Model: https://docs.ultralytics.com/tasks/pose/.
58
+ Monitors workouts using Ultralytics YOLO Pose Model.
59
+
60
+ This function processes an input image to track and analyze human poses for workout monitoring. It uses
61
+ the YOLO Pose model to detect keypoints, estimate angles, and count repetitions based on predefined
62
+ angle thresholds.
35
63
 
36
64
  Args:
37
- im0 (ndarray): The input image that will be used for processing
38
- Returns
39
- im0 (ndarray): The processed image for more usage
65
+ im0 (ndarray): Input image for processing.
66
+
67
+ Returns:
68
+ (ndarray): Processed image with annotations for workout monitoring.
69
+
70
+ Examples:
71
+ >>> gym = AIGym()
72
+ >>> image = cv2.imread("workout.jpg")
73
+ >>> processed_image = gym.monitor(image)
40
74
  """
41
75
  # Extract tracks
42
76
  tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"])[0]
@@ -12,10 +12,41 @@ from ultralytics.solutions.solutions import BaseSolution # Import a parent clas
12
12
 
13
13
 
14
14
  class Analytics(BaseSolution):
15
- """A class to create and update various types of charts (line, bar, pie, area) for visual analytics."""
15
+ """
16
+ A class for creating and updating various types of charts for visual analytics.
17
+
18
+ This class extends BaseSolution to provide functionality for generating line, bar, pie, and area charts
19
+ based on object detection and tracking data.
20
+
21
+ Attributes:
22
+ type (str): The type of analytics chart to generate ('line', 'bar', 'pie', or 'area').
23
+ x_label (str): Label for the x-axis.
24
+ y_label (str): Label for the y-axis.
25
+ bg_color (str): Background color of the chart frame.
26
+ fg_color (str): Foreground color of the chart frame.
27
+ title (str): Title of the chart window.
28
+ max_points (int): Maximum number of data points to display on the chart.
29
+ fontsize (int): Font size for text display.
30
+ color_cycle (cycle): Cyclic iterator for chart colors.
31
+ total_counts (int): Total count of detected objects (used for line charts).
32
+ clswise_count (Dict[str, int]): Dictionary for class-wise object counts.
33
+ fig (Figure): Matplotlib figure object for the chart.
34
+ ax (Axes): Matplotlib axes object for the chart.
35
+ canvas (FigureCanvas): Canvas for rendering the chart.
36
+
37
+ Methods:
38
+ process_data: Processes image data and updates the chart.
39
+ update_graph: Updates the chart with new data points.
40
+
41
+ Examples:
42
+ >>> analytics = Analytics(analytics_type="line")
43
+ >>> frame = cv2.imread("image.jpg")
44
+ >>> processed_frame = analytics.process_data(frame, frame_number=1)
45
+ >>> cv2.imshow("Analytics", processed_frame)
46
+ """
16
47
 
17
48
  def __init__(self, **kwargs):
18
- """Initialize the Analytics class with various chart types."""
49
+ """Initialize Analytics class with various chart types for visual data representation."""
19
50
  super().__init__(**kwargs)
20
51
 
21
52
  self.type = self.CFG["analytics_type"] # extract type of analytics
@@ -31,8 +62,8 @@ class Analytics(BaseSolution):
31
62
  figsize = (19.2, 10.8) # Set output image size 1920 * 1080
32
63
  self.color_cycle = cycle(["#DD00BA", "#042AFF", "#FF4447", "#7D24FF", "#BD00FF"])
33
64
 
34
- self.total_counts = 0 # count variable for storing total counts i.e for line
35
- self.clswise_count = {} # dictionary for classwise counts
65
+ self.total_counts = 0 # count variable for storing total counts i.e. for line
66
+ self.clswise_count = {} # dictionary for class-wise counts
36
67
 
37
68
  # Ensure line and area chart
38
69
  if self.type in {"line", "area"}:
@@ -48,15 +79,28 @@ class Analytics(BaseSolution):
48
79
  self.canvas = FigureCanvas(self.fig) # Set common axis properties
49
80
  self.ax.set_facecolor(self.bg_color)
50
81
  self.color_mapping = {}
51
- self.ax.axis("equal") if self.type == "pie" else None # Ensure pie chart is circular
82
+
83
+ if self.type == "pie": # Ensure pie chart is circular
84
+ self.ax.axis("equal")
52
85
 
53
86
  def process_data(self, im0, frame_number):
54
87
  """
55
- Process the image data, run object tracking.
88
+ Processes image data and runs object tracking to update analytics charts.
56
89
 
57
90
  Args:
58
- im0 (ndarray): Input image for processing.
59
- frame_number (int): Video frame # for plotting the data.
91
+ im0 (np.ndarray): Input image for processing.
92
+ frame_number (int): Video frame number for plotting the data.
93
+
94
+ Returns:
95
+ (np.ndarray): Processed image with updated analytics chart.
96
+
97
+ Raises:
98
+ ModuleNotFoundError: If an unsupported chart type is specified.
99
+
100
+ Examples:
101
+ >>> analytics = Analytics(analytics_type="line")
102
+ >>> frame = np.zeros((480, 640, 3), dtype=np.uint8)
103
+ >>> processed_frame = analytics.process_data(frame, frame_number=1)
60
104
  """
61
105
  self.extract_tracks(im0) # Extract tracks
62
106
 
@@ -79,13 +123,22 @@ class Analytics(BaseSolution):
79
123
 
80
124
  def update_graph(self, frame_number, count_dict=None, plot="line"):
81
125
  """
82
- Update the graph (line or area) with new data for single or multiple classes.
126
+ Updates the graph with new data for single or multiple classes.
83
127
 
84
128
  Args:
85
129
  frame_number (int): The current frame number.
86
- count_dict (dict, optional): Dictionary with class names as keys and counts as values for multiple classes.
87
- If None, updates a single line graph.
88
- plot (str): Type of the plot i.e. line, bar or area.
130
+ count_dict (Dict[str, int] | None): Dictionary with class names as keys and counts as values for multiple
131
+ classes. If None, updates a single line graph.
132
+ plot (str): Type of the plot. Options are 'line', 'bar', 'pie', or 'area'.
133
+
134
+ Returns:
135
+ (np.ndarray): Updated image containing the graph.
136
+
137
+ Examples:
138
+ >>> analytics = Analytics()
139
+ >>> frame_number = 10
140
+ >>> count_dict = {"person": 5, "car": 3}
141
+ >>> updated_image = analytics.update_graph(frame_number, count_dict, plot="bar")
89
142
  """
90
143
  if count_dict is None:
91
144
  # Single line update
@@ -4,15 +4,41 @@ import math
4
4
 
5
5
  import cv2
6
6
 
7
- from ultralytics.solutions.solutions import BaseSolution # Import a parent class
7
+ from ultralytics.solutions.solutions import BaseSolution
8
8
  from ultralytics.utils.plotting import Annotator, colors
9
9
 
10
10
 
11
11
  class DistanceCalculation(BaseSolution):
12
- """A class to calculate distance between two objects in a real-time video stream based on their tracks."""
12
+ """
13
+ A class to calculate distance between two objects in a real-time video stream based on their tracks.
14
+
15
+ This class extends BaseSolution to provide functionality for selecting objects and calculating the distance
16
+ between them in a video stream using YOLO object detection and tracking.
17
+
18
+ Attributes:
19
+ left_mouse_count (int): Counter for left mouse button clicks.
20
+ selected_boxes (Dict[int, List[float]]): Dictionary to store selected bounding boxes and their track IDs.
21
+ annotator (Annotator): An instance of the Annotator class for drawing on the image.
22
+ boxes (List[List[float]]): List of bounding boxes for detected objects.
23
+ track_ids (List[int]): List of track IDs for detected objects.
24
+ clss (List[int]): List of class indices for detected objects.
25
+ names (List[str]): List of class names that the model can detect.
26
+ centroids (List[List[int]]): List to store centroids of selected bounding boxes.
27
+
28
+ Methods:
29
+ mouse_event_for_distance: Handles mouse events for selecting objects in the video stream.
30
+ calculate: Processes video frames and calculates the distance between selected objects.
31
+
32
+ Examples:
33
+ >>> distance_calc = DistanceCalculation()
34
+ >>> frame = cv2.imread("frame.jpg")
35
+ >>> processed_frame = distance_calc.calculate(frame)
36
+ >>> cv2.imshow("Distance Calculation", processed_frame)
37
+ >>> cv2.waitKey(0)
38
+ """
13
39
 
14
40
  def __init__(self, **kwargs):
15
- """Initializes the DistanceCalculation class with the given parameters."""
41
+ """Initializes the DistanceCalculation class for measuring object distances in video streams."""
16
42
  super().__init__(**kwargs)
17
43
 
18
44
  # Mouse event information
@@ -21,14 +47,18 @@ class DistanceCalculation(BaseSolution):
21
47
 
22
48
  def mouse_event_for_distance(self, event, x, y, flags, param):
23
49
  """
24
- Handles mouse events to select regions in a real-time video stream.
50
+ Handles mouse events to select regions in a real-time video stream for distance calculation.
25
51
 
26
52
  Args:
27
- event (int): Type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN, etc.).
53
+ event (int): Type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN).
28
54
  x (int): X-coordinate of the mouse pointer.
29
55
  y (int): Y-coordinate of the mouse pointer.
30
- flags (int): Flags associated with the event (e.g., cv2.EVENT_FLAG_CTRLKEY, cv2.EVENT_FLAG_SHIFTKEY, etc.).
31
- param (dict): Additional parameters passed to the function.
56
+ flags (int): Flags associated with the event (e.g., cv2.EVENT_FLAG_CTRLKEY, cv2.EVENT_FLAG_SHIFTKEY).
57
+ param (Dict): Additional parameters passed to the function.
58
+
59
+ Examples:
60
+ >>> # Assuming 'dc' is an instance of DistanceCalculation
61
+ >>> cv2.setMouseCallback("window_name", dc.mouse_event_for_distance)
32
62
  """
33
63
  if event == cv2.EVENT_LBUTTONDOWN:
34
64
  self.left_mouse_count += 1
@@ -43,13 +73,23 @@ class DistanceCalculation(BaseSolution):
43
73
 
44
74
  def calculate(self, im0):
45
75
  """
46
- Processes the video frame and calculates the distance between two bounding boxes.
76
+ Processes a video frame and calculates the distance between two selected bounding boxes.
77
+
78
+ This method extracts tracks from the input frame, annotates bounding boxes, and calculates the distance
79
+ between two user-selected objects if they have been chosen.
47
80
 
48
81
  Args:
49
- im0 (ndarray): The image frame.
82
+ im0 (numpy.ndarray): The input image frame to process.
50
83
 
51
84
  Returns:
52
- (ndarray): The processed image frame.
85
+ (numpy.ndarray): The processed image frame with annotations and distance calculations.
86
+
87
+ Examples:
88
+ >>> import numpy as np
89
+ >>> from ultralytics.solutions import DistanceCalculation
90
+ >>> dc = DistanceCalculation()
91
+ >>> frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
92
+ >>> processed_frame = dc.calculate(frame)
53
93
  """
54
94
  self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator
55
95
  self.extract_tracks(im0) # Extract tracks
@@ -3,15 +3,40 @@
3
3
  import cv2
4
4
  import numpy as np
5
5
 
6
- from ultralytics.solutions.object_counter import ObjectCounter # Import object counter class
6
+ from ultralytics.solutions.object_counter import ObjectCounter
7
7
  from ultralytics.utils.plotting import Annotator
8
8
 
9
9
 
10
10
  class Heatmap(ObjectCounter):
11
- """A class to draw heatmaps in real-time video stream based on their tracks."""
11
+ """
12
+ A class to draw heatmaps in real-time video streams based on object tracks.
13
+
14
+ This class extends the ObjectCounter class to generate and visualize heatmaps of object movements in video
15
+ streams. It uses tracked object positions to create a cumulative heatmap effect over time.
16
+
17
+ Attributes:
18
+ initialized (bool): Flag indicating whether the heatmap has been initialized.
19
+ colormap (int): OpenCV colormap used for heatmap visualization.
20
+ heatmap (np.ndarray): Array storing the cumulative heatmap data.
21
+ annotator (Annotator): Object for drawing annotations on the image.
22
+
23
+ Methods:
24
+ heatmap_effect: Calculates and updates the heatmap effect for a given bounding box.
25
+ generate_heatmap: Generates and applies the heatmap effect to each frame.
26
+
27
+ Examples:
28
+ >>> from ultralytics.solutions import Heatmap
29
+ >>> heatmap = Heatmap(model="yolov8n.pt", colormap=cv2.COLORMAP_JET)
30
+ >>> results = heatmap("path/to/video.mp4")
31
+ >>> for result in results:
32
+ ... print(result.speed) # Print inference speed
33
+ ... cv2.imshow("Heatmap", result.plot())
34
+ ... if cv2.waitKey(1) & 0xFF == ord("q"):
35
+ ... break
36
+ """
12
37
 
13
38
  def __init__(self, **kwargs):
14
- """Initializes function for heatmap class with default values."""
39
+ """Initializes the Heatmap class for real-time video stream heatmap generation based on object tracks."""
15
40
  super().__init__(**kwargs)
16
41
 
17
42
  self.initialized = False # bool variable for heatmap initialization
@@ -23,10 +48,15 @@ class Heatmap(ObjectCounter):
23
48
 
24
49
  def heatmap_effect(self, box):
25
50
  """
26
- Efficient calculation of heatmap area and effect location for applying colormap.
51
+ Efficiently calculates heatmap area and effect location for applying colormap.
27
52
 
28
53
  Args:
29
- box (list): Bounding Box coordinates data [x0, y0, x1, y1]
54
+ box (List[float]): Bounding box coordinates [x0, y0, x1, y1].
55
+
56
+ Examples:
57
+ >>> heatmap = Heatmap()
58
+ >>> box = [100, 100, 200, 200]
59
+ >>> heatmap.heatmap_effect(box)
30
60
  """
31
61
  x0, y0, x1, y1 = map(int, box)
32
62
  radius_squared = (min(x1 - x0, y1 - y0) // 2) ** 2
@@ -48,9 +78,15 @@ class Heatmap(ObjectCounter):
48
78
  Generate heatmap for each frame using Ultralytics.
49
79
 
50
80
  Args:
51
- im0 (ndarray): Input image array for processing
81
+ im0 (np.ndarray): Input image array for processing.
82
+
52
83
  Returns:
53
- im0 (ndarray): Processed image for further usage
84
+ (np.ndarray): Processed image with heatmap overlay and object counts (if region is specified).
85
+
86
+ Examples:
87
+ >>> heatmap = Heatmap()
88
+ >>> im0 = cv2.imread("image.jpg")
89
+ >>> result = heatmap.generate_heatmap(im0)
54
90
  """
55
91
  if not self.initialized:
56
92
  self.heatmap = np.zeros_like(im0, dtype=np.float32) * 0.99
@@ -70,16 +106,17 @@ class Heatmap(ObjectCounter):
70
106
  self.store_classwise_counts(cls) # store classwise counts in dict
71
107
 
72
108
  # Store tracking previous position and perform object counting
73
- prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
109
+ prev_position = None
110
+ if len(self.track_history[track_id]) > 1:
111
+ prev_position = self.track_history[track_id][-2]
74
112
  self.count_objects(self.track_line, box, track_id, prev_position, cls) # Perform object counting
75
113
 
76
- self.display_counts(im0) if self.region is not None else None # Display the counts on the frame
114
+ if self.region is not None:
115
+ self.display_counts(im0) # Display the counts on the frame
77
116
 
78
117
  # Normalize, apply colormap to heatmap and combine with original image
79
- im0 = (
80
- im0
81
- if self.track_data.id is None
82
- else cv2.addWeighted(
118
+ if self.track_data.id is not None:
119
+ im0 = cv2.addWeighted(
83
120
  im0,
84
121
  0.5,
85
122
  cv2.applyColorMap(
@@ -88,7 +125,6 @@ class Heatmap(ObjectCounter):
88
125
  0.5,
89
126
  0,
90
127
  )
91
- )
92
128
 
93
129
  self.display_output(im0) # display output with base class function
94
130
  return im0 # return output image for more usage
@@ -1,18 +1,40 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- from shapely.geometry import LineString, Point
4
-
5
- from ultralytics.solutions.solutions import BaseSolution # Import a parent class
3
+ from ultralytics.solutions.solutions import BaseSolution
6
4
  from ultralytics.utils.plotting import Annotator, colors
7
5
 
8
6
 
9
7
  class ObjectCounter(BaseSolution):
10
- """A class to manage the counting of objects in a real-time video stream based on their tracks."""
8
+ """
9
+ A class to manage the counting of objects in a real-time video stream based on their tracks.
10
+
11
+ This class extends the BaseSolution class and provides functionality for counting objects moving in and out of a
12
+ specified region in a video stream. It supports both polygonal and linear regions for counting.
13
+
14
+ Attributes:
15
+ in_count (int): Counter for objects moving inward.
16
+ out_count (int): Counter for objects moving outward.
17
+ counted_ids (List[int]): List of IDs of objects that have been counted.
18
+ classwise_counts (Dict[str, Dict[str, int]]): Dictionary for counts, categorized by object class.
19
+ region_initialized (bool): Flag indicating whether the counting region has been initialized.
20
+ show_in (bool): Flag to control display of inward count.
21
+ show_out (bool): Flag to control display of outward count.
22
+
23
+ Methods:
24
+ count_objects: Counts objects within a polygonal or linear region.
25
+ store_classwise_counts: Initializes class-wise counts if not already present.
26
+ display_counts: Displays object counts on the frame.
27
+ count: Processes input data (frames or object tracks) and updates counts.
28
+
29
+ Examples:
30
+ >>> counter = ObjectCounter()
31
+ >>> frame = cv2.imread("frame.jpg")
32
+ >>> processed_frame = counter.count(frame)
33
+ >>> print(f"Inward count: {counter.in_count}, Outward count: {counter.out_count}")
34
+ """
11
35
 
12
36
  def __init__(self, **kwargs):
13
- """Initialization function for Count class, a child class of BaseSolution class, can be used for counting the
14
- objects.
15
- """
37
+ """Initializes the ObjectCounter class for real-time object counting in video streams."""
16
38
  super().__init__(**kwargs)
17
39
 
18
40
  self.in_count = 0 # Counter for objects moving inward
@@ -26,14 +48,23 @@ class ObjectCounter(BaseSolution):
26
48
 
27
49
  def count_objects(self, track_line, box, track_id, prev_position, cls):
28
50
  """
29
- Helper function to count objects within a polygonal region.
51
+ Counts objects within a polygonal or linear region based on their tracks.
30
52
 
31
53
  Args:
32
- track_line (dict): last 30 frame track record
33
- box (list): Bounding box data for specific track in current frame
34
- track_id (int): track ID of the object
35
- prev_position (tuple): last frame position coordinates of the track
36
- cls (int): Class index for classwise count updates
54
+ track_line (Dict): Last 30 frame track record for the object.
55
+ box (List[float]): Bounding box coordinates [x1, y1, x2, y2] for the specific track in the current frame.
56
+ track_id (int): Unique identifier for the tracked object.
57
+ prev_position (Tuple[float, float]): Last frame position coordinates (x, y) of the track.
58
+ cls (int): Class index for classwise count updates.
59
+
60
+ Examples:
61
+ >>> counter = ObjectCounter()
62
+ >>> track_line = {1: [100, 200], 2: [110, 210], 3: [120, 220]}
63
+ >>> box = [130, 230, 150, 250]
64
+ >>> track_id = 1
65
+ >>> prev_position = (120, 220)
66
+ >>> cls = 0
67
+ >>> counter.count_objects(track_line, box, track_id, prev_position, cls)
37
68
  """
38
69
  if prev_position is None or track_id in self.counted_ids:
39
70
  return
@@ -42,7 +73,7 @@ class ObjectCounter(BaseSolution):
42
73
  dx = (box[0] - prev_position[0]) * (centroid.x - prev_position[0])
43
74
  dy = (box[1] - prev_position[1]) * (centroid.y - prev_position[1])
44
75
 
45
- if len(self.region) >= 3 and self.r_s.contains(Point(track_line[-1])):
76
+ if len(self.region) >= 3 and self.r_s.contains(self.Point(track_line[-1])):
46
77
  self.counted_ids.append(track_id)
47
78
  # For polygon region
48
79
  if dx > 0:
@@ -52,7 +83,7 @@ class ObjectCounter(BaseSolution):
52
83
  self.out_count += 1
53
84
  self.classwise_counts[self.names[cls]]["OUT"] += 1
54
85
 
55
- elif len(self.region) < 3 and LineString([prev_position, box[:2]]).intersects(self.l_s):
86
+ elif len(self.region) < 3 and self.LineString([prev_position, box[:2]]).intersects(self.r_s):
56
87
  self.counted_ids.append(track_id)
57
88
  # For linear region
58
89
  if dx > 0 and dy > 0:
@@ -64,20 +95,34 @@ class ObjectCounter(BaseSolution):
64
95
 
65
96
  def store_classwise_counts(self, cls):
66
97
  """
67
- Initialize class-wise counts if not already present.
98
+ Initialize class-wise counts for a specific object class if not already present.
68
99
 
69
100
  Args:
70
- cls (int): Class index for classwise count updates
101
+ cls (int): Class index for classwise count updates.
102
+
103
+ This method ensures that the 'classwise_counts' dictionary contains an entry for the specified class,
104
+ initializing 'IN' and 'OUT' counts to zero if the class is not already present.
105
+
106
+ Examples:
107
+ >>> counter = ObjectCounter()
108
+ >>> counter.store_classwise_counts(0) # Initialize counts for class index 0
109
+ >>> print(counter.classwise_counts)
110
+ {'person': {'IN': 0, 'OUT': 0}}
71
111
  """
72
112
  if self.names[cls] not in self.classwise_counts:
73
113
  self.classwise_counts[self.names[cls]] = {"IN": 0, "OUT": 0}
74
114
 
75
115
  def display_counts(self, im0):
76
116
  """
77
- Helper function to display object counts on the frame.
117
+ Displays object counts on the input image or frame.
78
118
 
79
119
  Args:
80
- im0 (ndarray): The input image or frame
120
+ im0 (numpy.ndarray): The input image or frame to display counts on.
121
+
122
+ Examples:
123
+ >>> counter = ObjectCounter()
124
+ >>> frame = cv2.imread("image.jpg")
125
+ >>> counter.display_counts(frame)
81
126
  """
82
127
  labels_dict = {
83
128
  str.capitalize(key): f"{'IN ' + str(value['IN']) if self.show_in else ''} "
@@ -91,12 +136,21 @@ class ObjectCounter(BaseSolution):
91
136
 
92
137
  def count(self, im0):
93
138
  """
94
- Processes input data (frames or object tracks) and updates counts.
139
+ Processes input data (frames or object tracks) and updates object counts.
140
+
141
+ This method initializes the counting region, extracts tracks, draws bounding boxes and regions, updates
142
+ object counts, and displays the results on the input image.
95
143
 
96
144
  Args:
97
- im0 (ndarray): The input image that will be used for processing
98
- Returns
99
- im0 (ndarray): The processed image for more usage
145
+ im0 (numpy.ndarray): The input image or frame to be processed.
146
+
147
+ Returns:
148
+ (numpy.ndarray): The processed image with annotations and count information.
149
+
150
+ Examples:
151
+ >>> counter = ObjectCounter()
152
+ >>> frame = cv2.imread("path/to/image.jpg")
153
+ >>> processed_frame = counter.count(frame)
100
154
  """
101
155
  if not self.region_initialized:
102
156
  self.initialize_region()
@@ -122,7 +176,9 @@ class ObjectCounter(BaseSolution):
122
176
  )
123
177
 
124
178
  # store previous position of track for object counting
125
- prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
179
+ prev_position = None
180
+ if len(self.track_history[track_id]) > 1:
181
+ prev_position = self.track_history[track_id][-2]
126
182
  self.count_objects(self.track_line, box, track_id, prev_position, cls) # Perform object counting
127
183
 
128
184
  self.display_counts(im0) # Display the counts on the frame