ultralytics 8.3.221__py3-none-any.whl → 8.3.223__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tests/test_python.py CHANGED
@@ -136,23 +136,23 @@ def test_predict_visualize(model):
136
136
  YOLO(WEIGHTS_DIR / model)(SOURCE, imgsz=32, visualize=True)
137
137
 
138
138
 
139
- def test_predict_grey_and_4ch(tmp_path):
140
- """Test YOLO prediction on SOURCE converted to greyscale and 4-channel images with various filenames."""
139
+ def test_predict_gray_and_4ch(tmp_path):
140
+ """Test YOLO prediction on SOURCE converted to grayscale and 4-channel images with various filenames."""
141
141
  im = Image.open(SOURCE)
142
142
 
143
- source_greyscale = tmp_path / "greyscale.jpg"
143
+ source_grayscale = tmp_path / "grayscale.jpg"
144
144
  source_rgba = tmp_path / "4ch.png"
145
145
  source_non_utf = tmp_path / "non_UTF_测试文件_tést_image.jpg"
146
146
  source_spaces = tmp_path / "image with spaces.jpg"
147
147
 
148
- im.convert("L").save(source_greyscale) # greyscale
148
+ im.convert("L").save(source_grayscale) # grayscale
149
149
  im.convert("RGBA").save(source_rgba) # 4-ch PNG with alpha
150
150
  im.save(source_non_utf) # non-UTF characters in filename
151
151
  im.save(source_spaces) # spaces in filename
152
152
 
153
153
  # Inference
154
154
  model = YOLO(MODEL)
155
- for f in source_rgba, source_greyscale, source_non_utf, source_spaces:
155
+ for f in source_rgba, source_grayscale, source_non_utf, source_spaces:
156
156
  for source in Image.open(f), cv2.imread(str(f)), f:
157
157
  results = model(source, save=True, verbose=True, imgsz=32)
158
158
  assert len(results) == 1 # verify that an image was run
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- __version__ = "8.3.221"
3
+ __version__ = "8.3.223"
4
4
 
5
5
  import importlib
6
6
  import os
@@ -342,7 +342,7 @@ names:
342
342
  322: ringlet
343
343
  323: monarch butterfly
344
344
  324: small white
345
- 325: sulphur butterfly
345
+ 325: sulfur butterfly
346
346
  326: gossamer-winged butterfly
347
347
  327: starfish
348
348
  328: sea urchin
@@ -35,7 +35,7 @@ names:
35
35
  17: armband
36
36
  18: armchair
37
37
  19: armoire
38
- 20: armor/armour
38
+ 20: armor
39
39
  21: artichoke
40
40
  22: trash can/garbage can/wastebin/dustbin/trash barrel/trash bin
41
41
  23: ashtray
@@ -245,7 +245,7 @@ names:
245
245
  227: CD player
246
246
  228: celery
247
247
  229: cellular telephone/cellular phone/cellphone/mobile phone/smart phone
248
- 230: chain mail/ring mail/chain armor/chain armour/ring armor/ring armour
248
+ 230: chain mail/ring mail/chain armor/ring armor
249
249
  231: chair
250
250
  232: chaise longue/chaise/daybed
251
251
  233: chalice
@@ -305,7 +305,7 @@ names:
305
305
  287: coin
306
306
  288: colander/cullender
307
307
  289: coleslaw/slaw
308
- 290: coloring material/colouring material
308
+ 290: coloring material
309
309
  291: combination lock
310
310
  292: pacifier/teething ring
311
311
  293: comic book
@@ -401,7 +401,7 @@ names:
401
401
  383: domestic ass/donkey
402
402
  384: doorknob/doorhandle
403
403
  385: doormat/welcome mat
404
- 386: doughnut/donut
404
+ 386: donut
405
405
  387: dove
406
406
  388: dragonfly
407
407
  389: drawer
@@ -1072,7 +1072,7 @@ names:
1072
1072
  1054: tag
1073
1073
  1055: taillight/rear light
1074
1074
  1056: tambourine
1075
- 1057: army tank/armored combat vehicle/armoured combat vehicle
1075
+ 1057: army tank/armored combat vehicle
1076
1076
  1058: tank/tank storage vessel/storage tank
1077
1077
  1059: tank top/tank top clothing
1078
1078
  1060: tape/tape sticky cloth or paper
@@ -182,7 +182,7 @@ names:
182
182
  163: Dolphin
183
183
  164: Door
184
184
  165: Door handle
185
- 166: Doughnut
185
+ 166: Donut
186
186
  167: Dragonfly
