dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.3.248__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (236) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +13 -14
  2. dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
  3. tests/__init__.py +5 -7
  4. tests/conftest.py +8 -15
  5. tests/test_cli.py +1 -1
  6. tests/test_cuda.py +5 -8
  7. tests/test_engine.py +1 -1
  8. tests/test_exports.py +57 -12
  9. tests/test_integrations.py +4 -4
  10. tests/test_python.py +84 -53
  11. tests/test_solutions.py +160 -151
  12. ultralytics/__init__.py +1 -1
  13. ultralytics/cfg/__init__.py +56 -62
  14. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  16. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  17. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  18. ultralytics/cfg/datasets/VOC.yaml +15 -16
  19. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  20. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  21. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  22. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  24. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  25. ultralytics/cfg/datasets/dota8.yaml +2 -2
  26. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  27. ultralytics/cfg/datasets/kitti.yaml +27 -0
  28. ultralytics/cfg/datasets/lvis.yaml +5 -5
  29. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  30. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  31. ultralytics/cfg/datasets/xView.yaml +16 -16
  32. ultralytics/cfg/default.yaml +1 -1
  33. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  34. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  35. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  36. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  37. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  38. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  39. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  40. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  41. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  42. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  43. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  44. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  45. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  46. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  47. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  48. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  49. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  50. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  51. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  52. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  53. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  54. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  55. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  56. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  57. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  58. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  59. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  61. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  62. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  63. ultralytics/data/__init__.py +4 -4
  64. ultralytics/data/annotator.py +3 -4
  65. ultralytics/data/augment.py +285 -475
  66. ultralytics/data/base.py +18 -26
  67. ultralytics/data/build.py +147 -25
  68. ultralytics/data/converter.py +36 -46
  69. ultralytics/data/dataset.py +46 -74
  70. ultralytics/data/loaders.py +42 -49
  71. ultralytics/data/split.py +5 -6
  72. ultralytics/data/split_dota.py +8 -15
  73. ultralytics/data/utils.py +34 -43
  74. ultralytics/engine/exporter.py +319 -237
  75. ultralytics/engine/model.py +148 -188
  76. ultralytics/engine/predictor.py +29 -38
  77. ultralytics/engine/results.py +177 -311
  78. ultralytics/engine/trainer.py +83 -59
  79. ultralytics/engine/tuner.py +23 -34
  80. ultralytics/engine/validator.py +39 -22
  81. ultralytics/hub/__init__.py +16 -19
  82. ultralytics/hub/auth.py +6 -12
  83. ultralytics/hub/google/__init__.py +7 -10
  84. ultralytics/hub/session.py +15 -25
  85. ultralytics/hub/utils.py +5 -8
  86. ultralytics/models/__init__.py +1 -1
  87. ultralytics/models/fastsam/__init__.py +1 -1
  88. ultralytics/models/fastsam/model.py +8 -10
  89. ultralytics/models/fastsam/predict.py +17 -29
  90. ultralytics/models/fastsam/utils.py +1 -2
  91. ultralytics/models/fastsam/val.py +5 -7
  92. ultralytics/models/nas/__init__.py +1 -1
  93. ultralytics/models/nas/model.py +5 -8
  94. ultralytics/models/nas/predict.py +7 -9
  95. ultralytics/models/nas/val.py +1 -2
  96. ultralytics/models/rtdetr/__init__.py +1 -1
  97. ultralytics/models/rtdetr/model.py +5 -8
  98. ultralytics/models/rtdetr/predict.py +15 -19
  99. ultralytics/models/rtdetr/train.py +10 -13
  100. ultralytics/models/rtdetr/val.py +21 -23
  101. ultralytics/models/sam/__init__.py +15 -2
  102. ultralytics/models/sam/amg.py +14 -20
  103. ultralytics/models/sam/build.py +26 -19
  104. ultralytics/models/sam/build_sam3.py +377 -0
  105. ultralytics/models/sam/model.py +29 -32
  106. ultralytics/models/sam/modules/blocks.py +83 -144
  107. ultralytics/models/sam/modules/decoders.py +19 -37
  108. ultralytics/models/sam/modules/encoders.py +44 -101
  109. ultralytics/models/sam/modules/memory_attention.py +16 -30
  110. ultralytics/models/sam/modules/sam.py +200 -73
  111. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  112. ultralytics/models/sam/modules/transformer.py +18 -28
  113. ultralytics/models/sam/modules/utils.py +174 -50
  114. ultralytics/models/sam/predict.py +2248 -350
  115. ultralytics/models/sam/sam3/__init__.py +3 -0
  116. ultralytics/models/sam/sam3/decoder.py +546 -0
  117. ultralytics/models/sam/sam3/encoder.py +529 -0
  118. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  119. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  120. ultralytics/models/sam/sam3/model_misc.py +199 -0
  121. ultralytics/models/sam/sam3/necks.py +129 -0
  122. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  123. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  124. ultralytics/models/sam/sam3/vitdet.py +547 -0
  125. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  126. ultralytics/models/utils/loss.py +14 -26
  127. ultralytics/models/utils/ops.py +13 -17
  128. ultralytics/models/yolo/__init__.py +1 -1
  129. ultralytics/models/yolo/classify/predict.py +9 -12
  130. ultralytics/models/yolo/classify/train.py +11 -32
  131. ultralytics/models/yolo/classify/val.py +29 -28
  132. ultralytics/models/yolo/detect/predict.py +7 -10
  133. ultralytics/models/yolo/detect/train.py +11 -20
  134. ultralytics/models/yolo/detect/val.py +70 -58
  135. ultralytics/models/yolo/model.py +36 -53
  136. ultralytics/models/yolo/obb/predict.py +5 -14
  137. ultralytics/models/yolo/obb/train.py +11 -14
  138. ultralytics/models/yolo/obb/val.py +39 -36
  139. ultralytics/models/yolo/pose/__init__.py +1 -1
  140. ultralytics/models/yolo/pose/predict.py +6 -21
  141. ultralytics/models/yolo/pose/train.py +10 -15
  142. ultralytics/models/yolo/pose/val.py +38 -57
  143. ultralytics/models/yolo/segment/predict.py +14 -18
  144. ultralytics/models/yolo/segment/train.py +3 -6
  145. ultralytics/models/yolo/segment/val.py +93 -45
  146. ultralytics/models/yolo/world/train.py +8 -14
  147. ultralytics/models/yolo/world/train_world.py +11 -34
  148. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  149. ultralytics/models/yolo/yoloe/predict.py +16 -23
  150. ultralytics/models/yolo/yoloe/train.py +30 -43
  151. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  152. ultralytics/models/yolo/yoloe/val.py +15 -20
  153. ultralytics/nn/__init__.py +7 -7
  154. ultralytics/nn/autobackend.py +145 -77
  155. ultralytics/nn/modules/__init__.py +60 -60
  156. ultralytics/nn/modules/activation.py +4 -6
  157. ultralytics/nn/modules/block.py +132 -216
  158. ultralytics/nn/modules/conv.py +52 -97
  159. ultralytics/nn/modules/head.py +50 -103
  160. ultralytics/nn/modules/transformer.py +76 -88
  161. ultralytics/nn/modules/utils.py +16 -21
  162. ultralytics/nn/tasks.py +94 -154
  163. ultralytics/nn/text_model.py +40 -67
  164. ultralytics/solutions/__init__.py +12 -12
  165. ultralytics/solutions/ai_gym.py +11 -17
  166. ultralytics/solutions/analytics.py +15 -16
  167. ultralytics/solutions/config.py +5 -6
  168. ultralytics/solutions/distance_calculation.py +10 -13
  169. ultralytics/solutions/heatmap.py +7 -13
  170. ultralytics/solutions/instance_segmentation.py +5 -8
  171. ultralytics/solutions/object_blurrer.py +7 -10
  172. ultralytics/solutions/object_counter.py +12 -19
  173. ultralytics/solutions/object_cropper.py +8 -14
  174. ultralytics/solutions/parking_management.py +33 -31
  175. ultralytics/solutions/queue_management.py +10 -12
  176. ultralytics/solutions/region_counter.py +9 -12
  177. ultralytics/solutions/security_alarm.py +15 -20
  178. ultralytics/solutions/similarity_search.py +10 -15
  179. ultralytics/solutions/solutions.py +75 -74
  180. ultralytics/solutions/speed_estimation.py +7 -10
  181. ultralytics/solutions/streamlit_inference.py +2 -4
  182. ultralytics/solutions/templates/similarity-search.html +7 -18
  183. ultralytics/solutions/trackzone.py +7 -10
  184. ultralytics/solutions/vision_eye.py +5 -8
  185. ultralytics/trackers/__init__.py +1 -1
  186. ultralytics/trackers/basetrack.py +3 -5
  187. ultralytics/trackers/bot_sort.py +10 -27
  188. ultralytics/trackers/byte_tracker.py +14 -30
  189. ultralytics/trackers/track.py +3 -6
  190. ultralytics/trackers/utils/gmc.py +11 -22
  191. ultralytics/trackers/utils/kalman_filter.py +37 -48
  192. ultralytics/trackers/utils/matching.py +12 -15
  193. ultralytics/utils/__init__.py +116 -116
  194. ultralytics/utils/autobatch.py +2 -4
  195. ultralytics/utils/autodevice.py +17 -18
  196. ultralytics/utils/benchmarks.py +32 -46
  197. ultralytics/utils/callbacks/base.py +8 -10
  198. ultralytics/utils/callbacks/clearml.py +5 -13
  199. ultralytics/utils/callbacks/comet.py +32 -46
  200. ultralytics/utils/callbacks/dvc.py +13 -18
  201. ultralytics/utils/callbacks/mlflow.py +4 -5
  202. ultralytics/utils/callbacks/neptune.py +7 -15
  203. ultralytics/utils/callbacks/platform.py +314 -38
  204. ultralytics/utils/callbacks/raytune.py +3 -4
  205. ultralytics/utils/callbacks/tensorboard.py +23 -31
  206. ultralytics/utils/callbacks/wb.py +10 -13
  207. ultralytics/utils/checks.py +99 -76
  208. ultralytics/utils/cpu.py +3 -8
  209. ultralytics/utils/dist.py +8 -12
  210. ultralytics/utils/downloads.py +20 -30
  211. ultralytics/utils/errors.py +6 -14
  212. ultralytics/utils/events.py +2 -4
  213. ultralytics/utils/export/__init__.py +4 -236
  214. ultralytics/utils/export/engine.py +237 -0
  215. ultralytics/utils/export/imx.py +91 -55
  216. ultralytics/utils/export/tensorflow.py +231 -0
  217. ultralytics/utils/files.py +24 -28
  218. ultralytics/utils/git.py +9 -11
  219. ultralytics/utils/instance.py +30 -51
  220. ultralytics/utils/logger.py +212 -114
  221. ultralytics/utils/loss.py +14 -22
  222. ultralytics/utils/metrics.py +126 -155
  223. ultralytics/utils/nms.py +13 -16
  224. ultralytics/utils/ops.py +107 -165
  225. ultralytics/utils/patches.py +33 -21
  226. ultralytics/utils/plotting.py +72 -80
  227. ultralytics/utils/tal.py +25 -39
  228. ultralytics/utils/torch_utils.py +52 -78
  229. ultralytics/utils/tqdm.py +20 -20
  230. ultralytics/utils/triton.py +13 -19
  231. ultralytics/utils/tuner.py +17 -5
  232. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  233. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
  234. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
  235. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
  236. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
