dgenerate-ultralytics-headless 8.3.144__py3-none-any.whl → 8.3.146__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {dgenerate_ultralytics_headless-8.3.144.dist-info → dgenerate_ultralytics_headless-8.3.146.dist-info}/METADATA +2 -2
  2. {dgenerate_ultralytics_headless-8.3.144.dist-info → dgenerate_ultralytics_headless-8.3.146.dist-info}/RECORD +36 -35
  3. {dgenerate_ultralytics_headless-8.3.144.dist-info → dgenerate_ultralytics_headless-8.3.146.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +3 -0
  5. tests/test_cli.py +2 -7
  6. tests/test_python.py +42 -12
  7. ultralytics/__init__.py +1 -1
  8. ultralytics/cfg/__init__.py +0 -1
  9. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  10. ultralytics/data/augment.py +2 -2
  11. ultralytics/engine/model.py +14 -13
  12. ultralytics/engine/results.py +4 -4
  13. ultralytics/engine/validator.py +1 -1
  14. ultralytics/models/nas/model.py +0 -8
  15. ultralytics/models/yolo/classify/val.py +1 -5
  16. ultralytics/models/yolo/detect/val.py +9 -16
  17. ultralytics/models/yolo/obb/val.py +24 -17
  18. ultralytics/models/yolo/pose/val.py +19 -14
  19. ultralytics/models/yolo/segment/val.py +52 -44
  20. ultralytics/solutions/ai_gym.py +3 -5
  21. ultralytics/solutions/analytics.py +17 -9
  22. ultralytics/solutions/heatmap.py +1 -1
  23. ultralytics/solutions/instance_segmentation.py +1 -1
  24. ultralytics/solutions/object_counter.py +2 -8
  25. ultralytics/solutions/solutions.py +5 -4
  26. ultralytics/trackers/bot_sort.py +4 -2
  27. ultralytics/utils/__init__.py +1 -2
  28. ultralytics/utils/benchmarks.py +18 -16
  29. ultralytics/utils/checks.py +10 -5
  30. ultralytics/utils/downloads.py +1 -0
  31. ultralytics/utils/metrics.py +25 -26
  32. ultralytics/utils/plotting.py +10 -7
  33. ultralytics/utils/torch_utils.py +2 -2
  34. {dgenerate_ultralytics_headless-8.3.144.dist-info → dgenerate_ultralytics_headless-8.3.146.dist-info}/entry_points.txt +0 -0
  35. {dgenerate_ultralytics_headless-8.3.144.dist-info → dgenerate_ultralytics_headless-8.3.146.dist-info}/licenses/LICENSE +0 -0
  36. {dgenerate_ultralytics_headless-8.3.144.dist-info → dgenerate_ultralytics_headless-8.3.146.dist-info}/top_level.txt +0 -0
@@ -169,11 +169,12 @@ class BaseSolution:
169
169
  with self.profilers[0]:
170
170
  self.tracks = self.model.track(
171
171
  source=im0, persist=True, classes=self.classes, verbose=False, **self.track_add_args
172
- )
173
- self.track_data = self.tracks[0].obb or self.tracks[0].boxes # Extract tracks for OBB or object detection
172
+ )[0]
173
+ is_obb = self.tracks.obb is not None
174
+ self.track_data = self.tracks.obb if is_obb else self.tracks.boxes # Extract tracks for OBB or object detection
174
175
 
175
- if self.track_data and self.track_data.id is not None:
176
- self.boxes = self.track_data.xyxy.cpu()
176
+ if self.track_data and self.track_data.is_track:
177
+ self.boxes = (self.track_data.xyxyxyxy if is_obb else self.track_data.xyxy).cpu()
177
178
  self.clss = self.track_data.cls.cpu().tolist()
178
179
  self.track_ids = self.track_data.id.int().cpu().tolist()
179
180
  self.confs = self.track_data.conf.cpu().tolist()
@@ -260,11 +260,13 @@ class ReID:
260
260
  from ultralytics import YOLO
261
261
 
262
262
  self.model = YOLO(model)
