ultralytics 8.3.145__py3-none-any.whl → 8.3.146__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,7 @@
2
2
 
3
3
  from multiprocessing.pool import ThreadPool
4
4
  from pathlib import Path
5
+ from typing import Any, Dict, List, Optional, Tuple
5
6
 
6
7
  import numpy as np
7
8
  import torch
@@ -35,7 +36,7 @@ class SegmentationValidator(DetectionValidator):
35
36
  >>> validator()
36
37
  """
37
38
 
38
- def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
39
+ def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None) -> None:
39
40
  """
40
41
  Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
41
42
 
@@ -52,13 +53,21 @@ class SegmentationValidator(DetectionValidator):
52
53
  self.args.task = "segment"
53
54
  self.metrics = SegmentMetrics(save_dir=self.save_dir)
54
55
 
55
- def preprocess(self, batch):
56
- """Preprocess batch by converting masks to float and sending to device."""
56
+ def preprocess(self, batch: Dict[str, Any]) -> Dict[str, Any]:
57
+ """
58
+ Preprocess batch of images for YOLO segmentation validation.
59
+
60
+ Args:
61
+ batch (Dict[str, Any]): Batch containing images and annotations.
62
+
63
+ Returns:
64
+ (Dict[str, Any]): Preprocessed batch.
65
+ """
57
66
  batch = super().preprocess(batch)
58
67
  batch["masks"] = batch["masks"].to(self.device).float()
59
68
  return batch
60
69
 
61
- def init_metrics(self, model):
70
+ def init_metrics(self, model: torch.nn.Module) -> None:
62
71
  """
63
72
  Initialize metrics and select mask processing function based on save_json flag.
64
73
 
@@ -73,7 +82,7 @@ class SegmentationValidator(DetectionValidator):
73
82
  self.process = ops.process_mask_native if self.args.save_json or self.args.save_txt else ops.process_mask
74
83
  self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
75
84
 
76
- def get_desc(self):
85
+ def get_desc(self) -> str:
77
86
  """Return a formatted description of evaluation metrics."""
78
87
  return ("%22s" + "%11s" * 10) % (
79
88
  "Class",
@@ -89,44 +98,46 @@ class SegmentationValidator(DetectionValidator):
89
98
  "mAP50-95)",
90
99
  )
91
100
 
92
- def postprocess(self, preds):
101
+ def postprocess(self, preds: List[torch.Tensor]) -> Tuple[List[torch.Tensor], torch.Tensor]:
93
102
  """
94
103
  Post-process YOLO predictions and return output detections with proto.
95
104
 
96
105
  Args:
97
- preds (list): Raw predictions from the model.
106
+ preds (List[torch.Tensor]): Raw predictions from the model.
98
107
 
99
108
  Returns:
100
- p (torch.Tensor): Processed detection predictions.
109
+ p (List[torch.Tensor]): Processed detection predictions.
101
110
  proto (torch.Tensor): Prototype masks for segmentation.
102
111
  """
103
112
  p = super().postprocess(preds[0])
104
113
  proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
105
114
  return p, proto
106
115
 
107
- def _prepare_batch(self, si, batch):
116
+ def _prepare_batch(self, si: int, batch: Dict[str, Any]) -> Dict[str, Any]:
108
117
  """
109
118
  Prepare a batch for training or inference by processing images and targets.
110
119
 
111
120
  Args:
112
121
  si (int): Batch index.
113
- batch (dict): Batch data containing images and targets.
122
+ batch (Dict[str, Any]): Batch data containing images and annotations.
114
123
 
115
124
  Returns:
116
- (dict): Prepared batch with processed images and targets.
125
+ (Dict[str, Any]): Prepared batch with processed annotations.
117
126
  """
118
127
  prepared_batch = super()._prepare_batch(si, batch)
119
128
  midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si
120
129
  prepared_batch["masks"] = batch["masks"][midx]
121
130
  return prepared_batch
122
131
 
123
- def _prepare_pred(self, pred, pbatch, proto):
132
+ def _prepare_pred(
133
+ self, pred: torch.Tensor, pbatch: Dict[str, Any], proto: torch.Tensor
134
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
124
135
  """