@@ -19,8 +19,7 @@ from .utils.kalman_filter import KalmanFilterXYWH
19
19
 
20
20
 
21
21
  class BOTrack(STrack):
22
- """
23
- An extended version of the STrack class for YOLO, adding object tracking features.
22
+ """An extended version of the STrack class for YOLO, adding object tracking features.
24
23
 
25
24
  This class extends the STrack class to include additional functionalities for object tracking, such as feature
26
25
  smoothing, Kalman filter prediction, and reactivation of tracks.
@@ -46,9 +45,9 @@ class BOTrack(STrack):
46
45
 
47
46
  Examples:
48
47
  Create a BOTrack instance and update its features
49
- >>> bo_track = BOTrack(tlwh=[100, 50, 80, 40], score=0.9, cls=1, feat=np.random.rand(128))
48
+ >>> bo_track = BOTrack(xywh=np.array([100, 50, 80, 40, 0]), score=0.9, cls=1, feat=np.random.rand(128))
50
49
  >>> bo_track.predict()
51
- >>> new_track = BOTrack(tlwh=[110, 60, 80, 40], score=0.85, cls=1, feat=np.random.rand(128))
50
+ >>> new_track = BOTrack(xywh=np.array([110, 60, 80, 40, 0]), score=0.85, cls=1, feat=np.random.rand(128))
52
51
  >>> bo_track.update(new_track, frame_id=2)
