ultralytics 8.0.228__py3-none-any.whl → 8.0.229__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- __version__ = '8.0.228'
3
+ __version__ = '8.0.229'
4
4
 
5
5
  from ultralytics.models import RTDETR, SAM, YOLO
6
6
  from ultralytics.models.fastsam import FastSAM
@@ -61,6 +61,7 @@ augment: False # (bool) apply image augmentation to prediction sources
61
61
  agnostic_nms: False # (bool) class-agnostic NMS
62
62
  classes: # (int | list[int], optional) filter results by class, i.e. classes=0, or classes=[0,2,3]
63
63
  retina_masks: False # (bool) use high-resolution segmentation masks
64
+ embed: # (list[int], optional) return feature vectors/embeddings from given layers
64
65
 
65
66
  # Visualize settings ---------------------------------------------------------------------------------------------------
66
67
  show: False # (bool) show predicted images and videos if environment allows
@@ -94,7 +94,7 @@ class Model(nn.Module):
94
94
  self._load(model, task)
95
95
 
96
96
  def __call__(self, source=None, stream=False, **kwargs):
97
- """Calls the 'predict' function with given arguments to perform object detection."""
97
+ """Calls the predict() method with given arguments to perform object detection."""
98
98
  return self.predict(source, stream, **kwargs)
99
99
 
100
100
  @staticmethod
@@ -201,6 +201,24 @@ class Model(nn.Module):
201
201
  self._check_is_pytorch_model()
202
202
  self.model.fuse()
203
203
 
204
+ def embed(self, source=None, stream=False, **kwargs):
205
+ """
206
+ Calls the predict() method and returns image embeddings.
207
+
208
+ Args:
209
+ source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
210
+ Accepts all source types accepted by the YOLO model.
211
+ stream (bool): Whether to stream the predictions or not. Defaults to False.
212
+ **kwargs : Additional keyword arguments passed to the predictor.
213
+ Check the 'configuration' section in the documentation for all available options.
214
+
215
+ Returns:
216
+ (List[torch.Tensor]): A list of image embeddings.
217
+ """
218
+ if not kwargs.get('embed'):
219
+ kwargs['embed'] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
220
+ return self.predict(source, stream, **kwargs)
221
+
204
222
  def predict(self, source=None, stream=False, predictor=None, **kwargs):
205
223
  """
206
224
  Perform prediction using the YOLO model.
@@ -134,7 +134,7 @@ class BasePredictor:
134
134
  """Runs inference on a given image using the specified model and arguments."""
135
135
  visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
136
136
  mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
137
- return self.model(im, augment=self.args.augment, visualize=visualize)
137
+ return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
138
138
 
139
139
  def pre_transform(self, im):
140
140
  """
@@ -263,6 +263,9 @@ class BasePredictor:
263
263
  # Inference
264
264
  with profilers[1]:
265
265
  preds = self.inference(im, *args, **kwargs)
266
+ if self.args.embed:
267
+ yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors
268
+ continue
266
269
 
267
270
  # Postprocess
268
271
  with profilers[2]:
@@ -333,7 +333,7 @@ class AutoBackend(nn.Module):
333
333
 
334
334
  self.__dict__.update(locals()) # assign all variables to self
335
335
 
336
- def forward(self, im, augment=False, visualize=False):
336
+ def forward(self, im, augment=False, visualize=False, embed=None):
337
337
  """
338
338
  Runs inference on the YOLOv8 MultiBackend model.
339
339
 
@@ -341,6 +341,7 @@ class AutoBackend(nn.Module):
341
341
  im (torch.Tensor): The image tensor to perform inference on.
342
342
  augment (bool): whether to perform data augmentation during inference, defaults to False
343
343
  visualize (bool): whether to visualize the output predictions, defaults to False
344
+ embed (list, optional): A list of feature vectors/embeddings to return.
344
345
 
345
346
  Returns:
346
347
  (tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True)
@@ -352,7 +353,7 @@ class AutoBackend(nn.Module):
352
353
  im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
353
354
 
354
355
  if self.pt or self.nn_module: # PyTorch
355
- y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
356
+ y = self.model(im, augment=augment, visualize=visualize, embed=embed)
356
357
  elif self.jit: # TorchScript
357
358
  y = self.model(im)
358
359
  elif self.dnn: # ONNX OpenCV DNN
ultralytics/nn/tasks.py CHANGED
@@ -41,7 +41,7 @@ class BaseModel(nn.Module):
41
41
  return self.loss(x, *args, **kwargs)
42
42
  return self.predict(x, *args, **kwargs)
43
43
 
44
- def predict(self, x, profile=False, visualize=False, augment=False):
44
+ def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
45
45
  """
46
46
  Perform a forward pass through the network.
47
47
 
@@ -50,15 +50,16 @@ class BaseModel(nn.Module):
50
50
  profile (bool): Print the computation time of each layer if True, defaults to False.
51
51
  visualize (bool): Save the feature maps of the model if True, defaults to False.
52
52
  augment (bool): Augment image during prediction, defaults to False.
53
+ embed (list, optional): A list of feature vectors/embeddings to return.
53
54
 
54
55
  Returns:
55
56
  (torch.Tensor): The last output of the model.
56
57
  """
57
58
  if augment:
58
59
  return self._predict_augment(x)
59
- return self._predict_once(x, profile, visualize)
60
+ return self._predict_once(x, profile, visualize, embed)
60
61
 
61
- def _predict_once(self, x, profile=False, visualize=False):
62
+ def _predict_once(self, x, profile=False, visualize=False, embed=None):
62
63
  """
63
64
  Perform a forward pass through the network.
