ultralytics 8.3.15__py3-none-any.whl → 8.3.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tests/test_solutions.py CHANGED
@@ -17,10 +17,15 @@ def test_major_solutions():
17
17
  cap = cv2.VideoCapture("solutions_ci_demo.mp4")
18
18
  assert cap.isOpened(), "Error reading video file"
19
19
  region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)]
20
- counter = solutions.ObjectCounter(region=region_points, model="yolo11n.pt", show=False)
21
- heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False)
22
- speed = solutions.SpeedEstimator(region=region_points, model="yolo11n.pt", show=False)
23
- queue = solutions.QueueManager(region=region_points, model="yolo11n.pt", show=False)
20
+ counter = solutions.ObjectCounter(region=region_points, model="yolo11n.pt", show=False) # Test object counter
21
+ heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False) # Test heatmaps
22
+ speed = solutions.SpeedEstimator(region=region_points, model="yolo11n.pt", show=False) # Test queue manager
23
+ queue = solutions.QueueManager(region=region_points, model="yolo11n.pt", show=False) # Test speed estimation
24
+ line_analytics = solutions.Analytics(analytics_type="line", model="yolo11n.pt", show=False) # line analytics
25
+ pie_analytics = solutions.Analytics(analytics_type="pie", model="yolo11n.pt", show=False) # line analytics
26
+ bar_analytics = solutions.Analytics(analytics_type="bar", model="yolo11n.pt", show=False) # line analytics
27
+ area_analytics = solutions.Analytics(analytics_type="area", model="yolo11n.pt", show=False) # line analytics
28
+ frame_count = 0 # Required for analytics
24
29
  while cap.isOpened():
25
30
  success, im0 = cap.read()
26
31
  if not success:
@@ -30,24 +35,23 @@ def test_major_solutions():
30
35
  _ = heatmap.generate_heatmap(original_im0.copy())
31
36
  _ = speed.estimate_speed(original_im0.copy())
32
37
  _ = queue.process_queue(original_im0.copy())
38
+ _ = line_analytics.process_data(original_im0.copy(), frame_count)
39
+ _ = pie_analytics.process_data(original_im0.copy(), frame_count)
40
+ _ = bar_analytics.process_data(original_im0.copy(), frame_count)
41
+ _ = area_analytics.process_data(original_im0.copy(), frame_count)
33
42
  cap.release()
34
- cv2.destroyAllWindows()
35
-
36
43
 
37
- @pytest.mark.slow
38
- def test_aigym():
39
- """Test the workouts monitoring solution."""
44
+ # Test workouts monitoring
40
45
  safe_download(url=WORKOUTS_SOLUTION_DEMO)
41
- cap = cv2.VideoCapture("solution_ci_pose_demo.mp4")
42
- assert cap.isOpened(), "Error reading video file"
43
- gym = solutions.AIGym(line_width=2, kpts=[5, 11, 13])
44
- while cap.isOpened():
45
- success, im0 = cap.read()
46
+ cap1 = cv2.VideoCapture("solution_ci_pose_demo.mp4")
47
+ assert cap1.isOpened(), "Error reading video file"
48
+ gym = solutions.AIGym(line_width=2, kpts=[5, 11, 13], show=False)
49
+ while cap1.isOpened():
50
+ success, im0 = cap1.read()
46
51
  if not success:
47
52
  break
48
53
  _ = gym.monitor(im0)
49
- cap.release()
50
- cv2.destroyAllWindows()
54
+ cap1.release()
51
55
 
52
56
 
53
57
  @pytest.mark.slow
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- __version__ = "8.3.15"
3
+ __version__ = "8.3.17"
4
4
 
5
5
  import os
6
6
 
@@ -438,34 +438,60 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None):
438
438
 
439
439
  def merge_equals_args(args: List[str]) -> List[str]:
440
440
  """
441
- Merges arguments around isolated '=' in a list of strings, handling three cases:
442
- 1. ['arg', '=', 'val'] becomes ['arg=val'],
443
- 2. ['arg=', 'val'] becomes ['arg=val'],
444
- 3. ['arg', '=val'] becomes ['arg=val'].
441
+ Merges arguments around isolated '=' in a list of strings and joins fragments with brackets.
442
+
443
+ This function handles the following cases:
444
+ 1. ['arg', '=', 'val'] becomes ['arg=val']
445
+ 2. ['arg=', 'val'] becomes ['arg=val']
446
+ 3. ['arg', '=val'] becomes ['arg=val']
447
+ 4. Joins fragments with brackets, e.g., ['imgsz=[3,', '640,', '640]'] becomes ['imgsz=[3,640,640]']
445
448
 
446
449
  Args:
447
- args (List[str]): A list of strings where each element represents an argument.
450
+ args (List[str]): A list of strings where each element represents an argument or fragment.
448
451
 
449
452
  Returns:
450
- (List[str]): A list of strings where the arguments around isolated '=' are merged.
453
+ List[str]: A list of strings where the arguments around isolated '=' are merged and fragments with brackets are joined.
451
454
 
452
455
  Examples:
453
- >>> args = ["arg1", "=", "value", "arg2=", "value2", "arg3", "=value3"]
454
- >>> merge_equals_args(args)
455
- ['arg1=value', 'arg2=value2', 'arg3=value3']
456
+ >>> args = ["arg1", "=", "value", "arg2=", "value2", "arg3", "=value3", "imgsz=[3,", "640,", "640]"]
457
+ >>> merge_and_join_args(args)
458
+ ['arg1=value', 'arg2=value2', 'arg3=value3', 'imgsz=[3,640,640]']
456
459
  """
457
460
  new_args = []
458
- for i, arg in enumerate(args):
461
+ current = ""
462
+ depth = 0
463
+
464
+ i = 0
465
+ while i < len(args):
466
+ arg = args[i]
467
+
468
+ # Handle equals sign merging
459
469
  if arg == "=" and 0 < i < len(args) - 1: # merge ['arg', '=', 'val']
460
470
  new_args[-1] += f"={args[i + 1]}"
461
- del args[i + 1]
471
+ i += 2
472
+ continue
462
473
  elif arg.endswith("=") and i < len(args) - 1 and "=" not in args[i + 1]: # merge ['arg=', 'val']
463
474
  new_args.append(f"{arg}{args[i + 1]}")
464
- del args[i + 1]
475
+ i += 2
476
+ continue
465
477
  elif arg.startswith("=") and i > 0: # merge ['arg', '=val']
466
478
  new_args[-1] += arg
467
- else:
468
- new_args.append(arg)
479
+ i += 1
480
+ continue
481
+
482
+ # Handle bracket joining
483
+ depth += arg.count("[") - arg.count("]")
484
+ current += arg
485
+ if depth == 0:
486
+ new_args.append(current)
487
+ current = ""
488
+
489
+ i += 1
490
+
491
+ # Append any remaining current string
492
+ if current:
493
+ new_args.append(current)
494
+
469
495
  return new_args
470
496
 
471
497
 
@@ -15,3 +15,4 @@ down_angle: 90 # Workouts down_angle for counts, 90 is default value. You can ch
15
15
  kpts: [6, 8, 10] # Keypoints for workouts monitoring, i.e. If you want to consider keypoints for pushups that have mostly values of [6, 8, 10].
16
16
  colormap: # Colormap for heatmap, Only OPENCV supported colormaps can be used. By default COLORMAP_PARULA will be used for visualization.
17
17
  analytics_type: "line" # Analytics type i.e "line", "pie", "bar" or "area" charts. By default, "line" analytics will be used for processing.
18
+ json_file: # parking system regions file path.
@@ -13,9 +13,6 @@ from tqdm import tqdm
13
13
  from ultralytics.data.utils import exif_size, img2label_paths
14
14
  from ultralytics.utils.checks import check_requirements
15
15
 
16
- check_requirements("shapely")
17
- from shapely.geometry import Polygon
18
-
19
16
 
20
17
  def bbox_iof(polygon1, bbox2, eps=1e-6):
21
18
  """
@@ -33,6 +30,9 @@ def bbox_iof(polygon1, bbox2, eps=1e-6):
33
30
  Polygon format: [x1, y1, x2, y2, x3, y3, x4, y4].
34
31
  Bounding box format: [x_min, y_min, x_max, y_max].
35
32
  """
33
+ check_requirements("shapely")
34
+ from shapely.geometry import Polygon
35
+
36
36
  polygon1 = polygon1.reshape(-1, 4, 2)
37
37
  lt_point = np.min(polygon1, axis=-2) # left-top
38
38
  rb_point = np.max(polygon1, axis=-2) # right-bottom