53
52
  """
54
53
 
@@ -57,23 +56,15 @@ class BOTrack(STrack):
57
56
  def __init__(
58
57
  self, xywh: np.ndarray, score: float, cls: int, feat: np.ndarray | None = None, feat_history: int = 50
59
58
  ):
60
- """
61
- Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features.
59
+ """Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features.
62
60
 
63
61
  Args:
64
- xywh (np.ndarray): Bounding box coordinates in xywh format (center x, center y, width, height).
62
+ xywh (np.ndarray): Bounding box in `(x, y, w, h, idx)` or `(x, y, w, h, angle, idx)` format, where (x, y) is
63
+ the center, (w, h) are width and height, and `idx` is the detection index.
65
64
  score (float): Confidence score of the detection.
66
65
  cls (int): Class ID of the detected object.
67
66
  feat (np.ndarray, optional): Feature vector associated with the detection.
68
67
  feat_history (int): Maximum length of the feature history deque.
69
-
70
- Examples:
71
- Initialize a BOTrack object with bounding box, score, class ID, and feature vector
72
- >>> xywh = np.array([100, 150, 60, 50])
73
- >>> score = 0.9
74
- >>> cls = 1
75
- >>> feat = np.random.rand(128)
76
- >>> bo_track = BOTrack(xywh, score, cls, feat)
77
68
  """
78
69
  super().__init__(xywh, score, cls)
79
70
 
@@ -154,8 +145,7 @@ class BOTrack(STrack):
154
145
 
155
146
 
156
147
  class BOTSORT(BYTETracker):
157
- """
158
- An extended version of the BYTETracker class for YOLO, designed for object tracking with ReID and GMC algorithm.
148
+ """An extended version of the BYTETracker class for YOLO, designed for object tracking with ReID and GMC algorithm.
159
149
 
160
150
  Attributes:
161
151
  proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections.
@@ -177,22 +167,16 @@ class BOTSORT(BYTETracker):
177
167
  >>> bot_sort.init_track(dets, scores, cls, img)
178
168
  >>> bot_sort.multi_predict(tracks)
179
169
 