125
136
  Prepare predictions for evaluation by processing bounding boxes and masks.
126
137
 
127
138
  Args:
128
139
  pred (torch.Tensor): Raw predictions from the model.
129
- pbatch (dict): Prepared batch data.
140
+ pbatch (Dict[str, Any]): Prepared batch information.
130
141
  proto (torch.Tensor): Prototype masks for segmentation.
131
142
 
132
143
  Returns:
@@ -137,13 +148,13 @@ class SegmentationValidator(DetectionValidator):
137
148
  pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch["imgsz"])
138
149
  return predn, pred_masks
139
150
 
140
- def update_metrics(self, preds, batch):
151
+ def update_metrics(self, preds: Tuple[List[torch.Tensor], torch.Tensor], batch: Dict[str, Any]) -> None:
141
152
  """
142
153
  Update metrics with the current batch predictions and targets.
143
154
 
144
155
  Args:
145
- preds (list): Predictions from the model.
146
- batch (dict): Batch data containing images and targets.
156
+ preds (Tuple[List[torch.Tensor], torch.Tensor]): List of predictions from the model.
157
+ batch (Dict[str, Any]): Batch data containing ground truth.
147
158
  """
148
159
  for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
149
160
  self.seen += 1
@@ -214,21 +225,16 @@ class SegmentationValidator(DetectionValidator):
214
225
  self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
215
226
  )
216
227
 
217
- def finalize_metrics(self, *args, **kwargs):
218
- """
219
- Finalize evaluation metrics by setting the speed attribute in the metrics object.
220
-
221
- This method is called at the end of validation to set the processing speed for the metrics calculations.
222
- It transfers the validator's speed measurement to the metrics object for reporting.
223
-
224
- Args:
225
- *args (Any): Variable length argument list.
226
- **kwargs (Any): Arbitrary keyword arguments.
227
- """
228
- self.metrics.speed = self.speed
229
- self.metrics.confusion_matrix = self.confusion_matrix
230
-
231
- def _process_batch(self, detections, gt_bboxes, gt_cls, pred_masks=None, gt_masks=None, overlap=False, masks=False):
228
+ def _process_batch(
229
+ self,
230
+ detections: torch.Tensor,
231
+ gt_bboxes: torch.Tensor,
232
+ gt_cls: torch.Tensor,
233
+ pred_masks: Optional[torch.Tensor] = None,
234
+ gt_masks: Optional[torch.Tensor] = None,
235
+ overlap: Optional[bool] = False,
236
+ masks: Optional[bool] = False,
237
+ ) -> torch.Tensor:
232
238
  """
233
239
  Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
234
240
 
@@ -241,8 +247,8 @@ class SegmentationValidator(DetectionValidator):
241
247
  pred_masks (torch.Tensor, optional): Tensor representing predicted masks, if available. The shape should
242
248
  match the ground truth masks.
243
249
  gt_masks (torch.Tensor, optional): Tensor of shape (M, H, W) representing ground truth masks, if available.
244
- overlap (bool): Flag indicating if overlapping masks should be considered.
245
- masks (bool): Flag indicating if the batch contains mask data.
250
+ overlap (bool, optional): Flag indicating if overlapping masks should be considered.
251
+ masks (bool, optional): Flag indicating if the batch contains mask data.
246
252
 
247
253
  Returns:
248
254
  (torch.Tensor): A correct prediction matrix of shape (N, 10), where 10 represents different IoU levels.
@@ -272,12 +278,12 @@ class SegmentationValidator(DetectionValidator):
272
278
 
273
279
  return self.match_predictions(detections[:, 5], gt_cls, iou)
274
280
 
275
- def plot_val_samples(self, batch, ni):
281
+ def plot_val_samples(self, batch: Dict[str, Any], ni: int) -> None:
276
282
  """
277
283
  Plot validation samples with bounding box labels and masks.
