ultralytics 8.0.227__py3-none-any.whl → 8.0.229__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- __version__ = '8.0.227'
3
+ __version__ = '8.0.229'
4
4
 
5
5
  from ultralytics.models import RTDETR, SAM, YOLO
6
6
  from ultralytics.models.fastsam import FastSAM
@@ -63,7 +63,7 @@ CLI_HELP_MSG = \
63
63
  """
64
64
 
65
65
  # Define keys for arg type checks
66
- CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear'
66
+ CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear', 'time'
67
67
  CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr',
68
68
  'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud',
69
69
  'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou', 'fraction') # fraction floats 0.0 - 1.0
@@ -8,6 +8,7 @@ mode: train # (str) YOLO mode, i.e. train, val, predict, export, track, benchma
8
8
  model: # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml
9
9
  data: # (str, optional) path to data file, i.e. coco128.yaml
10
10
  epochs: 100 # (int) number of epochs to train for
11
+ time: # (float, optional) number of hours to train for, overrides epochs if supplied
11
12
  patience: 50 # (int) epochs to wait for no observable improvement for early stopping of training
12
13
  batch: 16 # (int) number of images per batch (-1 for AutoBatch)
13
14
  imgsz: 640 # (int | list) input images size as int for train and val modes, or list[w,h] for predict and export modes
@@ -60,6 +61,7 @@ augment: False # (bool) apply image augmentation to prediction sources
60
61
  agnostic_nms: False # (bool) class-agnostic NMS
61
62
  classes: # (int | list[int], optional) filter results by class, i.e. classes=0, or classes=[0,2,3]
62
63
  retina_masks: False # (bool) use high-resolution segmentation masks
64
+ embed: # (list[int], optional) return feature vectors/embeddings from given layers
63
65
 
64
66
  # Visualize settings ---------------------------------------------------------------------------------------------------
65
67
  show: False # (bool) show predicted images and videos if environment allows
ultralytics/data/build.py CHANGED
@@ -100,7 +100,7 @@ def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
100
100
  """Return an InfiniteDataLoader or DataLoader for training or validation set."""
101
101
  batch = min(batch, len(dataset))
102
102
  nd = torch.cuda.device_count() # number of CUDA devices
103
- nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers]) # number of workers
103
+ nw = min([os.cpu_count() // max(nd, 1), batch, workers]) # number of workers
104
104
  sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
105
105
  generator = torch.Generator()
106
106
  generator.manual_seed(6148914691236517205 + RANK)
@@ -459,11 +459,14 @@ class Exporter:
459
459
  f'{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from '
460
460
  'https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory '
461
461
  f'or in {ROOT}. See PNNX repo for full installation instructions.')
462
- _, assets = get_github_assets(repo='pnnx/pnnx', retry=True)
463
- system = 'macos' if MACOS else 'ubuntu' if LINUX else 'windows' # operating system
464
- asset = [x for x in assets if system in x][0] if assets else \
465
- f'https://github.com/pnnx/pnnx/releases/download/20230816/pnnx-20230816-{system}.zip' # fallback
466
- asset = attempt_download_asset(asset, repo='pnnx/pnnx', release='latest')
462
+ system = ['macos'] if MACOS else ['windows'] if WINDOWS else ['ubuntu', 'linux'] # operating system
463
+ try:
464
+ _, assets = get_github_assets(repo='pnnx/pnnx', retry=True)
465
+ url = [x for x in assets if any(s in x for s in system)][0]
466
+ except Exception as e:
467
+ url = f'https://github.com/pnnx/pnnx/releases/download/20231127/pnnx-20231127-{system[0]}.zip'
468
+ LOGGER.warning(f'{prefix} WARNING ⚠️ PNNX GitHub assets not found: {e}, using default {url}')
469
+ asset = attempt_download_asset(url, repo='pnnx/pnnx', release='latest')
467
470
  if check_is_path_safe(Path.cwd(), asset): # avoid path traversal security vulnerability
468
471
  unzip_dir = Path(asset).with_suffix('')
469
472
  (unzip_dir / name).rename(pnnx) # move binary to ROOT
@@ -781,7 +784,8 @@ class Exporter:
781
784
  @try_export
782
785
  def export_tfjs(self, prefix=colorstr('TensorFlow.js:')):
783
786
  """YOLOv8 TensorFlow.js export."""
784
- check_requirements('tensorflowjs')
787
+ # JAX bug requiring install constraints in https://github.com/google/jax/issues/18978
788
+ check_requirements(['jax<=0.4.21', 'jaxlib<=0.4.21', 'tensorflowjs'])
785
789
  import tensorflow as tf
786
790
  import tensorflowjs as tfjs # noqa
787
791
 
@@ -795,8 +799,9 @@ class Exporter:
795
799
  outputs = ','.join(gd_outputs(gd))
796
800
  LOGGER.info(f'\n{prefix} output node names: {outputs}')
797
801
 
802
+ quantization = '--quantize_float16' if self.args.half else '--quantize_uint8' if self.args.int8 else ''
798
803
  with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_: # exporter can not handle spaces in path
799
- cmd = f'tensorflowjs_converter --input_format=tf_frozen_model --output_node_names={outputs} "{fpb_}" "{f_}"'
804
+ cmd = f'tensorflowjs_converter --input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"'
800
805
  LOGGER.info(f"{prefix} running '{cmd}'")
801
806
  subprocess.run(cmd, shell=True)
802
807
 
@@ -94,7 +94,7 @@ class Model(nn.Module):
94
94
  self._load(model, task)
95
95
 
96
96
  def __call__(self, source=None, stream=False, **kwargs):
97
- """Calls the 'predict' function with given arguments to perform object detection."""
97
+ """Calls the predict() method with given arguments to perform object detection."""
98
98
  return self.predict(source, stream, **kwargs)
99
99
 
100
100
  @staticmethod
@@ -201,6 +201,24 @@ class Model(nn.Module):
201
201
  self._check_is_pytorch_model()
202
202
  self.model.fuse()
203
203
 
204
+ def embed(self, source=None, stream=False, **kwargs):
205
+ """
206
+ Calls the predict() method and returns image embeddings.
207
+
208
+ Args:
209
+ source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
210
+ Accepts all source types accepted by the YOLO model.
211
+ stream (bool): Whether to stream the predictions or not. Defaults to False.
212
+ **kwargs : Additional keyword arguments passed to the predictor.
213
+ Check the 'configuration' section in the documentation for all available options.
214
+
215
+ Returns:
216
+ (List[torch.Tensor]): A list of image embeddings.
217
+ """
218
+ if not kwargs.get('embed'):
219
+ kwargs['embed'] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
220
+ return self.predict(source, stream, **kwargs)
221
+
204
222
  def predict(self, source=None, stream=False, predictor=None, **kwargs):
205
223
  """
206
224
  Perform prediction using the YOLO model.
@@ -134,7 +134,7 @@ class BasePredictor:
134
134
  """Runs inference on a given image using the specified model and arguments."""
135
135
  visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
136
136
  mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
137
- return self.model(im, augment=self.args.augment, visualize=visualize)
137
+ return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
138
138
 
139
139
  def pre_transform(self, im):
