ultralytics 8.1.29__py3-none-any.whl → 8.3.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +36 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +110 -40
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +569 -242
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +190 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +526 -66
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +225 -77
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +160 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +40 -34
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +83 -55
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +302 -102
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.62.dist-info/METADATA +370 -0
  235. ultralytics-8.3.62.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.29.dist-info/METADATA +0 -373
  244. ultralytics-8.1.29.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.29.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
ultralytics/nn/tasks.py CHANGED
@@ -1,6 +1,9 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import contextlib
4
+ import pickle
5
+ import re
6
+ import types
4
7
  from copy import deepcopy
5
8
  from pathlib import Path
6
9
 
@@ -11,18 +14,28 @@ from ultralytics.nn.modules import (
11
14
  AIFI,
12
15
  C1,
13
16
  C2,
17
+ C2PSA,
14
18
  C3,
15
19
  C3TR,
20
+ ELAN1,
16
21
  OBB,
22
+ PSA,
17
23
  SPP,
24
+ SPPELAN,
18
25
  SPPF,
26
+ AConv,
27
+ ADown,
19
28
  Bottleneck,
20
29
  BottleneckCSP,
21
30
  C2f,
22
31
  C2fAttn,
23
- ImagePoolingAttn,
32
+ C2fCIB,
33
+ C2fPSA,
24
34
  C3Ghost,
35
+ C3k2,
25
36
  C3x,
37
+ CBFuse,
38
+ CBLinear,
26
39
  Classify,
27
40
  Concat,
28
41
  Conv,
@@ -36,30 +49,38 @@ from ultralytics.nn.modules import (
36
49
  GhostConv,
37
50
  HGBlock,
38
51
  HGStem,
52
+ ImagePoolingAttn,
53
+ Index,
39
54
  Pose,
40
55
  RepC3,
41
56
  RepConv,
57
+ RepNCSPELAN4,
58
+ RepVGGDW,
42
59
  ResNetLayer,
43
60
  RTDETRDecoder,
61
+ SCDown,
44
62
  Segment,
63
+ TorchVision,
45
64
  WorldDetect,
46
- RepNCSPELAN4,
47
- ADown,
48
- SPPELAN,
49
- CBFuse,
50
- CBLinear,
51
- Silence,
65
+ v10Detect,
52
66
  )
53
67
  from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
54
68
  from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
55
- from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss
69
+ from ultralytics.utils.loss import (
70
+ E2EDetectLoss,
71
+ v8ClassificationLoss,
72
+ v8DetectionLoss,
73
+ v8OBBLoss,
74
+ v8PoseLoss,
75
+ v8SegmentationLoss,
76
+ )
77
+ from ultralytics.utils.ops import make_divisible
56
78
  from ultralytics.utils.plotting import feature_visualization
57
79
  from ultralytics.utils.torch_utils import (
58
80
  fuse_conv_and_bn,
59
81
  fuse_deconv_and_bn,
60
82
  initialize_weights,
61
83
  intersect_dicts,
62
- make_divisible,
63
84
  model_info,
64
85
  scale_img,
65
86
  time_sync,
@@ -76,13 +97,17 @@ class BaseModel(nn.Module):
76
97
 
77
98
  def forward(self, x, *args, **kwargs):
78
99
  """
79
- Forward pass of the model on a single scale. Wrapper for `_forward_once` method.
100
+ Perform forward pass of the model for either training or inference.
101
+
102
+ If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.
80
103
 
81
104
  Args:
82
- x (torch.Tensor | dict): The input image tensor or a dict including image tensor and gt labels.
105
+ x (torch.Tensor | dict): Input tensor for inference, or dict with image tensor and labels for training.
106
+ *args (Any): Variable length argument list.
107
+ **kwargs (Any): Arbitrary keyword arguments.
83
108
 
84
109
  Returns:
85
- (torch.Tensor): The output of the network.
110
+ (torch.Tensor): Loss if x is a dict (training), or network predictions (inference).
86
111
  """
87
112
  if isinstance(x, dict): # for cases of training and validating while training.
88
113
  return self.loss(x, *args, **kwargs)
@@ -138,8 +163,8 @@ class BaseModel(nn.Module):
138
163
  def _predict_augment(self, x):
139
164
  """Perform augmentations on input image x and return augmented inference."""
140
165
  LOGGER.warning(
141
- f"WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. "
142
- f"Reverting to single-scale inference instead."
166
+ f"WARNING ⚠️ {self.__class__.__name__} does not support 'augment=True' prediction. "
167
+ f"Reverting to single-scale prediction."
143
168
  )
144
169
  return self._predict_once(x)
145
170
 
@@ -157,7 +182,7 @@ class BaseModel(nn.Module):
157
182
  None
158
183
  """
159
184
  c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
160
- flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # FLOPs
185
+ flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
161
186
  t = time_sync()
162
187
  for _ in range(10):
163
188
  m(x.copy() if c else x)
@@ -191,6 +216,9 @@ class BaseModel(nn.Module):
191
216
  if isinstance(m, RepConv):
192
217
  m.fuse_convs()
193
218
  m.forward = m.forward_fuse # update forward
219
+ if isinstance(m, RepVGGDW):
220
+ m.fuse()
221
+ m.forward = m.forward_fuse
194
222
  self.info(verbose=verbose)
195
223
 
196
224
  return self
@@ -260,7 +288,7 @@ class BaseModel(nn.Module):
260
288
  batch (dict): Batch to compute loss on
261
289
  preds (torch.Tensor | List[torch.Tensor]): Predictions.
262
290
  """
263
- if not hasattr(self, "criterion"):
291
+ if getattr(self, "criterion", None) is None:
264
292
  self.criterion = self.init_criterion()
265
293
 
266
294
  preds = self.forward(batch["img"]) if preds is None else preds
@@ -278,6 +306,12 @@ class DetectionModel(BaseModel):
278
306
  """Initialize the YOLOv8 detection model with the given config and parameters."""
279
307
  super().__init__()
280
308
  self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
309
+ if self.yaml["backbone"][0][2] == "Silence":
310
+ LOGGER.warning(
311
+ "WARNING ⚠️ YOLOv9 `Silence` module is deprecated in favor of nn.Identity. "
312
+ "Please delete local *.pt file and re-download the latest model checkpoint."
313
+ )
314
+ self.yaml["backbone"][0][2] = "nn.Identity"
281
315
 
282
316
  # Define model
283
317
  ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
@@ -287,14 +321,21 @@ class DetectionModel(BaseModel):
287
321
  self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
288
322
  self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
289
323
  self.inplace = self.yaml.get("inplace", True)
324
+ self.end2end = getattr(self.model[-1], "end2end", False)
290
325
 
291
326
  # Build strides
292
327
  m = self.model[-1] # Detect()
293
328
  if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
294
329
  s = 256 # 2x min stride
295
330
  m.inplace = self.inplace
296
- forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
297
- m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
331
+
332
+ def _forward(x):
333
+ """Performs a forward pass through the model, handling different Detect subclass types accordingly."""
334
+ if self.end2end:
335
+ return self.forward(x)["one2many"]
336
+ return self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
337
+
338
+ m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward
298
339
  self.stride = m.stride
299
340
  m.bias_init() # only run once
300
341
  else:
@@ -308,6 +349,9 @@ class DetectionModel(BaseModel):
308
349
 
309
350
  def _predict_augment(self, x):
310
351
  """Perform augmentations on input image x and return augmented inference and train outputs."""
352
+ if getattr(self, "end2end", False) or self.__class__.__name__ != "DetectionModel":
353
+ LOGGER.warning("WARNING ⚠️ Model does not support 'augment=True', reverting to single-scale prediction.")
354
+ return self._predict_once(x)
311
355
  img_size = x.shape[-2:] # height, width
312
356
  s = [1, 0.83, 0.67] # scales
313
357
  f = [None, 3, None] # flips (2-ud, 3-lr)
@@ -344,7 +388,7 @@ class DetectionModel(BaseModel):
344
388
 
345
389
  def init_criterion(self):
346
390
  """Initialize the loss criterion for the DetectionModel."""
347
- return v8DetectionLoss(self)
391
+ return E2EDetectLoss(self) if getattr(self, "end2end", False) else v8DetectionLoss(self)
348
392
 
349
393
 
350
394
  class OBBModel(DetectionModel):
@@ -425,11 +469,11 @@ class ClassificationModel(BaseModel):
425
469
  elif isinstance(m, nn.Sequential):
426
470
  types = [type(x) for x in m]
427
471
  if nn.Linear in types:
428
- i = types.index(nn.Linear) # nn.Linear index
472
+ i = len(types) - 1 - types[::-1].index(nn.Linear) # last nn.Linear index
429
473
  if m[i].out_features != nc:
430
474
  m[i] = nn.Linear(m[i].in_features, nc)
431
475
  elif nn.Conv2d in types:
432
- i = types.index(nn.Conv2d) # nn.Conv2d index
476
+ i = len(types) - 1 - types[::-1].index(nn.Conv2d) # last nn.Conv2d index
433
477
  if m[i].out_channels != nc:
434
478
  m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
435
479
 
@@ -560,30 +604,32 @@ class WorldModel(DetectionModel):
560
604
 
561
605
  def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
562
606
  """Initialize YOLOv8 world model with given config and parameters."""
563
- self.txt_feats = torch.randn(1, nc or 80, 512) # placeholder
607
+ self.txt_feats = torch.randn(1, nc or 80, 512) # features placeholder
608
+ self.clip_model = None # CLIP model placeholder
564
609
  super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
565
610
 
566
- def set_classes(self, text):
567
- """Perform a forward pass with optional profiling, visualization, and embedding extraction."""
611
+ def set_classes(self, text, batch=80, cache_clip_model=True):
612
+ """Set classes in advance so that model could do offline-inference without clip model."""
568
613
  try:
569
614
  import clip
570
615
  except ImportError:
571
- check_requirements("git+https://github.com/openai/CLIP.git")
616
+ check_requirements("git+https://github.com/ultralytics/CLIP.git")
572
617
  import clip
573
618
 
574
- model, _ = clip.load("ViT-B/32")
619
+ if (
620
+ not getattr(self, "clip_model", None) and cache_clip_model
621
+ ): # for backwards compatibility of models lacking clip_model attribute
622
+ self.clip_model = clip.load("ViT-B/32")[0]
623
+ model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0]
575
624
  device = next(model.parameters()).device
576
625
  text_token = clip.tokenize(text).to(device)
577
- txt_feats = model.encode_text(text_token).to(dtype=torch.float32)
626
+ txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
627
+ txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
578
628
  txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
579
- self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach()
629
+ self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
580
630
  self.model[-1].nc = len(text)
581
631
 
582
- def init_criterion(self):
583
- """Initialize the loss criterion for the model."""
584
- raise NotImplementedError
585
-
586
- def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
632
+ def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
587
633
  """
588
634
  Perform a forward pass through the model.
589
635
 
@@ -591,13 +637,14 @@ class WorldModel(DetectionModel):
591
637
  x (torch.Tensor): The input tensor.
592
638
  profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
593
639
  visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
640
+ txt_feats (torch.Tensor): The text features, use it if it's given. Defaults to None.
594
641
  augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
595
642
  embed (list, optional): A list of feature vectors/embeddings to return.
596
643
 
597
644
  Returns:
598
645
  (torch.Tensor): Model's output tensor.
599
646
  """
600
- txt_feats = self.txt_feats.to(device=x.device, dtype=x.dtype)
647
+ txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype)
601
648
  if len(txt_feats) != len(x):