278
284
 
279
285
  Args:
280
- batch (dict): Batch data containing images and targets.
286
+ batch (Dict[str, Any]): Batch containing images and annotations.
281
287
  ni (int): Batch index.
282
288
  """
283
289
  plot_images(
@@ -292,13 +298,13 @@ class SegmentationValidator(DetectionValidator):
292
298
  on_plot=self.on_plot,
293
299
  )
294
300
 
295
- def plot_predictions(self, batch, preds, ni):
301
+ def plot_predictions(self, batch: Dict[str, Any], preds: List[torch.Tensor], ni: int) -> None:
296
302
  """
297
303
  Plot batch predictions with masks and bounding boxes.
298
304
 
299
305
  Args:
300
- batch (dict): Batch data containing images.
301
- preds (list): Predictions from the model.
306
+ batch (Dict[str, Any]): Batch containing images and annotations.
307
+ preds (List[torch.Tensor]): List of predictions from the model.
302
308
  ni (int): Batch index.
303
309
  """
304
310
  plot_images(
@@ -312,15 +318,17 @@ class SegmentationValidator(DetectionValidator):
312
318
  ) # pred
313
319
  self.plot_masks.clear()
314
320
 
315
- def save_one_txt(self, predn, pred_masks, save_conf, shape, file):
321
+ def save_one_txt(
322
+ self, predn: torch.Tensor, pred_masks: torch.Tensor, save_conf: bool, shape: Tuple[int, int], file: Path
323
+ ) -> None:
316
324
  """
317
325
  Save YOLO detections to a txt file in normalized coordinates in a specific format.
318
326
 
319
327
  Args:
320
- predn (torch.Tensor): Predictions in the format [x1, y1, x2, y2, conf, cls].
328
+ predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
321
329
  pred_masks (torch.Tensor): Predicted masks.
322
330
  save_conf (bool): Whether to save confidence scores.
323
- shape (tuple): Original image shape.
331
+ shape (Tuple[int, int]): Shape of the original image.
324
332
  file (Path): File path to save the detections.
325
333
  """
326
334
  from ultralytics.engine.results import Results
@@ -333,7 +341,7 @@ class SegmentationValidator(DetectionValidator):
333
341
  masks=pred_masks,
334
342
  ).save_txt(file, save_conf=save_conf)
335
343
 
336
- def pred_to_json(self, predn, filename, pred_masks):
344
+ def pred_to_json(self, predn: torch.Tensor, filename: str, pred_masks: torch.Tensor) -> None:
337
345
  """
338
346
  Save one JSON result for COCO evaluation.
339
347
 
@@ -371,8 +379,8 @@ class SegmentationValidator(DetectionValidator):
371
379
  }
372
380
  )
373
381
 
374
- def eval_json(self, stats):
375
- """Return COCO-style object detection evaluation metrics."""
382
+ def eval_json(self, stats: Dict[str, Any]) -> Dict[str, Any]:
383
+ """Return COCO-style instance segmentation evaluation metrics."""
376
384
  if self.args.save_json and (self.is_lvis or self.is_coco) and len(self.jdict):
377
385
  pred_json = self.save_dir / "predictions.json" # predictions
378
386
 
@@ -68,6 +68,8 @@ class Analytics(BaseSolution):
68
68
 
69
69
  self.total_counts = 0 # count variable for storing total counts i.e. for line
70
70
  self.clswise_count = {} # dictionary for class-wise counts
71
+ self.update_every = kwargs.get("update_every", 30) # Only update graph every 30 frames by default
72
+ self.last_plot_im = None # Cache of the last rendered chart
71
73
 
72
74
  # Ensure line and area chart
73
75
  if self.type in {"line", "area"}:
@@ -111,16 +113,21 @@ class Analytics(BaseSolution):
111
113
  if self.type == "line":
112
114
  for _ in self.boxes:
113
115
  self.total_counts += 1
