ultralytics 8.0.227__tar.gz → 8.0.229__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (189) hide show
  1. {ultralytics-8.0.227/ultralytics.egg-info → ultralytics-8.0.229}/PKG-INFO +3 -1
  2. {ultralytics-8.0.227 → ultralytics-8.0.229}/requirements.txt +2 -0
  3. {ultralytics-8.0.227 → ultralytics-8.0.229}/setup.py +2 -0
  4. {ultralytics-8.0.227 → ultralytics-8.0.229}/tests/test_cuda.py +1 -0
  5. {ultralytics-8.0.227 → ultralytics-8.0.229}/tests/test_python.py +10 -0
  6. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/__init__.py +1 -1
  7. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/__init__.py +1 -1
  8. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/default.yaml +2 -0
  9. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/data/build.py +1 -1
  10. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/engine/exporter.py +12 -7
  11. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/engine/model.py +19 -1
  12. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/engine/predictor.py +4 -1
  13. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/engine/trainer.py +46 -26
  14. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/nn/autobackend.py +3 -2
  15. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/nn/tasks.py +17 -6
  16. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/solutions/ai_gym.py +1 -1
  17. ultralytics-8.0.229/ultralytics/solutions/heatmap.py +265 -0
  18. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/solutions/object_counter.py +84 -26
  19. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/plotting.py +92 -22
  20. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/torch_utils.py +1 -1
  21. {ultralytics-8.0.227 → ultralytics-8.0.229/ultralytics.egg-info}/PKG-INFO +3 -1
  22. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics.egg-info/requires.txt +2 -0
  23. ultralytics-8.0.227/ultralytics/solutions/heatmap.py +0 -181
  24. {ultralytics-8.0.227 → ultralytics-8.0.229}/CONTRIBUTING.md +0 -0
  25. {ultralytics-8.0.227 → ultralytics-8.0.229}/LICENSE +0 -0
  26. {ultralytics-8.0.227 → ultralytics-8.0.229}/MANIFEST.in +0 -0
  27. {ultralytics-8.0.227 → ultralytics-8.0.229}/README.md +0 -0
  28. {ultralytics-8.0.227 → ultralytics-8.0.229}/README.zh-CN.md +0 -0
  29. {ultralytics-8.0.227 → ultralytics-8.0.229}/setup.cfg +0 -0
  30. {ultralytics-8.0.227 → ultralytics-8.0.229}/tests/conftest.py +0 -0
  31. {ultralytics-8.0.227 → ultralytics-8.0.229}/tests/test_cli.py +0 -0
  32. {ultralytics-8.0.227 → ultralytics-8.0.229}/tests/test_engine.py +0 -0
  33. {ultralytics-8.0.227 → ultralytics-8.0.229}/tests/test_integrations.py +0 -0
  34. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/assets/bus.jpg +0 -0
  35. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/assets/zidane.jpg +0 -0
  36. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/datasets/Argoverse.yaml +0 -0
  37. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/datasets/DOTAv2.yaml +0 -0
  38. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/datasets/GlobalWheat2020.yaml +0 -0
  39. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/datasets/ImageNet.yaml +0 -0
  40. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/datasets/Objects365.yaml +0 -0
  41. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/datasets/SKU-110K.yaml +0 -0
  42. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/datasets/VOC.yaml +0 -0
  43. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/datasets/VisDrone.yaml +0 -0
  44. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/datasets/coco-pose.yaml +0 -0
  45. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/datasets/coco.yaml +0 -0
  46. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/datasets/coco128-seg.yaml +0 -0
  47. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/datasets/coco128.yaml +0 -0
  48. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/datasets/coco8-pose.yaml +0 -0
  49. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/datasets/coco8-seg.yaml +0 -0
  50. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/datasets/coco8.yaml +0 -0
  51. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/datasets/open-images-v7.yaml +0 -0
  52. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/datasets/tiger-pose.yaml +0 -0
  53. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/datasets/xView.yaml +0 -0
  54. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +0 -0
  55. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +0 -0
  56. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +0 -0
  57. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +0 -0
  58. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/v3/yolov3-spp.yaml +0 -0
  59. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/v3/yolov3-tiny.yaml +0 -0
  60. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/v3/yolov3.yaml +0 -0
  61. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/v5/yolov5-p6.yaml +0 -0
  62. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/v5/yolov5.yaml +0 -0
  63. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/v6/yolov6.yaml +0 -0
  64. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/v8/yolov8-cls.yaml +0 -0
  65. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +0 -0
  66. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +0 -0
  67. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/v8/yolov8-ghost.yaml +0 -0
  68. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/v8/yolov8-p2.yaml +0 -0
  69. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/v8/yolov8-p6.yaml +0 -0
  70. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +0 -0
  71. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/v8/yolov8-pose.yaml +0 -0
  72. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +0 -0
  73. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +0 -0
  74. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/v8/yolov8-seg.yaml +0 -0
  75. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/models/v8/yolov8.yaml +0 -0
  76. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/trackers/botsort.yaml +0 -0
  77. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/cfg/trackers/bytetrack.yaml +0 -0
  78. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/data/__init__.py +0 -0
  79. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/data/annotator.py +0 -0
  80. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/data/augment.py +0 -0
  81. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/data/base.py +0 -0
  82. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/data/converter.py +0 -0
  83. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/data/dataset.py +0 -0
  84. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/data/loaders.py +0 -0
  85. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/data/utils.py +0 -0
  86. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/engine/__init__.py +0 -0
  87. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/engine/results.py +0 -0
  88. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/engine/tuner.py +0 -0
  89. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/engine/validator.py +0 -0
  90. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/hub/__init__.py +0 -0
  91. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/hub/auth.py +0 -0
  92. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/hub/session.py +0 -0
  93. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/hub/utils.py +0 -0
  94. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/__init__.py +0 -0
  95. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/fastsam/__init__.py +0 -0
  96. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/fastsam/model.py +0 -0
  97. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/fastsam/predict.py +0 -0
  98. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/fastsam/prompt.py +0 -0
  99. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/fastsam/utils.py +0 -0
  100. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/fastsam/val.py +0 -0
  101. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/nas/__init__.py +0 -0
  102. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/nas/model.py +0 -0
  103. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/nas/predict.py +0 -0
  104. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/nas/val.py +0 -0
  105. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/rtdetr/__init__.py +0 -0
  106. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/rtdetr/model.py +0 -0
  107. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/rtdetr/predict.py +0 -0
  108. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/rtdetr/train.py +0 -0
  109. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/rtdetr/val.py +0 -0
  110. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/sam/__init__.py +0 -0
  111. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/sam/amg.py +0 -0
  112. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/sam/build.py +0 -0
  113. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/sam/model.py +0 -0
  114. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/sam/modules/__init__.py +0 -0
  115. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/sam/modules/decoders.py +0 -0
  116. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/sam/modules/encoders.py +0 -0
  117. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/sam/modules/sam.py +0 -0
  118. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/sam/modules/tiny_encoder.py +0 -0
  119. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/sam/modules/transformer.py +0 -0
  120. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/sam/predict.py +0 -0
  121. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/utils/__init__.py +0 -0
  122. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/utils/loss.py +0 -0
  123. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/utils/ops.py +0 -0
  124. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/yolo/__init__.py +0 -0
  125. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/yolo/classify/__init__.py +0 -0
  126. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/yolo/classify/predict.py +0 -0
  127. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/yolo/classify/train.py +0 -0
  128. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/yolo/classify/val.py +0 -0
  129. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/yolo/detect/__init__.py +0 -0
  130. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/yolo/detect/predict.py +0 -0
  131. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/yolo/detect/train.py +0 -0
  132. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/yolo/detect/val.py +0 -0
  133. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/yolo/model.py +0 -0
  134. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/yolo/pose/__init__.py +0 -0
  135. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/yolo/pose/predict.py +0 -0
  136. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/yolo/pose/train.py +0 -0
  137. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/yolo/pose/val.py +0 -0
  138. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/yolo/segment/__init__.py +0 -0
  139. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/yolo/segment/predict.py +0 -0
  140. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/yolo/segment/train.py +0 -0
  141. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/models/yolo/segment/val.py +0 -0
  142. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/nn/__init__.py +0 -0
  143. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/nn/modules/__init__.py +0 -0
  144. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/nn/modules/block.py +0 -0
  145. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/nn/modules/conv.py +0 -0
  146. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/nn/modules/head.py +0 -0
  147. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/nn/modules/transformer.py +0 -0
  148. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/nn/modules/utils.py +0 -0
  149. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/solutions/__init__.py +0 -0
  150. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/trackers/__init__.py +0 -0
  151. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/trackers/basetrack.py +0 -0
  152. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/trackers/bot_sort.py +0 -0
  153. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/trackers/byte_tracker.py +0 -0
  154. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/trackers/track.py +0 -0
  155. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/trackers/utils/__init__.py +0 -0
  156. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/trackers/utils/gmc.py +0 -0
  157. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/trackers/utils/kalman_filter.py +0 -0
  158. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/trackers/utils/matching.py +0 -0
  159. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/__init__.py +0 -0
  160. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/autobatch.py +0 -0
  161. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/benchmarks.py +0 -0
  162. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/callbacks/__init__.py +0 -0
  163. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/callbacks/base.py +0 -0
  164. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/callbacks/clearml.py +0 -0
  165. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/callbacks/comet.py +0 -0
  166. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/callbacks/dvc.py +0 -0
  167. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/callbacks/hub.py +0 -0
  168. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/callbacks/mlflow.py +0 -0
  169. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/callbacks/neptune.py +0 -0
  170. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/callbacks/raytune.py +0 -0
  171. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/callbacks/tensorboard.py +0 -0
  172. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/callbacks/wb.py +0 -0
  173. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/checks.py +0 -0
  174. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/dist.py +0 -0
  175. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/downloads.py +0 -0
  176. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/errors.py +0 -0
  177. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/files.py +0 -0
  178. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/instance.py +0 -0
  179. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/loss.py +0 -0
  180. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/metrics.py +0 -0
  181. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/ops.py +0 -0
  182. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/patches.py +0 -0
  183. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/tal.py +0 -0
  184. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/triton.py +0 -0
  185. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics/utils/tuner.py +0 -0
  186. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics.egg-info/SOURCES.txt +0 -0
  187. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics.egg-info/dependency_links.txt +0 -0
  188. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics.egg-info/entry_points.txt +0 -0
  189. {ultralytics-8.0.227 → ultralytics-8.0.229}/ultralytics.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ultralytics