602
649
  txt_feats = txt_feats.repeat(len(x), 1, 1)
603
650
  ori_txt_feats = txt_feats.clone()
@@ -625,6 +672,21 @@ class WorldModel(DetectionModel):
625
672
  return torch.unbind(torch.cat(embeddings, 1), dim=0)
626
673
  return x
627
674
 
675
+ def loss(self, batch, preds=None):
676
+ """
677
+ Compute loss.
678
+
679
+ Args:
680
+ batch (dict): Batch to compute loss on.
681
+ preds (torch.Tensor | List[torch.Tensor]): Predictions.
682
+ """
683
+ if not hasattr(self, "criterion"):
684
+ self.criterion = self.init_criterion()
685
+
686
+ if preds is None:
687
+ preds = self.forward(batch["img"], txt_feats=batch["txt_feats"])
688
+ return self.criterion(preds, batch)
689
+
628
690
 
629
691
  class Ensemble(nn.ModuleList):
630
692
  """Ensemble of models."""
@@ -646,7 +708,7 @@ class Ensemble(nn.ModuleList):
646
708
 
647
709
 
648
710
  @contextlib.contextmanager
649
- def temporary_modules(modules=None):
711
+ def temporary_modules(modules=None, attributes=None):
650
712
  """
651
713
  Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
652
714
 
@@ -656,11 +718,13 @@ def temporary_modules(modules=None):
656
718
 
657
719
  Args:
658
720
  modules (dict, optional): A dictionary mapping old module paths to new module paths.
721
+ attributes (dict, optional): A dictionary mapping old module attributes to new module attributes.
659
722
 
660
723
  Example:
661
724
  ```python
662
- with temporary_modules({'old.module.path': 'new.module.path'}):
663
- import old.module.path # this will now import new.module.path
725
+ with temporary_modules({"old.module": "new.module"}, {"old.module.attribute": "new.module.attribute"}):
726
+ import old.module # this will now import new.module
727
+ from old.module import attribute # this will now import new.module.attribute
664
728
  ```
665
729
 
666
730
  Note:
@@ -668,16 +732,23 @@ def temporary_modules(modules=None):
668
732
  Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
669
733
  applications or libraries. Use this function with caution.
670
734
  """