114
- plot_im = self.update_graph(frame_number=frame_number)
116
+ update_required = frame_number % self.update_every == 0 or self.last_plot_im is None
117
+ if update_required:
118
+ self.last_plot_im = self.update_graph(frame_number=frame_number)
119
+ plot_im = self.last_plot_im
115
120
  self.total_counts = 0
116
121
  elif self.type in {"pie", "bar", "area"}:
117
- self.clswise_count = {}
118
- for cls in self.clss:
119
- if self.names[int(cls)] in self.clswise_count:
120
- self.clswise_count[self.names[int(cls)]] += 1
121
- else:
122
- self.clswise_count[self.names[int(cls)]] = 1
123
- plot_im = self.update_graph(frame_number=frame_number, count_dict=self.clswise_count, plot=self.type)
122
+ from collections import Counter
123
+
124
+ self.clswise_count = Counter(self.names[int(cls)] for cls in self.clss)
125
+ update_required = frame_number % self.update_every == 0 or self.last_plot_im is None
126
+ if update_required:
127
+ self.last_plot_im = self.update_graph(
128
+ frame_number=frame_number, count_dict=self.clswise_count, plot=self.type
129
+ )
130
+ plot_im = self.last_plot_im
124
131
  else:
125
132
  raise ModuleNotFoundError(f"{self.type} chart is not supported ❌")
126
133
 
@@ -187,7 +194,7 @@ class Analytics(BaseSolution):
187
194
  self.ax.clear()
188
195
  for key, y_data in y_data_dict.items():
189
196
  color = next(color_cycle)
