dgenerate-ultralytics-headless 8.3.222__py3-none-any.whl → 8.3.225__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/METADATA +2 -2
  2. dgenerate_ultralytics_headless-8.3.225.dist-info/RECORD +286 -0
  3. tests/conftest.py +5 -8
  4. tests/test_cli.py +1 -8
  5. tests/test_python.py +1 -2
  6. ultralytics/__init__.py +1 -1
  7. ultralytics/cfg/__init__.py +34 -49
  8. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  9. ultralytics/cfg/datasets/kitti.yaml +27 -0
  10. ultralytics/cfg/datasets/lvis.yaml +5 -5
  11. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  12. ultralytics/data/annotator.py +3 -4
  13. ultralytics/data/augment.py +244 -323
  14. ultralytics/data/base.py +12 -22
  15. ultralytics/data/build.py +47 -40
  16. ultralytics/data/converter.py +32 -42
  17. ultralytics/data/dataset.py +43 -71
  18. ultralytics/data/loaders.py +22 -34
  19. ultralytics/data/split.py +5 -6
  20. ultralytics/data/split_dota.py +8 -15
  21. ultralytics/data/utils.py +27 -36
  22. ultralytics/engine/exporter.py +49 -116
  23. ultralytics/engine/model.py +144 -180
  24. ultralytics/engine/predictor.py +18 -29
  25. ultralytics/engine/results.py +165 -231
  26. ultralytics/engine/trainer.py +11 -19
  27. ultralytics/engine/tuner.py +13 -23
  28. ultralytics/engine/validator.py +6 -10
  29. ultralytics/hub/__init__.py +7 -12
  30. ultralytics/hub/auth.py +6 -12
  31. ultralytics/hub/google/__init__.py +7 -10
  32. ultralytics/hub/session.py +15 -25
  33. ultralytics/hub/utils.py +3 -6
  34. ultralytics/models/fastsam/model.py +6 -8
  35. ultralytics/models/fastsam/predict.py +5 -10
  36. ultralytics/models/fastsam/utils.py +1 -2
  37. ultralytics/models/fastsam/val.py +2 -4
  38. ultralytics/models/nas/model.py +5 -8
  39. ultralytics/models/nas/predict.py +7 -9
  40. ultralytics/models/nas/val.py +1 -2
  41. ultralytics/models/rtdetr/model.py +5 -8
  42. ultralytics/models/rtdetr/predict.py +15 -18
  43. ultralytics/models/rtdetr/train.py +10 -13
  44. ultralytics/models/rtdetr/val.py +13 -20
  45. ultralytics/models/sam/amg.py +12 -18
  46. ultralytics/models/sam/build.py +6 -9
  47. ultralytics/models/sam/model.py +16 -23
  48. ultralytics/models/sam/modules/blocks.py +62 -84
  49. ultralytics/models/sam/modules/decoders.py +17 -24
  50. ultralytics/models/sam/modules/encoders.py +40 -56
  51. ultralytics/models/sam/modules/memory_attention.py +10 -16
  52. ultralytics/models/sam/modules/sam.py +41 -47
  53. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  54. ultralytics/models/sam/modules/transformer.py +17 -27
  55. ultralytics/models/sam/modules/utils.py +31 -42
  56. ultralytics/models/sam/predict.py +172 -209
  57. ultralytics/models/utils/loss.py +14 -26
  58. ultralytics/models/utils/ops.py +13 -17
  59. ultralytics/models/yolo/classify/predict.py +8 -11
  60. ultralytics/models/yolo/classify/train.py +8 -16
  61. ultralytics/models/yolo/classify/val.py +13 -20
  62. ultralytics/models/yolo/detect/predict.py +4 -8
  63. ultralytics/models/yolo/detect/train.py +11 -20
  64. ultralytics/models/yolo/detect/val.py +38 -48
  65. ultralytics/models/yolo/model.py +35 -47
  66. ultralytics/models/yolo/obb/predict.py +5 -8
  67. ultralytics/models/yolo/obb/train.py +11 -14
  68. ultralytics/models/yolo/obb/val.py +20 -28
  69. ultralytics/models/yolo/pose/predict.py +5 -8
  70. ultralytics/models/yolo/pose/train.py +4 -8
  71. ultralytics/models/yolo/pose/val.py +31 -39
  72. ultralytics/models/yolo/segment/predict.py +9 -14
  73. ultralytics/models/yolo/segment/train.py +3 -6
  74. ultralytics/models/yolo/segment/val.py +16 -26
  75. ultralytics/models/yolo/world/train.py +8 -14
  76. ultralytics/models/yolo/world/train_world.py +11 -16
  77. ultralytics/models/yolo/yoloe/predict.py +16 -23
  78. ultralytics/models/yolo/yoloe/train.py +30 -43
  79. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  80. ultralytics/models/yolo/yoloe/val.py +15 -20
  81. ultralytics/nn/autobackend.py +10 -18
  82. ultralytics/nn/modules/activation.py +4 -6
  83. ultralytics/nn/modules/block.py +99 -185
  84. ultralytics/nn/modules/conv.py +45 -90
  85. ultralytics/nn/modules/head.py +44 -98
  86. ultralytics/nn/modules/transformer.py +44 -76
  87. ultralytics/nn/modules/utils.py +14 -19
  88. ultralytics/nn/tasks.py +86 -146
  89. ultralytics/nn/text_model.py +25 -40
  90. ultralytics/solutions/ai_gym.py +10 -16
  91. ultralytics/solutions/analytics.py +7 -10
  92. ultralytics/solutions/config.py +4 -5
  93. ultralytics/solutions/distance_calculation.py +9 -12
  94. ultralytics/solutions/heatmap.py +7 -13
  95. ultralytics/solutions/instance_segmentation.py +5 -8
  96. ultralytics/solutions/object_blurrer.py +7 -10
  97. ultralytics/solutions/object_counter.py +8 -12
  98. ultralytics/solutions/object_cropper.py +5 -8
  99. ultralytics/solutions/parking_management.py +12 -14
  100. ultralytics/solutions/queue_management.py +4 -6
  101. ultralytics/solutions/region_counter.py +7 -10
  102. ultralytics/solutions/security_alarm.py +14 -19
  103. ultralytics/solutions/similarity_search.py +7 -12
  104. ultralytics/solutions/solutions.py +31 -53
  105. ultralytics/solutions/speed_estimation.py +6 -9
  106. ultralytics/solutions/streamlit_inference.py +2 -4
  107. ultralytics/solutions/trackzone.py +7 -10
  108. ultralytics/solutions/vision_eye.py +5 -8
  109. ultralytics/trackers/basetrack.py +2 -4
  110. ultralytics/trackers/bot_sort.py +6 -11
  111. ultralytics/trackers/byte_tracker.py +10 -15
  112. ultralytics/trackers/track.py +3 -6
  113. ultralytics/trackers/utils/gmc.py +6 -12
  114. ultralytics/trackers/utils/kalman_filter.py +35 -43
  115. ultralytics/trackers/utils/matching.py +6 -10
  116. ultralytics/utils/__init__.py +61 -100
  117. ultralytics/utils/autobatch.py +2 -4
  118. ultralytics/utils/autodevice.py +11 -13
  119. ultralytics/utils/benchmarks.py +25 -35
  120. ultralytics/utils/callbacks/base.py +8 -10
  121. ultralytics/utils/callbacks/clearml.py +2 -4
  122. ultralytics/utils/callbacks/comet.py +30 -44
  123. ultralytics/utils/callbacks/dvc.py +13 -18
  124. ultralytics/utils/callbacks/mlflow.py +4 -5
  125. ultralytics/utils/callbacks/neptune.py +4 -6
  126. ultralytics/utils/callbacks/raytune.py +3 -4
  127. ultralytics/utils/callbacks/tensorboard.py +4 -6
  128. ultralytics/utils/callbacks/wb.py +10 -13
  129. ultralytics/utils/checks.py +29 -56
  130. ultralytics/utils/cpu.py +1 -2
  131. ultralytics/utils/dist.py +8 -12
  132. ultralytics/utils/downloads.py +17 -27
  133. ultralytics/utils/errors.py +6 -8
  134. ultralytics/utils/events.py +2 -4
  135. ultralytics/utils/export/__init__.py +4 -239
  136. ultralytics/utils/export/engine.py +237 -0
  137. ultralytics/utils/export/imx.py +11 -17
  138. ultralytics/utils/export/tensorflow.py +217 -0
  139. ultralytics/utils/files.py +10 -15
  140. ultralytics/utils/git.py +5 -7
  141. ultralytics/utils/instance.py +30 -51
  142. ultralytics/utils/logger.py +11 -15
  143. ultralytics/utils/loss.py +8 -14
  144. ultralytics/utils/metrics.py +98 -138
  145. ultralytics/utils/nms.py +13 -16
  146. ultralytics/utils/ops.py +47 -74
  147. ultralytics/utils/patches.py +11 -18
  148. ultralytics/utils/plotting.py +29 -42
  149. ultralytics/utils/tal.py +25 -39
  150. ultralytics/utils/torch_utils.py +45 -73
  151. ultralytics/utils/tqdm.py +6 -8
  152. ultralytics/utils/triton.py +9 -12
  153. ultralytics/utils/tuner.py +1 -2
  154. dgenerate_ultralytics_headless-8.3.222.dist-info/RECORD +0 -283
  155. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/WHEEL +0 -0
  156. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/entry_points.txt +0 -0
  157. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/licenses/LICENSE +0 -0
  158. {dgenerate_ultralytics_headless-8.3.222.dist-info → dgenerate_ultralytics_headless-8.3.225.dist-info}/top_level.txt +0 -0
