dgenerate-ultralytics-headless 8.3.145__py3-none-any.whl → 8.3.147__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {dgenerate_ultralytics_headless-8.3.145.dist-info → dgenerate_ultralytics_headless-8.3.147.dist-info}/METADATA +1 -1
  2. {dgenerate_ultralytics_headless-8.3.145.dist-info → dgenerate_ultralytics_headless-8.3.147.dist-info}/RECORD +32 -31
  3. {dgenerate_ultralytics_headless-8.3.145.dist-info → dgenerate_ultralytics_headless-8.3.147.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +3 -0
  5. tests/test_cli.py +2 -7
  6. tests/test_python.py +55 -18
  7. ultralytics/__init__.py +1 -1
  8. ultralytics/cfg/__init__.py +0 -1
  9. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  10. ultralytics/data/augment.py +2 -2
  11. ultralytics/engine/model.py +4 -4
  12. ultralytics/engine/validator.py +1 -1
  13. ultralytics/models/nas/model.py +0 -8
  14. ultralytics/models/yolo/classify/val.py +5 -9
  15. ultralytics/models/yolo/detect/val.py +8 -17
  16. ultralytics/models/yolo/obb/val.py +24 -17
  17. ultralytics/models/yolo/pose/val.py +19 -14
  18. ultralytics/models/yolo/segment/val.py +52 -44
  19. ultralytics/nn/tasks.py +3 -0
  20. ultralytics/solutions/analytics.py +17 -9
  21. ultralytics/solutions/object_counter.py +2 -4
  22. ultralytics/trackers/bot_sort.py +4 -2
  23. ultralytics/utils/__init__.py +2 -3
  24. ultralytics/utils/benchmarks.py +15 -15
  25. ultralytics/utils/checks.py +10 -5
  26. ultralytics/utils/downloads.py +1 -0
  27. ultralytics/utils/metrics.py +52 -33
  28. ultralytics/utils/plotting.py +10 -7
  29. ultralytics/utils/torch_utils.py +2 -2
  30. {dgenerate_ultralytics_headless-8.3.145.dist-info → dgenerate_ultralytics_headless-8.3.147.dist-info}/entry_points.txt +0 -0
  31. {dgenerate_ultralytics_headless-8.3.145.dist-info → dgenerate_ultralytics_headless-8.3.147.dist-info}/licenses/LICENSE +0 -0
  32. {dgenerate_ultralytics_headless-8.3.145.dist-info → dgenerate_ultralytics_headless-8.3.147.dist-info}/top_level.txt +0 -0
@@ -205,7 +205,7 @@ class DataExportMixin:
205
205
  to_sql: Export results to an SQLite database.
206
206
 
207
207
  Examples:
208
- >>> model = YOLO("yolov8n.pt")
208
+ >>> model = YOLO("yolo11n.pt")
209
209
  >>> results = model("image.jpg")
210
210
  >>> df = results.to_df()
211
211
  >>> print(df)
@@ -841,8 +841,7 @@ def is_docker() -> bool:
841
841
  (bool): True if the script is running inside a Docker container, False otherwise.
842
842
  """
843
843
  try:
844
- with open("/proc/self/cgroup") as f:
845
- return "docker" in f.read()
844
+ return os.path.exists("/.dockerenv")
846
845
  except Exception:
847
846
  return False
848
847
 
@@ -106,41 +106,41 @@ def benchmark(
106
106
  if format_arg:
107
107
  formats = frozenset(export_formats()["Argument"])
108
108
  assert format in formats, f"Expected format to be one of {formats}, but got '{format_arg}'."
109
- for i, (name, format, suffix, cpu, gpu, _) in enumerate(zip(*export_formats().values())):
109
+ for name, format, suffix, cpu, gpu, _ in zip(*export_formats().values()):
110
110
  emoji, filename = "❌", None # export defaults
111
111
  try:
112
112
  if format_arg and format_arg != format:
113
113
  continue
114
114
 
115
115
  # Checks
116
- if i == 7: # TF GraphDef
116
+ if format == "pb":
117
117
  assert model.task != "obb", "TensorFlow GraphDef not supported for OBB task"
118
- elif i == 9: # Edge TPU
118
+ elif format == "edgetpu":
119
119
  assert LINUX and not ARM64, "Edge TPU export only supported on non-aarch64 Linux"
120
- elif i in {5, 10}: # CoreML and TF.js
120
+ elif format in {"coreml", "tfjs"}:
121
121
  assert MACOS or (LINUX and not ARM64), (
122
122
  "CoreML and TF.js export only supported on macOS and non-aarch64 Linux"
123
123
  )
124
- if i in {5}: # CoreML
124
+ if format == "coreml":
125
125
  assert not IS_PYTHON_3_13, "CoreML not supported on Python 3.13"
126
- if i in {6, 7, 8, 9, 10}: # TF SavedModel, TF GraphDef, and TFLite, TF EdgeTPU and TF.js
126
+ if format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}:
127
127
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
128
128
  # assert not IS_PYTHON_MINIMUM_3_12, "TFLite exports not supported on Python>=3.12 yet"
129
- if i == 11: # Paddle
129
+ if format == "paddle":
130
130
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
131
131
  assert model.task != "obb", "Paddle OBB bug https://github.com/PaddlePaddle/Paddle/issues/72024"
132
132
  assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet"
133
133
  assert (LINUX and not IS_JETSON) or MACOS, "Windows and Jetson Paddle exports not supported yet"
134
- if i == 12: # MNN
134
+ if format == "mnn":
135
135
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 MNN exports not supported yet"
136
- if i == 13: # NCNN
136
+ if format == "ncnn":
137
137
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
138
- if i == 14: # IMX
138
+ if format == "imx":
139
139
  assert not is_end2end
140
140
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 IMX exports not supported"
141
141
  assert model.task == "detect", "IMX only supported for detection task"
142
142
  assert "C2f" in model.__str__(), "IMX only supported for YOLOv8" # TODO: enable for YOLO11
143
- if i == 15: # RKNN
143
+ if format == "rknn":
144
144
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 RKNN exports not supported yet"
145
145
  assert not is_end2end, "End-to-end models not supported by RKNN yet"
146
146
  assert LINUX, "RKNN only supported on Linux"
@@ -163,10 +163,10 @@ def benchmark(
163
163
  emoji = "❎" # indicates export succeeded
164
164
 
165
165
  # Predict
166
- assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported"
167
- assert i not in {9, 10}, "inference not supported" # Edge TPU and TF.js are unsupported
168
- assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML
169
- if i in {13}:
166
+ assert model.task != "pose" or format != "pb", "GraphDef Pose inference is not supported"
167
+ assert format not in {"edgetpu", "tfjs"}, "inference not supported"
168
+ assert format != "coreml" or platform.system() == "Darwin", "inference only supported on macOS>=10.13"
169
+ if format == "ncnn":
170
170
  assert not is_end2end, "End-to-end torch.topk operation is not supported for NCNN prediction yet"
171
171
  exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half, verbose=False)
172
172
 
@@ -401,11 +401,16 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
401
401
  def attempt_install(packages, commands, use_uv):
402
402
  """Attempt package installation with uv if available, falling back to pip."""
403
403
  if use_uv:
404
- # Note requires --break-system-packages on ARM64 dockerfile
405
- cmd = f"uv pip install --system --no-cache-dir {packages} {commands} --index-strategy=unsafe-best-match --break-system-packages --prerelease=allow"
406
- else:
407
- cmd = f"pip install --no-cache-dir {packages} {commands}"
408
- return subprocess.check_output(cmd, shell=True).decode()
404
+ base = f"uv pip install --no-cache-dir {packages} {commands} --index-strategy=unsafe-best-match --break-system-packages --prerelease=allow"
405
+ try:
406
+ return subprocess.check_output(base, shell=True, stderr=subprocess.PIPE).decode()
407
+ except subprocess.CalledProcessError as e:
408
+ if e.stderr and "No virtual environment found" in e.stderr.decode():
409
+ return subprocess.check_output(
410
+ base.replace("uv pip install", "uv pip install --system"), shell=True
411
+ ).decode()
412
+ raise
413
+ return subprocess.check_output(f"pip install --no-cache-dir {packages} {commands}", shell=True).decode()
409
414
 
410
415
  s = " ".join(f'"{x}"' for x in pkgs) # console string
411
416
  if s:
@@ -36,6 +36,7 @@ GITHUB_ASSETS_NAMES = frozenset(
36
36
  + [
37
37
  "mobile_sam.pt",
38
38
  "mobileclip_blt.ts",
39
+ "yolo11n-grayscale.pt",
39
40
  "calibration_image_sample_data_20x128x128x3_float32.npy.zip",
40
41
  ]
41
42
  )
@@ -309,7 +309,7 @@ def smooth_bce(eps: float = 0.1) -> Tuple[float, float]:
309
309
  return 1.0 - 0.5 * eps, 0.5 * eps
310
310
 
311
311
 
312
- class ConfusionMatrix:
312
+ class ConfusionMatrix(DataExportMixin):
313
313
  """
314
314
  A class for calculating and updating a confusion matrix for object detection and classification tasks.
315
315
 
@@ -321,7 +321,7 @@ class ConfusionMatrix:
321
321
  iou_thres (float): The Intersection over Union threshold.
322
322
  """
323
323
 
324
- def __init__(self, nc: int, conf: float = 0.25, iou_thres: float = 0.45, task: str = "detect"):
324
+ def __init__(self, nc: int, conf: float = 0.25, iou_thres: float = 0.45, names: tuple = (), task: str = "detect"):
325
325
  """
326
326
  Initialize a ConfusionMatrix instance.
327
327
 
@@ -329,11 +329,13 @@ class ConfusionMatrix:
329
329
  nc (int): Number of classes.
330
330
  conf (float, optional): Confidence threshold for detections.
331
331
  iou_thres (float, optional): IoU threshold for matching detections to ground truth.
332
+ names (tuple, optional): Names of classes, used as labels on the plot.
332
333
  task (str, optional): Type of task, either 'detect' or 'classify'.
333
334
  """
334
335
  self.task = task
335
336
  self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
336
337
  self.nc = nc # number of classes
338
+ self.names = list(names) # name of classes
337
339
  self.conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passed
338
340
  self.iou_thres = iou_thres
339
341
 
@@ -426,14 +428,13 @@ class ConfusionMatrix:
426
428
 
427
429
  @TryExcept(msg="ConfusionMatrix plot failure")
428
430
  @plt_settings()
429
- def plot(self, normalize: bool = True, save_dir: str = "", names: tuple = (), on_plot=None):
431
+ def plot(self, normalize: bool = True, save_dir: str = "", on_plot=None):
430
432
  """
431
433
  Plot the confusion matrix using matplotlib and save it to a file.
432
434
 
433
435
  Args:
434
436
  normalize (bool, optional): Whether to normalize the confusion matrix.
435
437
  save_dir (str, optional): Directory where the plot will be saved.
436
- names (tuple, optional): Names of classes, used as labels on the plot.
437
438
  on_plot (callable, optional): An optional callback to pass plots path and data when they are rendered.
438
439
  """
439
440
  import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
@@ -441,18 +442,17 @@ class ConfusionMatrix:
441
442
  array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns
442
443
  array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
443
444
 
444
- names = list(names)
445
445
  fig, ax = plt.subplots(1, 1, figsize=(12, 9))
446
446
  if self.nc >= 100: # downsample for large class count
447
447
  k = max(2, self.nc // 60) # step size for downsampling, always > 1
448
448
  keep_idx = slice(None, None, k) # create slice instead of array
449
- names = names[keep_idx] # slice class names
449
+ self.names = self.names[keep_idx] # slice class names
450
450
  array = array[keep_idx, :][:, keep_idx] # slice matrix rows and cols
451
451
  n = (self.nc + k - 1) // k # number of retained classes
452
452
  nc = nn = n if self.task == "classify" else n + 1 # adjust for background if needed
453
453
  else:
454
454
  nc = nn = self.nc if self.task == "classify" else self.nc + 1
455
- ticklabels = (names + ["background"]) if (0 < nn < 99) and (nn == nc) else "auto"
455
+ ticklabels = (self.names + ["background"]) if (0 < nn < 99) and (nn == nc) else "auto"
456
456
  xy_ticks = np.arange(len(ticklabels))
457
457
  tick_fontsize = max(6, 15 - 0.1 * nc) # Minimum size is 6
458
458
  label_fontsize = max(6, 12 - 0.1 * nc)
@@ -505,6 +505,26 @@ class ConfusionMatrix:
505
505
  for i in range(self.matrix.shape[0]):
506
506
  LOGGER.info(" ".join(map(str, self.matrix[i])))
507
507
 
508
+ def summary(self, **kwargs):
509
+ """Returns summary of the confusion matrix for export in different formats CSV, XML, HTML."""
510
+ import re
511
+
512
+ names = self.names if self.task == "classify" else self.names + ["background"]
513
+ clean_names, seen = [], set()
514
+ for name in names:
515
+ clean_name = re.sub(r"[^a-zA-Z0-9_]", "_", name)
516
+ original_clean = clean_name
517
+ counter = 1
518
+ while clean_name.lower() in seen:
519
+ clean_name = f"{original_clean}_{counter}"
520
+ counter += 1
521
+ seen.add(clean_name.lower())
522
+ clean_names.append(clean_name)
523
+ return [
524
+ dict({"Predicted": clean_names[i]}, **{clean_names[j]: self.matrix[i, j] for j in range(len(clean_names))})
525
+ for i in range(len(clean_names))
526
+ ]
527
+
508
528
 
509
529
  def smooth(y: np.ndarray, f: float = 0.05) -> np.ndarray:
510
530
  """Box filter of fraction f."""
@@ -520,7 +540,7 @@ def plot_pr_curve(
520
540
  py: np.ndarray,
521
541
  ap: np.ndarray,
522
542
  save_dir: Path = Path("pr_curve.png"),
523
- names: dict = {},
543
+ names: Dict[int, str] = {},
524
544
  on_plot=None,
525
545
  ):
526
546
  """
@@ -531,7 +551,7 @@ def plot_pr_curve(
531
551
  py (np.ndarray): Y values for the PR curve.
532
552
  ap (np.ndarray): Average precision values.
533
553
  save_dir (Path, optional): Path to save the plot.
534
- names (dict, optional): Dictionary mapping class indices to class names.
554
+ names (Dict[int, str], optional): Dictionary mapping class indices to class names.
535
555
  on_plot (callable, optional): Function to call after plot is saved.
536
556
  """
537
557
  import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
@@ -563,7 +583,7 @@ def plot_mc_curve(
563
583
  px: np.ndarray,
564
584
  py: np.ndarray,
565
585
  save_dir: Path = Path("mc_curve.png"),
566
- names: dict = {},
586
+ names: Dict[int, str] = {},
567
587
  xlabel: str = "Confidence",
568
588
  ylabel: str = "Metric",
569
589
  on_plot=None,
@@ -575,7 +595,7 @@ def plot_mc_curve(
575
595
  px (np.ndarray): X values for the metric-confidence curve.
576
596
  py (np.ndarray): Y values for the metric-confidence curve.
577
597
  save_dir (Path, optional): Path to save the plot.
578
- names (dict, optional): Dictionary mapping class indices to class names.
598
+ names (Dict[int, str], optional): Dictionary mapping class indices to class names.
579
599
  xlabel (str, optional): X-axis label.
580
600
  ylabel (str, optional): Y-axis label.
581
601
  on_plot (callable, optional): Function to call after plot is saved.
@@ -645,7 +665,7 @@ def ap_per_class(
645
665
  plot: bool = False,
646
666
  on_plot=None,
647
667
  save_dir: Path = Path(),
648
- names: dict = {},
668
+ names: Dict[int, str] = {},
649
669
  eps: float = 1e-16,
650
670
  prefix: str = "",
651
671
  ) -> Tuple:
@@ -660,7 +680,7 @@ def ap_per_class(
660
680
  plot (bool, optional): Whether to plot PR curves or not.
661
681
  on_plot (callable, optional): A callback to pass plots path and data when they are rendered.
662
682
  save_dir (Path, optional): Directory to save the PR curves.
663
- names (dict, optional): Dict of class names to plot PR curves.
683
+ names (Dict[int, str], optional): Dictionary of class names to plot PR curves.
664
684
  eps (float, optional): A small value to avoid division by zero.
665
685
  prefix (str, optional): A prefix string for saving the plot files.
666
686
 
@@ -720,8 +740,7 @@ def ap_per_class(
720
740
 
721
741
  # Compute F1 (harmonic mean of precision and recall)
722
742
  f1_curve = 2 * p_curve * r_curve / (p_curve + r_curve + eps)
723
- names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
724
- names = dict(enumerate(names)) # to dict
743
+ names = {i: names[k] for i, k in enumerate(unique_classes) if k in names} # dict: only classes that have data
725
744
  if plot:
726
745
  plot_pr_curve(x, prec_values, ap, save_dir / f"{prefix}PR_curve.png", names, on_plot=on_plot)
727
746
  plot_mc_curve(x, f1_curve, save_dir / f"{prefix}F1_curve.png", names, ylabel="F1", on_plot=on_plot)
@@ -915,20 +934,20 @@ class DetMetrics(SimpleClass, DataExportMixin):
915
934
  Attributes:
916
935
  save_dir (Path): A path to the directory where the output plots will be saved.
917
936
  plot (bool): A flag that indicates whether to plot precision-recall curves for each class.
918
- names (dict): A dictionary of class names.
937
+ names (Dict[int, str]): A dictionary of class names.
919
938
  box (Metric): An instance of the Metric class for storing detection results.
920
- speed (dict): A dictionary for storing execution times of different parts of the detection process.
939
+ speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
921
940
  task (str): The task type, set to 'detect'.
922
941
  """
923
942
 
924
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: dict = {}) -> None:
943
+ def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
925
944
  """
926
945
  Initialize a DetMetrics instance with a save directory, plot flag, and class names.
927
946
 
928
947
  Args:
929
948
  save_dir (Path, optional): Directory to save plots.
930
949
  plot (bool, optional): Whether to plot precision-recall curves.
931
- names (dict, optional): Dictionary mapping class indices to names.
950
+ names (Dict[int, str], optional): Dictionary of class names.
932
951
  """
933
952
  self.save_dir = save_dir
934
953
  self.plot = plot
@@ -1033,21 +1052,21 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
1033
1052
  Attributes:
1034
1053
  save_dir (Path): Path to the directory where the output plots should be saved.
1035
1054
  plot (bool): Whether to save the detection and segmentation plots.
1036
- names (dict): Dictionary of class names.
1037
- box (Metric): An instance of the Metric class to calculate box detection metrics.
1055
+ names (Dict[int, str]): Dictionary of class names.
1056
+ box (Metric): An instance of the Metric class for storing detection results.
1038
1057
  seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
1039
- speed (dict): Dictionary to store the time taken in different phases of inference.
1058
+ speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1040
1059
  task (str): The task type, set to 'segment'.
1041
1060
  """
1042
1061
 
1043
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: tuple = ()) -> None:
1062
+ def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
1044
1063
  """
1045
1064
  Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
1046
1065
 
1047
1066
  Args:
1048
1067
  save_dir (Path, optional): Directory to save plots.
1049
1068
  plot (bool, optional): Whether to plot precision-recall curves.
1050
- names (tuple, optional): Tuple mapping class indices to names.
1069
+ names (Dict[int, str], optional): Dictionary of class names.
1051
1070
  """
1052
1071
  self.save_dir = save_dir
1053
1072
  self.plot = plot
@@ -1196,10 +1215,10 @@ class PoseMetrics(SegmentMetrics):
1196
1215
  Attributes:
1197
1216
  save_dir (Path): Path to the directory where the output plots should be saved.
1198
1217
  plot (bool): Whether to save the detection and pose plots.
1199
- names (dict): Dictionary of class names.
1200
- box (Metric): An instance of the Metric class to calculate box detection metrics.
1218
+ names (Dict[int, str]): Dictionary of class names.
1201
1219
  pose (Metric): An instance of the Metric class to calculate pose metrics.
1202
- speed (dict): Dictionary to store the time taken in different phases of inference.
1220
+ box (Metric): An instance of the Metric class for storing detection results.
1221
+ speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1203
1222
  task (str): The task type, set to 'pose'.
1204
1223
 
1205
1224
  Methods:
@@ -1212,14 +1231,14 @@ class PoseMetrics(SegmentMetrics):
1212
1231
  results_dict: Return the dictionary containing all the detection and segmentation metrics and fitness score.
1213
1232
  """
1214
1233
 
1215
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: tuple = ()) -> None:
1234
+ def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
1216
1235
  """
1217
1236
  Initialize the PoseMetrics class with directory path, class names, and plotting options.
1218
1237
 
1219
1238
  Args:
1220
1239
  save_dir (Path, optional): Directory to save plots.
1221
1240
  plot (bool, optional): Whether to plot precision-recall curves.
1222
- names (tuple, optional): Tuple mapping class indices to names.
1241
+ names (Dict[int, str], optional): Dictionary of class names.
1223
1242
  """