64
65
 
@@ -66,11 +67,12 @@ class BaseModel(nn.Module):
66
67
  x (torch.Tensor): The input tensor to the model.
67
68
  profile (bool): Print the computation time of each layer if True, defaults to False.
68
69
  visualize (bool): Save the feature maps of the model if True, defaults to False.
70
+ embed (list, optional): A list of feature vectors/embeddings to return.
69
71
 
70
72
  Returns:
71
73
  (torch.Tensor): The last output of the model.
72
74
  """
73
- y, dt = [], [] # outputs
75
+ y, dt, embeddings = [], [], [] # outputs
74
76
  for m in self.model:
75
77
  if m.f != -1: # if not from previous layer
76
78
  x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
@@ -80,6 +82,10 @@ class BaseModel(nn.Module):
80
82
  y.append(x if m.i in self.save else None) # save output
81
83
  if visualize:
82
84
  feature_visualization(x, m.type, m.i, save_dir=visualize)
85
+ if embed and m.i in embed:
86
+ embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
87
+ if m.i == max(embed):
88
+ return torch.unbind(torch.cat(embeddings, 1), dim=0)
83
89
  return x
84
90
 
85
91
  def _predict_augment(self, x):
@@ -454,7 +460,7 @@ class RTDETRDetectionModel(DetectionModel):
454
460
  return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']],
455
461
  device=img.device)
456
462
 
457
- def predict(self, x, profile=False, visualize=False, batch=None, augment=False):
463
+ def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
458
464
  """
459
465
  Perform a forward pass through the model.
460
466
 
@@ -464,11 +470,12 @@ class RTDETRDetectionModel(DetectionModel):
464
470
  visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
465
471
  batch (dict, optional): Ground truth data for evaluation. Defaults to None.
466
472
  augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
473
+ embed (list, optional): A list of feature vectors/embeddings to return.
467
474
 
468
475
  Returns:
469
476
  (torch.Tensor): Model's output tensor.
470
477
  """
471
- y, dt = [], [] # outputs
478
+ y, dt, embeddings = [], [], [] # outputs
472
479
  for m in self.model[:-1]: # except the head part
473
480
  if m.f != -1: # if not from previous layer
474
481
  x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
@@ -478,6 +485,10 @@ class RTDETRDetectionModel(DetectionModel):
478
485
  y.append(x if m.i in self.save else None) # save output
479
486
  if visualize:
480
487
  feature_visualization(x, m.type, m.i, save_dir=visualize)
488
+ if embed and m.i in embed:
489
+ embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
490
+ if m.i == max(embed):
491
+ return torch.unbind(torch.cat(embeddings, 1), dim=0)
481
492
  head = self.model[-1]
482
493
  x = head([y[j] for j in head.f], batch) # head inference
483
494
  return x
@@ -10,8 +10,7 @@ from ultralytics.utils.plotting import Annotator
10
10
 
11
11
  check_requirements('shapely>=2.0.0')
12
12
 
13
- from shapely.geometry import Polygon
14
- from shapely.geometry.point import Point
13
+ from shapely.geometry import LineString, Point, Polygon
15
14
 
16
15
 
17
16
  class Heatmap:
@@ -23,6 +22,7 @@ class Heatmap:
23
22
  # Visual information
24
23
  self.annotator = None
25
24
  self.view_img = False
25
+ self.shape = 'circle'
26
26
 
27
27
  # Image information
28
28
  self.imw = None
@@ -38,17 +38,22 @@ class Heatmap:
38
38
  self.boxes = None
39
39
  self.track_ids = None
40
40
  self.clss = None
41
- self.track_history = None
41
+ self.track_history = defaultdict(list)
42
42
 
43
- # Counting info
43
+ # Region & Line Information
44
44
  self.count_reg_pts = None
45
- self.count_region = None
45
+ self.counting_region = None
46
+ self.line_dist_thresh = 15
47
+ self.region_thickness = 5
48
+ self.region_color = (255, 0, 255)
49
+
50
+ # Object Counting Information
46
51
  self.in_counts = 0
47
52
  self.out_counts = 0
48
- self.count_list = []
53
+ self.counting_list = []
49
54
  self.count_txt_thickness = 0
50
- self.count_reg_color = (0, 255, 0)
51
- self.region_thickness = 5
55
+ self.count_txt_color = (0, 0, 0)
56
+ self.count_color = (255, 255, 255)
52
57
 
53
58
  # Decay factor
54
59
  self.decay_factor = 0.99
@@ -64,9 +69,13 @@ class Heatmap:
64
69
  view_img=False,
65
70
  count_reg_pts=None,
66
71
  count_txt_thickness=2,
72
+ count_txt_color=(0, 0, 0),
73
+ count_color=(255, 255, 255),
67
74
  count_reg_color=(255, 0, 255),
68
75
  region_thickness=5,
69
- decay_factor=0.99):
76
+ line_dist_thresh=15,
77
+ decay_factor=0.99,
78
+ shape='circle'):
70
79
  """
71
80
  Configures the heatmap colormap, width, height and display parameters.
72
81
 
@@ -78,27 +87,55 @@ class Heatmap:
78
87
  view_img (bool): Flag indicating frame display
79
88
  count_reg_pts (list): Object counting region points
80
89
  count_txt_thickness (int): Text thickness for object counting display
90
+ count_txt_color (RGB color): count text color value
91
+ count_color (RGB color): count text background color value
81
92
  count_reg_color (RGB color): Color of object counting region
82
93
  region_thickness (int): Object counting Region thickness