3
- Version: 8.0.227
3
+ Version: 8.0.229
4
4
  Summary: Ultralytics YOLOv8 for SOTA object detection, multi-object tracking, instance segmentation, pose estimation and image classification.
5
5
  Home-page: https://github.com/ultralytics/ultralytics
6
6
  Author: Ultralytics
@@ -60,6 +60,8 @@ Provides-Extra: export
60
60
  Requires-Dist: coremltools>=7.0; extra == "export"
61
61
  Requires-Dist: openvino-dev>=2023.0; extra == "export"
62
62
  Requires-Dist: tensorflow<=2.13.1; extra == "export"
63
+ Requires-Dist: jax<=0.4.21; extra == "export"
64
+ Requires-Dist: jaxlib<=0.4.21; extra == "export"
63
65
  Requires-Dist: tensorflowjs; extra == "export"
64
66
 
65
67
  <div align="center">
@@ -32,6 +32,8 @@ seaborn>=0.11.0
32
32
  # scikit-learn==0.19.2 # CoreML quantization
33
33
  # tensorflow>=2.4.1,<=2.13.1 # TF exports (-cpu, -aarch64, -macos)
34
34
  # tflite-support
35
+ # jax<=0.4.21 # tensorflowjs bug https://github.com/google/jax/issues/18978
36
+ # jaxlib<=0.4.21 # tensorflowjs bug https://github.com/google/jax/issues/18978
35
37
  # tensorflowjs>=3.9.0 # TF.js export