140
140
  """
@@ -263,6 +263,9 @@ class BasePredictor:
263
263
  # Inference
264
264
  with profilers[1]:
265
265
  preds = self.inference(im, *args, **kwargs)
266
+ if self.args.embed:
267
+ yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors
268
+ continue
266
269
 
267
270
  # Postprocess
268
271
  with profilers[2]:
@@ -189,6 +189,14 @@ class BaseTrainer:
189
189
  else:
190
190
  self._do_train(world_size)
191
191
 
192
+ def _setup_scheduler(self):
193
+ """Initialize training learning rate scheduler."""
194
+ if self.args.cos_lr:
195
+ self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
196
+ else:
197
+ self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear
198
+ self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
199
+
192
200
  def _setup_ddp(self, world_size):
193
201
  """Initializes and sets the DistributedDataParallel parameters for training."""
194
202
  torch.cuda.set_device(RANK)
@@ -269,11 +277,7 @@ class BaseTrainer:
269
277
  decay=weight_decay,
270
278
  iterations=iterations)
271
279
  # Scheduler
272
- if self.args.cos_lr:
273
- self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
274
- else:
275
- self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
276
- self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
280
+ self._setup_scheduler()
277
281
  self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
278
282
  self.resume_training(ckpt)
279
283
  self.scheduler.last_epoch = self.start_epoch - 1 # do not move
@@ -285,17 +289,18 @@ class BaseTrainer:
285
289
  self._setup_ddp(world_size)
286
290
  self._setup_train(world_size)
287
291
 
288
- self.epoch_time = None
289
- self.epoch_time_start = time.time()
290
- self.train_time_start = time.time()
291
292
  nb = len(self.train_loader) # number of batches
292
293
  nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
293
294
  last_opt_step = -1
295
+ self.epoch_time = None
296
+ self.epoch_time_start = time.time()
297
+ self.train_time_start = time.time()
294
298
  self.run_callbacks('on_train_start')
295
299
  LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
296
300
  f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
297
301
  f"Logging results to {colorstr('bold', self.save_dir)}\n"
298
- f'Starting training for {self.epochs} epochs...')
302
+ f'Starting training for '
303
+ f'{self.args.time} hours...' if self.args.time else f'{self.epochs} epochs...')
299
304
  if self.args.close_mosaic:
300
305
  base_idx = (self.epochs - self.args.close_mosaic) * nb
301
306
  self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
@@ -323,7 +328,7 @@ class BaseTrainer:
323
328
  ni = i + nb * epoch
324
329
  if ni <= nw:
325
330
  xi = [0, nw] # x interp
326
- self.accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())
331
+ self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
327
332
  for j, x in enumerate(self.optimizer.param_groups):
328
333
  # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
329
334
  x['lr'] = np.interp(
@@ -348,6 +353,16 @@ class BaseTrainer:
348
353
  self.optimizer_step()
349
354
  last_opt_step = ni
350
355
 
356
+ # Timed stopping
357
+ if self.args.time:
358
+ self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600)
359
+ if RANK != -1: # if DDP training
360
+ broadcast_list = [self.stop if RANK == 0 else None]
361
+ dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
362
+ self.stop = broadcast_list[0]
363
+ if self.stop: # training time exceeded
364
+ break
365
+
351
366
  # Log
352
367
  mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
353
368
  loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
@@ -363,31 +378,37 @@ class BaseTrainer:
363
378
  self.run_callbacks('on_train_batch_end')
364
379
 
365
380
  self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
366
-
367
- with warnings.catch_warnings():
368
- warnings.simplefilter('ignore') # suppress 'Detected lr_scheduler.step() before optimizer.step()'
369
- self.scheduler.step()
370
381
  self.run_callbacks('on_train_epoch_end')
371
-
372
382
  if RANK in (-1, 0):
373
-
374
- # Validation
383
+ final_epoch = epoch + 1 == self.epochs
375
384
  self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
376
- final_epoch = (epoch + 1 == self.epochs) or self.stopper.possible_stop
377
385
 
378
- if self.args.val or final_epoch:
386
+ # Validation
387
+ if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
379
388
  self.metrics, self.fitness = self.validate()
380
389
  self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
381
- self.stop = self.stopper(epoch + 1, self.fitness)
390
+ self.stop |= self.stopper(epoch + 1, self.fitness)
391
+ if self.args.time:
392
+ self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600)
382
393
 
383
394
  # Save model
384
- if self.args.save or (epoch + 1 == self.epochs):
395
+ if self.args.save or final_epoch:
385
396
  self.save_model()
386
397
  self.run_callbacks('on_model_save')
387
398
 
388
- tnow = time.time()
389
- self.epoch_time = tnow - self.epoch_time_start
390
- self.epoch_time_start = tnow
399
+ # Scheduler
400
+ t = time.time()
401
+ self.epoch_time = t - self.epoch_time_start
402
+ self.epoch_time_start = t
403
+ with warnings.catch_warnings():
404
+ warnings.simplefilter('ignore') # suppress 'Detected lr_scheduler.step() before optimizer.step()'
405
+ if self.args.time:
406
+ mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
407
+ self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
408
+ self._setup_scheduler()
409
+ self.scheduler.last_epoch = self.epoch # do not move
410
+ self.stop |= epoch >= self.epochs # stop if exceeded epochs
411
+ self.scheduler.step()
391
412
  self.run_callbacks('on_fit_epoch_end')
392
413
  torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
393
414
 
@@ -395,8 +416,7 @@ class BaseTrainer:
395
416
  if RANK != -1: # if DDP training
396
417
  broadcast_list = [self.stop if RANK == 0 else None]
397
418
  dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
398
- if RANK != 0:
399
- self.stop = broadcast_list[0]
419
+ self.stop = broadcast_list[0]
400
420
  if self.stop:
401
421
  break # must break all DDP ranks
402
422
 
@@ -333,7 +333,7 @@ class AutoBackend(nn.Module):
333
333
 
334
334
  self.__dict__.update(locals()) # assign all variables to self
335
335
 
336
- def forward(self, im, augment=False, visualize=False):
336
+ def forward(self, im, augment=False, visualize=False, embed=None):
337
337
  """
338
338
  Runs inference on the YOLOv8 MultiBackend model.
339
339
 
@@ -341,6 +341,7 @@ class AutoBackend(nn.Module):
341
341
  im (torch.Tensor): The image tensor to perform inference on.
342
342
  augment (bool): whether to perform data augmentation during inference, defaults to False
343
343
  visualize (bool): whether to visualize the output predictions, defaults to False
344
+ embed (list, optional): A list of feature vectors/embeddings to return.
344
345
 
345
346
  Returns:
346
347
  (tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True)
@@ -352,7 +353,7 @@ class AutoBackend(nn.Module):
352
353
  im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
353
354
 
354
355
  if self.pt or self.nn_module: # PyTorch
355
- y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
356
+ y = self.model(im, augment=augment, visualize=visualize, embed=embed)
356
357
  elif self.jit: # TorchScript
357
358
  y = self.model(im)
358
359
  elif self.dnn: # ONNX OpenCV DNN
ultralytics/nn/tasks.py CHANGED
@@ -41,7 +41,7 @@ class BaseModel(nn.Module):
41
41
  return self.loss(x, *args, **kwargs)
42
42
  return self.predict(x, *args, **kwargs)
43
43
 
44
- def predict(self, x, profile=False, visualize=False, augment=False):
44
+ def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
45
45
  """
46
46
  Perform a forward pass through the network.
47
47
 
@@ -50,15 +50,16 @@ class BaseModel(nn.Module):
50
50
  profile (bool): Print the computation time of each layer if True, defaults to False.
51
51
  visualize (bool): Save the feature maps of the model if True, defaults to False.
52
52
  augment (bool): Augment image during prediction, defaults to False.
53
+ embed (list, optional): A list of feature vectors/embeddings to return.
53
54
 
54
55
  Returns:
55
56
  (torch.Tensor): The last output of the model.
56
57
  """
57
58
  if augment:
58
59
  return self._predict_augment(x)
59
- return self._predict_once(x, profile, visualize)
60
+ return self._predict_once(x, profile, visualize, embed)
60
61
 
