ultralytics 8.3.14__py3-none-any.whl → 8.3.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tests/test_solutions.py CHANGED
@@ -17,10 +17,15 @@ def test_major_solutions():
17
17
  cap = cv2.VideoCapture("solutions_ci_demo.mp4")
18
18
  assert cap.isOpened(), "Error reading video file"
19
19
  region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)]
20
- counter = solutions.ObjectCounter(region=region_points, model="yolo11n.pt", show=False)
21
- heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False)
22
- speed = solutions.SpeedEstimator(region=region_points, model="yolo11n.pt", show=False)
23
- queue = solutions.QueueManager(region=region_points, model="yolo11n.pt", show=False)
20
+ counter = solutions.ObjectCounter(region=region_points, model="yolo11n.pt", show=False) # Test object counter
21
+ heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False) # Test heatmaps
22
+ speed = solutions.SpeedEstimator(region=region_points, model="yolo11n.pt", show=False) # Test queue manager
23
+ queue = solutions.QueueManager(region=region_points, model="yolo11n.pt", show=False) # Test speed estimation
24
+ line_analytics = solutions.Analytics(analytics_type="line", model="yolo11n.pt", show=False) # line analytics
25
+ pie_analytics = solutions.Analytics(analytics_type="pie", model="yolo11n.pt", show=False) # line analytics
26
+ bar_analytics = solutions.Analytics(analytics_type="bar", model="yolo11n.pt", show=False) # line analytics
27
+ area_analytics = solutions.Analytics(analytics_type="area", model="yolo11n.pt", show=False) # line analytics
28
+ frame_count = 0 # Required for analytics
24
29
  while cap.isOpened():
25
30
  success, im0 = cap.read()
26
31
  if not success:
@@ -30,24 +35,23 @@ def test_major_solutions():
30
35
  _ = heatmap.generate_heatmap(original_im0.copy())
31
36
  _ = speed.estimate_speed(original_im0.copy())
32
37
  _ = queue.process_queue(original_im0.copy())
38
+ _ = line_analytics.process_data(original_im0.copy(), frame_count)
39
+ _ = pie_analytics.process_data(original_im0.copy(), frame_count)
40
+ _ = bar_analytics.process_data(original_im0.copy(), frame_count)
41
+ _ = area_analytics.process_data(original_im0.copy(), frame_count)
33
42
  cap.release()
34
- cv2.destroyAllWindows()
35
-
36
43
 
37
- @pytest.mark.slow
38
- def test_aigym():
39
- """Test the workouts monitoring solution."""
44
+ # Test workouts monitoring
40
45
  safe_download(url=WORKOUTS_SOLUTION_DEMO)
41
- cap = cv2.VideoCapture("solution_ci_pose_demo.mp4")
42
- assert cap.isOpened(), "Error reading video file"
43
- gym = solutions.AIGym(line_width=2, kpts=[5, 11, 13])
44
- while cap.isOpened():
45
- success, im0 = cap.read()
46
+ cap1 = cv2.VideoCapture("solution_ci_pose_demo.mp4")
47
+ assert cap1.isOpened(), "Error reading video file"
48
+ gym = solutions.AIGym(line_width=2, kpts=[5, 11, 13], show=False)
49
+ while cap1.isOpened():
50
+ success, im0 = cap1.read()
46
51
  if not success:
47
52
  break
48
53
  _ = gym.monitor(im0)
49
- cap.release()
50
- cv2.destroyAllWindows()
54
+ cap1.release()
51
55
 
52
56
 
53
57
  @pytest.mark.slow
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- __version__ = "8.3.14"
3
+ __version__ = "8.3.16"
4
4
 
5
5
  import os
6
6
 
@@ -15,3 +15,4 @@ down_angle: 90 # Workouts down_angle for counts, 90 is default value. You can ch
15
15
  kpts: [6, 8, 10] # Keypoints for workouts monitoring, i.e. If you want to consider keypoints for pushups that have mostly values of [6, 8, 10].
16
16
  colormap: # Colormap for heatmap, Only OPENCV supported colormaps can be used. By default COLORMAP_PARULA will be used for visualization.
17
17
  analytics_type: "line" # Analytics type i.e "line", "pie", "bar" or "area" charts. By default, "line" analytics will be used for processing.
18
+ json_file: # parking system regions file path.
@@ -13,9 +13,6 @@ from tqdm import tqdm
13
13
  from ultralytics.data.utils import exif_size, img2label_paths