ultralytics/data/utils.py CHANGED
@@ -51,11 +51,10 @@ def img2label_paths(img_paths: list[str]) -> list[str]:
51
51
  def check_file_speeds(
52
52
  files: list[str], threshold_ms: float = 10, threshold_mb: float = 50, max_files: int = 5, prefix: str = ""
53
53
  ):
54
- """
55
- Check dataset file access speed and provide performance feedback.
54
+ """Check dataset file access speed and provide performance feedback.
56
55
 
57
- This function tests the access speed of dataset files by measuring ping (stat call) time and read speed.
58
- It samples up to 5 files from the provided list and warns if access times exceed the threshold.
56
+ This function tests the access speed of dataset files by measuring ping (stat call) time and read speed. It samples
57
+ up to 5 files from the provided list and warns if access times exceed the threshold.
59
58
 
60
59
  Args:
61
60
  files (list[str]): List of file paths to check for access speed.
@@ -251,13 +250,12 @@ def verify_image_label(args: tuple) -> list:
251
250
 
252
251
 
253
252
  def visualize_image_annotations(image_path: str, txt_path: str, label_map: dict[int, str]):
254
- """
255
- Visualize YOLO annotations (bounding boxes and class labels) on an image.
253
+ """Visualize YOLO annotations (bounding boxes and class labels) on an image.
256
254
 