61
- def _predict_once(self, x, profile=False, visualize=False):
62
+ def _predict_once(self, x, profile=False, visualize=False, embed=None):
62
63
  """
63
64
  Perform a forward pass through the network.
64
65
 
@@ -66,11 +67,12 @@ class BaseModel(nn.Module):
66
67
  x (torch.Tensor): The input tensor to the model.
67
68
  profile (bool): Print the computation time of each layer if True, defaults to False.
68
69
  visualize (bool): Save the feature maps of the model if True, defaults to False.
70
+ embed (list, optional): A list of feature vectors/embeddings to return.
69
71
 
70
72
  Returns:
71
73
  (torch.Tensor): The last output of the model.
72
74
  """
73
- y, dt = [], [] # outputs
75
+ y, dt, embeddings = [], [], [] # outputs
74
76
  for m in self.model:
75
77
  if m.f != -1: # if not from previous layer
76
78
  x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
@@ -80,6 +82,10 @@ class BaseModel(nn.Module):
80
82
  y.append(x if m.i in self.save else None) # save output
81
83
  if visualize:
82
84
  feature_visualization(x, m.type, m.i, save_dir=visualize)
85
+ if embed and m.i in embed:
86
+ embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
87
+ if m.i == max(embed):
88
+ return torch.unbind(torch.cat(embeddings, 1), dim=0)
83
89
  return x
84
90
 
85
91
  def _predict_augment(self, x):
@@ -454,7 +460,7 @@ class RTDETRDetectionModel(DetectionModel):
454
460
  return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']],
455
461
  device=img.device)
456
462
 
457
- def predict(self, x, profile=False, visualize=False, batch=None, augment=False):
463
+ def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
458
464
  """
459
465
  Perform a forward pass through the model.
460
466
 
@@ -464,11 +470,12 @@ class RTDETRDetectionModel(DetectionModel):
464
470
  visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
465
471
  batch (dict, optional): Ground truth data for evaluation. Defaults to None.
466
472
  augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
473
+ embed (list, optional): A list of feature vectors/embeddings to return.
467
474
 
468
475
  Returns:
469
476
  (torch.Tensor): Model's output tensor.
470
477
  """
471
- y, dt = [], [] # outputs
478
+ y, dt, embeddings = [], [], [] # outputs
472
479
  for m in self.model[:-1]: # except the head part
473
480
  if m.f != -1: # if not from previous layer
474
481
  x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
@@ -478,6 +485,10 @@ class RTDETRDetectionModel(DetectionModel):
478
485
  y.append(x if m.i in self.save else None) # save output
479
486
  if visualize:
480
487
  feature_visualization(x, m.type, m.i, save_dir=visualize)
488
+ if embed and m.i in embed:
489
+ embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
490
+ if m.i == max(embed):
491
+ return torch.unbind(torch.cat(embeddings, 1), dim=0)
481
492
  head = self.model[-1]
482
493
  x = head([y[j] for j in head.f], batch) # head inference
483
494
  return x
@@ -62,7 +62,7 @@ class AIGym:
62
62
 
63
63
  def start_counting(self, im0, results, frame_count):
64
64
  """
65
- function used to count the gym steps
65
+ Function used to count the gym steps
66
66
  Args:
67
67
  im0 (ndarray): Current frame from the video stream.
68
68
  results: Pose estimation data
@@ -10,8 +10,7 @@ from ultralytics.utils.plotting import Annotator
10
10
 
11
11
  check_requirements('shapely>=2.0.0')
12
12
 
13
- from shapely.geometry import Polygon
14
- from shapely.geometry.point import Point
13
+ from shapely.geometry import LineString, Point, Polygon
15
14
 
16
15
 
17
16
  class Heatmap:
@@ -23,6 +22,7 @@ class Heatmap:
23
22
  # Visual information
24
23
  self.annotator = None
25
24
  self.view_img = False
25
+ self.shape = 'circle'
26
26
 
27
27
  # Image information
28
28
  self.imw = None
@@ -38,17 +38,25 @@ class Heatmap:
38
38
  self.boxes = None
39
39
  self.track_ids = None
40
40
  self.clss = None
41
- self.track_history = None
41
+ self.track_history = defaultdict(list)
42
42
 
43
- # Counting info
43
+ # Region & Line Information
44
44
  self.count_reg_pts = None
45
- self.count_region = None
45
+ self.counting_region = None
46
+ self.line_dist_thresh = 15
47
+ self.region_thickness = 5
48
+ self.region_color = (255, 0, 255)
49
+
50
+ # Object Counting Information
46
51
  self.in_counts = 0
47
52
  self.out_counts = 0
48
- self.count_list = []
53
+ self.counting_list = []
49
54
  self.count_txt_thickness = 0
50
- self.count_reg_color = (0, 255, 0)
51
- self.region_thickness = 5
55
+ self.count_txt_color = (0, 0, 0)
56
+ self.count_color = (255, 255, 255)
57
+
58
+ # Decay factor
59
+ self.decay_factor = 0.99
52
60
 
53
61
  # Check if environment support imshow
54
62
  self.env_check = check_imshow(warn=True)
@@ -61,8 +69,13 @@ class Heatmap:
61
69
  view_img=False,
62
70
  count_reg_pts=None,
63
71
  count_txt_thickness=2,
72
+ count_txt_color=(0, 0, 0),
73
+ count_color=(255, 255, 255),
64
74
  count_reg_color=(255, 0, 255),
65
- region_thickness=5):
75
+ region_thickness=5,
76
+ line_dist_thresh=15,
77
+ decay_factor=0.99,
78
+ shape='circle'):
66
79
  """
67
80
  Configures the heatmap colormap, width, height and display parameters.
68
81
 
@@ -74,25 +87,55 @@ class Heatmap:
74
87
  view_img (bool): Flag indicating frame display
75
88
  count_reg_pts (list): Object counting region points
76
89
  count_txt_thickness (int): Text thickness for object counting display
90
+ count_txt_color (RGB color): count text color value
91
+ count_color (RGB color): count text background color value
77
92
  count_reg_color (RGB color): Color of object counting region
78
93
  region_thickness (int): Object counting Region thickness
94
+ line_dist_thresh (int): Euclidean Distance threshold for line counter
95
+ decay_factor (float): value for removing heatmap area after object passed
96
+ shape (str): Heatmap shape, rect or circle shape supported
79
97
  """
80
98
  self.imw = imw
81
99
  self.imh = imh
82
- self.colormap = colormap
83
100
  self.heatmap_alpha = heatmap_alpha
84
101
  self.view_img = view_img
102
+ self.colormap = colormap
85
103
 
86
- self.heatmap = np.zeros((int(self.imw), int(self.imh)), dtype=np.float32) # Heatmap new frame
87
-
104
+ # Region and line selection
88
105
  if count_reg_pts is not None:
89
- self.track_history = defaultdict(list)
90
- self.count_reg_pts = count_reg_pts
91
- self.count_region = Polygon(self.count_reg_pts)
92
106
 
93
- self.count_txt_thickness = count_txt_thickness # Counting text thickness
94
- self.count_reg_color = count_reg_color
107
+ if len(count_reg_pts) == 2:
108
+ print('Line Counter Initiated.')
109
+ self.count_reg_pts = count_reg_pts
110
+ self.counting_region = LineString(count_reg_pts)
111
+
112
+ elif len(count_reg_pts) == 4:
113
+ print('Region Counter Initiated.')
114
+ self.count_reg_pts = count_reg_pts
115
+ self.counting_region = Polygon(self.count_reg_pts)
116
+
117
+ else:
118
+ print('Region or line points Invalid, 2 or 4 points supported')
119
+ print('Using Line Counter Now')
120
+ self.counting_region = Polygon([(20, 400), (1260, 400)]) # dummy points
121
+
122
+ # Heatmap new frame
123
+ self.heatmap = np.zeros((int(self.imw), int(self.imh)), dtype=np.float32)
124
+
125
+ self.count_txt_thickness = count_txt_thickness
126
+ self.count_txt_color = count_txt_color
127
+ self.count_color = count_color
128
+ self.region_color = count_reg_color
95
129
  self.region_thickness = region_thickness
130
+ self.decay_factor = decay_factor
131
+ self.line_dist_thresh = line_dist_thresh
132
+ self.shape = shape
133
+
134
+ # shape of heatmap, if not selected
135
+ if self.shape not in ['circle', 'rect']:
136
+ print("Unknown shape value provided, 'circle' & 'rect' supported")
137
+ print('Using Circular shape now')
138
+ self.shape = 'circle'
96
139
 
97
140
  def extract_results(self, tracks):
98
141
  """