14
14
  from ultralytics.utils.checks import check_requirements
15
15
 
16
- check_requirements("shapely")
17
- from shapely.geometry import Polygon
18
-
19
16
 
20
17
  def bbox_iof(polygon1, bbox2, eps=1e-6):
21
18
  """
@@ -33,6 +30,9 @@ def bbox_iof(polygon1, bbox2, eps=1e-6):
33
30
  Polygon format: [x1, y1, x2, y2, x3, y3, x4, y4].
34
31
  Bounding box format: [x_min, y_min, x_max, y_max].
35
32
  """
33
+ check_requirements("shapely")
34
+ from shapely.geometry import Polygon
35
+
36
36
  polygon1 = polygon1.reshape(-1, 4, 2)
37
37
  lt_point = np.min(polygon1, axis=-2) # left-top
38
38
  rb_point = np.max(polygon1, axis=-2) # right-bottom
@@ -398,7 +398,7 @@ class Exporter:
398
398
  """YOLO ONNX export."""
399
399
  requirements = ["onnx>=1.12.0"]
400
400
  if self.args.simplify:
401
- requirements += ["onnxslim==0.1.34", "onnxruntime" + ("-gpu" if torch.cuda.is_available() else "")]
401
+ requirements += ["onnxslim", "onnxruntime" + ("-gpu" if torch.cuda.is_available() else "")]
402
402
  check_requirements(requirements)
403
403
  import onnx # noqa
404
404
 
@@ -126,7 +126,7 @@ class AutoBackend(nn.Module):
126
126
  fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
127
127
  nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
128
128
  stride = 32 # default stride
129
- model, metadata = None, None
129
+ model, metadata, task = None, None, None
130
130
 
131
131
  # Set device
132
132
  cuda = torch.cuda.is_available() and device.type != "cpu" # use CUDA
@@ -336,11 +336,15 @@ class AutoBackend(nn.Module):
336
336
 
337
337
  Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate
338
338
  if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
339
- LOGGER.info(f"Loading {w} for TensorFlow Lite Edge TPU inference...")
339
+ device = device[3:] if str(device).startswith("tpu") else ":0"
340
+ LOGGER.info(f"Loading {w} on device {device[1:]} for TensorFlow Lite Edge TPU inference...")
340
341
  delegate = {"Linux": "libedgetpu.so.1", "Darwin": "libedgetpu.1.dylib", "Windows": "edgetpu.dll"}[
341
342
  platform.system()
342
343
  ]
343
- interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
344
+ interpreter = Interpreter(
345
+ model_path=w,
346
+ experimental_delegates=[load_delegate(delegate, options={"device": device})],
347
+ )
344
348
  else: # TFLite
345
349
  LOGGER.info(f"Loading {w} for TensorFlow Lite inference...")
346
350
  interpreter = Interpreter(model_path=w) # load TFLite model
@@ -1,16 +1,40 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- from ultralytics.solutions.solutions import BaseSolution # Import a parent class
3
+ from ultralytics.solutions.solutions import BaseSolution
4
4
  from ultralytics.utils.plotting import Annotator
5
5
 
6
6
 
7
7
  class AIGym(BaseSolution):
8
- """A class to manage the gym steps of people in a real-time video stream based on their poses."""
8
+ """
9
+ A class to manage gym steps of people in a real-time video stream based on their poses.
10
+
11
+ This class extends BaseSolution to monitor workouts using YOLO pose estimation models. It tracks and counts
12
+ repetitions of exercises based on predefined angle thresholds for up and down positions.
13
+
14
+ Attributes:
15
+ count (List[int]): Repetition counts for each detected person.
16
+ angle (List[float]): Current angle of the tracked body part for each person.
17
+ stage (List[str]): Current exercise stage ('up', 'down', or '-') for each person.
18
+ initial_stage (str | None): Initial stage of the exercise.
19
+ up_angle (float): Angle threshold for considering the 'up' position of an exercise.
20
+ down_angle (float): Angle threshold for considering the 'down' position of an exercise.
21
+ kpts (List[int]): Indices of keypoints used for angle calculation.
22
+ lw (int): Line width for drawing annotations.
23
+ annotator (Annotator): Object for drawing annotations on the image.
24
+
25
+ Methods:
26
+ monitor: Processes a frame to detect poses, calculate angles, and count repetitions.
27
+
28
+ Examples:
29
+ >>> gym = AIGym(model="yolov8n-pose.pt")
30
+ >>> image = cv2.imread("gym_scene.jpg")
31
+ >>> processed_image = gym.monitor(image)
32
+ >>> cv2.imshow("Processed Image", processed_image)
33
+ >>> cv2.waitKey(0)
34
+ """
9
35
 