257
- This function reads an image and its corresponding annotation file in YOLO format, then
258
- draws bounding boxes around detected objects and labels them with their respective class names.
259
- The bounding box colors are assigned based on the class ID, and the text color is dynamically
260
- adjusted for readability, depending on the background color's luminance.
255
+ This function reads an image and its corresponding annotation file in YOLO format, then draws bounding boxes around
256
+ detected objects and labels them with their respective class names. The bounding box colors are assigned based on
257
+ the class ID, and the text color is dynamically adjusted for readability, depending on the background color's
258
+ luminance.
261
259
 
262
260
  Args:
263
261
  image_path (str): The path to the image file to annotate, and it can be in formats supported by PIL.
@@ -297,13 +295,12 @@ def visualize_image_annotations(image_path: str, txt_path: str, label_map: dict[
297
295
  def polygon2mask(
298
296
  imgsz: tuple[int, int], polygons: list[np.ndarray], color: int = 1, downsample_ratio: int = 1
299
297
  ) -> np.ndarray:
300
- """
301
- Convert a list of polygons to a binary mask of the specified image size.
298
+ """Convert a list of polygons to a binary mask of the specified image size.
302
299
 
303
300
  Args:
304
301
  imgsz (tuple[int, int]): The size of the image as (height, width).
305
- polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where
306
- N is the number of polygons, and M is the number of points such that M % 2 = 0.
302
+ polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where N is the
303
+ number of polygons, and M is the number of points such that M % 2 = 0.
307
304
  color (int, optional): The color value to fill in the polygons on the mask.
308
305
  downsample_ratio (int, optional): Factor by which to downsample the mask.
309
306
 
@@ -322,13 +319,12 @@ def polygon2mask(
322
319
  def polygons2masks(
323
320
  imgsz: tuple[int, int], polygons: list[np.ndarray], color: int, downsample_ratio: int = 1
324
321
  ) -> np.ndarray:
325
- """
326
- Convert a list of polygons to a set of binary masks of the specified image size.
322
+ """Convert a list of polygons to a set of binary masks of the specified image size.
327
323
 
328
324
  Args:
329
325
  imgsz (tuple[int, int]): The size of the image as (height, width).
330
- polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where
331
- N is the number of polygons, and M is the number of points such that M % 2 = 0.
326
+ polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where N is the
327
+ number of polygons, and M is the number of points such that M % 2 = 0.
332
328
  color (int): The color value to fill in the polygons on the masks.
333
329
  downsample_ratio (int, optional): Factor by which to downsample each mask.
334
330
 
@@ -368,8 +364,7 @@ def polygons2masks_overlap(
368
364
 
369
365
 
370
366
  def find_dataset_yaml(path: Path) -> Path:
371
- """
372
- Find and return the YAML file associated with a Detect, Segment or Pose dataset.
367
+ """Find and return the YAML file associated with a Detect, Segment or Pose dataset.
373
368
 
374
369
  This function searches for a YAML file at the root level of the provided directory first, and if not found, it
375
370
  performs a recursive search. It prefers YAML files that have the same stem as the provided path.
@@ -389,8 +384,7 @@ def find_dataset_yaml(path: Path) -> Path:
389
384
 
390
385
 
391
386
  def check_det_dataset(dataset: str, autodownload: bool = True) -> dict[str, Any]:
392
- """
393
- Download, verify, and/or unzip a dataset if not found locally.
387
+ """Download, verify, and/or unzip a dataset if not found locally.
394
388
 
395
389
  This function checks the availability of a specified dataset, and if not found, it has the option to download and
396
390
  unzip the dataset. It then reads and parses the accompanying YAML data, ensuring key requirements are met and also
@@ -484,11 +478,10 @@ def check_det_dataset(dataset: str, autodownload: bool = True) -> dict[str, Any]
484
478
 
485
479
 
486
480
  def check_cls_dataset(dataset: str | Path, split: str = "") -> dict[str, Any]:
487
- """
488
- Check a classification dataset such as Imagenet.
481
+ """Check a classification dataset such as Imagenet.
489
482
 
490
- This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.
491
- If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
483
+ This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information. If the
484
+ dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
492
485
 
493
486
  Args:
494
487
  dataset (str | Path): The name of the dataset.
@@ -581,8 +574,7 @@ def check_cls_dataset(dataset: str | Path, split: str = "") -> dict[str, Any]:
581
574
 
582
575
 
583
576
  class HUBDatasetStats:
584
- """
585
- A class for generating HUB dataset JSON and `-hub` dataset directory.
577
+ """A class for generating HUB dataset JSON and `-hub` dataset directory.
586
578
 
587
579
  Args:
588
580
  path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip).
@@ -600,10 +592,6 @@ class HUBDatasetStats:
600
592
  get_json: Return dataset JSON for Ultralytics HUB.
601
593
  process_images: Compress images for Ultralytics HUB.
602
594
 
603
- Note:
604
- Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
605
- i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
606
-
607
595
  Examples:
608
596
  >>> from ultralytics.data.utils import HUBDatasetStats
609
597
  >>> stats = HUBDatasetStats("path/to/coco8.zip", task="detect") # detect dataset
@@ -613,6 +601,10 @@ class HUBDatasetStats:
613
601
  >>> stats = HUBDatasetStats("path/to/imagenet10.zip", task="classify") # classification dataset
614
602
  >>> stats.get_json(save=True)
615
603
  >>> stats.process_images()
604
+
605
+ Notes:
606
+ Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
607
+ i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
616
608
  """
617
609
 
618
610
  def __init__(self, path: str = "coco8.yaml", task: str = "detect", autodownload: bool = False):
@@ -748,10 +740,9 @@ class HUBDatasetStats:
748
740
 
749
741
 
750
742
  def compress_one_image(f: str, f_new: str | None = None, max_dim: int = 1920, quality: int = 50):
751
- """
752
- Compress a single image file to reduced size while preserving its aspect ratio and quality using either the Python
753
- Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will not be
754
- resized.
743
+ """Compress a single image file to reduced size while preserving its aspect ratio and quality using either the
744
+ Python Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it
745
+ will not be resized.
755
746
 
756
747
  Args:
757
748
  f (str): The path to the input image file.
@@ -107,9 +107,17 @@ from ultralytics.utils.checks import (
107
107
  is_intel,
108
108
  is_sudo_available,
109
109
  )
110
- from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download
111
- from ultralytics.utils.export import onnx2engine, torch2imx, torch2onnx
112
- from ultralytics.utils.files import file_size, spaces_in_path
110
+ from ultralytics.utils.downloads import get_github_assets, safe_download
111
+ from ultralytics.utils.export import (
112
+ keras2pb,
113
+ onnx2engine,
114
+ onnx2saved_model,
115
+ pb2tfjs,
116
+ tflite2edgetpu,
117
+ torch2imx,
118
+ torch2onnx,
119
+ )
120
+ from ultralytics.utils.files import file_size
113
121
  from ultralytics.utils.metrics import batch_probiou
114
122
  from ultralytics.utils.nms import TorchNMS
115
123
  from ultralytics.utils.ops import Profile
@@ -150,7 +158,7 @@ def export_formats():
150
158
  ["NCNN", "ncnn", "_ncnn_model", True, True, ["batch", "half"]],
151
159
  ["IMX", "imx", "_imx_model", True, True, ["int8", "fraction", "nms"]],
152
160
  ["RKNN", "rknn", "_rknn_model", False, False, ["batch", "name"]],
153
- ["ExecuTorch", "executorch", "_executorch_model", False, False, ["batch"]],
161
+ ["ExecuTorch", "executorch", "_executorch_model", True, False, ["batch"]],
154
162
  ]
155
163
  return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU", "Arguments"], zip(*x)))
156
164
 
@@ -184,8 +192,7 @@ def best_onnx_opset(onnx, cuda=False) -> int:
184
192
 
185
193
 
186
194
  def validate_args(format, passed_args, valid_args):
187
- """
188
- Validate arguments based on the export format.
195
+ """Validate arguments based on the export format.
189
196
 
190
197
  Args:
191
198
  format (str): The export format.
@@ -206,15 +213,6 @@ def validate_args(format, passed_args, valid_args):
206
213
  assert arg in valid_args, f"ERROR ❌️ argument '{arg}' is not supported for format='{format}'"
207
214
 
208
215
 
209
- def gd_outputs(gd):
210
- """Return TensorFlow GraphDef model output node names."""
211
- name_list, input_list = [], []
212
- for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
213
- name_list.append(node.name)
214
- input_list.extend(node.input)
215
- return sorted(f"{x}:0" for x in list(set(name_list) - set(input_list)) if not x.startswith("NoOp"))
216
-
217
-
218
216
  def try_export(inner_func):
219
217
  """YOLO export decorator, i.e. @try_export."""
220
218
  inner_args = get_default_args(inner_func)
@@ -239,8 +237,7 @@ def try_export(inner_func):
239
237
 
240
238
 
241
239
  class Exporter:
242
- """
243
- A class for exporting YOLO models to various formats.
240
+ """A class for exporting YOLO models to various formats.
244
241
 
245
242
  This class provides functionality to export YOLO models to different formats including ONNX, TensorRT, CoreML,
246
243
  TensorFlow, and others. It handles format validation, device selection, model preparation, and the actual export
@@ -290,8 +287,7 @@ class Exporter:
290
287
  """
291
288
 
292
289
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
293
- """
294
- Initialize the Exporter class.
290
+ """Initialize the Exporter class.
295
291
 
296
292
  Args:
297
293
  cfg (str, optional): Path to a configuration file.
@@ -371,7 +367,7 @@ class Exporter:
371
367
  LOGGER.warning("IMX export requires nms=True, setting nms=True.")
372
368
  self.args.nms = True
373
369
  if model.task not in {"detect", "pose", "classify"}:
374
- raise ValueError("IMX export only supported for detection and pose estimation models.")
370
+ raise ValueError("IMX export only supported for detection, pose estimation, and classification models.")
375
371
  if not hasattr(model, "names"):
376
372
  model.names = default_class_names()
377
373
  model.names = check_class_names(model.names)
@@ -461,6 +457,10 @@ class Exporter:
461
457
  from ultralytics.utils.export.imx import FXModel
462
458
 
463
459
  model = FXModel(model, self.imgsz)
460
+ if tflite or edgetpu:
461
+ from ultralytics.utils.export.tensorflow import tf_wrapper
462
+
463
+ model = tf_wrapper(model)
464
464
  for m in model.modules():
465
465
  if isinstance(m, Classify):
466
466
  m.export = True
@@ -642,7 +642,7 @@ class Exporter:
642
642
  assert TORCH_1_13, f"'nms=True' ONNX export requires torch>=1.13 (found torch=={TORCH_VERSION})"
643
643
 
644
644
  f = str(self.file.with_suffix(".onnx"))
645
- output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
645
+ output_names = ["output0", "output1"] if self.model.task == "segment" else ["output0"]
646
646
  dynamic = self.args.dynamic
647
647
  if dynamic:
648
648
  dynamic = {"images": {0: "batch", 2: "height", 3: "width"}} # shape(1,3,640,640)
@@ -1053,75 +1053,43 @@ class Exporter:
1053
1053
  if f.is_dir():
1054
1054
  shutil.rmtree(f) # delete output folder
1055
1055
 
1056
- # Pre-download calibration file to fix https://github.com/PINTO0309/onnx2tf/issues/545
1057
- onnx2tf_file = Path("calibration_image_sample_data_20x128x128x3_float32.npy")
1058
- if not onnx2tf_file.exists():
1059
- attempt_download_asset(f"{onnx2tf_file}.zip", unzip=True, delete=True)
1056
+ # Export to TF
1057
+ images = None
1058
+ if self.args.int8 and self.args.data:
1059
+ images = [batch["img"] for batch in self.get_int8_calibration_dataloader(prefix)]
1060
+ images = (
1061
+ torch.nn.functional.interpolate(torch.cat(images, 0).float(), size=self.imgsz)
1062
+ .permute(0, 2, 3, 1)
1063
+ .numpy()
1064
+ .astype(np.float32)
1065
+ )
1060
1066
 
1061
1067
  # Export to ONNX
1062
1068
  if isinstance(self.model.model[-1], RTDETRDecoder):
1063
1069
  self.args.opset = self.args.opset or 19
1064
1070
  assert 16 <= self.args.opset <= 19, "RTDETR export requires opset>=16;<=19"
1065
1071
  self.args.simplify = True
1066
- f_onnx = self.export_onnx()
1067
-
1068
- # Export to TF
1069
- np_data = None
1070
- if self.args.int8:
1071
- tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
1072
- if self.args.data:
1073
- f.mkdir()
1074
- images = [batch["img"] for batch in self.get_int8_calibration_dataloader(prefix)]
1075
- images = torch.nn.functional.interpolate(torch.cat(images, 0).float(), size=self.imgsz).permute(
1076
- 0, 2, 3, 1
1077
- )
1078
- np.save(str(tmp_file), images.numpy().astype(np.float32)) # BHWC
1079
- np_data = [["images", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]]
1080
-
1081
- import onnx2tf # scoped for after ONNX export for reduced conflict during import
1082
-
1083
- LOGGER.info(f"{prefix} starting TFLite export with onnx2tf {onnx2tf.__version__}...")
1084
- keras_model = onnx2tf.convert(
1085
- input_onnx_file_path=f_onnx,
1086
- output_folder_path=str(f),
1087
- not_use_onnxsim=True,
1088
- verbosity="error", # note INT8-FP16 activation bug https://github.com/ultralytics/ultralytics/issues/15873
1089
- output_integer_quantized_tflite=self.args.int8,
1090
- custom_input_op_name_np_data_path=np_data,
1091
- enable_batchmatmul_unfold=True and not self.args.int8, # fix lower no. of detected objects on GPU delegate
1092
- output_signaturedefs=True, # fix error with Attention block group convolution
1093
- disable_group_convolution=self.args.format in {"tfjs", "edgetpu"}, # fix error with group convolution
1072
+ f_onnx = self.export_onnx() # ensure ONNX is available
1073
+ keras_model = onnx2saved_model(
1074
+ f_onnx,
1075
+ f,
1076
+ int8=self.args.int8,
1077
+ images=images,
1078
+ disable_group_convolution=self.args.format in {"tfjs", "edgetpu"},
1079
+ prefix=prefix,
1094
1080
  )
1095
1081
  YAML.save(f / "metadata.yaml", self.metadata) # add metadata.yaml
1096
-
1097
- # Remove/rename TFLite models
1098
- if self.args.int8:
1099
- tmp_file.unlink(missing_ok=True)
1100
- for file in f.rglob("*_dynamic_range_quant.tflite"):
1101
- file.rename(file.with_name(file.stem.replace("_dynamic_range_quant", "_int8") + file.suffix))
1102
- for file in f.rglob("*_integer_quant_with_int16_act.tflite"):
1103
- file.unlink() # delete extra fp16 activation TFLite files
1104
-
1105
1082
  # Add TFLite metadata
1106
1083
  for file in f.rglob("*.tflite"):
1107
- f.unlink() if "quant_with_int16_act.tflite" in str(f) else self._add_tflite_metadata(file)
1084
+ file.unlink() if "quant_with_int16_act.tflite" in str(file) else self._add_tflite_metadata(file)
1108
1085
 
1109
1086
  return str(f), keras_model # or keras_model = tf.saved_model.load(f, tags=None, options=None)
1110
1087
 
1111
1088
  @try_export
1112
1089
  def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")):
1113
1090
  """Export YOLO model to TensorFlow GraphDef *.pb format https://github.com/leimao/Frozen-Graph-TensorFlow."""
1114
- import tensorflow as tf
1115
- from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
1116
-
1117
- LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...")
1118
1091
  f = self.file.with_suffix(".pb")
1119
-
1120
- m = tf.function(lambda x: keras_model(x)) # full model
1121
- m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
1122
- frozen_func = convert_variables_to_constants_v2(m)
1123
- frozen_func.graph.as_graph_def()
1124
- tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
1092
+ keras2pb(keras_model, f, prefix)
1125
1093
  return f
1126
1094
 
1127
1095
  @try_export
@@ -1189,22 +1157,11 @@ class Exporter:
1189
1157
  "sudo apt-get install edgetpu-compiler",
1190
1158
  ):
1191
1159
  subprocess.run(c if is_sudo_available() else c.replace("sudo ", ""), shell=True, check=True)
1192
- ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().rsplit(maxsplit=1)[-1]
1193
1160
 
1161
+ ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().rsplit(maxsplit=1)[-1]
1194
1162
  LOGGER.info(f"\n{prefix} starting export with Edge TPU compiler {ver}...")
1163
+ tflite2edgetpu(tflite_file=tflite_model, output_dir=tflite_model.parent, prefix=prefix)
1195
1164
  f = str(tflite_model).replace(".tflite", "_edgetpu.tflite") # Edge TPU model
1196
-
1197
- cmd = (
1198
- "edgetpu_compiler "
1199
- f'--out_dir "{Path(f).parent}" '
1200
- "--show_operations "
1201
- "--search_delegate "
1202
- "--delegate_search_step 30 "
1203
- "--timeout_sec 180 "
1204
- f'"{tflite_model}"'
1205
- )
1206
- LOGGER.info(f"{prefix} running '{cmd}'")
1207
- subprocess.run(cmd, shell=True)
1208
1165
  self._add_tflite_metadata(f)
1209
1166
  return f
1210
1167
 
@@ -1212,31 +1169,10 @@ class Exporter:
1212
1169
  def export_tfjs(self, prefix=colorstr("TensorFlow.js:")):
1213
1170
  """Export YOLO model to TensorFlow.js format."""
1214
1171
  check_requirements("tensorflowjs")
1215
- import tensorflow as tf
1216
- import tensorflowjs as tfjs
1217
1172
 
1218
- LOGGER.info(f"\n{prefix} starting export with tensorflowjs {tfjs.__version__}...")
1219
1173
  f = str(self.file).replace(self.file.suffix, "_web_model") # js dir
1220
1174
  f_pb = str(self.file.with_suffix(".pb")) # *.pb path
1221
-
1222
- gd = tf.Graph().as_graph_def() # TF GraphDef
1223
- with open(f_pb, "rb") as file:
1224
- gd.ParseFromString(file.read())
1225
- outputs = ",".join(gd_outputs(gd))
1226
- LOGGER.info(f"\n{prefix} output node names: {outputs}")
1227
-
1228
- quantization = "--quantize_float16" if self.args.half else "--quantize_uint8" if self.args.int8 else ""
1229
- with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_: # exporter can not handle spaces in path
1230
- cmd = (
1231
- "tensorflowjs_converter "
1232
- f'--input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"'
1233
- )
1234
- LOGGER.info(f"{prefix} running '{cmd}'")
1235
- subprocess.run(cmd, shell=True)
1236
-
1237
- if " " in f:
1238
- LOGGER.warning(f"{prefix} your model may not work correctly with spaces in path '{f}'.")
1239
-
1175
+ pb2tfjs(pb_file=f_pb, output_dir=f, half=self.args.half, int8=self.args.int8, prefix=prefix)
1240
1176
  # Add metadata
1241
1177
  YAML.save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
1242
1178
  return f
@@ -1420,8 +1356,7 @@ class IOSDetectModel(torch.nn.Module):
1420
1356
  """Wrap an Ultralytics YOLO model for Apple iOS CoreML export."""
1421
1357
 
1422
1358
  def __init__(self, model, im, mlprogram=True):
1423
- """
1424
- Initialize the IOSDetectModel class with a YOLO model and example image.
1359
+ """Initialize the IOSDetectModel class with a YOLO model and example image.
1425
1360
 
1426
1361
  Args:
1427
1362
  model (torch.nn.Module): The YOLO model to wrap.
@@ -1455,8 +1390,7 @@ class NMSModel(torch.nn.Module):
1455
1390
  """Model wrapper with embedded NMS for Detect, Segment, Pose and OBB."""
1456
1391
 
1457
1392
  def __init__(self, model, args):
1458
- """
1459
- Initialize the NMSModel.
1393
+ """Initialize the NMSModel.
1460
1394
 
1461
1395
  Args:
1462
1396
  model (torch.nn.Module): The model to wrap with NMS postprocessing.
@@ -1469,15 +1403,14 @@ class NMSModel(torch.nn.Module):
1469
1403
  self.is_tf = self.args.format in frozenset({"saved_model", "tflite", "tfjs"})
1470
1404
 
1471
1405
  def forward(self, x):
1472
- """
1473
- Perform inference with NMS post-processing. Supports Detect, Segment, OBB and Pose.
1406
+ """Perform inference with NMS post-processing. Supports Detect, Segment, OBB and Pose.
1474
1407
 
1475
1408
  Args:
1476
1409
  x (torch.Tensor): The preprocessed tensor with shape (N, 3, H, W).
1477
1410
 
1478
1411
  Returns:
1479
- (torch.Tensor): List of detections, each an (N, max_det, 4 + 2 + extra_shape) Tensor where N is the
1480
- number of detections after NMS.
1412
+ (torch.Tensor): List of detections, each an (N, max_det, 4 + 2 + extra_shape) Tensor where N is the number
1413
+ of detections after NMS.
1481
1414
  """
1482
1415
  from functools import partial
1483
1416