@@ -117,17 +160,31 @@ class Heatmap:
117
160
  if tracks[0].boxes.id is None:
118
161
  return self.im0
119
162
 
163
+ self.heatmap *= self.decay_factor # decay factor
120
164
  self.extract_results(tracks)
121
165
  self.annotator = Annotator(self.im0, self.count_txt_thickness, None)
122
166
 
123
167
  if self.count_reg_pts is not None:
168
+
124
169
  # Draw counting region
125
170
  self.annotator.draw_region(reg_pts=self.count_reg_pts,
126
- color=self.count_reg_color,
171
+ color=self.region_color,
127
172
  thickness=self.region_thickness)
128
173
 
129
174
  for box, cls, track_id in zip(self.boxes, self.clss, self.track_ids):
130
- self.heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] += 1
175
+
176
+ if self.shape == 'circle':
177
+ center = (int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2))
178
+ radius = min(int(box[2]) - int(box[0]), int(box[3]) - int(box[1])) // 2
179
+
180
+ y, x = np.ogrid[0:self.heatmap.shape[0], 0:self.heatmap.shape[1]]
181
+ mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius ** 2
182
+
183
+ self.heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] += \
184
+ (2 * mask[int(box[1]):int(box[3]), int(box[0]):int(box[2])])
185
+
186
+ else:
187
+ self.heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] += 2
131
188
 
132
189
  # Store tracking hist
133
190
  track_line = self.track_history[track_id]
@@ -136,16 +193,39 @@ class Heatmap:
136
193
  track_line.pop(0)
137
194
 
138
195
  # Count objects
139
- if self.count_region.contains(Point(track_line[-1])):
140
- if track_id not in self.count_list:
141
- self.count_list.append(track_id)
142
- if box[0] < self.count_region.centroid.x:
143
- self.out_counts += 1
144
- else:
145
- self.in_counts += 1
196
+ if len(self.count_reg_pts) == 4:
197
+ if self.counting_region.contains(Point(track_line[-1])):
198
+ if track_id not in self.counting_list:
199
+ self.counting_list.append(track_id)
200
+ if box[0] < self.counting_region.centroid.x:
201
+ self.out_counts += 1
202
+ else:
203
+ self.in_counts += 1
204
+
205
+ elif len(self.count_reg_pts) == 2:
206
+ distance = Point(track_line[-1]).distance(self.counting_region)
207
+ if distance < self.line_dist_thresh:
208
+ if track_id not in self.counting_list:
209
+ self.counting_list.append(track_id)
210
+ if box[0] < self.counting_region.centroid.x:
211
+ self.out_counts += 1
212
+ else:
213
+ self.in_counts += 1
146
214
  else:
147
215
  for box, cls in zip(self.boxes, self.clss):
148
- self.heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] += 1
216
+
217
+ if self.shape == 'circle':
218
+ center = (int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2))
219
+ radius = min(int(box[2]) - int(box[0]), int(box[3]) - int(box[1])) // 2
220
+
221
+ y, x = np.ogrid[0:self.heatmap.shape[0], 0:self.heatmap.shape[1]]
222
+ mask = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= radius ** 2
223
+
224
+ self.heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] += \
225
+ (2 * mask[int(box[1]):int(box[3]), int(box[0]):int(box[2])])
226
+
227
+ else:
228
+ self.heatmap[int(box[1]):int(box[3]), int(box[0]):int(box[2])] += 2
149
229
 
150
230
  # Normalize, apply colormap to heatmap and combine with original image
151
231
  heatmap_normalized = cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX)
@@ -154,7 +234,11 @@ class Heatmap:
154
234
  if self.count_reg_pts is not None:
155
235
  incount_label = 'InCount : ' + f'{self.in_counts}'
156
236
  outcount_label = 'OutCount : ' + f'{self.out_counts}'
157
- self.annotator.count_labels(in_count=incount_label, out_count=outcount_label)
237
+ self.annotator.count_labels(in_count=incount_label,
238
+ out_count=outcount_label,
239
+ count_txt_size=self.count_txt_thickness,
240
+ txt_color=self.count_txt_color,
241
+ color=self.count_color)
158
242
 
159
243
  im0_with_heatmap = cv2.addWeighted(self.im0, 1 - self.heatmap_alpha, heatmap_colored, self.heatmap_alpha, 0)
160
244
 
@@ -9,8 +9,7 @@ from ultralytics.utils.plotting import Annotator, colors
9
9
 
10
10
  check_requirements('shapely>=2.0.0')
11
11
 
12
- from shapely.geometry import Polygon
13
- from shapely.geometry.point import Point
12
+ from shapely.geometry import LineString, Point, Polygon
14
13
 
15
14
 
16
15
  class ObjectCounter:
@@ -23,10 +22,12 @@ class ObjectCounter:
23
22
  self.is_drawing = False
24
23
  self.selected_point = None
25
24
 
26
- # Region Information
27
- self.reg_pts = None
25
+ # Region & Line Information
26
+ self.reg_pts = [(20, 400), (1260, 400)]
27
+ self.line_dist_thresh = 15
28
28
  self.counting_region = None
29
- self.region_color = (255, 255, 255)
29
+ self.region_color = (255, 0, 255)
30
+ self.region_thickness = 5
30
31
 
31
32
  # Image and annotation Information
32
33
  self.im0 = None
@@ -40,11 +41,15 @@ class ObjectCounter:
40
41
  self.in_counts = 0
41
42
  self.out_counts = 0
42
43
  self.counting_list = []
44
+ self.count_txt_thickness = 0
45
+ self.count_txt_color = (0, 0, 0)
46
+ self.count_color = (255, 255, 255)
43
47
 
44
48
  # Tracks info
45
49
  self.track_history = defaultdict(list)
46
50
  self.track_thickness = 2
47
51
  self.draw_tracks = False
52
+ self.track_color = (0, 255, 0)
48
53
 
49
54
  # Check if environment support imshow
50
55
  self.env_check = check_imshow(warn=True)
@@ -52,11 +57,17 @@ class ObjectCounter:
52
57
  def set_args(self,
53
58
  classes_names,
54
59
  reg_pts,
55
- region_color=None,
60
+ count_reg_color=(255, 0, 255),
56
61
  line_thickness=2,
57
62
  track_thickness=2,
58
63
  view_img=False,
59
- draw_tracks=False):
64
+ draw_tracks=False,
65
+ count_txt_thickness=2,
66
+ count_txt_color=(0, 0, 0),
67
+ count_color=(255, 255, 255),
68
+ track_color=(0, 255, 0),
69
+ region_thickness=5,
70
+ line_dist_thresh=15):
60
71
  """
61
72
  Configures the Counter's image, bounding box line thickness, and counting region points.
62
73
 
@@ -65,18 +76,43 @@ class ObjectCounter:
65
76
  view_img (bool): Flag to control whether to display the video stream.
66
77
  reg_pts (list): Initial list of points defining the counting region.
67
78
  classes_names (dict): Classes names
68
- region_color (tuple): color for region line
69
79
  track_thickness (int): Track thickness
70
80
  draw_tracks (Bool): draw tracks
81
+ count_txt_thickness (int): Text thickness for object counting display
82
+ count_txt_color (RGB color): count text color value
83
+ count_color (RGB color): count text background color value
84
+ count_reg_color (RGB color): Color of object counting region
85
+ track_color (RGB color): color for tracks
86
+ region_thickness (int): Object counting Region thickness
87
+ line_dist_thresh (int): Euclidean Distance threshold for line counter
71
88
  """