94
+ line_dist_thresh (int): Euclidean Distance threshold for line counter
83
95
  decay_factor (float): value for removing heatmap area after object passed
96
+ shape (str): Heatmap shape, rect or circle shape supported
84
97
  """
85
98
  self.imw = imw
86
99
  self.imh = imh
87
- self.colormap = colormap
88
100
  self.heatmap_alpha = heatmap_alpha
89
101
  self.view_img = view_img
102
+ self.colormap = colormap
90
103
 
91
- self.heatmap = np.zeros((int(self.imw), int(self.imh)), dtype=np.float32) # Heatmap new frame
92
-
104
+ # Region and line selection
93
105
  if count_reg_pts is not None:
94
- self.track_history = defaultdict(list)
95
- self.count_reg_pts = count_reg_pts
96
- self.count_region = Polygon(self.count_reg_pts)
97
106
 
98
- self.count_txt_thickness = count_txt_thickness # Counting text thickness
99
- self.count_reg_color = count_reg_color
107
+ if len(count_reg_pts) == 2:
108
+ print('Line Counter Initiated.')
109
+ self.count_reg_pts = count_reg_pts
110
+ self.counting_region = LineString(count_reg_pts)
111
+
112
+ elif len(count_reg_pts) == 4:
113
+ print('Region Counter Initiated.')
114
+ self.count_reg_pts = count_reg_pts
115
+ self.counting_region = Polygon(self.count_reg_pts)
116
+
117
+ else:
118
+ print('Region or line points Invalid, 2 or 4 points supported')
119
+ print('Using Line Counter Now')
120
+ self.counting_region = Polygon([(20, 400), (1260, 400)]) # dummy points
121
+
122
+ # Heatmap new frame
123
+ self.heatmap = np.zeros((int(self.imw), int(self.imh)), dtype=np.float32)
124
+
125
+ self.count_txt_thickness = count_txt_thickness
126
+ self.count_txt_color = count_txt_color
127
+ self.count_color = count_color
128
+ self.region_color = count_reg_color
100
129
  self.region_thickness = region_thickness
101
130
  self.decay_factor = decay_factor
131
+ self.line_dist_thresh = line_dist_thresh
132
+ self.shape = shape
133
+
134
+ # shape of heatmap, if not selected
135
+ if self.shape not in ['circle', 'rect']:
136
+ print("Unknown shape value provided, 'circle' & 'rect' supported")
137
+ print('Using Circular shape now')
138
+ self.shape = 'circle'
102
139
 
103
140
  def extract_results(self, tracks):
104
141
  """
@@ -128,13 +165,26 @@ class Heatmap:
128
165
  self.annotator = Annotator(self.im0, self.count_txt_thickness, None)
129
166
 
130
167
  if self.count_reg_pts is not None:
168
+
131
169
  # Draw counting region
132
170
  self.annotator.draw_region(reg_pts=self.count_reg_pts,
133
- color=self.count_reg_color,
171
+ color=self.region_color,
134
172
  thickness=self.region_thickness)
135
173
 
136
174
  for box, cls, track_id in zip(self.boxes, self.clss, self.track_ids):
137
- self.heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] += 1
175
+
176
+ if self.shape == 'circle':
177
+ center = (int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2))
178
+ radius = min(int(box[2]) - int(box[0]), int(box[3]) - int(box[1])) // 2
179
+
180
+ y, x = np.ogrid[0:self.heatmap.shape[0], 0:self.heatmap.shape[1]]
181
+ mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius ** 2
182
+
183
+ self.heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] += \
184
+ (2 * mask[int(box[1]):int(box[3]), int(box[0]):int(box[2])])
185
+
186
+ else:
187
+ self.heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] += 2
138
188
 
139
189
  # Store tracking hist
140
190
  track_line = self.track_history[track_id]
@@ -143,16 +193,39 @@ class Heatmap:
143
193
  track_line.pop(0)
144
194
 
145
195
  # Count objects
146
- if self.count_region.contains(Point(track_line[-1])):
147
- if track_id not in self.count_list:
148
- self.count_list.append(track_id)
149
- if box[0] < self.count_region.centroid.x:
150
- self.out_counts += 1
151
- else:
152
- self.in_counts += 1
196
+ if len(self.count_reg_pts) == 4:
197
+ if self.counting_region.contains(Point(track_line[-1])):
198
+ if track_id not in self.counting_list:
199
+ self.counting_list.append(track_id)
200
+ if box[0] < self.counting_region.centroid.x:
201
+ self.out_counts += 1
202
+ else:
203
+ self.in_counts += 1
204
+
205
+ elif len(self.count_reg_pts) == 2:
206
+ distance = Point(track_line[-1]).distance(self.counting_region)
207
+ if distance < self.line_dist_thresh:
208
+ if track_id not in self.counting_list:
209
+ self.counting_list.append(track_id)
210
+ if box[0] < self.counting_region.centroid.x:
211
+ self.out_counts += 1
212
+ else:
213
+ self.in_counts += 1
153
214
  else:
154
215
  for box, cls in zip(self.boxes, self.clss):
155
- self.heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] += 1
216
+
217
+ if self.shape == 'circle':
218
+ center = (int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2))
219
+ radius = min(int(box[2]) - int(box[0]), int(box[3]) - int(box[1])) // 2
220
+
221
+ y, x = np.ogrid[0:self.heatmap.shape[0], 0:self.heatmap.shape[1]]
222
+ mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius ** 2
223
+
224
+ self.heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] += \
225
+ (2 * mask[int(box[1]):int(box[3]), int(box[0]):int(box[2])])
226
+
227
+ else:
228
+ self.heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] += 2
156
229
 