180
- Note:
170
+ Notes:
181
171
  The class is designed to work with a YOLO object detection model and supports ReID only if enabled via args.
182
172
  """
183
173
 
184
174
  def __init__(self, args: Any, frame_rate: int = 30):
185
- """
186
- Initialize BOTSORT object with ReID module and GMC algorithm.
175
+ """Initialize BOTSORT object with ReID module and GMC algorithm.
187
176
 
188
177
  Args:
189
178
  args (Any): Parsed command-line arguments containing tracking parameters.
190
179
  frame_rate (int): Frame rate of the video being processed.
191
-
192
- Examples:
193
- Initialize BOTSORT with command-line arguments and a specified frame rate:
194
- >>> args = parse_args()
195
- >>> bot_sort = BOTSORT(args, frame_rate=30)
196
180
  """
197
181
  super().__init__(args, frame_rate)
198
182
  self.gmc = GMC(method=args.gmc_method)
@@ -253,8 +237,7 @@ class ReID:
253
237
  """YOLO model as encoder for re-identification."""
254
238
 
255
239
  def __init__(self, model: str):
256
- """
257
- Initialize encoder for re-identification.
240
+ """Initialize encoder for re-identification.
258
241
 
259
242
  Args:
260
243
  model (str): Path to the YOLO model for re-identification.
@@ -14,8 +14,7 @@ from .utils.kalman_filter import KalmanFilterXYAH
14
14
 
15
15
 
16
16
  class STrack(BaseTrack):
17
- """
18
- Single object tracking representation that uses Kalman filtering for state estimation.
17
+ """Single object tracking representation that uses Kalman filtering for state estimation.
19
18
 
20
19
  This class is responsible for storing all the information regarding individual tracklets and performs state updates
21
20
  and predictions based on Kalman filter.
@@ -54,20 +53,13 @@ class STrack(BaseTrack):
54
53
  shared_kalman = KalmanFilterXYAH()
55
54
 
56
55
  def __init__(self, xywh: list[float], score: float, cls: Any):
57
- """
58
- Initialize a new STrack instance.
56
+ """Initialize a new STrack instance.
59
57
 
60
58
  Args:
61
- xywh (list[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where
62
- (x, y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id.
59
+ xywh (list[float]): Bounding box in `(x, y, w, h, idx)` or `(x, y, w, h, angle, idx)` format, where (x, y)
60
+ is the center, (w, h) are width and height, and `idx` is the detection index.
63
61
  score (float): Confidence score of the detection.
64
62
  cls (Any): Class label for the detected object.
65
-
66
- Examples:
67
- >>> xywh = [100.0, 150.0, 50.0, 75.0, 1]
68
- >>> score = 0.9
69
- >>> cls = "person"
70
- >>> track = STrack(xywh, score, cls)
71
63
  """
72
64
  super().__init__()
73
65
  # xywh+idx or xywha+idx
@@ -154,8 +146,7 @@ class STrack(BaseTrack):
154
146
  self.idx = new_track.idx
155
147
 
156
148
  def update(self, new_track: STrack, frame_id: int):
157
- """
158
- Update the state of a matched track.
149
+ """Update the state of a matched track.
159
150
 
160
151
  Args:
161
152
  new_track (STrack): The new track containing updated information.
@@ -230,7 +221,7 @@ class STrack(BaseTrack):
230
221
  def result(self) -> list[float]:
231
222
  """Get the current tracking results in the appropriate bounding box format."""
232
223
  coords = self.xyxy if self.angle is None else self.xywha
233
- return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
224
+ return [*coords.tolist(), self.track_id, self.score, self.cls, self.idx]
234
225
 
235
226
  def __repr__(self) -> str:
236
227
  """Return a string representation of the STrack object including start frame, end frame, and track ID."""
@@ -238,12 +229,11 @@ class STrack(BaseTrack):
238
229
 
239
230
 
240
231
  class BYTETracker:
241
- """
242
- BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
232
+ """BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
243
233
 
244
- This class encapsulates the functionality for initializing, updating, and managing the tracks for detected objects in a
245
- video sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for
246
- predicting the new object locations, and performs data association.
234
+ This class encapsulates the functionality for initializing, updating, and managing the tracks for detected objects
235
+ in a video sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman
236
+ filtering for predicting the new object locations, and performs data association.
247
237
 
248
238
  Attributes:
249
239
  tracked_stracks (list[STrack]): List of successfully activated tracks.
@@ -274,17 +264,11 @@ class BYTETracker:
274
264
  """
275
265
 
276
266
  def __init__(self, args, frame_rate: int = 30):
277
- """
278
- Initialize a BYTETracker instance for object tracking.
267
+ """Initialize a BYTETracker instance for object tracking.
279
268
 
280
269
  Args:
281
270
  args (Namespace): Command-line arguments containing tracking parameters.
282
271
  frame_rate (int): Frame rate of the video sequence.
283
-
284
- Examples:
285
- Initialize BYTETracker with command-line arguments and a frame rate of 30
286
- >>> args = Namespace(track_buffer=30)
287
- >>> tracker = BYTETracker(args, frame_rate=30)
288
272
  """
289
273
  self.tracked_stracks = [] # type: list[STrack]
290
274
  self.lost_stracks = [] # type: list[STrack]
@@ -354,9 +338,9 @@ class BYTETracker:
354
338
  # Step 3: Second association, with low score detection boxes association the untrack to the low score detections
355
339
  detections_second = self.init_track(results_second, feats_second)
356
340
  r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
357
- # TODO
341
+ # TODO: consider fusing scores or appearance features for second association.
358
342
  dists = matching.iou_distance(r_tracked_stracks, detections_second)
359
- matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5)
343
+ matches, u_track, _u_detection_second = matching.linear_assignment(dists, thresh=0.5)
360
344
  for itracked, idet in matches:
361
345
  track = r_tracked_stracks[itracked]
362
346
  det = detections_second[idet]
@@ -405,7 +389,7 @@ class BYTETracker:
405
389
  self.tracked_stracks, self.lost_stracks = self.remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
406
390
  self.removed_stracks.extend(removed_stracks)
407
391
  if len(self.removed_stracks) > 1000:
408
- self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum
392
+ self.removed_stracks = self.removed_stracks[-1000:] # clip removed stracks to 1000 maximum
409
393
 
410
394
  return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
411
395
 
@@ -16,8 +16,7 @@ TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT}
16
16
 
17
17
 
18
18
  def on_predict_start(predictor: object, persist: bool = False) -> None:
19
- """
20
- Initialize trackers for object tracking during prediction.
19
+ """Initialize trackers for object tracking during prediction.
21
20
 
22
21
  Args:
23
22
  predictor (ultralytics.engine.predictor.BasePredictor): The predictor object to initialize trackers for.
@@ -70,8 +69,7 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
70
69
 
71
70
 
72
71
  def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None:
73
- """
74
- Postprocess detected boxes and update with object tracking.
72
+ """Postprocess detected boxes and update with object tracking.
75
73
 
76
74
  Args:
77
75
  predictor (object): The predictor object containing the predictions.
@@ -103,8 +101,7 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
103
101
 
104
102
 
105
103
  def register_tracker(model: object, persist: bool) -> None:
106
- """
107
- Register tracking callbacks to the model for object tracking during prediction.
104
+ """Register tracking callbacks to the model for object tracking during prediction.
108
105
 
109
106
  Args:
110
107
  model (object): The model object to register tracking callbacks for.
@@ -11,8 +11,7 @@ from ultralytics.utils import LOGGER
11
11
 
12
12
 
13
13
  class GMC:
14
- """
15
- Generalized Motion Compensation (GMC) class for tracking and object detection in video frames.
14
+ """Generalized Motion Compensation (GMC) class for tracking and object detection in video frames.
16
15
 
17
16
  This class provides methods for tracking and detecting objects based on several tracking algorithms including ORB,
18
17
  SIFT, ECC, and Sparse Optical Flow. It also supports downscaling of frames for computational efficiency.
@@ -35,24 +34,18 @@ class GMC:
35
34
  Examples:
36
35
  Create a GMC object and apply it to a frame
37
36
  >>> gmc = GMC(method="sparseOptFlow", downscale=2)
38
- >>> frame = np.array([[1, 2, 3], [4, 5, 6]])
39
- >>> processed_frame = gmc.apply(frame)
40
- >>> print(processed_frame)
41
- array([[1, 2, 3],
42
- [4, 5, 6]])
37
+ >>> frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
38
+ >>> warp = gmc.apply(frame)
39
+ >>> print(warp.shape)
40
+ (2, 3)
43
41
  """
44
42
 
45
43
  def __init__(self, method: str = "sparseOptFlow", downscale: int = 2) -> None:
46
- """
47
- Initialize a Generalized Motion Compensation (GMC) object with tracking method and downscale factor.
44
+ """Initialize a Generalized Motion Compensation (GMC) object with tracking method and downscale factor.
48
45
 
49
46
  Args:
50
47
  method (str): The tracking method to use. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
51
48
  downscale (int): Downscale factor for processing frames.
52
-
53
- Examples:
54
- Initialize a GMC object with the 'sparseOptFlow' method and a downscale factor of 2
55
- >>> gmc = GMC(method="sparseOptFlow", downscale=2)
56
49
  """
57
50
  super().__init__()
58
51
 
@@ -91,8 +84,7 @@ class GMC:
91
84
  self.initializedFirstFrame = False
92
85
 
93
86
  def apply(self, raw_frame: np.ndarray, detections: list | None = None) -> np.ndarray:
94
- """
95
- Apply object detection on a raw frame using the specified method.
87
+ """Estimate a 2×3 motion compensation warp for a frame.
96
88
 
97
89
  Args:
98
90
  raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
@@ -118,8 +110,7 @@ class GMC:
118
110
  return np.eye(2, 3)
119
111
 
120
112
  def apply_ecc(self, raw_frame: np.ndarray) -> np.ndarray:
121
- """
122
- Apply the ECC (Enhanced Correlation Coefficient) algorithm to a raw frame for motion compensation.
113
+ """Apply the ECC (Enhanced Correlation Coefficient) algorithm to a raw frame for motion compensation.
123
114
 
124
115
  Args:
125
116
  raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
@@ -153,13 +144,12 @@ class GMC:
153
144
  try:
154
145
  (_, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
155
146
  except Exception as e:
156
- LOGGER.warning(f"find transform failed. Set warp as identity {e}")
147
+ LOGGER.warning(f"findTransformECC failed; using identity warp. {e}")
157
148
 
158
149
  return H
159
150
 
160
151
  def apply_features(self, raw_frame: np.ndarray, detections: list | None = None) -> np.ndarray:
161
- """
162
- Apply feature-based methods like ORB or SIFT to a raw frame.
152
+ """Apply feature-based methods like ORB or SIFT to a raw frame.
163
153
 
164
154
  Args:
165
155
  raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
@@ -276,8 +266,7 @@ class GMC:
276
266
  return H
277
267
 
278
268
  def apply_sparseoptflow(self, raw_frame: np.ndarray) -> np.ndarray:
279
- """
280
- Apply Sparse Optical Flow method to a raw frame.
269
+ """Apply Sparse Optical Flow method to a raw frame.
281
270
 
282
271
  Args:
283
272
  raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
@@ -5,11 +5,10 @@ import scipy.linalg
5
5
 
6
6
 
7
7
  class KalmanFilterXYAH:
8
- """
9
- A KalmanFilterXYAH class for tracking bounding boxes in image space using a Kalman filter.
8
+ """A KalmanFilterXYAH class for tracking bounding boxes in image space using a Kalman filter.
10
9
 
11
- Implements a simple Kalman filter for tracking bounding boxes in image space. The 8-dimensional state space
12
- (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect ratio a, height h, and their
10
+ Implements a simple Kalman filter for tracking bounding boxes in image space. The 8-dimensional state space (x, y,
11
+ a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect ratio a, height h, and their
13
12
  respective velocities. Object motion follows a constant velocity model, and bounding box location (x, y, a, h) is
14
13
  taken as a direct observation of the state space (linear observation model).
15
14
 
@@ -37,17 +36,12 @@ class KalmanFilterXYAH:
37
36
  """
38
37
 
39
38
  def __init__(self):
40
- """
41
- Initialize Kalman filter model matrices with motion and observation uncertainty weights.
39
+ """Initialize Kalman filter model matrices with motion and observation uncertainty weights.
42
40
 
43
41
  The Kalman filter is initialized with an 8-dimensional state space (x, y, a, h, vx, vy, va, vh), where (x, y)
44
42
  represents the bounding box center position, 'a' is the aspect ratio, 'h' is the height, and their respective
45
43
  velocities are (vx, vy, va, vh). The filter uses a constant velocity model for object motion and a linear
46
44
  observation model for bounding box location.
47
-
48
- Examples:
49
- Initialize a Kalman filter for tracking:
50
- >>> kf = KalmanFilterXYAH()
51
45
  """
52
46
  ndim, dt = 4, 1.0
53
47
 
@@ -62,15 +56,15 @@ class KalmanFilterXYAH:
62
56
  self._std_weight_velocity = 1.0 / 160
63
57
 
64
58
  def initiate(self, measurement: np.ndarray):
65
- """
66
- Create a track from an unassociated measurement.
59
+ """Create a track from an unassociated measurement.
67
60
 
68
61
  Args:
69
62
  measurement (np.ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a,
70
63
  and height h.
71
64
 
72
65
  Returns:
73
- mean (np.ndarray): Mean vector (8-dimensional) of the new track. Unobserved velocities are initialized to 0 mean.
66
+ mean (np.ndarray): Mean vector (8-dimensional) of the new track. Unobserved velocities are initialized to 0
67
+ mean.
74
68
  covariance (np.ndarray): Covariance matrix (8x8 dimensional) of the new track.
75
69
 
76
70
  Examples:
@@ -96,12 +90,12 @@ class KalmanFilterXYAH:
96
90
  return mean, covariance
97
91
 
98
92
  def predict(self, mean: np.ndarray, covariance: np.ndarray):
99
- """
100
- Run Kalman filter prediction step.
93
+ """Run Kalman filter prediction step.
101
94
 
102
95
  Args:
103
96
  mean (np.ndarray): The 8-dimensional mean vector of the object state at the previous time step.
104
- covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.
97
+ covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time
98
+ step.
105
99
 
106
100
  Returns:
107
101
  mean (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.
@@ -133,8 +127,7 @@ class KalmanFilterXYAH:
133
127
  return mean, covariance
134
128
 
135
129
  def project(self, mean: np.ndarray, covariance: np.ndarray):
136
- """
137
- Project state distribution to measurement space.
130
+ """Project state distribution to measurement space.
138
131
 
139
132
  Args:
140
133
  mean (np.ndarray): The state's mean vector (8 dimensional array).
@@ -163,8 +156,7 @@ class KalmanFilterXYAH:
163
156
  return mean, covariance + innovation_cov
164
157
 
165
158
  def multi_predict(self, mean: np.ndarray, covariance: np.ndarray):
166
- """
167
- Run Kalman filter prediction step for multiple object states (Vectorized version).
159
+ """Run Kalman filter prediction step for multiple object states (Vectorized version).
168
160
 
169
161
  Args:
170
162
  mean (np.ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
@@ -175,9 +167,10 @@ class KalmanFilterXYAH:
175
167
  covariance (np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8).
176
168
 
177
169
  Examples:
170
+ >>> kf = KalmanFilterXYAH()
178
171
  >>> mean = np.random.rand(10, 8) # 10 object states
179
172
  >>> covariance = np.random.rand(10, 8, 8) # Covariance matrices for 10 object states
180
- >>> predicted_mean, predicted_covariance = kalman_filter.multi_predict(mean, covariance)
173
+ >>> predicted_mean, predicted_covariance = kf.multi_predict(mean, covariance)
181
174
  """
182
175
  std_pos = [
183
176
  self._std_weight_position * mean[:, 3],
@@ -203,8 +196,7 @@ class KalmanFilterXYAH:
203
196
  return mean, covariance
204
197
 
205
198
  def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray):
206
- """
207
- Run Kalman filter correction step.
199
+ """Run Kalman filter correction step.
208
200
 
209
201
  Args:
210
202
  mean (np.ndarray): The predicted state's mean vector (8 dimensional).
@@ -243,8 +235,7 @@ class KalmanFilterXYAH:
243
235
  only_position: bool = False,
244
236
  metric: str = "maha",
245
237
  ) -> np.ndarray:
246
- """
247
- Compute gating distance between state distribution and measurements.
238
+ """Compute gating distance between state distribution and measurements.
248
239
 
249
240
  A suitable distance threshold can be obtained from `chi2inv95`. If `only_position` is False, the chi-square
250
241
  distribution has 4 degrees of freedom, otherwise 2.
@@ -252,11 +243,12 @@ class KalmanFilterXYAH:
252
243
  Args:
253
244
  mean (np.ndarray): Mean vector over the state distribution (8 dimensional).
254
245
  covariance (np.ndarray): Covariance of the state distribution (8x8 dimensional).
255
- measurements (np.ndarray): An (N, 4) matrix of N measurements, each in format (x, y, a, h) where (x, y) is the
256
- bounding box center position, a the aspect ratio, and h the height.
257
- only_position (bool, optional): If True, distance computation is done with respect to box center position only.
258
- metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the squared
259
- Euclidean distance and 'maha' for the squared Mahalanobis distance.
246
+ measurements (np.ndarray): An (N, 4) matrix of N measurements, each in format (x, y, a, h) where (x, y) is
247
+ the bounding box center position, a the aspect ratio, and h the height.
248
+ only_position (bool, optional): If True, distance computation is done with respect to box center position
249
+ only.
250
+ metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the
251
+ squared Euclidean distance and 'maha' for the squared Mahalanobis distance.
260
252
 
261
253
  Returns:
262
254
  (np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between
@@ -287,12 +279,11 @@ class KalmanFilterXYAH:
287
279
 
288
280
 
289
281
  class KalmanFilterXYWH(KalmanFilterXYAH):
290
- """
291
- A KalmanFilterXYWH class for tracking bounding boxes in image space using a Kalman filter.
282
+ """A KalmanFilterXYWH class for tracking bounding boxes in image space using a Kalman filter.
292
283
 
293
- Implements a Kalman filter for tracking bounding boxes with state space (x, y, w, h, vx, vy, vw, vh), where
294
- (x, y) is the center position, w is the width, h is the height, and vx, vy, vw, vh are their respective velocities.
295
- The object motion follows a constant velocity model, and the bounding box location (x, y, w, h) is taken as a direct
284
+ Implements a Kalman filter for tracking bounding boxes with state space (x, y, w, h, vx, vy, vw, vh), where (x, y)
285
+ is the center position, w is the width, h is the height, and vx, vy, vw, vh are their respective velocities. The
286
+ object motion follows a constant velocity model, and the bounding box location (x, y, w, h) is taken as a direct
296
287
  observation of the state space (linear observation model).
297
288
 
298
289
  Attributes:
@@ -318,14 +309,15 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
318
309
  """
319
310
 
320
311
  def initiate(self, measurement: np.ndarray):
321
- """
322
- Create track from unassociated measurement.
312
+ """Create track from unassociated measurement.
323
313
 
324
314
  Args:
325
- measurement (np.ndarray): Bounding box coordinates (x, y, w, h) with center position (x, y), width, and height.
315
+ measurement (np.ndarray): Bounding box coordinates (x, y, w, h) with center position (x, y), width, and
316
+ height.
326
317
 
327
318
  Returns:
328
- mean (np.ndarray): Mean vector (8 dimensional) of the new track. Unobserved velocities are initialized to 0 mean.
319
+ mean (np.ndarray): Mean vector (8 dimensional) of the new track. Unobserved velocities are initialized to 0
320
+ mean.
329
321
  covariance (np.ndarray): Covariance matrix (8x8 dimensional) of the new track.
330
322
 
331
323
  Examples:
@@ -362,12 +354,12 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
362
354
  return mean, covariance
363
355
 
364
356
  def predict(self, mean: np.ndarray, covariance: np.ndarray):
365
- """
366
- Run Kalman filter prediction step.
357
+ """Run Kalman filter prediction step.
367
358
 
368
359
  Args:
369
360
  mean (np.ndarray): The 8-dimensional mean vector of the object state at the previous time step.
370
- covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.
361
+ covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time
362
+ step.
371
363
 
372
364
  Returns:
373
365
  mean (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.
@@ -399,8 +391,7 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
399
391
  return mean, covariance
400
392
 
401
393
  def project(self, mean: np.ndarray, covariance: np.ndarray):
402
- """
403
- Project state distribution to measurement space.
394
+ """Project state distribution to measurement space.
404
395
 
405
396
  Args:
406
397
  mean (np.ndarray): The state's mean vector (8 dimensional array).
@@ -429,8 +420,7 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
429
420
  return mean, covariance + innovation_cov
430
421
 
431
422
  def multi_predict(self, mean: np.ndarray, covariance: np.ndarray):
432
- """
433
- Run Kalman filter prediction step (Vectorized version).
423
+ """Run Kalman filter prediction step (Vectorized version).
434
424
 
435
425
  Args:
436
426
  mean (np.ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
@@ -470,8 +460,7 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
470
460
  return mean, covariance
471
461
 
472
462
  def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray):
473
- """
474
- Run Kalman filter correction step.
463
+ """Run Kalman filter correction step.
475
464
 
476
465
  Args:
477
466
  mean (np.ndarray): The predicted state's mean vector (8 dimensional).
@@ -18,8 +18,7 @@ except (ImportError, AssertionError, AttributeError):
18
18
 
19
19
 
20
20
  def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True):
21
- """
22
- Perform linear assignment using either the scipy or lap.lapjv method.
21
+ """Perform linear assignment using either the scipy or lap.lapjv method.
23
22
 
24
23
  Args:
25
24
  cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).
@@ -27,9 +26,10 @@ def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = Tr
27
26
  use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used.
28
27
 
29
28
  Returns:
30
- matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches.
31
- unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,).
32
- unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,).
29
+ matched_indices (list[list[int]] | np.ndarray): Matched indices of shape (K, 2), where K is the number of
30
+ matches.
31
+ unmatched_a (np.ndarray): Unmatched indices from the first set, with shape (L,).
32
+ unmatched_b (np.ndarray): Unmatched indices from the second set, with shape (M,).
33
33
 
34
34
  Examples:
35
35
  >>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
@@ -62,8 +62,7 @@ def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = Tr
62
62
 
63
63
 
64
64
  def iou_distance(atracks: list, btracks: list) -> np.ndarray:
65
- """
66
- Compute cost based on Intersection over Union (IoU) between tracks.
65
+ """Compute cost based on Intersection over Union (IoU) between tracks.
67
66
 
68
67
  Args:
69
68
  atracks (list[STrack] | list[np.ndarray]): List of tracks 'a' or bounding boxes.
@@ -78,7 +77,7 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray:
78
77
  >>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])]
79
78
  >>> cost_matrix = iou_distance(atracks, btracks)
80
79
  """
81
- if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray):
80
+ if (atracks and isinstance(atracks[0], np.ndarray)) or (btracks and isinstance(btracks[0], np.ndarray)):
82
81
  atlbrs = atracks
83
82
  btlbrs = btracks
84
83
  else:
@@ -102,8 +101,7 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray:
102
101
 
103
102
 
104
103
  def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -> np.ndarray:
105
- """
106
- Compute distance between tracks and detections based on embeddings.
104
+ """Compute distance between tracks and detections based on embeddings.
107
105
 
108
106
  Args:
109
107
  tracks (list[STrack]): List of tracks, where each track contains embedding features.
@@ -111,8 +109,8 @@ def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -
111
109
  metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc.
112
110
 
113
111
  Returns:
114
- (np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks
115
- and M is the number of detections.
112
+ (np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks and M
113
+ is the number of detections.
116
114
 
117
115
  Examples:
118
116
  Compute the embedding distance between tracks and detections using cosine metric
@@ -132,8 +130,7 @@ def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -
132
130
 
133
131
 
134
132
  def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray:
135
- """
136
- Fuse cost matrix with detection scores to produce a single similarity matrix.
133
+ """Fuse cost matrix with detection scores to produce a single similarity matrix.
137
134
 
138
135
  Args:
139
136
  cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).
@@ -152,6 +149,6 @@ def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray:
152
149
  return cost_matrix
153
150
  iou_sim = 1 - cost_matrix
154
151
  det_scores = np.array([det.score for det in detections])
155
- det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
152
+ det_scores = det_scores[None].repeat(cost_matrix.shape[0], axis=0)
156
153
  fuse_sim = iou_sim * det_scores
157
154
  return 1 - fuse_sim # fuse_cost