36
38
  # openvino-dev>=2023.0 # OpenVINO export
37
39
 
@@ -81,6 +81,8 @@ setup(
81
81
  'coremltools>=7.0',
82
82
  'openvino-dev>=2023.0',
83
83
  'tensorflow<=2.13.1', # TF bug https://github.com/ultralytics/ultralytics/issues/5161
84
+ 'jax<=0.4.21', # tensorflowjs bug https://github.com/google/jax/issues/18978
85
+ 'jaxlib<=0.4.21', # tensorflowjs bug https://github.com/google/jax/issues/18978
84
86
  'tensorflowjs', # automatically installs tensorflow
85
87
  ], },
86
88
  classifiers=[
@@ -61,6 +61,7 @@ def test_autobatch():
61
61
  check_train_batch_size(YOLO(MODEL).model.cuda(), imgsz=128, amp=True)
62
62
 
63
63
 
64
+ @pytest.mark.slow
64
65
  @pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
65
66
  def test_utils_benchmarks():
66
67
  """Profile YOLO models for performance benchmarks."""
@@ -511,3 +511,13 @@ def test_model_tune():
511
511
  """Tune YOLO model for performance."""
512
512
  YOLO('yolov8n-pose.pt').tune(data='coco8-pose.yaml', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu')
513
513
  YOLO('yolov8n-cls.pt').tune(data='imagenet10', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu')
514
+
515
+
516
+ def test_model_embeddings():
517
+ """Test YOLO model embeddings."""
518
+ model_detect = YOLO(MODEL)
519
+ model_segment = YOLO(WEIGHTS_DIR / 'yolov8n-seg.pt')
520
+
521
+ for batch in [SOURCE], [SOURCE, SOURCE]: # test batch size 1 and 2
522
+ assert len(model_detect.embed(source=batch, imgsz=32)) == len(batch)
523
+ assert len(model_segment.embed(source=batch, imgsz=32)) == len(batch)
@@ -1,6 +1,6 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- __version__ = '8.0.227'
3
+ __version__ = '8.0.229'
4
4
 
5
5
  from ultralytics.models import RTDETR, SAM, YOLO
6
6
  from ultralytics.models.fastsam import FastSAM
@@ -63,7 +63,7 @@ CLI_HELP_MSG = \
63
63
  """
64
64
 
65
65
  # Define keys for arg type checks
66
- CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear'
66
+ CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear', 'time'
67
67
  CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr',
68
68
  'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud',
69
69
  'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou', 'fraction') # fraction floats 0.0 - 1.0
@@ -8,6 +8,7 @@ mode: train # (str) YOLO mode, i.e. train, val, predict, export, track, benchma
8
8
  model: # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml
9
9
  data: # (str, optional) path to data file, i.e. coco128.yaml
10
10
  epochs: 100 # (int) number of epochs to train for
11
+ time: # (float, optional) number of hours to train for, overrides epochs if supplied
11
12
  patience: 50 # (int) epochs to wait for no observable improvement for early stopping of training
12
13
  batch: 16 # (int) number of images per batch (-1 for AutoBatch)
13
14
  imgsz: 640 # (int | list) input images size as int for train and val modes, or list[w,h] for predict and export modes
@@ -60,6 +61,7 @@ augment: False # (bool) apply image augmentation to prediction sources
60
61
  agnostic_nms: False # (bool) class-agnostic NMS
61
62
  classes: # (int | list[int], optional) filter results by class, i.e. classes=0, or classes=[0,2,3]
62
63
  retina_masks: False # (bool) use high-resolution segmentation masks
64
+ embed: # (list[int], optional) return feature vectors/embeddings from given layers
63
65
 
64
66
  # Visualize settings ---------------------------------------------------------------------------------------------------
65
67
  show: False # (bool) show predicted images and videos if environment allows
@@ -100,7 +100,7 @@ def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):
100
100
  """Return an InfiniteDataLoader or DataLoader for training or validation set."""
101
101
  batch = min(batch, len(dataset))
102
102
  nd = torch.cuda.device_count() # number of CUDA devices
103
- nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers]) # number of workers
103
+ nw = min([os.cpu_count() // max(nd, 1), batch, workers]) # number of workers
104
104
  sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
105
105
  generator = torch.Generator()
106
106
  generator.manual_seed(6148914691236517205 + RANK)
@@ -459,11 +459,14 @@ class Exporter:
459
459
  f'{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from '
460
460
  'https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory '
461
461
  f'or in {ROOT}. See PNNX repo for full installation instructions.')
462
- _, assets = get_github_assets(repo='pnnx/pnnx', retry=True)
463
- system = 'macos' if MACOS else 'ubuntu' if LINUX else 'windows' # operating system
464
- asset = [x for x in assets if system in x][0] if assets else \
465
- f'https://github.com/pnnx/pnnx/releases/download/20230816/pnnx-20230816-{system}.zip' # fallback
466
- asset = attempt_download_asset(asset, repo='pnnx/pnnx', release='latest')
462
+ system = ['macos'] if MACOS else ['windows'] if WINDOWS else ['ubuntu', 'linux'] # operating system
463
+ try:
464
+ _, assets = get_github_assets(repo='pnnx/pnnx', retry=True)
465
+ url = [x for x in assets if any(s in x for s in system)][0]
466
+ except Exception as e:
467
+ url = f'https://github.com/pnnx/pnnx/releases/download/20231127/pnnx-20231127-{system[0]}.zip'
468
+ LOGGER.warning(f'{prefix} WARNING ⚠️ PNNX GitHub assets not found: {e}, using default {url}')
469
+ asset = attempt_download_asset(url, repo='pnnx/pnnx', release='latest')
467
470
  if check_is_path_safe(Path.cwd(), asset): # avoid path traversal security vulnerability
468
471
  unzip_dir = Path(asset).with_suffix('')
469
472
  (unzip_dir / name).rename(pnnx) # move binary to ROOT
@@ -781,7 +784,8 @@ class Exporter:
781
784
  @try_export
782
785
  def export_tfjs(self, prefix=colorstr('TensorFlow.js:')):
783
786
  """YOLOv8 TensorFlow.js export."""
784
- check_requirements('tensorflowjs')
787
+ # JAX bug requiring install constraints in https://github.com/google/jax/issues/18978
788
+ check_requirements(['jax<=0.4.21', 'jaxlib<=0.4.21', 'tensorflowjs'])
785
789
  import tensorflow as tf
786
790
  import tensorflowjs as tfjs # noqa
787
791
 
@@ -795,8 +799,9 @@ class Exporter:
795
799
  outputs = ','.join(gd_outputs(gd))
796
800
  LOGGER.info(f'\n{prefix} output node names: {outputs}')
797
801
 
802
+ quantization = '--quantize_float16' if self.args.half else '--quantize_uint8' if self.args.int8 else ''
798
803
  with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_: # exporter can not handle spaces in path
799
- cmd = f'tensorflowjs_converter --input_format=tf_frozen_model --output_node_names={outputs} "{fpb_}" "{f_}"'
804
+ cmd = f'tensorflowjs_converter --input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"'
800
805
  LOGGER.info(f"{prefix} running '{cmd}'")
801
806
  subprocess.run(cmd, shell=True)
802
807
 
@@ -94,7 +94,7 @@ class Model(nn.Module):
94
94
  self._load(model, task)
95
95
 
96
96
  def __call__(self, source=None, stream=False, **kwargs):
97
- """Calls the 'predict' function with given arguments to perform object detection."""
97
+ """Calls the predict() method with given arguments to perform object detection."""
98
98
  return self.predict(source, stream, **kwargs)
99
99
 
100
100
  @staticmethod
@@ -201,6 +201,24 @@ class Model(nn.Module):
201
201
  self._check_is_pytorch_model()
202
202
  self.model.fuse()
203
203
 
204
+ def embed(self, source=None, stream=False, **kwargs):
205
+ """
206
+ Calls the predict() method and returns image embeddings.
207
+
208
+ Args:
209
+ source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
210
+ Accepts all source types accepted by the YOLO model.
211
+ stream (bool): Whether to stream the predictions or not. Defaults to False.
212
+ **kwargs : Additional keyword arguments passed to the predictor.
213
+ Check the 'configuration' section in the documentation for all available options.
214
+
215
+ Returns:
216
+ (List[torch.Tensor]): A list of image embeddings.
217
+ """
218
+ if not kwargs.get('embed'):
219
+ kwargs['embed'] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
220
+ return self.predict(source, stream, **kwargs)
221
+
204
222
  def predict(self, source=None, stream=False, predictor=None, **kwargs):
205
223
  """
206
224
  Perform prediction using the YOLO model.
@@ -134,7 +134,7 @@ class BasePredictor:
134
134
  """Runs inference on a given image using the specified model and arguments."""
135
135
  visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
136
136
  mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
137
- return self.model(im, augment=self.args.augment, visualize=visualize)
137
+ return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
138
138
 
139
139
  def pre_transform(self, im):
140
140
  """
@@ -263,6 +263,9 @@ class BasePredictor:
263
263
  # Inference
264
264
  with profilers[1]:
265
265
  preds = self.inference(im, *args, **kwargs)
266
+ if self.args.embed:
267
+ yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors
268
+ continue
266
269
 
267
270
  # Postprocess
268
271
  with profilers[2]:
@@ -189,6 +189,14 @@ class BaseTrainer:
189
189
  else:
190
190
  self._do_train(world_size)
191
191
 
192
+ def _setup_scheduler(self):
193
+ """Initialize training learning rate scheduler."""
194
+ if self.args.cos_lr:
195
+ self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
196
+ else:
197
+ self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear
198
+ self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
199
+
192
200
  def _setup_ddp(self, world_size):
193
201
  """Initializes and sets the DistributedDataParallel parameters for training."""
194
202
  torch.cuda.set_device(RANK)
@@ -269,11 +277,7 @@ class BaseTrainer:
269
277
  decay=weight_decay,
270
278
  iterations=iterations)
271
279
  # Scheduler
272
- if self.args.cos_lr:
273
- self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
274
- else:
275
- self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
276
- self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
280
+ self._setup_scheduler()
277
281
  self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
278
282
  self.resume_training(ckpt)
279
283
  self.scheduler.last_epoch = self.start_epoch - 1 # do not move
@@ -285,17 +289,18 @@ class BaseTrainer:
285
289
  self._setup_ddp(world_size)
286
290
  self._setup_train(world_size)
287
291
 
288
- self.epoch_time = None
289
- self.epoch_time_start = time.time()
290
- self.train_time_start = time.time()
291
292
  nb = len(self.train_loader) # number of batches
292
293
  nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
293
294
  last_opt_step = -1
295
+ self.epoch_time = None
296
+ self.epoch_time_start = time.time()
297
+ self.train_time_start = time.time()
294
298
  self.run_callbacks('on_train_start')
295
299
  LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
296
300
  f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
297
301
  f"Logging results to {colorstr('bold', self.save_dir)}\n"
298
- f'Starting training for {self.epochs} epochs...')
302
+ f'Starting training for '
303
+ f'{self.args.time} hours...' if self.args.time else f'{self.epochs} epochs...')
299
304
  if self.args.close_mosaic:
300
305
  base_idx = (self.epochs - self.args.close_mosaic) * nb
301
306
  self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
@@ -323,7 +328,7 @@ class BaseTrainer:
323
328
  ni = i + nb * epoch
324
329
  if ni <= nw:
325
330
  xi = [0, nw] # x interp
326
- self.accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())
331
+ self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
327
332
  for j, x in enumerate(self.optimizer.param_groups):
328
333
  # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
329
334
  x['lr'] = np.interp(
@@ -348,6 +353,16 @@ class BaseTrainer:
348
353
  self.optimizer_step()
349
354
  last_opt_step = ni
350
355
 
356
+ # Timed stopping
357
+ if self.args.time:
358
+ self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600)
359
+ if RANK != -1: # if DDP training
360
+ broadcast_list = [self.stop if RANK == 0 else None]
361
+ dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
362
+ self.stop = broadcast_list[0]
363
+ if self.stop: # training time exceeded
364
+ break
365
+
351
366
  # Log
352
367
  mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
353
368
  loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
@@ -363,31 +378,37 @@ class BaseTrainer:
363
378
  self.run_callbacks('on_train_batch_end')
364
379
 
365
380
  self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
366
-
367
- with warnings.catch_warnings():
368
- warnings.simplefilter('ignore') # suppress 'Detected lr_scheduler.step() before optimizer.step()'
369
- self.scheduler.step()
370
381
  self.run_callbacks('on_train_epoch_end')
371
-
372
382
  if RANK in (-1, 0):
373
-
374
- # Validation
383
+ final_epoch = epoch + 1 == self.epochs
375
384
  self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
376
- final_epoch = (epoch + 1 == self.epochs) or self.stopper.possible_stop
377
385
 
378
- if self.args.val or final_epoch:
386
+ # Validation
387
+ if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
379
388
  self.metrics, self.fitness = self.validate()
380
389
  self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
381
- self.stop = self.stopper(epoch + 1, self.fitness)
390
+ self.stop |= self.stopper(epoch + 1, self.fitness)
391
+ if self.args.time:
392
+ self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600)
382
393
 
383
394
  # Save model
384
- if self.args.save or (epoch + 1 == self.epochs):
395
+ if self.args.save or final_epoch:
385
396
  self.save_model()
386
397
  self.run_callbacks('on_model_save')
387
398
 
388
- tnow = time.time()
389
- self.epoch_time = tnow - self.epoch_time_start
390
- self.epoch_time_start = tnow
399
+ # Scheduler
400
+ t = time.time()
401
+ self.epoch_time = t - self.epoch_time_start
402
+ self.epoch_time_start = t
403
+ with warnings.catch_warnings():
404
+ warnings.simplefilter('ignore') # suppress 'Detected lr_scheduler.step() before optimizer.step()'
405
+ if self.args.time:
406
+ mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
407
+ self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
408
+ self._setup_scheduler()
409
+ self.scheduler.last_epoch = self.epoch # do not move
410
+ self.stop |= epoch >= self.epochs # stop if exceeded epochs
411
+ self.scheduler.step()
391
412
  self.run_callbacks('on_fit_epoch_end')
392
413
  torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
393
414
 
@@ -395,8 +416,7 @@ class BaseTrainer:
395
416
  if RANK != -1: # if DDP training
396
417
  broadcast_list = [self.stop if RANK == 0 else None]
397
418
  dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
398
- if RANK != 0:
399
- self.stop = broadcast_list[0]
419
+ self.stop = broadcast_list[0]
400
420
  if self.stop:
401
421
  break # must break all DDP ranks
402
422
 
@@ -333,7 +333,7 @@ class AutoBackend(nn.Module):
333
333
 
334
334
  self.__dict__.update(locals()) # assign all variables to self
335
335
 
336
- def forward(self, im, augment=False, visualize=False):
336
+ def forward(self, im, augment=False, visualize=False, embed=None):
337
337
  """
338
338
  Runs inference on the YOLOv8 MultiBackend model.
339
339
 
@@ -341,6 +341,7 @@ class AutoBackend(nn.Module):
341
341
  im (torch.Tensor): The image tensor to perform inference on.
342
342
  augment (bool): whether to perform data augmentation during inference, defaults to False
343
343
  visualize (bool): whether to visualize the output predictions, defaults to False
344
+ embed (list, optional): A list of feature vectors/embeddings to return.
344
345
 
345
346
  Returns:
346
347
  (tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True)
@@ -352,7 +353,7 @@ class AutoBackend(nn.Module):
352
353
  im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
353
354
 
354
355
  if self.pt or self.nn_module: # PyTorch
355
- y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
356
+ y = self.model(im, augment=augment, visualize=visualize, embed=embed)
356
357
  elif self.jit: # TorchScript
357
358
  y = self.model(im)
358
359
  elif self.dnn: # ONNX OpenCV DNN
@@ -41,7 +41,7 @@ class BaseModel(nn.Module):
41
41
  return self.loss(x, *args, **kwargs)
42
42
  return self.predict(x, *args, **kwargs)
43
43
 
44
- def predict(self, x, profile=False, visualize=False, augment=False):
44
+ def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
45
45
  """
46
46
  Perform a forward pass through the network.
47
47
 
@@ -50,15 +50,16 @@ class BaseModel(nn.Module):
50
50
  profile (bool): Print the computation time of each layer if True, defaults to False.
51
51
  visualize (bool): Save the feature maps of the model if True, defaults to False.
52
52
  augment (bool): Augment image during prediction, defaults to False.
53
+ embed (list, optional): A list of feature vectors/embeddings to return.
53
54
 
54
55
  Returns:
55
56
  (torch.Tensor): The last output of the model.
56
57
  """
57
58
  if augment:
58
59
  return self._predict_augment(x)
59
- return self._predict_once(x, profile, visualize)
60
+ return self._predict_once(x, profile, visualize, embed)
60
61
 
61
- def _predict_once(self, x, profile=False, visualize=False):
62
+ def _predict_once(self, x, profile=False, visualize=False, embed=None):
62
63
  """
63
64
  Perform a forward pass through the network.
64
65
 
@@ -66,11 +67,12 @@ class BaseModel(nn.Module):
66
67
  x (torch.Tensor): The input tensor to the model.
67
68
  profile (bool): Print the computation time of each layer if True, defaults to False.
68
69
  visualize (bool): Save the feature maps of the model if True, defaults to False.
70
+ embed (list, optional): A list of feature vectors/embeddings to return.
69
71
 
70
72
  Returns:
71
73
  (torch.Tensor): The last output of the model.
72
74
  """
73
- y, dt = [], [] # outputs
75
+ y, dt, embeddings = [], [], [] # outputs
74
76
  for m in self.model:
75
77
  if m.f != -1: # if not from previous layer
76
78
  x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
@@ -80,6 +82,10 @@ class BaseModel(nn.Module):
80
82
  y.append(x if m.i in self.save else None) # save output
81
83
  if visualize:
82
84
  feature_visualization(x, m.type, m.i, save_dir=visualize)
85
+ if embed and m.i in embed:
86
+ embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
87
+ if m.i == max(embed):
88
+ return torch.unbind(torch.cat(embeddings, 1), dim=0)
83
89
  return x
84
90
 
85
91
  def _predict_augment(self, x):
@@ -454,7 +460,7 @@ class RTDETRDetectionModel(DetectionModel):
454
460
  return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']],
455
461
  device=img.device)
456
462
 
457
- def predict(self, x, profile=False, visualize=False, batch=None, augment=False):
463
+ def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
458
464
  """
459
465
  Perform a forward pass through the model.
460
466
 
@@ -464,11 +470,12 @@ class RTDETRDetectionModel(DetectionModel):
464
470
  visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
465
471
  batch (dict, optional): Ground truth data for evaluation. Defaults to None.
466
472
  augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
473
+ embed (list, optional): A list of feature vectors/embeddings to return.
467
474
 
468
475
  Returns:
469
476
  (torch.Tensor): Model's output tensor.
470
477
  """
471
- y, dt = [], [] # outputs
478
+ y, dt, embeddings = [], [], [] # outputs
472
479
  for m in self.model[:-1]: # except the head part
473
480
  if m.f != -1: # if not from previous layer
474
481
  x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
@@ -478,6 +485,10 @@ class RTDETRDetectionModel(DetectionModel):
478
485
  y.append(x if m.i in self.save else None) # save output
479
486
  if visualize:
480
487
  feature_visualization(x, m.type, m.i, save_dir=visualize)
488
+ if embed and m.i in embed:
489
+ embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
490
+ if m.i == max(embed):
491
+ return torch.unbind(torch.cat(embeddings, 1), dim=0)
481
492
  head = self.model[-1]
482
493
  x = head([y[j] for j in head.f], batch) # head inference
483
494
  return x
@@ -62,7 +62,7 @@ class AIGym:
62
62
 
63
63
  def start_counting(self, im0, results, frame_count):
64
64
  """
65
- function used to count the gym steps
65
+ Function used to count the gym steps
66
66
  Args:
67
67
  im0 (ndarray): Current frame from the video stream.
68
68
  results: Pose estimation data