72
89
  self.tf = line_thickness
73
90
  self.view_img = view_img
74
91
  self.track_thickness = track_thickness
75
92
  self.draw_tracks = draw_tracks
76
- self.reg_pts = reg_pts
77
- self.counting_region = Polygon(self.reg_pts)
93
+
94
+ # Region and line selection
95
+ if len(reg_pts) == 2:
96
+ print('Line Counter Initiated.')
97
+ self.reg_pts = reg_pts
98
+ self.counting_region = LineString(self.reg_pts)
99
+ elif len(reg_pts) == 4:
100
+ print('Region Counter Initiated.')
101
+ self.reg_pts = reg_pts
102
+ self.counting_region = Polygon(self.reg_pts)
103
+ else:
104
+ print('Invalid Region points provided, region_points can be 2 or 4')
105
+ print('Using Line Counter Now')
106
+ self.counting_region = LineString(self.reg_pts)
107
+
78
108
  self.names = classes_names
79
- self.region_color = region_color if region_color else self.region_color
109
+ self.track_color = track_color
110
+ self.count_txt_thickness = count_txt_thickness
111
+ self.count_txt_color = count_txt_color
112
+ self.count_color = count_color
113
+ self.region_color = count_reg_color
114
+ self.region_thickness = region_thickness
115
+ self.line_dist_thresh = line_dist_thresh
80
116
 
81
117
  def mouse_event_for_region(self, event, x, y, flags, params):
82
118
  """
@@ -113,11 +149,14 @@ class ObjectCounter:
113
149
  clss = tracks[0].boxes.cls.cpu().tolist()
114
150
  track_ids = tracks[0].boxes.id.int().cpu().tolist()
115
151
 
152
+ # Annotator Init and region drawing
116
153
  self.annotator = Annotator(self.im0, self.tf, self.names)
117
- self.annotator.draw_region(reg_pts=self.reg_pts, color=(0, 255, 0))
154
+ self.annotator.draw_region(reg_pts=self.reg_pts, color=self.region_color, thickness=self.region_thickness)
118
155
 
156
+ # Extract tracks
119
157
  for box, track_id, cls in zip(boxes, track_ids, clss):
120
- self.annotator.box_label(box, label=self.names[cls], color=colors(int(cls), True)) # Draw bounding box
158
+ self.annotator.box_label(box, label=str(track_id) + ':' + self.names[cls],
159
+ color=colors(int(cls), True)) # Draw bounding box
121
160
 
122
161
  # Draw Tracks
123
162
  track_line = self.track_history[track_id]
@@ -125,27 +164,45 @@ class ObjectCounter:
125
164
  if len(track_line) > 30:
126
165
  track_line.pop(0)
127
166
 
167
+ # Draw track trails
128
168
  if self.draw_tracks:
129
169
  self.annotator.draw_centroid_and_tracks(track_line,
130
- color=(0, 255, 0),
170
+ color=self.track_color,
131
171
  track_thickness=self.track_thickness)
132
172
 
133
173
  # Count objects
134
- if self.counting_region.contains(Point(track_line[-1])):
135
- if track_id not in self.counting_list:
136
- self.counting_list.append(track_id)
137
- if box[0] < self.counting_region.centroid.x:
138
- self.out_counts += 1
139
- else:
140
- self.in_counts += 1
174
+ if len(self.reg_pts) == 4:
175
+ if self.counting_region.contains(Point(track_line[-1])):
176
+ if track_id not in self.counting_list:
177
+ self.counting_list.append(track_id)
178
+ if box[0] < self.counting_region.centroid.x:
179
+ self.out_counts += 1
180
+ else:
181
+ self.in_counts += 1
182
+
183
+ elif len(self.reg_pts) == 2:
184
+ distance = Point(track_line[-1]).distance(self.counting_region)
185
+ if distance < self.line_dist_thresh:
186
+ if track_id not in self.counting_list:
187
+ self.counting_list.append(track_id)
188
+ if box[0] < self.counting_region.centroid.x:
189
+ self.out_counts += 1
190
+ else:
191
+ self.in_counts += 1
192
+
193
+ incount_label = 'In Count : ' + f'{self.in_counts}'
194
+ outcount_label = 'OutCount : ' + f'{self.out_counts}'
195
+ self.annotator.count_labels(in_count=incount_label,
196
+ out_count=outcount_label,
197
+ count_txt_size=self.count_txt_thickness,
198
+ txt_color=self.count_txt_color,
199
+ color=self.count_color)
141
200
 
142
201
  if self.env_check and self.view_img:
143
- incount_label = 'InCount : ' + f'{self.in_counts}'
144
- outcount_label = 'OutCount : ' + f'{self.out_counts}'
145
- self.annotator.count_labels(in_count=incount_label, out_count=outcount_label)
146
202
  cv2.namedWindow('Ultralytics YOLOv8 Object Counter')
147
- cv2.setMouseCallback('Ultralytics YOLOv8 Object Counter', self.mouse_event_for_region,
148
- {'region_points': self.reg_pts})
203
+ if len(self.reg_pts) == 4: # only add mouse event If user drawn region
204
+ cv2.setMouseCallback('Ultralytics YOLOv8 Object Counter', self.mouse_event_for_region,
205
+ {'region_points': self.reg_pts})
149
206
  cv2.imshow('Ultralytics YOLOv8 Object Counter', self.im0)
150
207
  # Break Window
151
208
  if cv2.waitKey(1) & 0xFF == ord('q'):
@@ -160,6 +217,7 @@ class ObjectCounter:
160
217
  tracks (list): List of tracks obtained from the object tracking process.
161
218
  """
162
219
  self.im0 = im0 # store image
220
+
163
221
  if tracks[0].boxes.id is None:
164
222
  return
165
223
  self.extract_and_process_tracks(tracks)
@@ -260,19 +260,41 @@ class Annotator:
260
260
 
261
261
  # Object Counting Annotator
262
262
  def draw_region(self, reg_pts=None, color=(0, 255, 0), thickness=5):
263
- # Draw region line
263
+ """
264
+ Draw region line
265
+ Args:
266
+ reg_pts (list): Region Points (for line 2 points, for region 4 points)
267
+ color (tuple): Region Color value
268
+ thickness (int): Region area thickness value
269
+ """
264
270
  cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness)
265
271
 
266
272
  def draw_centroid_and_tracks(self, track, color=(255, 0, 255), track_thickness=2):
267
- # Draw region line
273
+ """
274
+ Draw centroid point and track trails
275
+ Args:
276
+ track (list): object tracking points for trails display
277
+ color (tuple): tracks line color
278
+ track_thickness (int): track line thickness value
279
+ """
268
280
  points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
269
281
  cv2.polylines(self.im, [points], isClosed=False, color=color, thickness=track_thickness)
270
282
  cv2.circle(self.im, (int(track[-1][0]), int(track[-1][1])), track_thickness * 2, color, -1)
271
283
 
272
- def count_labels(self, in_count=0, out_count=0, color=(255, 255, 255), txt_color=(0, 0, 0)):
284
+ def count_labels(self, in_count=0, out_count=0, count_txt_size=2, color=(255, 255, 255), txt_color=(0, 0, 0)):
285
+ """
286
+ Plot counts for object counter
287
+ Args:
288
+ in_count (int): in count value
289
+ out_count (int): out count value
290
+ count_txt_size (int): text size for counts display
291
+ color (tuple): background color of counts display
292
+ txt_color (tuple): text color of counts display
293
+ """
294
+ self.tf = count_txt_size
273
295
  tl = self.tf or round(0.002 * (self.im.shape[0] + self.im.shape[1]) / 2) + 1