187
187
  168: Drawer
188
188
  169: Dress
ultralytics/data/base.py CHANGED
@@ -307,7 +307,7 @@ class BaseDataset(Dataset):
307
307
  b += im.nbytes
308
308
  if not os.access(Path(im_file).parent, os.W_OK):
309
309
  self.cache = None
310
- LOGGER.warning(f"{self.prefix}Skipping caching images to disk, directory not writeable")
310
+ LOGGER.warning(f"{self.prefix}Skipping caching images to disk, directory not writable")
311
311
  return False
312
312
  disk_required = b * self.ni / n * (1 + safety_margin) # bytes required to cache dataset to disk
313
313
  total, _used, free = shutil.disk_usage(Path(self.im_files[0]).parent)
ultralytics/data/utils.py CHANGED
@@ -804,4 +804,4 @@ def save_dataset_cache_file(prefix: str, path: Path, x: dict, version: str):
804
804
  np.save(file, x)
805
805
  LOGGER.info(f"{prefix}New cache created: {path}")
806
806
  else:
807
- LOGGER.warning(f"{prefix}Cache directory {path.parent} is not writeable, cache not saved.")
807
+ LOGGER.warning(f"{prefix}Cache directory {path.parent} is not writable, cache not saved.")
@@ -107,9 +107,17 @@ from ultralytics.utils.checks import (
107
107
  is_intel,
108
108
  is_sudo_available,
109
109
  )
110
- from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download
111
- from ultralytics.utils.export import onnx2engine, torch2imx, torch2onnx
112
- from ultralytics.utils.files import file_size, spaces_in_path
110
+ from ultralytics.utils.downloads import get_github_assets, safe_download
111
+ from ultralytics.utils.export import (
112
+ keras2pb,
113
+ onnx2engine,
114
+ onnx2saved_model,
115
+ pb2tfjs,
116
+ tflite2edgetpu,
117
+ torch2imx,
118
+ torch2onnx,
119
+ )
120
+ from ultralytics.utils.files import file_size
113
121
  from ultralytics.utils.metrics import batch_probiou
114
122
  from ultralytics.utils.nms import TorchNMS
115
123
  from ultralytics.utils.ops import Profile
@@ -206,15 +214,6 @@ def validate_args(format, passed_args, valid_args):
206
214
  assert arg in valid_args, f"ERROR ❌️ argument '{arg}' is not supported for format='{format}'"
207
215
 
208
216
 
209
- def gd_outputs(gd):
210
- """Return TensorFlow GraphDef model output node names."""
211
- name_list, input_list = [], []
212
- for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
213
- name_list.append(node.name)
214
- input_list.extend(node.input)
215
- return sorted(f"{x}:0" for x in list(set(name_list) - set(input_list)) if not x.startswith("NoOp"))
216
-
217
-
218
217
  def try_export(inner_func):
219
218
  """YOLO export decorator, i.e. @try_export."""
220
219
  inner_args = get_default_args(inner_func)
@@ -367,11 +366,11 @@ class Exporter:
367
366
  if not self.args.int8:
368
367
  LOGGER.warning("IMX export requires int8=True, setting int8=True.")
369
368
  self.args.int8 = True
370
- if not self.args.nms:
369
+ if not self.args.nms and model.task in {"detect", "pose"}:
371
370
  LOGGER.warning("IMX export requires nms=True, setting nms=True.")
372
371
  self.args.nms = True
373
- if model.task not in {"detect", "pose"}:
374
- raise ValueError("IMX export only supported for detection and pose estimation models.")
372
+ if model.task not in {"detect", "pose", "classify"}:
373
+ raise ValueError("IMX export only supported for detection, pose estimation, and classification models.")
375
374
  if not hasattr(model, "names"):
376
375
  model.names = default_class_names()
377
376
  model.names = check_class_names(model.names)
@@ -396,8 +395,6 @@ class Exporter:
396
395
  assert self.args.name in RKNN_CHIPS, (
397
396
  f"Invalid processor name '{self.args.name}' for Rockchip RKNN export. Valid names are {RKNN_CHIPS}."
398
397
  )
399
- if self.args.int8 and tflite:
400
- assert not getattr(model, "end2end", False), "TFLite INT8 export not supported for end2end models."
401
398
  if self.args.nms:
402
399
  assert not isinstance(model, ClassificationModel), "'nms=True' is not valid for classification models."
403
400
  assert not tflite or not ARM64 or not LINUX, "TFLite export with NMS unsupported on ARM64 Linux"