157
230
  # Normalize, apply colormap to heatmap and combine with original image
158
231
  heatmap_normalized = cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX)
@@ -161,7 +234,11 @@ class Heatmap:
161
234
  if self.count_reg_pts is not None:
162
235
  incount_label = 'InCount : ' + f'{self.in_counts}'
163
236
  outcount_label = 'OutCount : ' + f'{self.out_counts}'
164
- self.annotator.count_labels(in_count=incount_label, out_count=outcount_label)
237
+ self.annotator.count_labels(in_count=incount_label,
238
+ out_count=outcount_label,
239
+ count_txt_size=self.count_txt_thickness,
240
+ txt_color=self.count_txt_color,
241
+ color=self.count_color)
165
242
 
166
243
  im0_with_heatmap = cv2.addWeighted(self.im0, 1 - self.heatmap_alpha, heatmap_colored, self.heatmap_alpha, 0)
167
244
 
@@ -9,8 +9,7 @@ from ultralytics.utils.plotting import Annotator, colors
9
9
 
10
10
  check_requirements('shapely>=2.0.0')
11
11
 
12
- from shapely.geometry import Polygon
13
- from shapely.geometry.point import Point
12
+ from shapely.geometry import LineString, Point, Polygon
14
13
 
15
14
 
16
15
  class ObjectCounter:
@@ -23,10 +22,12 @@ class ObjectCounter:
23
22
  self.is_drawing = False
24
23
  self.selected_point = None
25
24
 
26
- # Region Information
27
- self.reg_pts = None
25
+ # Region & Line Information
26
+ self.reg_pts = [(20, 400), (1260, 400)]
27
+ self.line_dist_thresh = 15
28
28
  self.counting_region = None
29
- self.region_color = (255, 255, 255)
29
+ self.region_color = (255, 0, 255)
30
+ self.region_thickness = 5
30
31
 
31
32
  # Image and annotation Information
32
33
  self.im0 = None
@@ -40,11 +41,15 @@ class ObjectCounter:
40
41
  self.in_counts = 0
41
42
  self.out_counts = 0
42
43
  self.counting_list = []
44
+ self.count_txt_thickness = 0
45
+ self.count_txt_color = (0, 0, 0)
46
+ self.count_color = (255, 255, 255)
43
47
 
44
48
  # Tracks info
45
49
  self.track_history = defaultdict(list)
46
50
  self.track_thickness = 2
47
51
  self.draw_tracks = False
52
+ self.track_color = (0, 255, 0)
48
53
 
49
54
  # Check if environment support imshow
50
55
  self.env_check = check_imshow(warn=True)
@@ -52,11 +57,17 @@ class ObjectCounter:
52
57
  def set_args(self,
53
58
  classes_names,
54
59
  reg_pts,
55
- region_color=None,
60
+ count_reg_color=(255, 0, 255),
56
61
  line_thickness=2,
57
62
  track_thickness=2,
58
63
  view_img=False,
59
- draw_tracks=False):
64
+ draw_tracks=False,
65
+ count_txt_thickness=2,
66
+ count_txt_color=(0, 0, 0),
67
+ count_color=(255, 255, 255),
68
+ track_color=(0, 255, 0),
69
+ region_thickness=5,
70
+ line_dist_thresh=15):
60
71
  """
61
72
  Configures the Counter's image, bounding box line thickness, and counting region points.
62
73
 
@@ -65,18 +76,43 @@ class ObjectCounter:
65
76
  view_img (bool): Flag to control whether to display the video stream.
66
77
  reg_pts (list): Initial list of points defining the counting region.
67
78
  classes_names (dict): Classes names
68
- region_color (tuple): color for region line
69
79
  track_thickness (int): Track thickness
70
80
  draw_tracks (Bool): draw tracks
81
+ count_txt_thickness (int): Text thickness for object counting display
82
+ count_txt_color (RGB color): count text color value
83
+ count_color (RGB color): count text background color value
84
+ count_reg_color (RGB color): Color of object counting region
85
+ track_color (RGB color): color for tracks
86
+ region_thickness (int): Object counting Region thickness
87
+ line_dist_thresh (int): Euclidean Distance threshold for line counter
71
88
  """
72
89
  self.tf = line_thickness
73
90
  self.view_img = view_img
74
91
  self.track_thickness = track_thickness
75
92
  self.draw_tracks = draw_tracks
76
- self.reg_pts = reg_pts
77
- self.counting_region = Polygon(self.reg_pts)
93
+
94
+ # Region and line selection
95
+ if len(reg_pts) == 2:
96
+ print('Line Counter Initiated.')
97
+ self.reg_pts = reg_pts
98
+ self.counting_region = LineString(self.reg_pts)
99
+ elif len(reg_pts) == 4:
100
+ print('Region Counter Initiated.')
101
+ self.reg_pts = reg_pts
102
+ self.counting_region = Polygon(self.reg_pts)
103
+ else:
104
+ print('Invalid Region points provided, region_points can be 2 or 4')
105
+ print('Using Line Counter Now')
106
+ self.counting_region = LineString(self.reg_pts)
107
+
78
108
  self.names = classes_names
79
- self.region_color = region_color if region_color else self.region_color
109
+ self.track_color = track_color
110
+ self.count_txt_thickness = count_txt_thickness
111
+ self.count_txt_color = count_txt_color
112
+ self.count_color = count_color
113
+ self.region_color = count_reg_color
114
+ self.region_thickness = region_thickness
115
+ self.line_dist_thresh = line_dist_thresh
80
116
 
81
117
  def mouse_event_for_region(self, event, x, y, flags, params):