274
296
  tf = max(tl - 1, 1)
275
- gap = int(24 * tl) # Calculate the gap between in_count and out_count based on line_thickness
297
+ gap = int(24 * tl) # gap between in_count and out_count based on line_thickness
276
298
 
277
299
  # Get text size for in_count and out_count
278
300
  t_size_in = cv2.getTextSize(str(in_count), 0, fontScale=tl / 2, thickness=tf)[0]
@@ -306,14 +328,13 @@ class Annotator:
306
328
  thickness=self.tf,
307
329
  lineType=cv2.LINE_AA)
308
330
 
309
- # AI GYM Annotator
310
- def estimate_pose_angle(self, a, b, c):
331
+ @staticmethod
332
+ def estimate_pose_angle(a, b, c):
311
333
  """Calculate the pose angle for object
312
334
  Args:
313
335
  a (float) : The value of pose point a
314
336
  b (float): The value of pose point b
315
337
  c (float): The value o pose point c
316
-
317
338
  Returns:
318
339
  angle (degree): Degree value of angle between three points
319
340
  """
@@ -325,7 +346,15 @@ class Annotator:
325
346
  return angle
326
347
 
327
348
  def draw_specific_points(self, keypoints, indices=[2, 5, 7], shape=(640, 640), radius=2):
328
- """Draw specific keypoints for gym steps counting."""
349
+ """
350
+ Draw specific keypoints for gym steps counting.
351
+
352
+ Args:
353
+ keypoints (list): list of keypoints data to be plotted
354
+ indices (list): keypoints ids list to be plotted
355
+ shape (tuple): imgsz for model inference
356
+ radius (int): Keypoint radius value
357
+ """
329
358
  nkpts, ndim = keypoints.shape
330
359
  nkpts == 17 and ndim == 3
331
360
  for i, k in enumerate(keypoints):
@@ -340,25 +369,31 @@ class Annotator:
340
369
  return self.im
341
370
 
342
371
  def plot_angle_and_count_and_stage(self, angle_text, count_text, stage_text, center_kpt, line_thickness=2):
343
- """Plot the pose angle, count value and step stage."""
344
- angle_text, count_text, stage_text = f' {angle_text:.2f}', 'Steps : ' + f'{count_text}', f' {stage_text}'
372
+ """
373
+ Plot the pose angle, count value and step stage.
374
+
375
+ Args:
376
+ angle_text (str): angle value for workout monitoring
377
+ count_text (str): counts value for workout monitoring
378
+ stage_text (str): stage decision for workout monitoring
379
+ center_kpt (int): centroid pose index for workout monitoring
380
+ line_thickness (int): thickness for text display
381
+ """
382
+ angle_text, count_text, stage_text = (f' {angle_text:.2f}', 'Steps : ' + f'{count_text}', f' {stage_text}')
345
383
  font_scale = 0.6 + (line_thickness / 10.0)
346
384
 
347
385
  # Draw angle
348
- (angle_text_width, angle_text_height), _ = cv2.getTextSize(angle_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale,
349
- line_thickness)
386
+ (angle_text_width, angle_text_height), _ = cv2.getTextSize(angle_text, 0, font_scale, line_thickness)
350
387
  angle_text_position = (int(center_kpt[0]), int(center_kpt[1]))
351
388
  angle_background_position = (angle_text_position[0], angle_text_position[1] - angle_text_height - 5)
352
389
  angle_background_size = (angle_text_width + 2 * 5, angle_text_height + 2 * 5 + (line_thickness * 2))
353
390
  cv2.rectangle(self.im, angle_background_position, (angle_background_position[0] + angle_background_size[0],
354
391
  angle_background_position[1] + angle_background_size[1]),
355
392
  (255, 255, 255), -1)
356
- cv2.putText(self.im, angle_text, angle_text_position, cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0),
357
- line_thickness)
393
+ cv2.putText(self.im, angle_text, angle_text_position, 0, font_scale, (0, 0, 0), line_thickness)
358
394
 
359
395
  # Draw Counts
360
- (count_text_width, count_text_height), _ = cv2.getTextSize(count_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale,
361
- line_thickness)
396
+ (count_text_width, count_text_height), _ = cv2.getTextSize(count_text, 0, font_scale, line_thickness)
362
397
  count_text_position = (angle_text_position[0], angle_text_position[1] + angle_text_height + 20)
363
398
  count_background_position = (angle_background_position[0],
364
399
  angle_background_position[1] + angle_background_size[1] + 5)
@@ -367,12 +402,10 @@ class Annotator:
367
402
  cv2.rectangle(self.im, count_background_position, (count_background_position[0] + count_background_size[0],
368
403
  count_background_position[1] + count_background_size[1]),
369
404
  (255, 255, 255), -1)
370
- cv2.putText(self.im, count_text, count_text_position, cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0),
371
- line_thickness)
405
+ cv2.putText(self.im, count_text, count_text_position, 0, font_scale, (0, 0, 0), line_thickness)
372
406
 
373
407
  # Draw Stage
374
- (stage_text_width, stage_text_height), _ = cv2.getTextSize(stage_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale,
375
- line_thickness)
408
+ (stage_text_width, stage_text_height), _ = cv2.getTextSize(stage_text, 0, font_scale, line_thickness)
376
409
  stage_text_position = (int(center_kpt[0]), int(center_kpt[1]) + angle_text_height + count_text_height + 40)
377
410
  stage_background_position = (stage_text_position[0], stage_text_position[1] - stage_text_height - 5)
378
411
  stage_background_size = (stage_text_width + 10, stage_text_height + 10)
@@ -380,8 +413,45 @@ class Annotator:
380
413
  cv2.rectangle(self.im, stage_background_position, (stage_background_position[0] + stage_background_size[0],
381
414
  stage_background_position[1] + stage_background_size[1]),
382
415
  (255, 255, 255), -1)
383
- cv2.putText(self.im, stage_text, stage_text_position, cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0),
384
- line_thickness)
416
+ cv2.putText(self.im, stage_text, stage_text_position, 0, font_scale, (0, 0, 0), line_thickness)
417
+
418
+ def seg_bbox(self, mask, mask_color=(255, 0, 255), det_label=None, track_label=None):
419
+ """
420
+ Function for drawing segmented object in bounding box shape.
421
+
422
+ Args:
423
+ mask (list): masks data list for instance segmentation area plotting
424
+ mask_color (tuple): mask foreground color
425
+ det_label (str): Detection label text
426
+ track_label (str): Tracking label text
427
+ """
428
+ cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2)
429
+
430
+ label = f'Track ID: {track_label}' if track_label else det_label
431
+ text_size, _ = cv2.getTextSize(label, 0, 0.7, 1)
432
+
433
+ cv2.rectangle(self.im, (int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10),
434
+ (int(mask[0][0]) + text_size[0] // 2 + 5, int(mask[0][1] + 5)), mask_color, -1)
435
+
436
+ cv2.putText(self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1]) - 5), 0, 0.7, (255, 255, 255),
437
+ 2)
438
+
439
+ def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, 255), thickness=2, pins_radius=10):
440
+ """
441
+ Function for pinpoint human-vision eye mapping and plotting.
442
+
443
+ Args:
444
+ box (list): Bounding box coordinates
445
+ center_point (tuple): center point for vision eye view
446
+ color (tuple): object centroid and line color value
447
+ pin_color (tuple): visioneye point color value
448
+ thickness (int): int value for line thickness
449
+ pins_radius (int): visioneye point radius value
450
+ """
451
+ center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
452
+ cv2.circle(self.im, center_point, pins_radius, pin_color, -1)
453
+ cv2.circle(self.im, center_bbox, pins_radius, color, -1)
454
+ cv2.line(self.im, center_point, center_bbox, color, thickness)
385
455
 