@@ -463,6 +460,10 @@ class Exporter:
463
460
  from ultralytics.utils.export.imx import FXModel
464
461
 
465
462
  model = FXModel(model, self.imgsz)
463
+ if tflite or edgetpu:
464
+ from ultralytics.utils.export.tensorflow import tf_wrapper
465
+
466
+ model = tf_wrapper(model)
466
467
  for m in model.modules():
467
468
  if isinstance(m, Classify):
468
469
  m.export = True
@@ -644,7 +645,7 @@ class Exporter:
644
645
  assert TORCH_1_13, f"'nms=True' ONNX export requires torch>=1.13 (found torch=={TORCH_VERSION})"
645
646
 
646
647
  f = str(self.file.with_suffix(".onnx"))
647
- output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
648
+ output_names = ["output0", "output1"] if self.model.task == "segment" else ["output0"]
648
649
  dynamic = self.args.dynamic
649
650
  if dynamic:
650
651
  dynamic = {"images": {0: "batch", 2: "height", 3: "width"}} # shape(1,3,640,640)
@@ -1055,75 +1056,43 @@ class Exporter:
1055
1056
  if f.is_dir():
1056
1057
  shutil.rmtree(f) # delete output folder
1057
1058
 
1058
- # Pre-download calibration file to fix https://github.com/PINTO0309/onnx2tf/issues/545
1059
- onnx2tf_file = Path("calibration_image_sample_data_20x128x128x3_float32.npy")
1060
- if not onnx2tf_file.exists():
1061
- attempt_download_asset(f"{onnx2tf_file}.zip", unzip=True, delete=True)
1059
+ # Export to TF
1060
+ images = None
1061
+ if self.args.int8 and self.args.data:
1062
+ images = [batch["img"] for batch in self.get_int8_calibration_dataloader(prefix)]
1063
+ images = (
1064
+ torch.nn.functional.interpolate(torch.cat(images, 0).float(), size=self.imgsz)
1065
+ .permute(0, 2, 3, 1)
1066
+ .numpy()
1067
+ .astype(np.float32)
1068
+ )
1062
1069
 
1063
1070
  # Export to ONNX
1064
1071
  if isinstance(self.model.model[-1], RTDETRDecoder):
1065
1072
  self.args.opset = self.args.opset or 19
1066
1073
  assert 16 <= self.args.opset <= 19, "RTDETR export requires opset>=16;<=19"
1067
1074
  self.args.simplify = True
1068
- f_onnx = self.export_onnx()
1069
-
1070
- # Export to TF
1071
- np_data = None
1072
- if self.args.int8:
1073
- tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
1074
- if self.args.data:
1075
- f.mkdir()
1076
- images = [batch["img"] for batch in self.get_int8_calibration_dataloader(prefix)]
1077
- images = torch.nn.functional.interpolate(torch.cat(images, 0).float(), size=self.imgsz).permute(
1078
- 0, 2, 3, 1
1079
- )
1080
- np.save(str(tmp_file), images.numpy().astype(np.float32)) # BHWC
1081
- np_data = [["images", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]]
1082
-
1083
- import onnx2tf # scoped for after ONNX export for reduced conflict during import
1084
-
1085
- LOGGER.info(f"{prefix} starting TFLite export with onnx2tf {onnx2tf.__version__}...")
1086
- keras_model = onnx2tf.convert(
1087
- input_onnx_file_path=f_onnx,
1088
- output_folder_path=str(f),
1089
- not_use_onnxsim=True,
1090
- verbosity="error", # note INT8-FP16 activation bug https://github.com/ultralytics/ultralytics/issues/15873
1091
- output_integer_quantized_tflite=self.args.int8,
1092
- custom_input_op_name_np_data_path=np_data,
1093
- enable_batchmatmul_unfold=True and not self.args.int8, # fix lower no. of detected objects on GPU delegate
1094
- output_signaturedefs=True, # fix error with Attention block group convolution
1095
- disable_group_convolution=self.args.format in {"tfjs", "edgetpu"}, # fix error with group convolution
1075
+ f_onnx = self.export_onnx() # ensure ONNX is available
1076
+ keras_model = onnx2saved_model(
1077
+ f_onnx,
1078
+ f,
1079
+ int8=self.args.int8,
1080
+ images=images,
1081
+ disable_group_convolution=self.args.format in {"tfjs", "edgetpu"},
1082
+ prefix=prefix,
1096
1083
  )
1097
1084
  YAML.save(f / "metadata.yaml", self.metadata) # add metadata.yaml
1098
-
1099
- # Remove/rename TFLite models
1100
- if self.args.int8:
1101
- tmp_file.unlink(missing_ok=True)
1102
- for file in f.rglob("*_dynamic_range_quant.tflite"):
1103
- file.rename(file.with_name(file.stem.replace("_dynamic_range_quant", "_int8") + file.suffix))
1104
- for file in f.rglob("*_integer_quant_with_int16_act.tflite"):
1105
- file.unlink() # delete extra fp16 activation TFLite files
1106
-
1107
1085
  # Add TFLite metadata
1108
1086
  for file in f.rglob("*.tflite"):
1109
- f.unlink() if "quant_with_int16_act.tflite" in str(f) else self._add_tflite_metadata(file)
1087
+ file.unlink() if "quant_with_int16_act.tflite" in str(file) else self._add_tflite_metadata(file)
1110
1088
 
1111
1089
  return str(f), keras_model # or keras_model = tf.saved_model.load(f, tags=None, options=None)
1112
1090
 
1113
1091
  @try_export
1114
1092
  def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")):
1115
1093
  """Export YOLO model to TensorFlow GraphDef *.pb format https://github.com/leimao/Frozen-Graph-TensorFlow."""
1116
- import tensorflow as tf
1117
- from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
1118
-
1119
- LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
1120
1094
  f = self.file.with_suffix(".pb")
1121
-
1122
- m = tf.function(lambda x: keras_model(x)) # full model
1123
- m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
1124
- frozen_func = convert_variables_to_constants_v2(m)
1125
- frozen_func.graph.as_graph_def()
1126
- tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
1095
+ keras2pb(keras_model, f, prefix)
1127
1096
  return f
1128
1097
 
1129
1098
  @try_export
@@ -1191,22 +1160,11 @@ class Exporter:
1191
1160
  "sudo apt-get install edgetpu-compiler",
1192
1161
  ):