671
- if not modules:
735
+ if modules is None:
672
736
  modules = {}
673
-
674
- import importlib
737
+ if attributes is None:
738
+ attributes = {}
675
739
  import sys
740
+ from importlib import import_module
676
741
 
677
742
  try:
743
+ # Set attributes in sys.modules under their old name
744
+ for old, new in attributes.items():
745
+ old_module, old_attr = old.rsplit(".", 1)
746
+ new_module, new_attr = new.rsplit(".", 1)
747
+ setattr(import_module(old_module), old_attr, getattr(import_module(new_module), new_attr))
748
+
678
749
  # Set modules in sys.modules under their old name
679
750
  for old, new in modules.items():
680
- sys.modules[old] = importlib.import_module(new)
751
+ sys.modules[old] = import_module(new)
681
752
 
682
753
  yield
683
754
  finally:
@@ -687,17 +758,58 @@ def temporary_modules(modules=None):
687
758
  del sys.modules[old]
688
759
 
689
760
 
690
- def torch_safe_load(weight):
761
+ class SafeClass:
762
+ """A placeholder class to replace unknown classes during unpickling."""
763
+
764
+ def __init__(self, *args, **kwargs):
765
+ """Initialize SafeClass instance, ignoring all arguments."""
766
+ pass
767
+
768
+ def __call__(self, *args, **kwargs):
769
+ """Run SafeClass instance, ignoring all arguments."""
770
+ pass
771
+
772
+
773
+ class SafeUnpickler(pickle.Unpickler):
774
+ """Custom Unpickler that replaces unknown classes with SafeClass."""
775
+
776
+ def find_class(self, module, name):
777
+ """Attempt to find a class, returning SafeClass if not among safe modules."""
778
+ safe_modules = (
779
+ "torch",
780
+ "collections",
781
+ "collections.abc",
782
+ "builtins",
783
+ "math",
784
+ "numpy",
785
+ # Add other modules considered safe
786
+ )
787
+ if module in safe_modules:
788
+ return super().find_class(module, name)
789
+ else:
790
+ return SafeClass
791
+
792
+
793
+ def torch_safe_load(weight, safe_only=False):
691
794
  """
692
- This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised,
693
- it catches the error, logs a warning message, and attempts to install the missing module via the
694
- check_requirements() function. After installation, the function again attempts to load the model using torch.load().
795
+ Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
796
+ error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
797
+ After installation, the function again attempts to load the model using torch.load().
695
798
 
696
799
  Args:
697
800
  weight (str): The file path of the PyTorch model.
801
+ safe_only (bool): If True, replace unknown classes with SafeClass during loading.
802
+
803
+ Example:
804
+ ```python
805
+ from ultralytics.nn.tasks import torch_safe_load
806
+
807
+ ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True)
808
+ ```
698
809
 
699
810
  Returns:
700
- (dict): The loaded PyTorch model.
811
+ ckpt (dict): The loaded model checkpoint.
812
+ file (str): The loaded filename
701
813
  """