10
36
  def __init__(self, **kwargs):
11
- """Initialization function for AiGYM class, a child class of BaseSolution class, can be used for workouts
12
- monitoring.
13
- """
37
+ """Initializes AIGym for workout monitoring using pose estimation and predefined angles."""
14
38
  # Check if the model name ends with '-pose'
15
39
  if "model" in kwargs and "-pose" not in kwargs["model"]:
16
40
  kwargs["model"] = "yolo11n-pose.pt"
@@ -31,12 +55,22 @@ class AIGym(BaseSolution):
31
55
 
32
56
  def monitor(self, im0):
33
57
  """
34
- Monitor the workouts using Ultralytics YOLOv8 Pose Model: https://docs.ultralytics.com/tasks/pose/.
58
+ Monitors workouts using Ultralytics YOLO Pose Model.
59
+
60
+ This function processes an input image to track and analyze human poses for workout monitoring. It uses
61
+ the YOLO Pose model to detect keypoints, estimate angles, and count repetitions based on predefined
62
+ angle thresholds.
35
63
 
36
64
  Args:
37
- im0 (ndarray): The input image that will be used for processing
38
- Returns
39
- im0 (ndarray): The processed image for more usage
65
+ im0 (ndarray): Input image for processing.
66
+
67
+ Returns:
68
+ (ndarray): Processed image with annotations for workout monitoring.
69
+
70
+ Examples:
71
+ >>> gym = AIGym()
72
+ >>> image = cv2.imread("workout.jpg")
73
+ >>> processed_image = gym.monitor(image)
40
74
  """
41
75
  # Extract tracks
42
76
  tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"])[0]
@@ -12,10 +12,41 @@ from ultralytics.solutions.solutions import BaseSolution # Import a parent clas
12
12
 
13
13
 
14
14
  class Analytics(BaseSolution):
15
- """A class to create and update various types of charts (line, bar, pie, area) for visual analytics."""
15
+ """
16
+ A class for creating and updating various types of charts for visual analytics.
17
+
18
+ This class extends BaseSolution to provide functionality for generating line, bar, pie, and area charts
19
+ based on object detection and tracking data.
20
+
21
+ Attributes:
22
+ type (str): The type of analytics chart to generate ('line', 'bar', 'pie', or 'area').
23
+ x_label (str): Label for the x-axis.
24
+ y_label (str): Label for the y-axis.
25
+ bg_color (str): Background color of the chart frame.
26
+ fg_color (str): Foreground color of the chart frame.
27
+ title (str): Title of the chart window.
28
+ max_points (int): Maximum number of data points to display on the chart.
29
+ fontsize (int): Font size for text display.
30
+ color_cycle (cycle): Cyclic iterator for chart colors.
31
+ total_counts (int): Total count of detected objects (used for line charts).
32
+ clswise_count (Dict[str, int]): Dictionary for class-wise object counts.
33
+ fig (Figure): Matplotlib figure object for the chart.
34
+ ax (Axes): Matplotlib axes object for the chart.
35
+ canvas (FigureCanvas): Canvas for rendering the chart.
36
+
37
+ Methods:
38
+ process_data: Processes image data and updates the chart.
39
+ update_graph: Updates the chart with new data points.
40
+
41
+ Examples:
42
+ >>> analytics = Analytics(analytics_type="line")
43
+ >>> frame = cv2.imread("image.jpg")
44
+ >>> processed_frame = analytics.process_data(frame, frame_number=1)
45
+ >>> cv2.imshow("Analytics", processed_frame)
46
+ """
16
47
 
17
48
  def __init__(self, **kwargs):
18
- """Initialize the Analytics class with various chart types."""
49
+ """Initialize Analytics class with various chart types for visual data representation."""
19
50
  super().__init__(**kwargs)
20
51
 
21
52
  self.type = self.CFG["analytics_type"] # extract type of analytics
@@ -31,8 +62,8 @@ class Analytics(BaseSolution):
31
62
  figsize = (19.2, 10.8) # Set output image size 1920 * 1080