82
118
  """
@@ -113,11 +149,14 @@ class ObjectCounter:
113
149
  clss = tracks[0].boxes.cls.cpu().tolist()
114
150
  track_ids = tracks[0].boxes.id.int().cpu().tolist()
115
151
 
152
+ # Annotator Init and region drawing
116
153
  self.annotator = Annotator(self.im0, self.tf, self.names)
117
- self.annotator.draw_region(reg_pts=self.reg_pts, color=(0, 255, 0))
154
+ self.annotator.draw_region(reg_pts=self.reg_pts, color=self.region_color, thickness=self.region_thickness)
118
155
 
156
+ # Extract tracks
119
157
  for box, track_id, cls in zip(boxes, track_ids, clss):
120
- self.annotator.box_label(box, label=self.names[cls], color=colors(int(cls), True)) # Draw bounding box
158
+ self.annotator.box_label(box, label=str(track_id) + ':' + self.names[cls],
159
+ color=colors(int(cls), True)) # Draw bounding box
121
160
 
122
161
  # Draw Tracks
123
162
  track_line = self.track_history[track_id]
@@ -125,28 +164,45 @@ class ObjectCounter:
125
164
  if len(track_line) > 30:
126
165
  track_line.pop(0)
127
166
 
167
+ # Draw track trails
128
168
  if self.draw_tracks:
129
169
  self.annotator.draw_centroid_and_tracks(track_line,
130
- color=(0, 255, 0),
170
+ color=self.track_color,
131
171
  track_thickness=self.track_thickness)
132
172
 
133
173
  # Count objects
134
- if self.counting_region.contains(Point(track_line[-1])):
135
- if track_id not in self.counting_list:
136
- self.counting_list.append(track_id)
137
- if box[0] < self.counting_region.centroid.x:
138
- self.out_counts += 1
139
- else:
140
- self.in_counts += 1
141
-
142
- incount_label = 'InCount : ' + f'{self.in_counts}'
174
+ if len(self.reg_pts) == 4:
175
+ if self.counting_region.contains(Point(track_line[-1])):
176
+ if track_id not in self.counting_list:
177
+ self.counting_list.append(track_id)
178
+ if box[0] < self.counting_region.centroid.x:
179
+ self.out_counts += 1
180
+ else:
181
+ self.in_counts += 1
182
+
183
+ elif len(self.reg_pts) == 2:
184
+ distance = Point(track_line[-1]).distance(self.counting_region)
185
+ if distance < self.line_dist_thresh:
186
+ if track_id not in self.counting_list:
187
+ self.counting_list.append(track_id)
188
+ if box[0] < self.counting_region.centroid.x:
189
+ self.out_counts += 1
190
+ else:
191
+ self.in_counts += 1
192
+
193
+ incount_label = 'In Count : ' + f'{self.in_counts}'
143
194
  outcount_label = 'OutCount : ' + f'{self.out_counts}'
144
- self.annotator.count_labels(in_count=incount_label, out_count=outcount_label)
195
+ self.annotator.count_labels(in_count=incount_label,
196
+ out_count=outcount_label,
197
+ count_txt_size=self.count_txt_thickness,
198
+ txt_color=self.count_txt_color,
199
+ color=self.count_color)
145
200
 
146
201
  if self.env_check and self.view_img:
147
202
  cv2.namedWindow('Ultralytics YOLOv8 Object Counter')
148
- cv2.setMouseCallback('Ultralytics YOLOv8 Object Counter', self.mouse_event_for_region,
149
- {'region_points': self.reg_pts})
203
+ if len(self.reg_pts) == 4: # only add mouse event If user drawn region
204
+ cv2.setMouseCallback('Ultralytics YOLOv8 Object Counter', self.mouse_event_for_region,
205
+ {'region_points': self.reg_pts})
150
206
  cv2.imshow('Ultralytics YOLOv8 Object Counter', self.im0)
151
207
  # Break Window
152
208
  if cv2.waitKey(1) & 0xFF == ord('q'):
@@ -161,6 +217,7 @@ class ObjectCounter:
161
217
  tracks (list): List of tracks obtained from the object tracking process.
162
218
  """
163
219
  self.im0 = im0 # store image
220
+
164
221
  if tracks[0].boxes.id is None:
165
222
  return
166
223
  self.extract_and_process_tracks(tracks)
@@ -260,19 +260,41 @@ class Annotator:
260
260
 
261
261
  # Object Counting Annotator
262
262
  def draw_region(self, reg_pts=None, color=(0, 255, 0), thickness=5):
263
- # Draw region line
263
+ """
264
+ Draw region line
265
+ Args:
266
+ reg_pts (list): Region Points (for line 2 points, for region 4 points)
267
+ color (tuple): Region Color value
268
+ thickness (int): Region area thickness value
269
+ """
264
270
  cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness)
265
271
 
266
272
  def draw_centroid_and_tracks(self, track, color=(255, 0, 255), track_thickness=2):
267
- # Draw region line
273
+ """
274
+ Draw centroid point and track trails
275
+ Args:
276
+ track (list): object tracking points for trails display
277
+ color (tuple): tracks line color
278
+ track_thickness (int): track line thickness value
279
+ """
268
280
  points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
269
281
  cv2.polylines(self.im, [points], isClosed=False, color=color, thickness=track_thickness)
270
282
  cv2.circle(self.im, (int(track[-1][0]), int(track[-1][1])), track_thickness * 2, color, -1)
271
283
 