702
814
  from ultralytics.utils.downloads import attempt_download_asset
703
815
 
@@ -705,13 +817,26 @@ def torch_safe_load(weight):
705
817
  file = attempt_download_asset(weight) # search online if missing locally
706
818
  try:
707
819
  with temporary_modules(
708
- {
820
+ modules={
709
821
  "ultralytics.yolo.utils": "ultralytics.utils",
710
822
  "ultralytics.yolo.v8": "ultralytics.models.yolo",
711
823
  "ultralytics.yolo.data": "ultralytics.data",
712
- }
713
- ): # for legacy 8.0 Classify and Pose models
714
- ckpt = torch.load(file, map_location="cpu")
824
+ },
825
+ attributes={
826
+ "ultralytics.nn.modules.block.Silence": "torch.nn.Identity", # YOLOv9e
827
+ "ultralytics.nn.tasks.YOLOv10DetectionModel": "ultralytics.nn.tasks.DetectionModel", # YOLOv10
828
+ "ultralytics.utils.loss.v10DetectLoss": "ultralytics.utils.loss.E2EDetectLoss", # YOLOv10
829
+ },
830
+ ):
831
+ if safe_only:
832
+ # Load via custom pickle module
833
+ safe_pickle = types.ModuleType("safe_pickle")
834
+ safe_pickle.Unpickler = SafeUnpickler
835
+ safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load()
836
+ with open(file, "rb") as f:
837
+ ckpt = torch.load(f, pickle_module=safe_pickle)
838
+ else:
839
+ ckpt = torch.load(file, map_location="cpu")
715
840
 