263
- self.model(embed=[len(self.model.model.model) - 2 if ".pt" in model else -1], verbose=False) # initialize
263
+ self.model(embed=[len(self.model.model.model) - 2 if ".pt" in model else -1], verbose=False, save=False) # init
264
264
 
265
265
  def __call__(self, img: np.ndarray, dets: np.ndarray) -> List[np.ndarray]:
266
266
  """Extract embeddings for detected objects."""
267
- feats = self.model([save_one_box(det, img, save=False) for det in xywh2xyxy(torch.from_numpy(dets[:, :4]))])
267
+ feats = self.model.predictor(
268
+ [save_one_box(det, img, save=False) for det in xywh2xyxy(torch.from_numpy(dets[:, :4]))]
269
+ )
268
270
  if len(feats) != dets.shape[0] and feats[0].shape[0] == dets.shape[0]:
269
271
  feats = feats[0] # batched prediction with non-PyTorch backend
270
272
  return [f.cpu().numpy() for f in feats]
@@ -841,8 +841,7 @@ def is_docker() -> bool:
841
841
  (bool): True if the script is running inside a Docker container, False otherwise.
842
842
  """
843
843
  try:
844
- with open("/proc/self/cgroup") as f:
845
- return "docker" in f.read()
844
+ return os.path.exists("/.dockerenv")
846
845
  except Exception:
847
846
  return False
848
847
 
@@ -59,6 +59,7 @@ def benchmark(
59
59
  verbose=False,
60
60
  eps=1e-3,
61
61
  format="",
62
+ **kwargs,
62
63
  ):
63
64
  """
64
65
  Benchmark a YOLO model across different formats for speed and accuracy.