272
- def count_labels(self, in_count=0, out_count=0, color=(255, 255, 255), txt_color=(0, 0, 0)):
284
+ def count_labels(self, in_count=0, out_count=0, count_txt_size=2, color=(255, 255, 255), txt_color=(0, 0, 0)):
285
+ """
286
+ Plot counts for object counter
287
+ Args:
288
+ in_count (int): in count value
289
+ out_count (int): out count value
290
+ count_txt_size (int): text size for counts display
291
+ color (tuple): background color of counts display
292
+ txt_color (tuple): text color of counts display
293
+ """
294
+ self.tf = count_txt_size
273
295
  tl = self.tf or round(0.002 * (self.im.shape[0] + self.im.shape[1]) / 2) + 1
274
296
  tf = max(tl - 1, 1)
275
- gap = int(24 * tl) # Calculate the gap between in_count and out_count based on line_thickness
297
+ gap = int(24 * tl) # gap between in_count and out_count based on line_thickness
276
298
 
277
299
  # Get text size for in_count and out_count
278
300
  t_size_in = cv2.getTextSize(str(in_count), 0, fontScale=tl / 2, thickness=tf)[0]
@@ -306,14 +328,13 @@ class Annotator:
306
328
  thickness=self.tf,
307
329
  lineType=cv2.LINE_AA)
308
330
 
309
- # AI GYM Annotator
310
- def estimate_pose_angle(self, a, b, c):
331
+ @staticmethod
332
+ def estimate_pose_angle(a, b, c):
311
333
  """Calculate the pose angle for object
312
334
  Args:
313
335
  a (float) : The value of pose point a
314
336
  b (float): The value of pose point b
315
337
  c (float): The value o pose point c
316
-
317
338
  Returns:
318
339
  angle (degree): Degree value of angle between three points
319
340
  """
@@ -325,7 +346,15 @@ class Annotator:
325
346
  return angle
326
347
 
327
348
  def draw_specific_points(self, keypoints, indices=[2, 5, 7], shape=(640, 640), radius=2):
328
- """Draw specific keypoints for gym steps counting."""
349
+ """
350
+ Draw specific keypoints for gym steps counting.
351
+
352
+ Args:
353
+ keypoints (list): list of keypoints data to be plotted
354
+ indices (list): keypoints ids list to be plotted
355
+ shape (tuple): imgsz for model inference
356
+ radius (int): Keypoint radius value
357
+ """
329
358
  nkpts, ndim = keypoints.shape
330
359
  nkpts == 17 and ndim == 3
331
360
  for i, k in enumerate(keypoints):
@@ -340,8 +369,17 @@ class Annotator:
340
369
  return self.im
341
370
 
342
371
  def plot_angle_and_count_and_stage(self, angle_text, count_text, stage_text, center_kpt, line_thickness=2):
343
- """Plot the pose angle, count value and step stage."""
344
- angle_text, count_text, stage_text = f' {angle_text:.2f}', 'Steps : ' + f'{count_text}', f' {stage_text}'
372
+ """
373
+ Plot the pose angle, count value and step stage.
374
+
375
+ Args:
376
+ angle_text (str): angle value for workout monitoring
377
+ count_text (str): counts value for workout monitoring
378
+ stage_text (str): stage decision for workout monitoring
379
+ center_kpt (int): centroid pose index for workout monitoring
380
+ line_thickness (int): thickness for text display
381
+ """
382
+ angle_text, count_text, stage_text = (f' {angle_text:.2f}', 'Steps : ' + f'{count_text}', f' {stage_text}')
345
383
  font_scale = 0.6 + (line_thickness / 10.0)
346
384
 
347
385
  # Draw angle
@@ -378,18 +416,38 @@ class Annotator:
378
416
  cv2.putText(self.im, stage_text, stage_text_position, 0, font_scale, (0, 0, 0), line_thickness)
379
417
 
380
418
  def seg_bbox(self, mask, mask_color=(255, 0, 255), det_label=None, track_label=None):
381
- """Function for drawing segmented object in bounding box shape."""
419
+ """
420
+ Function for drawing segmented object in bounding box shape.
421
+
422
+ Args:
423
+ mask (list): masks data list for instance segmentation area plotting
424
+ mask_color (tuple): mask foreground color
425
+ det_label (str): Detection label text
426
+ track_label (str): Tracking label text
427
+ """
382
428
  cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)
383
429
 
384
430
  label = f'Track ID: {track_label}' if track_label else det_label
385
431
  text_size, _ = cv2.getTextSize(label, 0, 0.7, 1)
432
+
386
433
  cv2.rectangle(self.im, (int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),
387
434
  (int(mask[0][0]) + text_size[0] // 2 + 5, int(mask[0][1] + 5)), mask_color, -1)
435
+
388
436
  cv2.putText(self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1]) - 5), 0, 0.7, (255, 255, 255),
389
437
  2)
390
438
 
391
439
  def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, 255), thickness=2, pins_radius=10):
392
- """Function for pinpoint human-vision eye mapping and plotting."""
440
+ """
441
+ Function for pinpoint human-vision eye mapping and plotting.
442
+
443
+ Args:
444
+ box (list): Bounding box coordinates
445
+ center_point (tuple): center point for vision eye view
446
+ color (tuple): object centroid and line color value
447
+ pin_color (tuple): visioneye point color value
448
+ thickness (int): int value for line thickness
449
+ pins_radius (int): visioneye point radius value
450
+ """
393
451
  center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
394
452
  cv2.circle(self.im, center_point, pins_radius, pin_color, -1)
395
453
  cv2.circle(self.im, center_bbox, pins_radius, color, -1)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ultralytics
3
- Version: 8.0.228
3
+ Version: 8.0.229
4
4
  Summary: Ultralytics YOLOv8 for SOTA object detection, multi-object tracking, instance segmentation, pose estimation and image classification.