716
841
  except ModuleNotFoundError as e: # e.name is missing module name
717
842
  if e.name == "models":
@@ -721,14 +846,14 @@ def torch_safe_load(weight):
721
846
  f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "
722
847
  f"YOLOv8 at https://github.com/ultralytics/ultralytics."
723
848
  f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
724
- f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
849
+ f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolov8n.pt'"
725
850
  )
726
851
  ) from e
727
852
  LOGGER.warning(
728
- f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
853
+ f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in Ultralytics requirements."
729
854
  f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
730
855
  f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
731
- f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
856
+ f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolov8n.pt'"
732
857
  )
733
858
  check_requirements(e.name) # install missing module
734
859
  ckpt = torch.load(file, map_location="cpu")
@@ -741,12 +866,11 @@ def torch_safe_load(weight):
741
866
  )
742
867
  ckpt = {"model": ckpt.model}
743
868
 
744
- return ckpt, file # load
869
+ return ckpt, file
745
870
 
746
871
 
747
872
  def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
748
873
  """Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a."""
749
-
750
874
  ensemble = Ensemble()
751
875
  for w in weights if isinstance(weights, list) else [weights]:
752
876
  ckpt, w = torch_safe_load(w) # load ckpt
@@ -814,6 +938,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
814
938
  import ast
815
939
 
816
940
  # Args