32
63
  self.color_cycle = cycle(["#DD00BA", "#042AFF", "#FF4447", "#7D24FF", "#BD00FF"])
33
64
 
34
- self.total_counts = 0 # count variable for storing total counts i.e for line
35
- self.clswise_count = {} # dictionary for classwise counts
65
+ self.total_counts = 0 # count variable for storing total counts i.e. for line
66
+ self.clswise_count = {} # dictionary for class-wise counts
36
67
 
37
68
  # Ensure line and area chart
38
69
  if self.type in {"line", "area"}:
@@ -48,15 +79,28 @@ class Analytics(BaseSolution):
48
79
  self.canvas = FigureCanvas(self.fig) # Set common axis properties
49
80
  self.ax.set_facecolor(self.bg_color)
50
81
  self.color_mapping = {}
51
- self.ax.axis("equal") if self.type == "pie" else None # Ensure pie chart is circular
82
+
83
+ if self.type == "pie": # Ensure pie chart is circular
84
+ self.ax.axis("equal")
52
85
 
53
86
  def process_data(self, im0, frame_number):
54
87
  """
55
- Process the image data, run object tracking.
88
+ Processes image data and runs object tracking to update analytics charts.
56
89
 
57
90
  Args:
58
- im0 (ndarray): Input image for processing.
59
- frame_number (int): Video frame # for plotting the data.
91
+ im0 (np.ndarray): Input image for processing.
92
+ frame_number (int): Video frame number for plotting the data.
93
+
94
+ Returns:
95
+ (np.ndarray): Processed image with updated analytics chart.
96
+
97
+ Raises:
98
+ ModuleNotFoundError: If an unsupported chart type is specified.
99
+
100
+ Examples:
101
+ >>> analytics = Analytics(analytics_type="line")
102
+ >>> frame = np.zeros((480, 640, 3), dtype=np.uint8)
103
+ >>> processed_frame = analytics.process_data(frame, frame_number=1)
60
104
  """
61
105
  self.extract_tracks(im0) # Extract tracks
62
106
 
@@ -79,13 +123,22 @@ class Analytics(BaseSolution):
79
123
 
80
124
  def update_graph(self, frame_number, count_dict=None, plot="line"):
81
125
  """
82
- Update the graph (line or area) with new data for single or multiple classes.
126
+ Updates the graph with new data for single or multiple classes.
83
127
 
84
128
  Args:
85
129
  frame_number (int): The current frame number.
86
- count_dict (dict, optional): Dictionary with class names as keys and counts as values for multiple classes.
87
- If None, updates a single line graph.
88
- plot (str): Type of the plot i.e. line, bar or area.
130
+ count_dict (Dict[str, int] | None): Dictionary with class names as keys and counts as values for multiple
131
+ classes. If None, updates a single line graph.
132
+ plot (str): Type of the plot. Options are 'line', 'bar', 'pie', or 'area'.
133
+
134
+ Returns:
135
+ (np.ndarray): Updated image containing the graph.
136
+
137
+ Examples:
138
+ >>> analytics = Analytics()
139
+ >>> frame_number = 10
140
+ >>> count_dict = {"person": 5, "car": 3}
141
+ >>> updated_image = analytics.update_graph(frame_number, count_dict, plot="bar")
89
142
  """
90
143
  if count_dict is None:
91
144
  # Single line update
@@ -4,15 +4,41 @@ import math
4
4
 
5
5
  import cv2
6
6
 
7
- from ultralytics.solutions.solutions import BaseSolution # Import a parent class
7
+ from ultralytics.solutions.solutions import BaseSolution
8
8
  from ultralytics.utils.plotting import Annotator, colors
9
9
 
10
10
 
11
11
  class DistanceCalculation(BaseSolution):