@@ -28,6 +28,7 @@ class Detect(nn.Module):
28
28
  shape = None
29
29
  anchors = torch.empty(0) # init
30
30
  strides = torch.empty(0) # init
31
+ legacy = False # backward compatibility for v3/v5/v8/v9 models
31
32
 
32
33
  def __init__(self, nc=80, ch=()):
33
34
  """Initializes the YOLO detection layer with specified number of classes and channels."""
@@ -41,13 +42,17 @@ class Detect(nn.Module):
41
42
  self.cv2 = nn.ModuleList(
42
43
  nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
43
44
  )
44
- self.cv3 = nn.ModuleList(
45
- nn.Sequential(
46
- nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
47
- nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
48
- nn.Conv2d(c3, self.nc, 1),
45
+ self.cv3 = (
46
+ nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
47
+ if self.legacy
48
+ else nn.ModuleList(
49
+ nn.Sequential(
50
+ nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
51
+ nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
52
+ nn.Conv2d(c3, self.nc, 1),
53
+ )
54
+ for x in ch
49
55
  )
50
- for x in ch
51
56
  )
52
57
  self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
53
58
 
ultralytics/nn/tasks.py CHANGED
@@ -936,6 +936,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
936
936
  import ast
937
937
 
938
938
  # Args
939
+ legacy = True # backward compatibility for v3/v5/v8/v9 models
939
940
  max_channels = float("inf")
940
941
  nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
941
942
  depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
@@ -1027,8 +1028,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
1027
1028
  }:
1028
1029
  args.insert(2, n) # number of repeats
1029
1030
  n = 1
1030
- if m is C3k2 and scale in "mlx": # for M/L/X sizes
1031
- args[3] = True
1031
+ if m is C3k2: # for M/L/X sizes
1032
+ legacy = False
1033
+ if scale in "mlx":
1034
+ args[3] = True
1032
1035
  elif m is AIFI:
1033
1036
  args = [ch[f], *args]
1034
1037
  elif m in {HGStem, HGBlock}:
@@ -1047,6 +1050,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
1047
1050
  args.append([ch[x] for x in f])
1048
1051
  if m is Segment:
1049
1052
  args[2] = make_divisible(min(args[2], max_channels) * width, 8)
1053
+ if m in {Detect, Segment, Pose, OBB}:
1054
+ m.legacy = legacy
1050
1055
  elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
1051
1056
  args.insert(1, [ch[x] for x in f])
1052
1057
  elif m is CBLinear:
@@ -1,16 +1,40 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- from ultralytics.solutions.solutions import BaseSolution # Import a parent class
3
+ from ultralytics.solutions.solutions import BaseSolution
4
4
  from ultralytics.utils.plotting import Annotator
5
5
 
6
6
 
7
7
  class AIGym(BaseSolution):
8
- """A class to manage the gym steps of people in a real-time video stream based on their poses."""
8
+ """
9
+ A class to manage gym steps of people in a real-time video stream based on their poses.
10
+
11
+ This class extends BaseSolution to monitor workouts using YOLO pose estimation models. It tracks and counts
12
+ repetitions of exercises based on predefined angle thresholds for up and down positions.
13
+
14
+ Attributes:
15
+ count (List[int]): Repetition counts for each detected person.
16
+ angle (List[float]): Current angle of the tracked body part for each person.
17
+ stage (List[str]): Current exercise stage ('up', 'down', or '-') for each person.
18
+ initial_stage (str | None): Initial stage of the exercise.
19
+ up_angle (float): Angle threshold for considering the 'up' position of an exercise.
20
+ down_angle (float): Angle threshold for considering the 'down' position of an exercise.
21
+ kpts (List[int]): Indices of keypoints used for angle calculation.
22
+ lw (int): Line width for drawing annotations.
23
+ annotator (Annotator): Object for drawing annotations on the image.
24
+
25
+ Methods:
26
+ monitor: Processes a frame to detect poses, calculate angles, and count repetitions.
27
+
28
+ Examples:
29
+ >>> gym = AIGym(model="yolov8n-pose.pt")
30
+ >>> image = cv2.imread("gym_scene.jpg")
31
+ >>> processed_image = gym.monitor(image)
32
+ >>> cv2.imshow("Processed Image", processed_image)
33
+ >>> cv2.waitKey(0)
34
+ """
9
35
 
10
36
  def __init__(self, **kwargs):
11
- """Initialization function for AiGYM class, a child class of BaseSolution class, can be used for workouts
12
- monitoring.
13
- """
37
+ """Initializes AIGym for workout monitoring using pose estimation and predefined angles."""
14
38
  # Check if the model name ends with '-pose'
15
39
  if "model" in kwargs and "-pose" not in kwargs["model"]:
16
40
  kwargs["model"] = "yolo11n-pose.pt"
@@ -31,12 +55,22 @@ class AIGym(BaseSolution):
31
55
 
32
56
  def monitor(self, im0):
33
57
  """
34
- Monitor the workouts using Ultralytics YOLO Pose Model: https://docs.ultralytics.com/tasks/pose/.
58
+ Monitors workouts using Ultralytics YOLO Pose Model.
59
+
60
+ This function processes an input image to track and analyze human poses for workout monitoring. It uses
61
+ the YOLO Pose model to detect keypoints, estimate angles, and count repetitions based on predefined
62
+ angle thresholds.
35
63
 
36
64
  Args:
37
- im0 (ndarray): The input image that will be used for processing
38
- Returns
39
- im0 (ndarray): The processed image for more usage
65
+ im0 (ndarray): Input image for processing.
66
+
67
+ Returns:
68
+ (ndarray): Processed image with annotations for workout monitoring.
69
+
70
+ Examples:
71
+ >>> gym = AIGym()
72
+ >>> image = cv2.imread("workout.jpg")
73
+ >>> processed_image = gym.monitor(image)
40
74
  """
41
75
  # Extract tracks
42
76
  tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"])[0]
@@ -12,10 +12,41 @@ from ultralytics.solutions.solutions import BaseSolution # Import a parent clas
12
12
 
13
13
 
14
14
  class Analytics(BaseSolution):
15
- """A class to create and update various types of charts (line, bar, pie, area) for visual analytics."""
15
+ """
16
+ A class for creating and updating various types of charts for visual analytics.
17
+
18
+ This class extends BaseSolution to provide functionality for generating line, bar, pie, and area charts
19
+ based on object detection and tracking data.
20
+
21
+ Attributes:
22
+ type (str): The type of analytics chart to generate ('line', 'bar', 'pie', or 'area').
23
+ x_label (str): Label for the x-axis.
24
+ y_label (str): Label for the y-axis.
25
+ bg_color (str): Background color of the chart frame.
26
+ fg_color (str): Foreground color of the chart frame.
27
+ title (str): Title of the chart window.
28
+ max_points (int): Maximum number of data points to display on the chart.
29
+ fontsize (int): Font size for text display.
30
+ color_cycle (cycle): Cyclic iterator for chart colors.
31
+ total_counts (int): Total count of detected objects (used for line charts).
32
+ clswise_count (Dict[str, int]): Dictionary for class-wise object counts.
33
+ fig (Figure): Matplotlib figure object for the chart.
34
+ ax (Axes): Matplotlib axes object for the chart.
35
+ canvas (FigureCanvas): Canvas for rendering the chart.
36
+
37
+ Methods:
38
+ process_data: Processes image data and updates the chart.
39
+ update_graph: Updates the chart with new data points.
40
+
41
+ Examples:
42
+ >>> analytics = Analytics(analytics_type="line")
43
+ >>> frame = cv2.imread("image.jpg")
44
+ >>> processed_frame = analytics.process_data(frame, frame_number=1)
45
+ >>> cv2.imshow("Analytics", processed_frame)
46
+ """
16
47
 
17
48
  def __init__(self, **kwargs):
18
- """Initialize the Analytics class with various chart types."""
49
+ """Initialize Analytics class with various chart types for visual data representation."""
19
50
  super().__init__(**kwargs)
20
51
 
21
52
  self.type = self.CFG["analytics_type"] # extract type of analytics
@@ -31,8 +62,8 @@ class Analytics(BaseSolution):
31
62
  figsize = (19.2, 10.8) # Set output image size 1920 * 1080
32
63
  self.color_cycle = cycle(["#DD00BA", "#042AFF", "#FF4447", "#7D24FF", "#BD00FF"])
33
64
 
34
- self.total_counts = 0 # count variable for storing total counts i.e for line
35
- self.clswise_count = {} # dictionary for classwise counts
65
+ self.total_counts = 0 # count variable for storing total counts i.e. for line
66
+ self.clswise_count = {} # dictionary for class-wise counts
36
67
 
37
68
  # Ensure line and area chart
38
69
  if self.type in {"line", "area"}:
@@ -48,15 +79,28 @@ class Analytics(BaseSolution):
48
79
  self.canvas = FigureCanvas(self.fig) # Set common axis properties
49
80
  self.ax.set_facecolor(self.bg_color)
50
81
  self.color_mapping = {}
51
- self.ax.axis("equal") if self.type == "pie" else None # Ensure pie chart is circular
82
+
83
+ if self.type == "pie": # Ensure pie chart is circular
84
+ self.ax.axis("equal")
52
85
 
53
86
  def process_data(self, im0, frame_number):
54
87
  """
55
- Process the image data, run object tracking.
88
+ Processes image data and runs object tracking to update analytics charts.
56
89
 
57
90
  Args:
58
- im0 (ndarray): Input image for processing.
59
- frame_number (int): Video frame # for plotting the data.
91
+ im0 (np.ndarray): Input image for processing.
92
+ frame_number (int): Video frame number for plotting the data.
93
+
94
+ Returns:
95
+ (np.ndarray): Processed image with updated analytics chart.
96
+
97
+ Raises:
98
+ ModuleNotFoundError: If an unsupported chart type is specified.
99
+
100
+ Examples:
101
+ >>> analytics = Analytics(analytics_type="line")
102
+ >>> frame = np.zeros((480, 640, 3), dtype=np.uint8)
103
+ >>> processed_frame = analytics.process_data(frame, frame_number=1)
60
104
  """
61
105
  self.extract_tracks(im0) # Extract tracks
62
106
 
@@ -79,13 +123,22 @@ class Analytics(BaseSolution):
79
123
 
80
124
  def update_graph(self, frame_number, count_dict=None, plot="line"):
81
125
  """
82
- Update the graph (line or area) with new data for single or multiple classes.
126
+ Updates the graph with new data for single or multiple classes.
83
127
 
84
128
  Args:
85
129
  frame_number (int): The current frame number.
86
- count_dict (dict, optional): Dictionary with class names as keys and counts as values for multiple classes.
87
- If None, updates a single line graph.
88
- plot (str): Type of the plot i.e. line, bar or area.
130
+ count_dict (Dict[str, int] | None): Dictionary with class names as keys and counts as values for multiple
131
+ classes. If None, updates a single line graph.
132
+ plot (str): Type of the plot. Options are 'line', 'bar', 'pie', or 'area'.
133
+
134
+ Returns:
135
+ (np.ndarray): Updated image containing the graph.
136
+
137
+ Examples:
138
+ >>> analytics = Analytics()
139
+ >>> frame_number = 10
140
+ >>> count_dict = {"person": 5, "car": 3}
141
+ >>> updated_image = analytics.update_graph(frame_number, count_dict, plot="bar")
89
142
  """
90
143
  if count_dict is None:
91
144
  # Single line update
@@ -4,15 +4,41 @@ import math
4
4
 
5
5
  import cv2
6
6
 
7
- from ultralytics.solutions.solutions import BaseSolution # Import a parent class
7
+ from ultralytics.solutions.solutions import BaseSolution
8
8
  from ultralytics.utils.plotting import Annotator, colors
9
9
 
10
10
 
11
11
  class DistanceCalculation(BaseSolution):
12
- """A class to calculate distance between two objects in a real-time video stream based on their tracks."""
12
+ """
13
+ A class to calculate distance between two objects in a real-time video stream based on their tracks.
14
+
15
+ This class extends BaseSolution to provide functionality for selecting objects and calculating the distance
16
+ between them in a video stream using YOLO object detection and tracking.
17
+
18
+ Attributes:
19
+ left_mouse_count (int): Counter for left mouse button clicks.
20
+ selected_boxes (Dict[int, List[float]]): Dictionary to store selected bounding boxes and their track IDs.
21
+ annotator (Annotator): An instance of the Annotator class for drawing on the image.
22
+ boxes (List[List[float]]): List of bounding boxes for detected objects.
23
+ track_ids (List[int]): List of track IDs for detected objects.
24
+ clss (List[int]): List of class indices for detected objects.
25
+ names (List[str]): List of class names that the model can detect.
26
+ centroids (List[List[int]]): List to store centroids of selected bounding boxes.
27
+
28
+ Methods:
29
+ mouse_event_for_distance: Handles mouse events for selecting objects in the video stream.
30
+ calculate: Processes video frames and calculates the distance between selected objects.
31
+
32
+ Examples:
33
+ >>> distance_calc = DistanceCalculation()
34
+ >>> frame = cv2.imread("frame.jpg")
35
+ >>> processed_frame = distance_calc.calculate(frame)
36
+ >>> cv2.imshow("Distance Calculation", processed_frame)
37
+ >>> cv2.waitKey(0)
38
+ """
13
39
 
14
40
  def __init__(self, **kwargs):
15
- """Initializes the DistanceCalculation class with the given parameters."""
41
+ """Initializes the DistanceCalculation class for measuring object distances in video streams."""
16
42
  super().__init__(**kwargs)
17
43
 
18
44
  # Mouse event information
@@ -21,14 +47,18 @@ class DistanceCalculation(BaseSolution):
21
47
 
22
48
  def mouse_event_for_distance(self, event, x, y, flags, param):
23
49
  """
24
- Handles mouse events to select regions in a real-time video stream.
50
+ Handles mouse events to select regions in a real-time video stream for distance calculation.
25
51
 
26
52
  Args:
27
- event (int): Type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN, etc.).
53
+ event (int): Type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN).
28
54
  x (int): X-coordinate of the mouse pointer.
29
55
  y (int): Y-coordinate of the mouse pointer.
30
- flags (int): Flags associated with the event (e.g., cv2.EVENT_FLAG_CTRLKEY, cv2.EVENT_FLAG_SHIFTKEY, etc.).
31
- param (dict): Additional parameters passed to the function.
56
+ flags (int): Flags associated with the event (e.g., cv2.EVENT_FLAG_CTRLKEY, cv2.EVENT_FLAG_SHIFTKEY).
57
+ param (Dict): Additional parameters passed to the function.
58
+
59
+ Examples:
60
+ >>> # Assuming 'dc' is an instance of DistanceCalculation
61
+ >>> cv2.setMouseCallback("window_name", dc.mouse_event_for_distance)
32
62
  """
33
63
  if event == cv2.EVENT_LBUTTONDOWN:
34
64
  self.left_mouse_count += 1
@@ -43,13 +73,23 @@ class DistanceCalculation(BaseSolution):
43
73
 
44
74
  def calculate(self, im0):
45
75
  """
46
- Processes the video frame and calculates the distance between two bounding boxes.
76
+ Processes a video frame and calculates the distance between two selected bounding boxes.
77
+
78
+ This method extracts tracks from the input frame, annotates bounding boxes, and calculates the distance
79
+ between two user-selected objects if they have been chosen.
47
80
 
48
81
  Args:
49
- im0 (ndarray): The image frame.
82
+ im0 (numpy.ndarray): The input image frame to process.
50
83
 
51
84
  Returns:
52
- (ndarray): The processed image frame.
85
+ (numpy.ndarray): The processed image frame with annotations and distance calculations.
86
+
87
+ Examples:
88
+ >>> import numpy as np
89
+ >>> from ultralytics.solutions import DistanceCalculation
90
+ >>> dc = DistanceCalculation()
91
+ >>> frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
92
+ >>> processed_frame = dc.calculate(frame)
53
93
  """
54
94
  self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator
55
95
  self.extract_tracks(im0) # Extract tracks
@@ -3,15 +3,40 @@
3
3
  import cv2
4
4
  import numpy as np
5
5
 
6
- from ultralytics.solutions.object_counter import ObjectCounter # Import object counter class
6
+ from ultralytics.solutions.object_counter import ObjectCounter
7
7
  from ultralytics.utils.plotting import Annotator
8
8
 
9
9
 
10
10
  class Heatmap(ObjectCounter):
11
- """A class to draw heatmaps in real-time video stream based on their tracks."""
11
+ """
12
+ A class to draw heatmaps in real-time video streams based on object tracks.
13
+
14
+ This class extends the ObjectCounter class to generate and visualize heatmaps of object movements in video
15
+ streams. It uses tracked object positions to create a cumulative heatmap effect over time.
16
+
17
+ Attributes:
18
+ initialized (bool): Flag indicating whether the heatmap has been initialized.
19
+ colormap (int): OpenCV colormap used for heatmap visualization.
20
+ heatmap (np.ndarray): Array storing the cumulative heatmap data.
21
+ annotator (Annotator): Object for drawing annotations on the image.
22
+
23
+ Methods:
24
+ heatmap_effect: Calculates and updates the heatmap effect for a given bounding box.
25
+ generate_heatmap: Generates and applies the heatmap effect to each frame.
26
+
27
+ Examples:
28
+ >>> from ultralytics.solutions import Heatmap
29
+ >>> heatmap = Heatmap(model="yolov8n.pt", colormap=cv2.COLORMAP_JET)
30
+ >>> results = heatmap("path/to/video.mp4")
31
+ >>> for result in results:
32
+ ... print(result.speed) # Print inference speed
33
+ ... cv2.imshow("Heatmap", result.plot())
34
+ ... if cv2.waitKey(1) & 0xFF == ord("q"):
35
+ ... break
36
+ """
12
37
 
13
38
  def __init__(self, **kwargs):
14
- """Initializes function for heatmap class with default values."""
39
+ """Initializes the Heatmap class for real-time video stream heatmap generation based on object tracks."""
15
40
  super().__init__(**kwargs)
16
41
 
17
42
  self.initialized = False # bool variable for heatmap initialization
@@ -23,10 +48,15 @@ class Heatmap(ObjectCounter):
23
48
 
24
49
  def heatmap_effect(self, box):
25
50
  """
26
- Efficient calculation of heatmap area and effect location for applying colormap.
51
+ Efficiently calculates heatmap area and effect location for applying colormap.
27
52
 
28
53
  Args:
29
- box (list): Bounding Box coordinates data [x0, y0, x1, y1]
54
+ box (List[float]): Bounding box coordinates [x0, y0, x1, y1].
55
+
56
+ Examples:
57
+ >>> heatmap = Heatmap()
58
+ >>> box = [100, 100, 200, 200]
59
+ >>> heatmap.heatmap_effect(box)
30
60
  """
31
61
  x0, y0, x1, y1 = map(int, box)
32
62
  radius_squared = (min(x1 - x0, y1 - y0) // 2) ** 2
@@ -48,9 +78,15 @@ class Heatmap(ObjectCounter):
48
78
  Generate heatmap for each frame using Ultralytics.
49
79
 
50
80
  Args:
51
- im0 (ndarray): Input image array for processing
81
+ im0 (np.ndarray): Input image array for processing.
82
+
52
83
  Returns:
53
- im0 (ndarray): Processed image for further usage
84
+ (np.ndarray): Processed image with heatmap overlay and object counts (if region is specified).
85
+
86
+ Examples:
87
+ >>> heatmap = Heatmap()
88
+ >>> im0 = cv2.imread("image.jpg")
89
+ >>> result = heatmap.generate_heatmap(im0)
54
90
  """
55
91
  if not self.initialized:
56
92
  self.heatmap = np.zeros_like(im0, dtype=np.float32) * 0.99
@@ -70,16 +106,17 @@ class Heatmap(ObjectCounter):
70
106
  self.store_classwise_counts(cls) # store classwise counts in dict
71
107
 
72
108
  # Store tracking previous position and perform object counting
73
- prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
109
+ prev_position = None
110
+ if len(self.track_history[track_id]) > 1:
111
+ prev_position = self.track_history[track_id][-2]
74
112
  self.count_objects(self.track_line, box, track_id, prev_position, cls) # Perform object counting
75
113
 
76
- self.display_counts(im0) if self.region is not None else None # Display the counts on the frame
114
+ if self.region is not None:
115
+ self.display_counts(im0) # Display the counts on the frame
77
116
 
78
117
  # Normalize, apply colormap to heatmap and combine with original image
79
- im0 = (
80
- im0
81
- if self.track_data.id is None
82
- else cv2.addWeighted(
118
+ if self.track_data.id is not None:
119
+ im0 = cv2.addWeighted(
83
120
  im0,
84
121
  0.5,
85
122
  cv2.applyColorMap(
@@ -88,7 +125,6 @@ class Heatmap(ObjectCounter):
88
125
  0.5,
89
126
  0,
90
127
  )
91
- )
92
128
 
93
129
  self.display_output(im0) # display output with base class function
94
130
  return im0 # return output image for more usage