941
+ legacy = True # backward compatibility for v3/v5/v8/v9 models
817
942
  max_channels = float("inf")
818
943
  nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
819
944
  depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
@@ -839,9 +964,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
839
964
  if isinstance(a, str):
840
965
  with contextlib.suppress(ValueError):
841
966
  args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
842
-
843
967
  n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
844
- if m in (
968
+ if m in {
845
969
  Classify,
846
970
  Conv,
847
971
  ConvTranspose,
@@ -850,14 +974,19 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
850
974
  GhostBottleneck,
851
975
  SPP,
852
976
  SPPF,
977
+ C2fPSA,
978
+ C2PSA,
853
979
  DWConv,
854
980
  Focus,
855
981
  BottleneckCSP,
856
982
  C1,
857
983
  C2,
858
984
  C2f,
985
+ C3k2,
859
986
  RepNCSPELAN4,
987
+ ELAN1,
860
988
  ADown,
989
+ AConv,
861
990
  SPPELAN,
862
991
  C2fAttn,
863
992
  C3,
@@ -867,7 +996,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
867
996
  DWConvTranspose2d,
868
997
  C3x,
869
998
  RepC3,
870
- ):
999
+ PSA,
1000
+ SCDown,
1001
+ C2fCIB,
1002
+ }:
871
1003
  c1, c2 = ch[f], args[0]
872
1004
  if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
873
1005
  c2 = make_divisible(min(c2, max_channels) * width, 8)
@@ -878,12 +1010,31 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
878
1010
  ) # num heads
879
1011
 
880
1012
  args = [c1, c2, *args[1:]]
881
- if m in (BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3):
1013
+ if m in {
1014
+ BottleneckCSP,
1015
+ C1,
1016
+ C2,
1017
+ C2f,
1018
+ C3k2,
1019
+ C2fAttn,
1020
+ C3,
1021
+ C3TR,
1022
+ C3Ghost,
1023
+ C3x,
1024
+ RepC3,
1025
+ C2fPSA,
1026
+ C2fCIB,
1027
+ C2PSA,
1028
+ }:
882
1029
  args.insert(2, n) # number of repeats
883
1030
  n = 1
1031
+ if m is C3k2: # for M/L/X sizes
1032
+ legacy = False
1033
+ if scale in "mlx":
1034
+ args[3] = True
884
1035
  elif m is AIFI:
885
1036
  args = [ch[f], *args]
886
- elif m in (HGStem, HGBlock):
1037
+ elif m in {HGStem, HGBlock}:
887
1038
  c1, cm, c2 = ch[f], args[0], args[1]
888
1039
  args = [c1, cm, c2, *args[2:]]
889
1040
  if m is HGBlock:
@@ -895,13 +1046,15 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
895
1046
  args = [ch[f]]
896
1047
  elif m is Concat:
897
1048
  c2 = sum(ch[x] for x in f)
898
- elif m in (Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn):
1049
+ elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}:
899
1050
  args.append([ch[x] for x in f])
900
1051
  if m is Segment:
901
1052
  args[2] = make_divisible(min(args[2], max_channels) * width, 8)
1053
+ if m in {Detect, Segment, Pose, OBB}:
1054
+ m.legacy = legacy
902
1055
  elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
903
1056
  args.insert(1, [ch[x] for x in f])
904
- elif m is CBLinear:
1057
+ elif m in {CBLinear, TorchVision, Index}:
905
1058
  c2 = args[0]
906
1059
  c1 = ch[f]
907
1060
  args = [c1, c2, *args[1:]]
@@ -912,10 +1065,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
912
1065
 
913
1066
  m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
914
1067
  t = str(m)[8:-2].replace("__main__.", "") # module type
915
- m.np = sum(x.numel() for x in m_.parameters()) # number params
1068
+ m_.np = sum(x.numel() for x in m_.parameters()) # number params
916
1069
  m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
917
1070
  if verbose:
918
- LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}") # print
1071
+ LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m_.np:10.0f} {t:<45}{str(args):<30}") # print
919
1072
  save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