12
- """A class to calculate distance between two objects in a real-time video stream based on their tracks."""
12
+ """
13
+ A class to calculate distance between two objects in a real-time video stream based on their tracks.
14
+
15
+ This class extends BaseSolution to provide functionality for selecting objects and calculating the distance
16
+ between them in a video stream using YOLO object detection and tracking.
17
+
18
+ Attributes:
19
+ left_mouse_count (int): Counter for left mouse button clicks.
20
+ selected_boxes (Dict[int, List[float]]): Dictionary to store selected bounding boxes and their track IDs.
21
+ annotator (Annotator): An instance of the Annotator class for drawing on the image.
22
+ boxes (List[List[float]]): List of bounding boxes for detected objects.
23
+ track_ids (List[int]): List of track IDs for detected objects.
24
+ clss (List[int]): List of class indices for detected objects.
25
+ names (List[str]): List of class names that the model can detect.
26
+ centroids (List[List[int]]): List to store centroids of selected bounding boxes.
27
+
28
+ Methods:
29
+ mouse_event_for_distance: Handles mouse events for selecting objects in the video stream.
30
+ calculate: Processes video frames and calculates the distance between selected objects.
31
+
32
+ Examples:
33
+ >>> distance_calc = DistanceCalculation()
34
+ >>> frame = cv2.imread("frame.jpg")
35
+ >>> processed_frame = distance_calc.calculate(frame)
36
+ >>> cv2.imshow("Distance Calculation", processed_frame)
37
+ >>> cv2.waitKey(0)
38
+ """
13
39
 
14
40
  def __init__(self, **kwargs):
15
- """Initializes the DistanceCalculation class with the given parameters."""
41
+ """Initializes the DistanceCalculation class for measuring object distances in video streams."""
16
42
  super().__init__(**kwargs)
17
43
 
18
44
  # Mouse event information
@@ -21,14 +47,18 @@ class DistanceCalculation(BaseSolution):
21
47
 
22
48
  def mouse_event_for_distance(self, event, x, y, flags, param):
23
49
  """
24
- Handles mouse events to select regions in a real-time video stream.
50
+ Handles mouse events to select regions in a real-time video stream for distance calculation.
25
51
 
26
52
  Args:
27
- event (int): Type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN, etc.).
53
+ event (int): Type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN).
28
54
  x (int): X-coordinate of the mouse pointer.
29
55
  y (int): Y-coordinate of the mouse pointer.
30
- flags (int): Flags associated with the event (e.g., cv2.EVENT_FLAG_CTRLKEY, cv2.EVENT_FLAG_SHIFTKEY, etc.).
31
- param (dict): Additional parameters passed to the function.
56
+ flags (int): Flags associated with the event (e.g., cv2.EVENT_FLAG_CTRLKEY, cv2.EVENT_FLAG_SHIFTKEY).
57
+ param (Dict): Additional parameters passed to the function.
58
+
59
+ Examples:
60
+ >>> # Assuming 'dc' is an instance of DistanceCalculation
61
+ >>> cv2.setMouseCallback("window_name", dc.mouse_event_for_distance)
32
62
  """
33
63
  if event == cv2.EVENT_LBUTTONDOWN:
34
64
  self.left_mouse_count += 1
@@ -43,13 +73,23 @@ class DistanceCalculation(BaseSolution):
43
73
 
44
74
  def calculate(self, im0):
45
75
  """
46
- Processes the video frame and calculates the distance between two bounding boxes.
76
+ Processes a video frame and calculates the distance between two selected bounding boxes.
77
+
78
+ This method extracts tracks from the input frame, annotates bounding boxes, and calculates the distance
79
+ between two user-selected objects if they have been chosen.
47
80
 
48
81
  Args:
49
- im0 (ndarray): The image frame.
82
+ im0 (numpy.ndarray): The input image frame to process.
50
83
 
51
84
  Returns:
52
- (ndarray): The processed image frame.
85
+ (numpy.ndarray): The processed image frame with annotations and distance calculations.
86
+
87
+ Examples:
88
+ >>> import numpy as np
89
+ >>> from ultralytics.solutions import DistanceCalculation
90
+ >>> dc = DistanceCalculation()
91
+ >>> frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
92
+ >>> processed_frame = dc.calculate(frame)
53
93
  """
54
94
  self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator
55
95
  self.extract_tracks(im0) # Extract tracks
@@ -3,15 +3,40 @@
3
3
  import cv2
4
4
  import numpy as np
5
5
 
6
- from ultralytics.solutions.object_counter import ObjectCounter # Import object counter class
6
+ from ultralytics.solutions.object_counter import ObjectCounter
7
7
  from ultralytics.utils.plotting import Annotator
8
8
 
9
9
 
10
10
  class Heatmap(ObjectCounter):