5
5
  Home-page: https://github.com/ultralytics/ultralytics
6
6
  Author: Ultralytics
@@ -1,8 +1,8 @@
1
- ultralytics/__init__.py,sha256=0ZRVYelXCSx1Ikbi9p1qb3Fpj3B7aKu1bXGw-UUX6co,463
1
+ ultralytics/__init__.py,sha256=EhwYUqe_mS7jZksFDr-yyVHPSPorRPpkiAZ_w0K_MzU,463
2
2
  ultralytics/assets/bus.jpg,sha256=wCAZxJecGR63Od3ZRERe9Aja1Weayrb9Ug751DS_vGM,137419
3
3
  ultralytics/assets/zidane.jpg,sha256=Ftc4aeMmen1O0A3o6GCDO9FlfBslLpTAw0gnetx7bts,50427
4
4
  ultralytics/cfg/__init__.py,sha256=GszkldmONF8PI1J1o4TqNzq1Btzk8R7Y3_susNMHvpA,19859
5
- ultralytics/cfg/default.yaml,sha256=arCmsoPF88mHt0KAuDl3CxN4CbcjXqAXR_sOK9m06YQ,7738
5
+ ultralytics/cfg/default.yaml,sha256=-ejbKAG_xK9bke-Yr5w-1wcwHeA5vZ1FyOgBi27aShQ,7822
6
6
  ultralytics/cfg/datasets/Argoverse.yaml,sha256=TJhOiAm1QOsQnDkg1eEGYlaylgkvKLzBUdQ5gzyi_pY,2856
7
7
  ultralytics/cfg/datasets/DOTAv2.yaml,sha256=SmSpmbz_wRT8HMmPqsHpjep_b-nvckTutoEwVpGaUZM,1149
8
8
  ultralytics/cfg/datasets/GlobalWheat2020.yaml,sha256=Wd9spwO4HV48REvgqDUX-kM5a8rxceFalxkcvWDnJZI,1981
@@ -56,8 +56,8 @@ ultralytics/data/loaders.py,sha256=yDI0Xtb6IxpkU-fxdlPiBOY1FYDPEPDahre0rcgy2T8,2
56
56
  ultralytics/data/utils.py,sha256=1vKuCYOA_haro4tjzVSgOgwvdyEXw4UKmfYyPtfwXis,29699
57
57
  ultralytics/engine/__init__.py,sha256=mHtJuK4hwF8cuV-VHDc7tp6u6D1gHz2Z7JI8grmQDTs,42
58
58
  ultralytics/engine/exporter.py,sha256=8bttk0XZMo8tstnVdGBVPhGnPnirOor28xDLHQP4cz4,51226
59
- ultralytics/engine/model.py,sha256=1cmagS8BskMzOay9uDlFIvS5m58GB2kxZoYnUfxIAbU,19236
60
- ultralytics/engine/predictor.py,sha256=tgwQ58bziem5rZXucVyK0LP5fzvAgJtICBaN7kLUM9s,17548
59
+ ultralytics/engine/model.py,sha256=L1irDV83yBT2aWu053ukhsFGh5hZlsqKyJb_9-d5D0I,20107
60
+ ultralytics/engine/predictor.py,sha256=C6ZmZu5q8-6stuk1vE1P-E9LoPaO_85xo645SjemchA,17777
61
61
  ultralytics/engine/results.py,sha256=2GND_qGa8W8qJTyaSSt3qoPBqAS2JA5CDAB4y6wwdh8,23417
62
62
  ultralytics/engine/trainer.py,sha256=G7WN1rKd-rasJ0cFdZu8cj795z_5o1mI_7eX6p54X3k,33886
63
63
  ultralytics/engine/tuner.py,sha256=_9MAsXQwDtmDznqb6_cgk1DIo8FTwLgM3OTEifCxRp0,11715
@@ -115,8 +115,8 @@ ultralytics/models/yolo/segment/predict.py,sha256=yUs60HFBn7PZ3mErtUAnT69ijPBzFd
115
115
  ultralytics/models/yolo/segment/train.py,sha256=o1q4ZTmZlSwUbFIFaT_T7LvYaKOLq_QXxB-z61YwHx8,2276
116
116
  ultralytics/models/yolo/segment/val.py,sha256=DT-z-XnxP77nTIu2VfmGlpUyeBnDmIszT4vpP7mkGNA,11956
117
117
  ultralytics/nn/__init__.py,sha256=7T_GW3YsPg1kA-74UklF2UcabcRyttRZYrCOXiNnJqU,555
118
- ultralytics/nn/autobackend.py,sha256=eFn23VKky5qEwXpAcK3VXJwC_kiXsFKVQcbO2C38v60,26957
119
- ultralytics/nn/tasks.py,sha256=wiT7k194SU8Ckb3hICaOwyQ5tdSyT5qbug0VlnW5kyA,36609
118
+ ultralytics/nn/autobackend.py,sha256=BRiDYbLrsIOF9DHoVB-IbLUZ1NOtzBdSy9xb420c2FQ,27022
119
+ ultralytics/nn/tasks.py,sha256=djDmgi5PFpUwuNhuxbUpCqZvtHfADvaBycGqY110pIo,37466
120
120
  ultralytics/nn/modules/__init__.py,sha256=vrndehuJuLdA3UMHgByPUSR8rz32naUN0LIZoPzF7YQ,1698
121
121
  ultralytics/nn/modules/block.py,sha256=_A24bZ1xSWvrvqk5RODeobBZL6ReI6ICk-vwilERTZs,14475