1193
1162
  subprocess.run(c if is_sudo_available() else c.replace("sudo ", ""), shell=True, check=True)
1194
- ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().rsplit(maxsplit=1)[-1]
1195
1163
 
1164
+ ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().rsplit(maxsplit=1)[-1]
1196
1165
  LOGGER.info(f"\n{prefix} starting export with Edge TPU compiler {ver}...")
1166
+ tflite2edgetpu(tflite_file=tflite_model, output_dir=tflite_model.parent, prefix=prefix)
1197
1167
  f = str(tflite_model).replace(".tflite", "_edgetpu.tflite") # Edge TPU model
1198
-
1199
- cmd = (
1200
- "edgetpu_compiler "
1201
- f'--out_dir "{Path(f).parent}" '
1202
- "--show_operations "
1203
- "--search_delegate "
1204
- "--delegate_search_step 30 "
1205
- "--timeout_sec 180 "
1206
- f'"{tflite_model}"'
1207
- )
1208
- LOGGER.info(f"{prefix} running '{cmd}'")
1209
- subprocess.run(cmd, shell=True)
1210
1168
  self._add_tflite_metadata(f)
1211
1169
  return f
1212
1170
 
@@ -1214,31 +1172,10 @@ class Exporter:
1214
1172
  def export_tfjs(self, prefix=colorstr("TensorFlow.js:")):
1215
1173
  """Export YOLO model to TensorFlow.js format."""
1216
1174
  check_requirements("tensorflowjs")
1217
- import tensorflow as tf
1218
- import tensorflowjs as tfjs
1219
1175
 
1220
- LOGGER.info(f"\n{prefix} starting export with tensorflowjs {tfjs.__version__}...")
1221
1176
  f = str(self.file).replace(self.file.suffix, "_web_model") # js dir
1222
1177
  f_pb = str(self.file.with_suffix(".pb")) # *.pb path
1223
-
1224
- gd = tf.Graph().as_graph_def() # TF GraphDef
1225
- with open(f_pb, "rb") as file:
1226
- gd.ParseFromString(file.read())
1227
- outputs = ",".join(gd_outputs(gd))
1228
- LOGGER.info(f"\n{prefix} output node names: {outputs}")
1229
-
1230
- quantization = "--quantize_float16" if self.args.half else "--quantize_uint8" if self.args.int8 else ""
1231
- with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_: # exporter can not handle spaces in path
1232
- cmd = (
1233
- "tensorflowjs_converter "
1234
- f'--input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"'
1235
- )
1236
- LOGGER.info(f"{prefix} running '{cmd}'")
1237
- subprocess.run(cmd, shell=True)
1238
-
1239
- if " " in f:
1240
- LOGGER.warning(f"{prefix} your model may not work correctly with spaces in path '{f}'.")
1241
-
1178
+ pb2tfjs(pb_file=f_pb, output_dir=f, half=self.args.half, int8=self.args.int8, prefix=prefix)
1242
1179
  # Add metadata