386
456
 
387
457
  @TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
@@ -363,7 +363,7 @@ def de_parallel(model):
363
363
 
364
364
  def one_cycle(y1=0.0, y2=1.0, steps=100):
365
365
  """Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf."""
366
- return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
366
+ return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1
367
367
 
368
368
 
369
369
  def init_seeds(seed=0, deterministic=False):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ultralytics
3
- Version: 8.0.227
3
+ Version: 8.0.229
4
4
  Summary: Ultralytics YOLOv8 for SOTA object detection, multi-object tracking, instance segmentation, pose estimation and image classification.
5
5
  Home-page: https://github.com/ultralytics/ultralytics
6
6
  Author: Ultralytics
@@ -60,6 +60,8 @@ Provides-Extra: export
60
60
  Requires-Dist: coremltools >=7.0 ; extra == 'export'
61
61
  Requires-Dist: openvino-dev >=2023.0 ; extra == 'export'
62
62
  Requires-Dist: tensorflow <=2.13.1 ; extra == 'export'
63
+ Requires-Dist: jax <=0.4.21 ; extra == 'export'
64
+ Requires-Dist: jaxlib <=0.4.21 ; extra == 'export'
63
65
  Requires-Dist: tensorflowjs ; extra == 'export'
64
66
 
65
67
  <div align="center">
@@ -1,8 +1,8 @@
1
- ultralytics/__init__.py,sha256=QzCgepGoTUoeYPBmIF3x8iTZ4L5fCkbplNSn2zxuH5s,463
1
+ ultralytics/__init__.py,sha256=EhwYUqe_mS7jZksFDr-yyVHPSPorRPpkiAZ_w0K_MzU,463
2
2
  ultralytics/assets/bus.jpg,sha256=wCAZxJecGR63Od3ZRERe9Aja1Weayrb9Ug751DS_vGM,137419
3
3
  ultralytics/assets/zidane.jpg,sha256=Ftc4aeMmen1O0A3o6GCDO9FlfBslLpTAw0gnetx7bts,50427
4
- ultralytics/cfg/__init__.py,sha256=kuLZLdP7SKcvBYEvtKZPtXXqCEhkQyEZysYGH_jMx6E,19851
5
- ultralytics/cfg/default.yaml,sha256=CdgfcU2VAFjJXuCUrghQcupOlgqZQ08vSuTFmNSLY8A,7652
4
+ ultralytics/cfg/__init__.py,sha256=GszkldmONF8PI1J1o4TqNzq1Btzk8R7Y3_susNMHvpA,19859
5
+ ultralytics/cfg/default.yaml,sha256=-ejbKAG_xK9bke-Yr5w-1wcwHeA5vZ1FyOgBi27aShQ,7822
6
6
  ultralytics/cfg/datasets/Argoverse.yaml,sha256=TJhOiAm1QOsQnDkg1eEGYlaylgkvKLzBUdQ5gzyi_pY,2856
7
7
  ultralytics/cfg/datasets/DOTAv2.yaml,sha256=SmSpmbz_wRT8HMmPqsHpjep_b-nvckTutoEwVpGaUZM,1149
8
8
  ultralytics/cfg/datasets/GlobalWheat2020.yaml,sha256=Wd9spwO4HV48REvgqDUX-kM5a8rxceFalxkcvWDnJZI,1981
@@ -49,17 +49,17 @@ ultralytics/data/__init__.py,sha256=TWN-3tE7pPBkGkvAFZoSexBkCw24Fp49swcKeIylHlE,
49
49
  ultralytics/data/annotator.py,sha256=8Ui_4H4dAU09BQ-gDwW4uqDVMxiKZzNIfF1s5ZvhGk0,2122
50
50
  ultralytics/data/augment.py,sha256=T5CXNeO8o5q9dbskyYFeFZtmNcN2nV2qhRqpniptT6M,48264
51
51
  ultralytics/data/base.py,sha256=ltqBt-UFnnPlK_2E4nVvYjIAUkR9PAgW1kNdGN101m4,13309
52
- ultralytics/data/build.py,sha256=5CwYQ5qXcvtyMGEAZatP0m8BHWiEv1fyIo51U-p6w9o,6691
52
+ ultralytics/data/build.py,sha256=2QClXZyI6CEV-DYqR-7soXs-gc_4sF9AIQU9UTpW3As,6671
53
53
  ultralytics/data/converter.py,sha256=tbV_LVvkr4gkLTuNM2v0dQPBuOBH2pknKlB3ivagzQU,12505
54
54
  ultralytics/data/dataset.py,sha256=IZyVml86cLF2t8Y8rToEN-H-OSdt5QQgXAFHo6YAJ_U,16019
55
55
  ultralytics/data/loaders.py,sha256=yDI0Xtb6IxpkU-fxdlPiBOY1FYDPEPDahre0rcgy2T8,22200
56
56
  ultralytics/data/utils.py,sha256=1vKuCYOA_haro4tjzVSgOgwvdyEXw4UKmfYyPtfwXis,29699
57
57
  ultralytics/engine/__init__.py,sha256=mHtJuK4hwF8cuV-VHDc7tp6u6D1gHz2Z7JI8grmQDTs,42
58
- ultralytics/engine/exporter.py,sha256=_l4IPGRvaTtVjgiuCtVKrJnJDzJlJCiwNXtdfo7Hulo,50787
59
- ultralytics/engine/model.py,sha256=1cmagS8BskMzOay9uDlFIvS5m58GB2kxZoYnUfxIAbU,19236
60
- ultralytics/engine/predictor.py,sha256=tgwQ58bziem5rZXucVyK0LP5fzvAgJtICBaN7kLUM9s,17548
58
+ ultralytics/engine/exporter.py,sha256=8bttk0XZMo8tstnVdGBVPhGnPnirOor28xDLHQP4cz4,51226
59
+ ultralytics/engine/model.py,sha256=L1irDV83yBT2aWu053ukhsFGh5hZlsqKyJb_9-d5D0I,20107
60
+ ultralytics/engine/predictor.py,sha256=C6ZmZu5q8-6stuk1vE1P-E9LoPaO_85xo645SjemchA,17777
61
61
  ultralytics/engine/results.py,sha256=2GND_qGa8W8qJTyaSSt3qoPBqAS2JA5CDAB4y6wwdh8,23417
62
- ultralytics/engine/trainer.py,sha256=9IxGS3K3QE3pGhEd0JWBijIpHj1MSvRI38KtvyfS2Ck,32553
62
+ ultralytics/engine/trainer.py,sha256=G7WN1rKd-rasJ0cFdZu8cj795z_5o1mI_7eX6p54X3k,33886
63
63
  ultralytics/engine/tuner.py,sha256=_9MAsXQwDtmDznqb6_cgk1DIo8FTwLgM3OTEifCxRp0,11715
64
64
  ultralytics/engine/validator.py,sha256=1-N1Fh563A4sD-sB1c3MiYX9PtJliZ-ta0c-sObUDfc,14453
65
65
  ultralytics/hub/__init__.py,sha256=iZzEg98gDEr2bfPZopHwnFIfDVDZ9a-yuAAkPKnn2hw,3685
@@ -115,8 +115,8 @@ ultralytics/models/yolo/segment/predict.py,sha256=yUs60HFBn7PZ3mErtUAnT69ijPBzFd
115
115
  ultralytics/models/yolo/segment/train.py,sha256=o1q4ZTmZlSwUbFIFaT_T7LvYaKOLq_QXxB-z61YwHx8,2276