920
1073
  layers.append(m_)
921
1074
  if i == 0:
@@ -926,8 +1079,6 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
926
1079
 
927
1080
  def yaml_model_load(path):
928
1081
  """Load a YOLOv8 model from a YAML file."""
929
- import re
930
-
931
1082
  path = Path(path)
932
1083
  if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
933
1084
  new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
@@ -954,11 +1105,10 @@ def guess_model_scale(model_path):
954
1105
  Returns:
955
1106
  (str): The size character of the model's scale, which can be n, s, m, l, or x.
956
1107
  """
957
- with contextlib.suppress(AttributeError):
958
- import re
959
-
960
- return re.search(r"yolov\d+([nslmx])", Path(model_path).stem).group(1) # n, s, m, l, or x
961
- return ""
1108
+ try:
1109
+ return re.search(r"yolo[v]?\d+([nslmx])", Path(model_path).stem).group(1) # noqa, returns n, s, m, l, or x
1110
+ except AttributeError:
1111
+ return ""
962
1112
 
963
1113
 
964
1114
  def guess_model_task(model):
@@ -978,9 +1128,9 @@ def guess_model_task(model):
978
1128
  def cfg2task(cfg):
979
1129
  """Guess from YAML dictionary."""
980
1130
  m = cfg["head"][-1][-2].lower() # output module name
981
- if m in ("classify", "classifier", "cls", "fc"):
1131
+ if m in {"classify", "classifier", "cls", "fc"}:
982
1132
  return "classify"
983
- if m == "detect":
1133
+ if "detect" in m:
984
1134
  return "detect"
985
1135
  if m == "segment":
986
1136
  return "segment"
@@ -993,7 +1143,6 @@ def guess_model_task(model):
993
1143
  if isinstance(model, dict):
994
1144
  with contextlib.suppress(Exception):
995
1145
  return cfg2task(model)
996
-
997
1146
  # Guess from PyTorch model
998
1147
  if isinstance(model, nn.Module): # PyTorch model
999
1148
  for x in "model.args", "model.model.args", "model.model.model.args":
@@ -1002,7 +1151,6 @@ def guess_model_task(model):
1002
1151
  for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
1003
1152
  with contextlib.suppress(Exception):
1004
1153
  return cfg2task(eval(x))
1005
-
1006
1154
  for m in model.modules():
1007
1155
  if isinstance(m, Segment):
1008
1156
  return "segment"
@@ -1012,7 +1160,7 @@ def guess_model_task(model):
1012
1160
  return "pose"
1013
1161
  elif isinstance(m, OBB):
1014
1162
  return "obb"
1015
- elif isinstance(m, (Detect, WorldDetect)):
1163
+ elif isinstance(m, (Detect, WorldDetect, v10Detect)):
1016
1164
  return "detect"
1017
1165
 
1018
1166
  # Guess from model filename
@@ -1 +1,30 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from .ai_gym import AIGym
4
+ from .analytics import Analytics
5
+ from .distance_calculation import DistanceCalculation
6
+ from .heatmap import Heatmap
7
+ from .object_counter import ObjectCounter
8
+ from .parking_management import ParkingManagement, ParkingPtsSelection
9
+ from .queue_management import QueueManager
10
+ from .region_counter import RegionCounter
11
+ from .security_alarm import SecurityAlarm
12
+ from .speed_estimation import SpeedEstimator
13
+ from .streamlit_inference import Inference
14
+ from .trackzone import TrackZone
15
+
16
+ __all__ = (
17
+ "AIGym",
18
+ "DistanceCalculation",
19
+ "Heatmap",
20
+ "ObjectCounter",
21
+ "ParkingManagement",
22
+ "ParkingPtsSelection",
23
+ "QueueManager",
24
+ "SpeedEstimator",
25
+ "Analytics",
26
+ "Inference",
27
+ "RegionCounter",
28
+ "TrackZone",
29
+ "SecurityAlarm",
30
+ )