190
- self.ax.fill_between(x_data, y_data, color=color, alpha=0.7)
197
+ self.ax.fill_between(x_data, y_data, color=color, alpha=0.55)
191
198
  self.ax.plot(
192
199
  x_data,
193
200
  y_data,
@@ -235,6 +242,7 @@ class Analytics(BaseSolution):
235
242
 
236
243
  # Common plot settings
237
244
  self.ax.set_facecolor("#f0f0f0") # Set to light gray or any other color you like
245
+ self.ax.grid(True, linestyle="--", linewidth=0.5, alpha=0.5) # Display grid for more data insights
238
246
  self.ax.set_title(self.title, color=self.fg_color, fontsize=self.fontsize)
239
247
  self.ax.set_xlabel(self.x_label, color=self.fg_color, fontsize=self.fontsize - 3)
240
248
  self.ax.set_ylabel(self.y_label, color=self.fg_color, fontsize=self.fontsize - 3)
@@ -79,8 +79,7 @@ class ObjectCounter(BaseSolution):
79
79
  return
80
80
 
81
81
  if len(self.region) == 2: # Linear region (defined as a line segment)
82
- line = self.LineString(self.region) # Check if the line intersects the trajectory of the object
83
- if line.intersects(self.LineString([prev_position, current_centroid])):
82
+ if self.r_s.intersects(self.LineString([prev_position, current_centroid])):
84
83
  # Determine orientation of the region (vertical or horizontal)
85
84
  if abs(self.region[0][0] - self.region[1][0]) < abs(self.region[0][1] - self.region[1][1]):
86
85
  # Vertical region: Compare x-coordinates to determine direction
@@ -100,8 +99,7 @@ class ObjectCounter(BaseSolution):
100
99
  self.counted_ids.append(track_id)
101
100
 
102
101
  elif len(self.region) > 2: # Polygonal region
103
- polygon = self.Polygon(self.region)
104
- if polygon.contains(self.Point(current_centroid)):
102
+ if self.r_s.contains(self.Point(current_centroid)):
105
103
  # Determine motion direction for vertical or horizontal polygons
106
104
  region_width = max(p[0] for p in self.region) - min(p[0] for p in self.region)
107
105
  region_height = max(p[1] for p in self.region) - min(p[1] for p in self.region)
@@ -260,11 +260,13 @@ class ReID:
260
260
  from ultralytics import YOLO
261
261
 
262
262
  self.model = YOLO(model)
263
- self.model(embed=[len(self.model.model.model) - 2 if ".pt" in model else -1], verbose=False) # initialize
263
+ self.model(embed=[len(self.model.model.model) - 2 if ".pt" in model else -1], verbose=False, save=False) # init
264
264
 
265
265
  def __call__(self, img: np.ndarray, dets: np.ndarray) -> List[np.ndarray]:
266
266
  """Extract embeddings for detected objects."""
267
- feats = self.model([save_one_box(det, img, save=False) for det in xywh2xyxy(torch.from_numpy(dets[:, :4]))])
267
+ feats = self.model.predictor(
268
+ [save_one_box(det, img, save=False) for det in xywh2xyxy(torch.from_numpy(dets[:, :4]))]
269
+ )
268
270
  if len(feats) != dets.shape[0] and feats[0].shape[0] == dets.shape[0]:
269
271
  feats = feats[0] # batched prediction with non-PyTorch backend
270
272
  return [f.cpu().numpy() for f in feats]
@@ -841,8 +841,7 @@ def is_docker() -> bool:
841
841
  (bool): True if the script is running inside a Docker container, False otherwise.
842
842
  """
843
843
  try:
844
- with open("/proc/self/cgroup") as f:
845
- return "docker" in f.read()
844
+ return os.path.exists("/.dockerenv")
846
845
  except Exception:
847
846
  return False
848
847
 
@@ -106,41 +106,41 @@ def benchmark(
106
106
  if format_arg:
107
107
  formats = frozenset(export_formats()["Argument"])
108
108
  assert format in formats, f"Expected format to be one of {formats}, but got '{format_arg}'."
109
- for i, (name, format, suffix, cpu, gpu, _) in enumerate(zip(*export_formats().values())):
109
+ for name, format, suffix, cpu, gpu, _ in zip(*export_formats().values()):
110
110
  emoji, filename = "❌", None # export defaults
111
111
  try:
112
112
  if format_arg and format_arg != format:
113
113
  continue
114
114
 
115
115
  # Checks
116
- if i == 7: # TF GraphDef
116
+ if format == "pb":
117
117
  assert model.task != "obb", "TensorFlow GraphDef not supported for OBB task"
118
- elif i == 9: # Edge TPU
118
+ elif format == "edgetpu":
119
119
  assert LINUX and not ARM64, "Edge TPU export only supported on non-aarch64 Linux"
120
- elif i in {5, 10}: # CoreML and TF.js
120
+ elif format in {"coreml", "tfjs"}:
121
121
  assert MACOS or (LINUX and not ARM64), (
122
122
  "CoreML and TF.js export only supported on macOS and non-aarch64 Linux"
123
123
  )
124
- if i in {5}: # CoreML
124
+ if format == "coreml":
125
125
  assert not IS_PYTHON_3_13, "CoreML not supported on Python 3.13"
126
- if i in {6, 7, 8, 9, 10}: # TF SavedModel, TF GraphDef, and TFLite, TF EdgeTPU and TF.js
126
+ if format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}:
127
127
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
128
128
  # assert not IS_PYTHON_MINIMUM_3_12, "TFLite exports not supported on Python>=3.12 yet"
129
- if i == 11: # Paddle
129
+ if format == "paddle":
130
130
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
131
131
  assert model.task != "obb", "Paddle OBB bug https://github.com/PaddlePaddle/Paddle/issues/72024"
132
132
  assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet"
133
133
  assert (LINUX and not IS_JETSON) or MACOS, "Windows and Jetson Paddle exports not supported yet"
134
- if i == 12: # MNN
134
+ if format == "mnn":
135
135
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 MNN exports not supported yet"
136
- if i == 13: # NCNN
136
+ if format == "ncnn":
137
137
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
138
- if i == 14: # IMX
138
+ if format == "imx":
139
139
  assert not is_end2end
140
140
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 IMX exports not supported"
141
141
  assert model.task == "detect", "IMX only supported for detection task"
142
142
  assert "C2f" in model.__str__(), "IMX only supported for YOLOv8" # TODO: enable for YOLO11
143
- if i == 15: # RKNN
143
+ if format == "rknn":
144
144
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 RKNN exports not supported yet"
145
145
  assert not is_end2end, "End-to-end models not supported by RKNN yet"
146
146
  assert LINUX, "RKNN only supported on Linux"
@@ -163,10 +163,10 @@ def benchmark(
163
163
  emoji = "❎" # indicates export succeeded
164
164
 
165
165
  # Predict
166
- assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported"
167
- assert i not in {9, 10}, "inference not supported" # Edge TPU and TF.js are unsupported
168
- assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML
169
- if i in {13}:
166
+ assert model.task != "pose" or format != "pb", "GraphDef Pose inference is not supported"
167
+ assert format not in {"edgetpu", "tfjs"}, "inference not supported"
168
+ assert format != "coreml" or platform.system() == "Darwin", "inference only supported on macOS>=10.13"
169
+ if format == "ncnn":
170
170
  assert not is_end2end, "End-to-end torch.topk operation is not supported for NCNN prediction yet"
171
171
  exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half, verbose=False)
172
172
 
@@ -401,11 +401,16 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
401
401
  def attempt_install(packages, commands, use_uv):
402
402
  """Attempt package installation with uv if available, falling back to pip."""
403
403
  if use_uv:
404
- # Note requires --break-system-packages on ARM64 dockerfile
405
- cmd = f"uv pip install --system --no-cache-dir {packages} {commands} --index-strategy=unsafe-best-match --break-system-packages --prerelease=allow"
406
- else:
407
- cmd = f"pip install --no-cache-dir {packages} {commands}"
408
- return subprocess.check_output(cmd, shell=True).decode()
404
+ base = f"uv pip install --no-cache-dir {packages} {commands} --index-strategy=unsafe-best-match --break-system-packages --prerelease=allow"
405
+ try:
406
+ return subprocess.check_output(base, shell=True, stderr=subprocess.PIPE).decode()
407
+ except subprocess.CalledProcessError as e:
408
+ if e.stderr and "No virtual environment found" in e.stderr.decode():
409
+ return subprocess.check_output(
410
+ base.replace("uv pip install", "uv pip install --system"), shell=True
411
+ ).decode()
412
+ raise
413
+ return subprocess.check_output(f"pip install --no-cache-dir {packages} {commands}", shell=True).decode()
409
414
 
410
415
  s = " ".join(f'"{x}"' for x in pkgs) # console string
411
416
  if s:
@@ -36,6 +36,7 @@ GITHUB_ASSETS_NAMES = frozenset(
36
36
  + [
37
37
  "mobile_sam.pt",
38
38
  "mobileclip_blt.ts",
39
+ "yolo11n-grayscale.pt",
39
40
  "calibration_image_sample_data_20x128x128x3_float32.npy.zip",
40
41
  ]
41
42
  )
@@ -520,7 +520,7 @@ def plot_pr_curve(
520
520
  py: np.ndarray,
521
521
  ap: np.ndarray,
522
522
  save_dir: Path = Path("pr_curve.png"),
523
- names: dict = {},
523
+ names: Dict[int, str] = {},
524
524
  on_plot=None,
525
525
  ):
526
526
  """
@@ -531,7 +531,7 @@ def plot_pr_curve(
531
531
  py (np.ndarray): Y values for the PR curve.
532
532
  ap (np.ndarray): Average precision values.
533
533
  save_dir (Path, optional): Path to save the plot.
534
- names (dict, optional): Dictionary mapping class indices to class names.
534
+ names (Dict[int, str], optional): Dictionary mapping class indices to class names.
535
535
  on_plot (callable, optional): Function to call after plot is saved.
536
536
  """
537
537
  import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
@@ -563,7 +563,7 @@ def plot_mc_curve(
563
563
  px: np.ndarray,
564
564
  py: np.ndarray,
565
565
  save_dir: Path = Path("mc_curve.png"),
566
- names: dict = {},
566
+ names: Dict[int, str] = {},
567
567
  xlabel: str = "Confidence",
568
568
  ylabel: str = "Metric",
569
569
  on_plot=None,
@@ -575,7 +575,7 @@ def plot_mc_curve(
575
575
  px (np.ndarray): X values for the metric-confidence curve.
576
576
  py (np.ndarray): Y values for the metric-confidence curve.
577
577
  save_dir (Path, optional): Path to save the plot.
578
- names (dict, optional): Dictionary mapping class indices to class names.
578
+ names (Dict[int, str], optional): Dictionary mapping class indices to class names.
579
579
  xlabel (str, optional): X-axis label.
580
580
  ylabel (str, optional): Y-axis label.
581
581
  on_plot (callable, optional): Function to call after plot is saved.
@@ -645,7 +645,7 @@ def ap_per_class(
645
645
  plot: bool = False,
646
646
  on_plot=None,
647
647
  save_dir: Path = Path(),
648
- names: dict = {},
648
+ names: Dict[int, str] = {},
649
649
  eps: float = 1e-16,
650
650
  prefix: str = "",
651
651
  ) -> Tuple:
@@ -660,7 +660,7 @@ def ap_per_class(
660
660
  plot (bool, optional): Whether to plot PR curves or not.
661
661
  on_plot (callable, optional): A callback to pass plots path and data when they are rendered.
662
662
  save_dir (Path, optional): Directory to save the PR curves.
663
- names (dict, optional): Dict of class names to plot PR curves.
663
+ names (Dict[int, str], optional): Dictionary of class names to plot PR curves.
664
664
  eps (float, optional): A small value to avoid division by zero.
665
665
  prefix (str, optional): A prefix string for saving the plot files.
666
666
 
@@ -720,8 +720,7 @@ def ap_per_class(
720
720
 
721
721
  # Compute F1 (harmonic mean of precision and recall)
722
722
  f1_curve = 2 * p_curve * r_curve / (p_curve + r_curve + eps)
723
- names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
724
- names = dict(enumerate(names)) # to dict
723
+ names = {i: names[k] for i, k in enumerate(unique_classes) if k in names} # dict: only classes that have data
725
724
  if plot:
726
725
  plot_pr_curve(x, prec_values, ap, save_dir / f"{prefix}PR_curve.png", names, on_plot=on_plot)
727
726
  plot_mc_curve(x, f1_curve, save_dir / f"{prefix}F1_curve.png", names, ylabel="F1", on_plot=on_plot)
@@ -915,20 +914,20 @@ class DetMetrics(SimpleClass, DataExportMixin):
915
914
  Attributes:
916
915
  save_dir (Path): A path to the directory where the output plots will be saved.
917
916
  plot (bool): A flag that indicates whether to plot precision-recall curves for each class.
918
- names (dict): A dictionary of class names.
917
+ names (Dict[int, str]): A dictionary of class names.
919
918
  box (Metric): An instance of the Metric class for storing detection results.
920
- speed (dict): A dictionary for storing execution times of different parts of the detection process.
919
+ speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
921
920
  task (str): The task type, set to 'detect'.
922
921
  """
923
922
 
924
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: dict = {}) -> None:
923
+ def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
925
924
  """
926
925
  Initialize a DetMetrics instance with a save directory, plot flag, and class names.
927
926
 
928
927
  Args:
929
928
  save_dir (Path, optional): Directory to save plots.
930
929
  plot (bool, optional): Whether to plot precision-recall curves.
931
- names (dict, optional): Dictionary mapping class indices to names.
930
+ names (Dict[int, str], optional): Dictionary of class names.
932
931
  """
933
932
  self.save_dir = save_dir
934
933
  self.plot = plot
@@ -1033,21 +1032,21 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
1033
1032
  Attributes:
1034
1033
  save_dir (Path): Path to the directory where the output plots should be saved.
1035
1034
  plot (bool): Whether to save the detection and segmentation plots.
1036
- names (dict): Dictionary of class names.
1037
- box (Metric): An instance of the Metric class to calculate box detection metrics.
1035
+ names (Dict[int, str]): Dictionary of class names.
1036
+ box (Metric): An instance of the Metric class for storing detection results.
1038
1037
  seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
1039
- speed (dict): Dictionary to store the time taken in different phases of inference.
1038
+ speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1040
1039
  task (str): The task type, set to 'segment'.
1041
1040
  """
1042
1041
 
1043
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: tuple = ()) -> None:
1042
+ def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
1044
1043
  """
1045
1044
  Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
1046
1045
 
1047
1046
  Args:
1048
1047
  save_dir (Path, optional): Directory to save plots.
1049
1048
  plot (bool, optional): Whether to plot precision-recall curves.
1050
- names (tuple, optional): Tuple mapping class indices to names.
1049
+ names (Dict[int, str], optional): Dictionary of class names.
1051
1050
  """
1052
1051
  self.save_dir = save_dir
1053
1052
  self.plot = plot
@@ -1196,10 +1195,10 @@ class PoseMetrics(SegmentMetrics):
1196
1195
  Attributes:
1197
1196
  save_dir (Path): Path to the directory where the output plots should be saved.
1198
1197
  plot (bool): Whether to save the detection and pose plots.
1199
- names (dict): Dictionary of class names.
1200
- box (Metric): An instance of the Metric class to calculate box detection metrics.
1198
+ names (Dict[int, str]): Dictionary of class names.
1201
1199
  pose (Metric): An instance of the Metric class to calculate pose metrics.
1202
- speed (dict): Dictionary to store the time taken in different phases of inference.
1200
+ box (Metric): An instance of the Metric class for storing detection results.
1201
+ speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1203
1202
  task (str): The task type, set to 'pose'.
1204
1203
 
1205
1204
  Methods:
@@ -1212,14 +1211,14 @@ class PoseMetrics(SegmentMetrics):
1212
1211
  results_dict: Return the dictionary containing all the detection and segmentation metrics and fitness score.
1213
1212
  """
1214
1213
 
1215
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: tuple = ()) -> None:
1214
+ def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
1216
1215
  """
1217
1216
  Initialize the PoseMetrics class with directory path, class names, and plotting options.
1218
1217
 
1219
1218
  Args:
1220
1219
  save_dir (Path, optional): Directory to save plots.
1221
1220
  plot (bool, optional): Whether to plot precision-recall curves.
1222
- names (tuple, optional): Tuple mapping class indices to names.
1221
+ names (Dict[int, str], optional): Dictionary of class names.
1223
1222
  """