1224
1243
  super().__init__(save_dir, plot, names)
1225
1244
  self.save_dir = save_dir
@@ -1420,23 +1439,23 @@ class OBBMetrics(SimpleClass, DataExportMixin):
1420
1439
  Attributes:
1421
1440
  save_dir (Path): Path to the directory where the output plots should be saved.
1422
1441
  plot (bool): Whether to save the detection plots.
1423
- names (dict): Dictionary of class names.
1442
+ names (Dict[int, str]): Dictionary of class names.
1424
1443
  box (Metric): An instance of the Metric class for storing detection results.
1425
- speed (dict): A dictionary for storing execution times of different parts of the detection process.
1444
+ speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1426
1445
  task (str): The task type, set to 'obb'.
1427
1446
 
1428
1447
  References:
1429
1448
  https://arxiv.org/pdf/2106.06072.pdf
1430
1449
  """
1431
1450
 
1432
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: tuple = ()) -> None:
1451
+ def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
1433
1452
  """
1434
1453
  Initialize an OBBMetrics instance with directory, plotting, and class names.
1435
1454
 
1436
1455
  Args:
1437
1456
  save_dir (Path, optional): Directory to save plots.
1438
1457
  plot (bool, optional): Whether to plot precision-recall curves.
1439
- names (tuple, optional): Tuple mapping class indices to names.
1458
+ names (Dict[int, str], optional): Dictionary of class names.
1440
1459
  """
1441
1460
  self.save_dir = save_dir
1442
1461
  self.plot = plot
@@ -201,6 +201,11 @@ class Annotator:
201
201
  input_is_pil = isinstance(im, Image.Image)
202
202
  self.pil = pil or non_ascii or input_is_pil
203
203
  self.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2)
204
+ if not input_is_pil:
205
+ if im.shape[2] == 1: # handle grayscale
206
+ im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
207
+ elif im.shape[2] > 3: # multispectral
208
+ im = np.ascontiguousarray(im[..., :3])
204
209
  if self.pil: # use PIL
205
210
  self.im = im if input_is_pil else Image.fromarray(im)
206
211
  if self.im.mode not in {"RGB", "RGBA"}: # multispectral
@@ -216,10 +221,6 @@ class Annotator:
216
221
  if check_version(pil_version, "9.2.0"):
217
222
  self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height
218
223
  else: # use cv2
219
- if im.shape[2] == 1: # handle grayscale
220
- im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
221
- elif im.shape[2] > 3: # multispectral
222
- im = np.ascontiguousarray(im[..., :3])
223
224
  assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images."
224
225
  self.im = im if im.flags.writeable else im.copy()
225
226
  self.tf = max(self.lw - 1, 1) # font thickness
@@ -644,7 +645,7 @@ def save_one_box(
644
645
  gain (float, optional): A multiplicative factor to increase the size of the bounding box.
645
646
  pad (int, optional): The number of pixels to add to the width and height of the bounding box.
646
647
  square (bool, optional): If True, the bounding box will be transformed into a square.
647
- BGR (bool, optional): If True, the image will be saved in BGR format, otherwise in RGB.
648
+ BGR (bool, optional): If True, the image will be returned in BGR format, otherwise in RGB.
648
649
  save (bool, optional): If True, the cropped image will be saved to disk.
649
650
 
650
651
  Returns:
@@ -664,12 +665,14 @@ def save_one_box(
664
665
  b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
665
666
  xyxy = ops.xywh2xyxy(b).long()
666
667
  xyxy = ops.clip_boxes(xyxy, im.shape)
667
- crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR else -1)]
668
+ grayscale = im.shape[2] == 1 # grayscale image
669
+ crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR or grayscale else -1)]
668
670
  if save:
669
671
  file.parent.mkdir(parents=True, exist_ok=True) # make directory
670
672
  f = str(increment_path(file).with_suffix(".jpg"))
671
673
  # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
672
- Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
674
+ crop = crop.squeeze(-1) if grayscale else crop[..., ::-1] if BGR else crop
675
+ Image.fromarray(crop).save(f, quality=95, subsampling=0) # save RGB
673
676
  return crop
674
677
 
675
678
 
@@ -10,7 +10,7 @@ from contextlib import contextmanager
10
10
  from copy import deepcopy
11
11
  from datetime import datetime
12
12
  from pathlib import Path
13
- from typing import Union
13
+ from typing import Any, Dict, Union
14
14
 
15
15
  import numpy as np
16
16
  import torch
@@ -704,7 +704,7 @@ class ModelEMA:
704
704
  copy_attr(self.ema, model, include, exclude)
705
705
 
706
706
 
707
- def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict:
707
+ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: Dict[str, Any] = None) -> Dict[str, Any]:
708
708
  """
709
709
  Strip optimizer from 'f' to finalize training, optionally save as 's'.
710
710