@@ -73,6 +74,7 @@ def benchmark(
73
74
  verbose (bool | float): If True or a float, assert benchmarks pass with given metric.
74
75
  eps (float): Epsilon value for divide by zero prevention.
75
76
  format (str): Export format for benchmarking. If not supplied all formats are benchmarked.
77
+ **kwargs (Any): Additional keyword arguments for exporter.
76
78
 
77
79
  Returns:
78
80
  (pandas.DataFrame): A pandas DataFrame with benchmark results for each format, including file size, metric,
@@ -104,41 +106,41 @@ def benchmark(
104
106
  if format_arg:
105
107
  formats = frozenset(export_formats()["Argument"])
106
108
  assert format in formats, f"Expected format to be one of {formats}, but got '{format_arg}'."
107
- for i, (name, format, suffix, cpu, gpu, _) in enumerate(zip(*export_formats().values())):
109
+ for name, format, suffix, cpu, gpu, _ in zip(*export_formats().values()):
108
110
  emoji, filename = "❌", None # export defaults
109
111
  try:
110
112
  if format_arg and format_arg != format:
111
113
  continue
112
114
 
113
115
  # Checks
114
- if i == 7: # TF GraphDef
116
+ if format == "pb":
115
117
  assert model.task != "obb", "TensorFlow GraphDef not supported for OBB task"
116
- elif i == 9: # Edge TPU
118
+ elif format == "edgetpu":
117
119
  assert LINUX and not ARM64, "Edge TPU export only supported on non-aarch64 Linux"
118
- elif i in {5, 10}: # CoreML and TF.js
120
+ elif format in {"coreml", "tfjs"}:
119
121
  assert MACOS or (LINUX and not ARM64), (
120
122
  "CoreML and TF.js export only supported on macOS and non-aarch64 Linux"
121
123
  )
122
- if i in {5}: # CoreML
124
+ if format == "coreml":
123
125
  assert not IS_PYTHON_3_13, "CoreML not supported on Python 3.13"
124
- if i in {6, 7, 8, 9, 10}: # TF SavedModel, TF GraphDef, and TFLite, TF EdgeTPU and TF.js
126
+ if format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}:
125
127
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
126
128
  # assert not IS_PYTHON_MINIMUM_3_12, "TFLite exports not supported on Python>=3.12 yet"
127
- if i == 11: # Paddle
129
+ if format == "paddle":
128
130
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
129
131
  assert model.task != "obb", "Paddle OBB bug https://github.com/PaddlePaddle/Paddle/issues/72024"
130
132
  assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet"
131
133
  assert (LINUX and not IS_JETSON) or MACOS, "Windows and Jetson Paddle exports not supported yet"
132
- if i == 12: # MNN
134
+ if format == "mnn":
133
135
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 MNN exports not supported yet"
134
- if i == 13: # NCNN
136
+ if format == "ncnn":
135
137
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
136
- if i == 14: # IMX
138
+ if format == "imx":
137
139
  assert not is_end2end
138
140
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 IMX exports not supported"
139
141
  assert model.task == "detect", "IMX only supported for detection task"
140
142
  assert "C2f" in model.__str__(), "IMX only supported for YOLOv8" # TODO: enable for YOLO11
141
- if i == 15: # RKNN
143
+ if format == "rknn":
142
144
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 RKNN exports not supported yet"
143
145
  assert not is_end2end, "End-to-end models not supported by RKNN yet"
144
146
  assert LINUX, "RKNN only supported on Linux"
@@ -154,17 +156,17 @@ def benchmark(
154
156
  exported_model = model # PyTorch format
155
157
  else:
156
158
  filename = model.export(
157
- imgsz=imgsz, format=format, half=half, int8=int8, data=data, device=device, verbose=False
159
+ imgsz=imgsz, format=format, half=half, int8=int8, data=data, device=device, verbose=False, **kwargs
158
160
  )
159
161
  exported_model = YOLO(filename, task=model.task)
160
162
  assert suffix in str(filename), "export failed"
161
163
  emoji = "❎" # indicates export succeeded
162
164
 
163
165
  # Predict
164
- assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported"
165
- assert i not in {9, 10}, "inference not supported" # Edge TPU and TF.js are unsupported
166
- assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML
167
- if i in {13}:
166
+ assert model.task != "pose" or format != "pb", "GraphDef Pose inference is not supported"
167
+ assert format not in {"edgetpu", "tfjs"}, "inference not supported"
168
+ assert format != "coreml" or platform.system() == "Darwin", "inference only supported on macOS>=10.13"
169
+ if format == "ncnn":
168
170
  assert not is_end2end, "End-to-end torch.topk operation is not supported for NCNN prediction yet"
169
171
  exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half, verbose=False)
170
172
 
@@ -401,11 +401,16 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
401
401
  def attempt_install(packages, commands, use_uv):
402
402
  """Attempt package installation with uv if available, falling back to pip."""
403
403
  if use_uv:
404
- # Note requires --break-system-packages on ARM64 dockerfile
405
- cmd = f"uv pip install --system --no-cache-dir {packages} {commands} --index-strategy=unsafe-best-match --break-system-packages --prerelease=allow"
406
- else:
407
- cmd = f"pip install --no-cache-dir {packages} {commands}"
408
- return subprocess.check_output(cmd, shell=True).decode()
404
+ base = f"uv pip install --no-cache-dir {packages} {commands} --index-strategy=unsafe-best-match --break-system-packages --prerelease=allow"
405
+ try:
406
+ return subprocess.check_output(base, shell=True, stderr=subprocess.PIPE).decode()
407
+ except subprocess.CalledProcessError as e:
408
+ if e.stderr and "No virtual environment found" in e.stderr.decode():
409
+ return subprocess.check_output(
410
+ base.replace("uv pip install", "uv pip install --system"), shell=True
411
+ ).decode()
412
+ raise
413
+ return subprocess.check_output(f"pip install --no-cache-dir {packages} {commands}", shell=True).decode()
409
414
 
410
415
  s = " ".join(f'"{x}"' for x in pkgs) # console string
411
416
  if s:
@@ -36,6 +36,7 @@ GITHUB_ASSETS_NAMES = frozenset(
36
36
  + [
37
37
  "mobile_sam.pt",
38
38
  "mobileclip_blt.ts",
39
+ "yolo11n-grayscale.pt",
39
40
  "calibration_image_sample_data_20x128x128x3_float32.npy.zip",
40
41
  ]
41
42
  )
@@ -520,7 +520,7 @@ def plot_pr_curve(
520
520
  py: np.ndarray,
521
521
  ap: np.ndarray,
522
522
  save_dir: Path = Path("pr_curve.png"),
523
- names: dict = {},
523
+ names: Dict[int, str] = {},
524
524
  on_plot=None,
525
525
  ):
526
526
  """
@@ -531,7 +531,7 @@ def plot_pr_curve(
531
531
  py (np.ndarray): Y values for the PR curve.
532
532
  ap (np.ndarray): Average precision values.
533
533
  save_dir (Path, optional): Path to save the plot.
534
- names (dict, optional): Dictionary mapping class indices to class names.
534
+ names (Dict[int, str], optional): Dictionary mapping class indices to class names.
535
535
  on_plot (callable, optional): Function to call after plot is saved.
536
536
  """
537
537
  import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
@@ -563,7 +563,7 @@ def plot_mc_curve(
563
563
  px: np.ndarray,
564
564
  py: np.ndarray,
565
565
  save_dir: Path = Path("mc_curve.png"),
566
- names: dict = {},
566
+ names: Dict[int, str] = {},
567
567
  xlabel: str = "Confidence",
568
568
  ylabel: str = "Metric",
569
569
  on_plot=None,
@@ -575,7 +575,7 @@ def plot_mc_curve(
575
575
  px (np.ndarray): X values for the metric-confidence curve.
576
576
  py (np.ndarray): Y values for the metric-confidence curve.
577
577
  save_dir (Path, optional): Path to save the plot.
578
- names (dict, optional): Dictionary mapping class indices to class names.
578
+ names (Dict[int, str], optional): Dictionary mapping class indices to class names.
579
579
  xlabel (str, optional): X-axis label.
580
580
  ylabel (str, optional): Y-axis label.
581
581
  on_plot (callable, optional): Function to call after plot is saved.
@@ -645,7 +645,7 @@ def ap_per_class(
645
645
  plot: bool = False,
646
646
  on_plot=None,
647
647
  save_dir: Path = Path(),
648
- names: dict = {},
648
+ names: Dict[int, str] = {},
649
649
  eps: float = 1e-16,
650
650
  prefix: str = "",
651
651
  ) -> Tuple:
@@ -660,7 +660,7 @@ def ap_per_class(
660
660
  plot (bool, optional): Whether to plot PR curves or not.
661
661
  on_plot (callable, optional): A callback to pass plots path and data when they are rendered.
662
662
  save_dir (Path, optional): Directory to save the PR curves.
663
- names (dict, optional): Dict of class names to plot PR curves.
663
+ names (Dict[int, str], optional): Dictionary of class names to plot PR curves.
664
664
  eps (float, optional): A small value to avoid division by zero.
665
665
  prefix (str, optional): A prefix string for saving the plot files.
666
666
 
@@ -720,8 +720,7 @@ def ap_per_class(
720
720
 
721
721
  # Compute F1 (harmonic mean of precision and recall)
722
722
  f1_curve = 2 * p_curve * r_curve / (p_curve + r_curve + eps)
723
- names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
724
- names = dict(enumerate(names)) # to dict
723
+ names = {i: names[k] for i, k in enumerate(unique_classes) if k in names} # dict: only classes that have data
725
724
  if plot:
726
725
  plot_pr_curve(x, prec_values, ap, save_dir / f"{prefix}PR_curve.png", names, on_plot=on_plot)
727
726
  plot_mc_curve(x, f1_curve, save_dir / f"{prefix}F1_curve.png", names, ylabel="F1", on_plot=on_plot)
@@ -915,20 +914,20 @@ class DetMetrics(SimpleClass, DataExportMixin):
915
914
  Attributes:
916
915
  save_dir (Path): A path to the directory where the output plots will be saved.
917
916
  plot (bool): A flag that indicates whether to plot precision-recall curves for each class.
918
- names (dict): A dictionary of class names.
917
+ names (Dict[int, str]): A dictionary of class names.
919
918
  box (Metric): An instance of the Metric class for storing detection results.
920
- speed (dict): A dictionary for storing execution times of different parts of the detection process.
919
+ speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
921
920
  task (str): The task type, set to 'detect'.
922
921
  """
923
922
 
924
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: dict = {}) -> None:
923
+ def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
925
924
  """
926
925
  Initialize a DetMetrics instance with a save directory, plot flag, and class names.
927
926
 
928
927
  Args:
929
928
  save_dir (Path, optional): Directory to save plots.
930
929
  plot (bool, optional): Whether to plot precision-recall curves.
931
- names (dict, optional): Dictionary mapping class indices to names.
930
+ names (Dict[int, str], optional): Dictionary of class names.
932
931
  """
933
932
  self.save_dir = save_dir
934
933
  self.plot = plot
@@ -1033,21 +1032,21 @@ class SegmentMetrics(SimpleClass, DataExportMixin):
1033
1032
  Attributes:
1034
1033
  save_dir (Path): Path to the directory where the output plots should be saved.
1035
1034
  plot (bool): Whether to save the detection and segmentation plots.
1036
- names (dict): Dictionary of class names.
1037
- box (Metric): An instance of the Metric class to calculate box detection metrics.
1035
+ names (Dict[int, str]): Dictionary of class names.
1036
+ box (Metric): An instance of the Metric class for storing detection results.
1038
1037
  seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
1039
- speed (dict): Dictionary to store the time taken in different phases of inference.
1038
+ speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1040
1039
  task (str): The task type, set to 'segment'.
1041
1040
  """
1042
1041
 
1043
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: tuple = ()) -> None:
1042
+ def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
1044
1043
  """
1045
1044
  Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
1046
1045
 
1047
1046
  Args:
1048
1047
  save_dir (Path, optional): Directory to save plots.
1049
1048
  plot (bool, optional): Whether to plot precision-recall curves.
1050
- names (tuple, optional): Tuple mapping class indices to names.
1049
+ names (Dict[int, str], optional): Dictionary of class names.
1051
1050
  """
1052
1051
  self.save_dir = save_dir
1053
1052
  self.plot = plot
@@ -1196,10 +1195,10 @@ class PoseMetrics(SegmentMetrics):
1196
1195
  Attributes:
1197
1196
  save_dir (Path): Path to the directory where the output plots should be saved.
1198
1197
  plot (bool): Whether to save the detection and pose plots.
1199
- names (dict): Dictionary of class names.
1200
- box (Metric): An instance of the Metric class to calculate box detection metrics.
1198
+ names (Dict[int, str]): Dictionary of class names.
1201
1199
  pose (Metric): An instance of the Metric class to calculate pose metrics.
1202
- speed (dict): Dictionary to store the time taken in different phases of inference.
1200
+ box (Metric): An instance of the Metric class for storing detection results.
1201
+ speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1203
1202
  task (str): The task type, set to 'pose'.
1204
1203
 
1205
1204
  Methods:
@@ -1212,14 +1211,14 @@ class PoseMetrics(SegmentMetrics):
1212
1211
  results_dict: Return the dictionary containing all the detection and segmentation metrics and fitness score.
1213
1212
  """
1214
1213
 
1215
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: tuple = ()) -> None:
1214
+ def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
1216
1215
  """
1217
1216
  Initialize the PoseMetrics class with directory path, class names, and plotting options.
1218
1217
 
1219
1218
  Args:
1220
1219
  save_dir (Path, optional): Directory to save plots.
1221
1220
  plot (bool, optional): Whether to plot precision-recall curves.
1222
- names (tuple, optional): Tuple mapping class indices to names.
1221
+ names (Dict[int, str], optional): Dictionary of class names.
1223
1222
  """
1224
1223
  super().__init__(save_dir, plot, names)
1225
1224
  self.save_dir = save_dir
@@ -1420,23 +1419,23 @@ class OBBMetrics(SimpleClass, DataExportMixin):
1420
1419
  Attributes:
1421
1420
  save_dir (Path): Path to the directory where the output plots should be saved.
1422
1421
  plot (bool): Whether to save the detection plots.
1423
- names (dict): Dictionary of class names.
1422
+ names (Dict[int, str]): Dictionary of class names.
1424
1423
  box (Metric): An instance of the Metric class for storing detection results.
1425
- speed (dict): A dictionary for storing execution times of different parts of the detection process.
1424
+ speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
1426
1425
  task (str): The task type, set to 'obb'.
1427
1426
 
1428
1427
  References:
1429
1428
  https://arxiv.org/pdf/2106.06072.pdf
1430
1429
  """
1431
1430
 
1432
- def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: tuple = ()) -> None:
1431
+ def __init__(self, save_dir: Path = Path("."), plot: bool = False, names: Dict[int, str] = {}) -> None:
1433
1432
  """
1434
1433
  Initialize an OBBMetrics instance with directory, plotting, and class names.
1435
1434
 
1436
1435
  Args:
1437
1436
  save_dir (Path, optional): Directory to save plots.
1438
1437
  plot (bool, optional): Whether to plot precision-recall curves.
1439
- names (tuple, optional): Tuple mapping class indices to names.
1438
+ names (Dict[int, str], optional): Dictionary of class names.
1440
1439
  """
1441
1440
  self.save_dir = save_dir
1442
1441
  self.plot = plot
@@ -201,6 +201,11 @@ class Annotator:
201
201
  input_is_pil = isinstance(im, Image.Image)
202
202
  self.pil = pil or non_ascii or input_is_pil
203
203
  self.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2)
204
+ if not input_is_pil:
205
+ if im.shape[2] == 1: # handle grayscale
206
+ im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
207
+ elif im.shape[2] > 3: # multispectral
208
+ im = np.ascontiguousarray(im[..., :3])
204
209
  if self.pil: # use PIL
205
210
  self.im = im if input_is_pil else Image.fromarray(im)
206
211
  if self.im.mode not in {"RGB", "RGBA"}: # multispectral
@@ -216,10 +221,6 @@ class Annotator:
216
221
  if check_version(pil_version, "9.2.0"):
217
222
  self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height
218
223
  else: # use cv2
219
- if im.shape[2] == 1: # handle grayscale
220
- im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
221
- elif im.shape[2] > 3: # multispectral
222
- im = np.ascontiguousarray(im[..., :3])
223
224
  assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images."
224
225
  self.im = im if im.flags.writeable else im.copy()
225
226
  self.tf = max(self.lw - 1, 1) # font thickness
@@ -644,7 +645,7 @@ def save_one_box(
644
645
  gain (float, optional): A multiplicative factor to increase the size of the bounding box.
645
646
  pad (int, optional): The number of pixels to add to the width and height of the bounding box.
646
647
  square (bool, optional): If True, the bounding box will be transformed into a square.
647
- BGR (bool, optional): If True, the image will be saved in BGR format, otherwise in RGB.
648
+ BGR (bool, optional): If True, the image will be returned in BGR format, otherwise in RGB.
648
649
  save (bool, optional): If True, the cropped image will be saved to disk.
649
650
 
650
651
  Returns:
@@ -664,12 +665,14 @@ def save_one_box(
664
665
  b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
665
666
  xyxy = ops.xywh2xyxy(b).long()
666
667
  xyxy = ops.clip_boxes(xyxy, im.shape)
667
- crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR else -1)]
668
+ grayscale = im.shape[2] == 1 # grayscale image
669
+ crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR or grayscale else -1)]
668
670
  if save:
669
671
  file.parent.mkdir(parents=True, exist_ok=True) # make directory
670
672
  f = str(increment_path(file).with_suffix(".jpg"))
671
673
  # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
672
- Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
674
+ crop = crop.squeeze(-1) if grayscale else crop[..., ::-1] if BGR else crop
675
+ Image.fromarray(crop).save(f, quality=95, subsampling=0) # save RGB
673
676
  return crop
674
677
 
675
678
 
@@ -10,7 +10,7 @@ from contextlib import contextmanager
10
10
  from copy import deepcopy
11
11
  from datetime import datetime
12
12
  from pathlib import Path
13
- from typing import Union
13
+ from typing import Any, Dict, Union
14
14
 
15
15
  import numpy as np
16
16
  import torch
@@ -704,7 +704,7 @@ class ModelEMA:
704
704
  copy_attr(self.ema, model, include, exclude)
705
705
 
706
706
 
707
- def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict:
707
+ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: Dict[str, Any] = None) -> Dict[str, Any]:
708
708
  """
709
709
  Strip optimizer from 'f' to finalize training, optionally save as 's'.
710
710