1224
1223
  super().__init__(save_dir, plot, names)
1225
1224
  self.save_dir = save_dir
@@ -1420,23 +1419,23 @@ class OBBMetrics(SimpleClass, DataExportMixin):
1420
1419
  Attributes:
1421
1420
  save_dir (Path): Path to the directory where the output plots should be saved.
1422
1421
  plot (bool): Whether to save the detection plots.
1423
- names (dict): Dictionary of class names.
1422
+ names (Dict[int, str]): Dictionary of class names.
1424
1423
  box (Metric): An instance of the Metric class for storing detection results.
1425
- speed (dict): A dictionary for storing execution times of different parts of the detection process.
1424
+ speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1426
1425
  task (str): The task type, set to 'obb'.
1427
1426
 
1428
1427
  References:
1429
1428
  https://arxiv.org/pdf/2106.06072.pdf
1430
1429
  """
1431
1430
 
1432
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: tuple = ()) -> None:
1431
+ def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
1433
1432
  """
1434
1433
  Initialize an OBBMetrics instance with directory, plotting, and class names.
1435
1434
 
1436
1435
  Args:
1437
1436
  save_dir (Path, optional): Directory to save plots.
1438
1437
  plot (bool, optional): Whether to plot precision-recall curves.
1439
- names (tuple, optional): Tuple mapping class indices to names.
1438
+ names (Dict[int, str], optional): Dictionary of class names.
1440
1439
  """
1441
1440
  self.save_dir = save_dir
1442
1441
  self.plot = plot