1243
1180
  YAML.save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
1244
1181
  return f
@@ -1510,17 +1447,16 @@ class NMSModel(torch.nn.Module):
1510
1447
  box, score, cls, extra = box[mask], score[mask], cls[mask], extra[mask]
1511
1448
  nmsbox = box.clone()
1512
1449
  # `8` is the minimum value experimented to get correct NMS results for obb
1513
- multiplier = 8 if self.obb else 1
1450
+ multiplier = (8 if self.obb else 1) / max(len(self.model.names), 1)
1514
1451
  # Normalize boxes for NMS since large values for class offset causes issue with int8 quantization
1515
1452
  if self.args.format == "tflite": # TFLite is already normalized
1516
1453
  nmsbox *= multiplier
1517
1454
  else:
1518
- nmsbox = multiplier * nmsbox / torch.tensor(x.shape[2:], **kwargs).max()
1519
- if not self.args.agnostic_nms: # class-specific NMS
1455
+ nmsbox = multiplier * (nmsbox / torch.tensor(x.shape[2:], **kwargs).max())
1456
+ if not self.args.agnostic_nms: # class-wise NMS
1520
1457
  end = 2 if self.obb else 4
1521
1458
  # fully explicit expansion otherwise reshape error
1522
- # large max_wh causes issues when quantizing
1523
- cls_offset = cls.reshape(-1, 1).expand(nmsbox.shape[0], end)
1459
+ cls_offset = cls.view(cls.shape[0], 1).expand(cls.shape[0], end)
1524
1460
  offbox = nmsbox[:, :end] + cls_offset * multiplier
1525
1461
  nmsbox = torch.cat((offbox, nmsbox[:, end:]), dim=-1)
1526
1462
  nms_fn = (
@@ -877,7 +877,7 @@ class Model(torch.nn.Module):
877
877
  >>> model = model._apply(lambda t: t.cuda()) # Move model to GPU
878
878
  """
879
879
  self._check_is_pytorch_model()
880
- self = super()._apply(fn) # noqa
880
+ self = super()._apply(fn)
881
881
  self.predictor = None # reset predictor as device may have changed
882
882
  self.overrides["device"] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0'
883
883
  return self
@@ -727,7 +727,7 @@ class BaseTrainer:
727
727
 
728
728
  def label_loss_items(self, loss_items=None, prefix="train"):
729
729
  """
730
- Return a loss dict with labelled training loss items tensor.
730
+ Return a loss dict with labeled training loss items tensor.
731
731
 
732
732
  Note:
733
733
  This is not needed for classification but necessary for segmentation & detection
@@ -89,7 +89,7 @@ class RTDETRDataset(YOLODataset):
89
89
  transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)
90
90
  else:
91
91
  # transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scale_fill=True)])