122
122
  ultralytics/nn/modules/conv.py,sha256=z_OQka9s5h0p3k1yWrq7SHg1BsA6PfN5lDSQubW2I_k,12774
@@ -125,8 +125,8 @@ ultralytics/nn/modules/transformer.py,sha256=R7K_3r4aTlvghiTTRzh69NmNzlO_1Siiifb
125
125
  ultralytics/nn/modules/utils.py,sha256=q-qfebnMD2iqZyTslZTHsZYW7hyrX62VRgUmHX683-U,3436
126
126
  ultralytics/solutions/__init__.py,sha256=mHtJuK4hwF8cuV-VHDc7tp6u6D1gHz2Z7JI8grmQDTs,42
127
127
  ultralytics/solutions/ai_gym.py,sha256=YnBeC8Vf3-ai4OQIebEXl5yDho6uRspY2XVL8Ipr-h8,6235
128
- ultralytics/solutions/heatmap.py,sha256=eK4qk4HMEvIYTW77HsTK4Z_vopvWZuFCmp84_v18TY0,6669
129
- ultralytics/solutions/object_counter.py,sha256=n3JpYcFVlqHGxtnFUQn0gLBdaOxqdJpghmgRUPa9Cq4,6576
128
+ ultralytics/solutions/heatmap.py,sha256=BKsFF3GbtWGHKCfIWXOcu54dZArRf8b1rIq6dAZLzeQ,10310
129
+ ultralytics/solutions/object_counter.py,sha256=-hEmw93gSz_cqZyHPp7w9nNLBXlPSLeP8pBj2cvNNi8,9338
130
130
  ultralytics/trackers/__init__.py,sha256=dR9unDaRBd6MgMnTKxqJZ0KsJ8BeFGg-LTYQvC7BnIY,227
131
131
  ultralytics/trackers/basetrack.py,sha256=Vbs76Zue_jYdJFudztTJaUnGgMMUwVqoa0BSOhyBh0o,3580
132
132
  ultralytics/trackers/bot_sort.py,sha256=orTkrMj2yHfEQVKaQVWbguTx98S2gvLnaOB0D2JN1Gc,8602
@@ -149,7 +149,7 @@ ultralytics/utils/loss.py,sha256=pYkQu-11idOM_6MDXrRS7PgJRlEH8qTxhYgOr4a_aq4,257
149
149
  ultralytics/utils/metrics.py,sha256=g_NgGDG5pFchoB0u6JxuTavLtrARphvfQEfY2KUksJc,47438
150
150
  ultralytics/utils/ops.py,sha256=h2nRGf6pAwO3muXx0SWi0p-ROrglQtGng0C_coDPhgQ,31297
151
151
  ultralytics/utils/patches.py,sha256=V3ARuy0sg-_yn6nzL7iOWSzR_RzFOuzsICy4P6qUegc,2233
152
- ultralytics/utils/plotting.py,sha256=mPy98up1uX1g1fNesQ4OG2XZx70aiEuGzgCk7b9hQa8,39543
152
+ ultralytics/utils/plotting.py,sha256=iy5r40PLPueacuEQn6N1fS-FtvW_MCqs1-vwbPxciQ8,41662
153
153
  ultralytics/utils/tal.py,sha256=WxW_J5QC8oYAXKDy_huJC3mijBtpWG7UR145IAXO5_I,13675
154
154
  ultralytics/utils/torch_utils.py,sha256=09M6zCz66_rR5NdbryDDiyT6-BUxKJ4l3OZCRHnfCkM,24553
155
155
  ultralytics/utils/triton.py,sha256=opbB1ndgwfmUJzyvUH9vvMe2SrDW6FqmFxKEeNDaALQ,3932
@@ -165,9 +165,9 @@ ultralytics/utils/callbacks/neptune.py,sha256=qIN0gJipB1f3Di7bw0Rb28jLYoCzJSWSqF
165
165
  ultralytics/utils/callbacks/raytune.py,sha256=PGZvW_haVq8Cqha3GgvL7iBMAaxfn8_3u_IIdYCNMPo,608
166
166
  ultralytics/utils/callbacks/tensorboard.py,sha256=XXnpkIJrI_A_68JLRvYvRMHzekn-US1uIcru7vRs_e0,2896
167
167
  ultralytics/utils/callbacks/wb.py,sha256=x_j4ZH4Klp0_Ld13f0UezFluUTS5Ovfgk9hcjwqeruU,6762
168
- ultralytics-8.0.228.dist-info/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
169
- ultralytics-8.0.228.dist-info/METADATA,sha256=qci4P37ud8A4_W5pNWqZMVrwnfuBQaiGYNK8u_238tI,32271
170
- ultralytics-8.0.228.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
171
- ultralytics-8.0.228.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
172
- ultralytics-8.0.228.dist-info/top_level.txt,sha256=aNSJehhoYKycM3X4Tj38Q-BrmWFFm3hFuEXfPIR89eI,784
173
- ultralytics-8.0.228.dist-info/RECORD,,
168
+ ultralytics-8.0.229.dist-info/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
169
+ ultralytics-8.0.229.dist-info/METADATA,sha256=Vol7oXlFIUdEVAb6xjU8dcFfrILh9PPSRMwyoGgarLs,32271
170
+ ultralytics-8.0.229.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
171
+ ultralytics-8.0.229.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
172
+ ultralytics-8.0.229.dist-info/top_level.txt,sha256=aNSJehhoYKycM3X4Tj38Q-BrmWFFm3hFuEXfPIR89eI,784
173
+ ultralytics-8.0.229.dist-info/RECORD,,