11
- """A class to draw heatmaps in real-time video stream based on their tracks."""
11
+ """
12
+ A class to draw heatmaps in real-time video streams based on object tracks.
13
+
14
+ This class extends the ObjectCounter class to generate and visualize heatmaps of object movements in video
15
+ streams. It uses tracked object positions to create a cumulative heatmap effect over time.
16
+
17
+ Attributes:
18
+ initialized (bool): Flag indicating whether the heatmap has been initialized.
19
+ colormap (int): OpenCV colormap used for heatmap visualization.
20
+ heatmap (np.ndarray): Array storing the cumulative heatmap data.
21
+ annotator (Annotator): Object for drawing annotations on the image.
22
+
23
+ Methods:
24
+ heatmap_effect: Calculates and updates the heatmap effect for a given bounding box.
25
+ generate_heatmap: Generates and applies the heatmap effect to each frame.
26
+
27
+ Examples:
28
+ >>> from ultralytics.solutions import Heatmap
29
+ >>> heatmap = Heatmap(model="yolov8n.pt", colormap=cv2.COLORMAP_JET)
30
+ >>> results = heatmap("path/to/video.mp4")
31
+ >>> for result in results:
32
+ ... print(result.speed) # Print inference speed
33
+ ... cv2.imshow("Heatmap", result.plot())
34
+ ... if cv2.waitKey(1) & 0xFF == ord("q"):
35
+ ... break
36
+ """
12
37
 
13
38
  def __init__(self, **kwargs):
14
- """Initializes function for heatmap class with default values."""
39
+ """Initializes the Heatmap class for real-time video stream heatmap generation based on object tracks."""
15
40
  super().__init__(**kwargs)
16
41
 
17
42
  self.initialized = False # bool variable for heatmap initialization
@@ -23,10 +48,15 @@ class Heatmap(ObjectCounter):
23
48
 
24
49
  def heatmap_effect(self, box):
25
50
  """
26
- Efficient calculation of heatmap area and effect location for applying colormap.
51
+ Efficiently calculates heatmap area and effect location for applying colormap.
27
52
 
28
53
  Args:
29
- box (list): Bounding Box coordinates data [x0, y0, x1, y1]
54
+ box (List[float]): Bounding box coordinates [x0, y0, x1, y1].
55
+
56
+ Examples:
57
+ >>> heatmap = Heatmap()
58
+ >>> box = [100, 100, 200, 200]
59
+ >>> heatmap.heatmap_effect(box)
30
60
  """
31
61
  x0, y0, x1, y1 = map(int, box)
32
62
  radius_squared = (min(x1 - x0, y1 - y0) // 2) ** 2
@@ -48,9 +78,15 @@ class Heatmap(ObjectCounter):
48
78
  Generate heatmap for each frame using Ultralytics.
49
79
 
50
80
  Args:
51
- im0 (ndarray): Input image array for processing
81
+ im0 (np.ndarray): Input image array for processing.
82
+
52
83
  Returns:
53
- im0 (ndarray): Processed image for further usage
84
+ (np.ndarray): Processed image with heatmap overlay and object counts (if region is specified).
85
+
86
+ Examples:
87
+ >>> heatmap = Heatmap()
88
+ >>> im0 = cv2.imread("image.jpg")
89
+ >>> result = heatmap.generate_heatmap(im0)
54
90
  """
55
91
  if not self.initialized:
56
92
  self.heatmap = np.zeros_like(im0, dtype=np.float32) * 0.99
@@ -70,16 +106,17 @@ class Heatmap(ObjectCounter):
70
106
  self.store_classwise_counts(cls) # store classwise counts in dict
71
107
 
72
108
  # Store tracking previous position and perform object counting
73
- prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
109
+ prev_position = None
110
+ if len(self.track_history[track_id]) > 1:
111
+ prev_position = self.track_history[track_id][-2]
74
112
  self.count_objects(self.track_line, box, track_id, prev_position, cls) # Perform object counting
75
113
 
76
- self.display_counts(im0) if self.region is not None else None # Display the counts on the frame
114
+ if self.region is not None:
115
+ self.display_counts(im0) # Display the counts on the frame
77
116
 
78
117
  # Normalize, apply colormap to heatmap and combine with original image
79
- im0 = (
80
- im0
81
- if self.track_data.id is None
82
- else cv2.addWeighted(
118
+ if self.track_data.id is not None:
119
+ im0 = cv2.addWeighted(
83
120
  im0,
84
121
  0.5,
85
122
  cv2.applyColorMap(
@@ -88,7 +125,6 @@ class Heatmap(ObjectCounter):
88
125
  0.5,
89
126
  0,
90
127
  )
91
- )
92
128
 
93
129
  self.display_output(im0) # display output with base class function
94
130
  return im0 # return output image for more usage