92
- transforms = Compose([])
92
+ transforms = Compose([lambda x: {**x, **{"ratio_pad": [x["ratio_pad"], [0, 0]]}}])
93
93
  transforms.append(
94
94
  Format(
95
95
  bbox_format="xywh",
@@ -38,7 +38,7 @@ class ClassificationTrainer(BaseTrainer):
38
38
  preprocess_batch: Preprocess a batch of images and classes.
39
39
  progress_string: Return a formatted string showing training progress.
40
40
  get_validator: Return an instance of ClassificationValidator.
41
- label_loss_items: Return a loss dict with labelled training loss items.
41
+ label_loss_items: Return a loss dict with labeled training loss items.
42
42
  final_eval: Evaluate trained model and save validation results.
43
43
  plot_training_samples: Plot training samples with their annotations.
44
44
 
@@ -178,7 +178,7 @@ class ClassificationTrainer(BaseTrainer):
178
178
 
179
179
  def label_loss_items(self, loss_items: torch.Tensor | None = None, prefix: str = "train"):
180
180
  """
181
- Return a loss dict with labelled training loss items tensor.
181
+ Return a loss dict with labeled training loss items tensor.
182
182
 
183
183
  Args:
184
184
  loss_items (torch.Tensor, optional): Loss tensor items.
@@ -428,7 +428,7 @@ class AutoBackend(nn.Module):
428
428
  LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...")
429
429
  import tensorflow as tf
430
430
 
431
- from ultralytics.engine.exporter import gd_outputs
431
+ from ultralytics.utils.export.tensorflow import gd_outputs
432
432
 
433
433
  def wrap_frozen_graph(gd, inputs, outputs):
434
434
  """Wrap frozen graphs for deployment."""
@@ -166,22 +166,8 @@ class Detect(nn.Module):
166
166
  self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
167
167
  self.shape = shape
168
168
 
169
- if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
170
- box = x_cat[:, : self.reg_max * 4]
171
- cls = x_cat[:, self.reg_max * 4 :]
172
- else:
173
- box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
174
-
175
- if self.export and self.format in {"tflite", "edgetpu"}:
176
- # Precompute normalization factor to increase numerical stability
177
- # See https://github.com/ultralytics/ultralytics/issues/7371
178
- grid_h = shape[2]
179
- grid_w = shape[3]
180
- grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
181
- norm = self.strides / (self.stride[0] * grid_size)
182
- dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
183
- else:
184
- dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
169
+ box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
170
+ dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
185
171
  return torch.cat((dbox, cls.sigmoid()), 1)
186
172
 
187
173
  def bias_init(self):
@@ -391,20 +377,9 @@ class Pose(Detect):
391
377
  """Decode keypoints from predictions."""
392
378
  ndim = self.kpt_shape[1]
393
379
  if self.export:
394
- if self.format in {
395
- "tflite",
396
- "edgetpu",
397
- }: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
398
- # Precompute normalization factor to increase numerical stability
399
- y = kpts.view(bs, *self.kpt_shape, -1)
400
- grid_h, grid_w = self.shape[2], self.shape[3]
401
- grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)
402
- norm = self.strides / (self.stride[0] * grid_size)
403
- a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm
404
- else:
405
- # NCNN fix
406
- y = kpts.view(bs, *self.kpt_shape, -1)
407
- a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
380
+ # NCNN fix
381
+ y = kpts.view(bs, *self.kpt_shape, -1)
382
+ a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
408
383
  if ndim == 3:
409
384
  a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
410
385
  return a.view(bs, self.nk, -1)
@@ -795,13 +795,13 @@ def is_pip_package(filepath: str = __name__) -> bool:
795
795
 
796
796
  def is_dir_writeable(dir_path: str | Path) -> bool:
797
797
  """
798
- Check if a directory is writeable.
798
+ Check if a directory is writable.
799
799
 
800
800
  Args:
801
801
  dir_path (str | Path): The path to the directory.
802
802
 
803
803
  Returns:
804
- (bool): True if the directory is writeable, False otherwise.
804
+ (bool): True if the directory is writable, False otherwise.
805
805
  """
806
806
  return os.access(str(dir_path), os.W_OK)
807
807
 
@@ -882,14 +882,14 @@ def get_user_config_dir(sub_dir="Ultralytics"):
882
882
  p.mkdir(parents=True, exist_ok=True)
883
883
  return p
884
884
 
885
- # Fallbacks for Docker, GCP/AWS functions where only /tmp is writeable
885
+ # Fallbacks for Docker, GCP/AWS functions where only /tmp is writable
886
886
  for alt in [Path("/tmp") / sub_dir, Path.cwd() / sub_dir]:
887
887
  if alt.exists():
888
888
  return alt
889
889
  if is_dir_writeable(alt.parent):
890
890
  alt.mkdir(parents=True, exist_ok=True)
891
891
  LOGGER.warning(
892
- f"user config directory '{p}' is not writeable, using '{alt}'. Set YOLO_CONFIG_DIR to override."
892
+ f"user config directory '{p}' is not writable, using '{alt}'. Set YOLO_CONFIG_DIR to override."
893
893
  )
894
894
  return alt
895
895
 
@@ -144,7 +144,9 @@ def benchmark(
144
144
  if format == "imx":
145
145
  assert not is_end2end
146
146
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 IMX exports not supported"
147
- assert model.task == "detect", "IMX only supported for detection task"
147
+ assert model.task in {"detect", "classify", "pose"}, (
148
+ "IMX export is only supported for detection, classification and pose estimation tasks"
149
+ )
148
150
  assert "C2f" in model.__str__(), "IMX only supported for YOLOv8n and YOLO11n"
149
151
  if format == "rknn":
150
152
  assert not isinstance(model, YOLOWorld), "YOLOWorldv2 RKNN exports not supported yet"