116
116
  ultralytics/models/yolo/segment/val.py,sha256=DT-z-XnxP77nTIu2VfmGlpUyeBnDmIszT4vpP7mkGNA,11956
117
117
  ultralytics/nn/__init__.py,sha256=7T_GW3YsPg1kA-74UklF2UcabcRyttRZYrCOXiNnJqU,555
118
- ultralytics/nn/autobackend.py,sha256=eFn23VKky5qEwXpAcK3VXJwC_kiXsFKVQcbO2C38v60,26957
119
- ultralytics/nn/tasks.py,sha256=wiT7k194SU8Ckb3hICaOwyQ5tdSyT5qbug0VlnW5kyA,36609
118
+ ultralytics/nn/autobackend.py,sha256=BRiDYbLrsIOF9DHoVB-IbLUZ1NOtzBdSy9xb420c2FQ,27022
119
+ ultralytics/nn/tasks.py,sha256=djDmgi5PFpUwuNhuxbUpCqZvtHfADvaBycGqY110pIo,37466
120
120
  ultralytics/nn/modules/__init__.py,sha256=vrndehuJuLdA3UMHgByPUSR8rz32naUN0LIZoPzF7YQ,1698
121
121
  ultralytics/nn/modules/block.py,sha256=_A24bZ1xSWvrvqk5RODeobBZL6ReI6ICk-vwilERTZs,14475
122
122
  ultralytics/nn/modules/conv.py,sha256=z_OQka9s5h0p3k1yWrq7SHg1BsA6PfN5lDSQubW2I_k,12774
@@ -124,9 +124,9 @@ ultralytics/nn/modules/head.py,sha256=GVory97vQms41CRExgEhMd5dJZTwleJVL5dWyU6pU2
124
124
  ultralytics/nn/modules/transformer.py,sha256=R7K_3r4aTlvghiTTRzh69NmNzlO_1SiiifbevHGllEE,17895
125
125
  ultralytics/nn/modules/utils.py,sha256=q-qfebnMD2iqZyTslZTHsZYW7hyrX62VRgUmHX683-U,3436
126
126
  ultralytics/solutions/__init__.py,sha256=mHtJuK4hwF8cuV-VHDc7tp6u6D1gHz2Z7JI8grmQDTs,42
127
- ultralytics/solutions/ai_gym.py,sha256=AkD2stdBQXETbXftZVCTmwHgZ6X_Ok5nS4wFazfJuDA,6235
128
- ultralytics/solutions/heatmap.py,sha256=3wjg31Mgt6fmDt-rXpx53v4XGsgMBjdaWHi7ENbpBnc,6391
129
- ultralytics/solutions/object_counter.py,sha256=tRl2G94v7qbQkJSBbl9vR95jMLKTKIMMMJp2t40Xp9s,6587
127
+ ultralytics/solutions/ai_gym.py,sha256=YnBeC8Vf3-ai4OQIebEXl5yDho6uRspY2XVL8Ipr-h8,6235
128
+ ultralytics/solutions/heatmap.py,sha256=BKsFF3GbtWGHKCfIWXOcu54dZArRf8b1rIq6dAZLzeQ,10310
129
+ ultralytics/solutions/object_counter.py,sha256=-hEmw93gSz_cqZyHPp7w9nNLBXlPSLeP8pBj2cvNNi8,9338
130
130
  ultralytics/trackers/__init__.py,sha256=dR9unDaRBd6MgMnTKxqJZ0KsJ8BeFGg-LTYQvC7BnIY,227
131
131
  ultralytics/trackers/basetrack.py,sha256=Vbs76Zue_jYdJFudztTJaUnGgMMUwVqoa0BSOhyBh0o,3580
132
132
  ultralytics/trackers/bot_sort.py,sha256=orTkrMj2yHfEQVKaQVWbguTx98S2gvLnaOB0D2JN1Gc,8602
@@ -149,9 +149,9 @@ ultralytics/utils/loss.py,sha256=pYkQu-11idOM_6MDXrRS7PgJRlEH8qTxhYgOr4a_aq4,257
149
149
  ultralytics/utils/metrics.py,sha256=g_NgGDG5pFchoB0u6JxuTavLtrARphvfQEfY2KUksJc,47438
150
150
  ultralytics/utils/ops.py,sha256=h2nRGf6pAwO3muXx0SWi0p-ROrglQtGng0C_coDPhgQ,31297
151
151
  ultralytics/utils/patches.py,sha256=V3ARuy0sg-_yn6nzL7iOWSzR_RzFOuzsICy4P6qUegc,2233
152
- ultralytics/utils/plotting.py,sha256=kMiJApYWRSCQsd92Dgh-mTSAg1ia27AofAbNA5Ma4P8,38711
152
+ ultralytics/utils/plotting.py,sha256=iy5r40PLPueacuEQn6N1fS-FtvW_MCqs1-vwbPxciQ8,41662
153
153
  ultralytics/utils/tal.py,sha256=WxW_J5QC8oYAXKDy_huJC3mijBtpWG7UR145IAXO5_I,13675
154
- ultralytics/utils/torch_utils.py,sha256=I0xpBXzehK1ZwlQeIjWHr0EQXBewTN4twqMAMllq-_k,24547
154
+ ultralytics/utils/torch_utils.py,sha256=09M6zCz66_rR5NdbryDDiyT6-BUxKJ4l3OZCRHnfCkM,24553
155
155
  ultralytics/utils/triton.py,sha256=opbB1ndgwfmUJzyvUH9vvMe2SrDW6FqmFxKEeNDaALQ,3932
156
156
  ultralytics/utils/tuner.py,sha256=8QfeoYdVtPZHSkg7o06DTlwFKQS-f_5XemDa1vKkums,6227
157
157
  ultralytics/utils/callbacks/__init__.py,sha256=nhrnMPpPDb5fgqw42w8e7fC5TjEPC-jp04dpQtaQtkU,214
@@ -165,9 +165,9 @@ ultralytics/utils/callbacks/neptune.py,sha256=qIN0gJipB1f3Di7bw0Rb28jLYoCzJSWSqF
165
165
  ultralytics/utils/callbacks/raytune.py,sha256=PGZvW_haVq8Cqha3GgvL7iBMAaxfn8_3u_IIdYCNMPo,608
166
166
  ultralytics/utils/callbacks/tensorboard.py,sha256=XXnpkIJrI_A_68JLRvYvRMHzekn-US1uIcru7vRs_e0,2896
167
167
  ultralytics/utils/callbacks/wb.py,sha256=x_j4ZH4Klp0_Ld13f0UezFluUTS5Ovfgk9hcjwqeruU,6762
168
- ultralytics-8.0.227.dist-info/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
169
- ultralytics-8.0.227.dist-info/METADATA,sha256=rIyhvhn4oc_R1dkikm65MTVK1paFKoxmhQlEg1pNJc8,32172
170
- ultralytics-8.0.227.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
171
- ultralytics-8.0.227.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
172
- ultralytics-8.0.227.dist-info/top_level.txt,sha256=aNSJehhoYKycM3X4Tj38Q-BrmWFFm3hFuEXfPIR89eI,784
173
- ultralytics-8.0.227.dist-info/RECORD,,
168
+ ultralytics-8.0.229.dist-info/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
169
+ ultralytics-8.0.229.dist-info/METADATA,sha256=Vol7oXlFIUdEVAb6xjU8dcFfrILh9PPSRMwyoGgarLs,32271
170
+ ultralytics-8.0.229.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
171
+ ultralytics-8.0.229.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
172
+ ultralytics-8.0.229.dist-info/top_level.txt,sha256=aNSJehhoYKycM3X4Tj38Q-BrmWFFm3hFuEXfPIR89eI,784
173
+ ultralytics-8.0.229.dist-info/RECORD,,