ultralytics-opencv-headless 8.3.253__py3-none-any.whl → 8.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tests/__init__.py +2 -2
  2. tests/conftest.py +1 -1
  3. tests/test_cuda.py +8 -2
  4. tests/test_engine.py +6 -6
  5. tests/test_exports.py +10 -3
  6. tests/test_integrations.py +9 -9
  7. tests/test_python.py +14 -14
  8. tests/test_solutions.py +3 -3
  9. ultralytics/__init__.py +1 -1
  10. ultralytics/cfg/__init__.py +6 -6
  11. ultralytics/cfg/default.yaml +3 -1
  12. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  13. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  14. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  15. ultralytics/cfg/models/26/yolo26-p6.yaml +60 -0
  16. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  17. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  18. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  19. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  20. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  21. ultralytics/data/augment.py +7 -0
  22. ultralytics/data/dataset.py +1 -1
  23. ultralytics/engine/exporter.py +10 -3
  24. ultralytics/engine/model.py +1 -1
  25. ultralytics/engine/trainer.py +40 -15
  26. ultralytics/engine/tuner.py +15 -7
  27. ultralytics/models/fastsam/predict.py +1 -1
  28. ultralytics/models/yolo/detect/train.py +3 -2
  29. ultralytics/models/yolo/detect/val.py +6 -0
  30. ultralytics/models/yolo/model.py +1 -1
  31. ultralytics/models/yolo/obb/predict.py +1 -1
  32. ultralytics/models/yolo/obb/train.py +1 -1
  33. ultralytics/models/yolo/pose/train.py +1 -1
  34. ultralytics/models/yolo/segment/predict.py +1 -1
  35. ultralytics/models/yolo/segment/train.py +1 -1
  36. ultralytics/models/yolo/segment/val.py +3 -1
  37. ultralytics/models/yolo/yoloe/train.py +6 -1
  38. ultralytics/models/yolo/yoloe/train_seg.py +6 -1
  39. ultralytics/nn/autobackend.py +7 -3
  40. ultralytics/nn/modules/__init__.py +8 -0
  41. ultralytics/nn/modules/block.py +127 -8
  42. ultralytics/nn/modules/head.py +818 -205
  43. ultralytics/nn/tasks.py +74 -29
  44. ultralytics/nn/text_model.py +5 -2
  45. ultralytics/optim/__init__.py +5 -0
  46. ultralytics/optim/muon.py +338 -0
  47. ultralytics/utils/benchmarks.py +1 -0
  48. ultralytics/utils/callbacks/platform.py +9 -7
  49. ultralytics/utils/downloads.py +3 -1
  50. ultralytics/utils/export/engine.py +19 -10
  51. ultralytics/utils/export/imx.py +22 -11
  52. ultralytics/utils/export/tensorflow.py +1 -41
  53. ultralytics/utils/loss.py +584 -203
  54. ultralytics/utils/metrics.py +1 -0
  55. ultralytics/utils/ops.py +11 -2
  56. ultralytics/utils/tal.py +98 -19
  57. {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/METADATA +31 -39
  58. {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/RECORD +62 -51
  59. {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/WHEEL +0 -0
  60. {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/entry_points.txt +0 -0
  61. {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/licenses/LICENSE +0 -0
  62. {ultralytics_opencv_headless-8.3.253.dist-info → ultralytics_opencv_headless-8.4.0.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,7 @@ from ultralytics.utils import NOT_MACOS14
15
15
  from ultralytics.utils.tal import dist2bbox, dist2rbox, make_anchors
16
16
  from ultralytics.utils.torch_utils import TORCH_1_11, fuse_conv_and_bn, smart_inference_mode
17
17
 
18
- from .block import DFL, SAVPE, BNContrastiveHead, ContrastiveHead, Proto, Residual, SwiGLUFFN
18
+ from .block import DFL, SAVPE, BNContrastiveHead, ContrastiveHead, Proto, Proto26, RealNVP, Residual, SwiGLUFFN
19
19
  from .conv import Conv, DWConv
20
20
  from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
21
21
  from .utils import bias_init_with_prob, linear_init
@@ -68,7 +68,6 @@ class Detect(nn.Module):
68
68
  dynamic = False # force grid reconstruction
69
69
  export = False # export mode
70
70
  format = None # export format
71
- end2end = False # end2end
72
71
  max_det = 300 # max_det
73
72
  shape = None
74
73
  anchors = torch.empty(0) # init
@@ -76,17 +75,19 @@ class Detect(nn.Module):
76
75
  legacy = False # backward compatibility for v3/v5/v8/v9 models
77
76
  xyxy = False # xyxy or xywh output
78
77
 
79
- def __init__(self, nc: int = 80, ch: tuple = ()):
78
+ def __init__(self, nc: int = 80, reg_max=16, end2end=False, ch: tuple = ()):
80
79
  """Initialize the YOLO detection layer with specified number of classes and channels.
81
80
 
82
81
  Args:
83
82
  nc (int): Number of classes.
83
+ reg_max (int): Maximum number of DFL channels.
84
+ end2end (bool): Whether to use end-to-end NMS-free detection.
84
85
  ch (tuple): Tuple of channel sizes from backbone feature maps.
85
86
  """
86
87
  super().__init__()
87
88
  self.nc = nc # number of classes
88
89
  self.nl = len(ch) # number of detection layers
89
- self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
90
+ self.reg_max = reg_max # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
90
91
  self.no = nc + self.reg_max * 4 # number of outputs per anchor
91
92
  self.stride = torch.zeros(self.nl) # strides computed during build
92
93
  c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
@@ -107,77 +108,98 @@ class Detect(nn.Module):
107
108
  )
108
109
  self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
109
110
 
110
- if self.end2end:
111
+ if end2end:
111
112
  self.one2one_cv2 = copy.deepcopy(self.cv2)
112
113
  self.one2one_cv3 = copy.deepcopy(self.cv3)
113
114
 
114
- def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor] | tuple:
115
- """Concatenate and return predicted bounding boxes and class probabilities."""
115
+ @property
116
+ def one2many(self):
117
+ """Returns the one-to-many head components, here for v5/v5/v8/v9/11 backward compatibility."""
118
+ return dict(box_head=self.cv2, cls_head=self.cv3)
119
+
120
+ @property
121
+ def one2one(self):
122
+ """Returns the one-to-one head components."""
123
+ return dict(box_head=self.one2one_cv2, cls_head=self.one2one_cv3)
124
+
125
+ @property
126
+ def end2end(self):
127
+ """Checks if the model has one2one for v5/v5/v8/v9/11 backward compatibility."""
128
+ return hasattr(self, "one2one")
129
+
130
+ def forward_head(
131
+ self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None
132
+ ) -> dict[str, torch.Tensor]:
133
+ """Concatenates and returns predicted bounding boxes and class probabilities."""
134
+ if box_head is None or cls_head is None: # for fused inference
135
+ return dict()
136
+ bs = x[0].shape[0] # batch size
137
+ boxes = torch.cat([box_head[i](x[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)], dim=-1)
138
+ scores = torch.cat([cls_head[i](x[i]).view(bs, self.nc, -1) for i in range(self.nl)], dim=-1)
139
+ return dict(boxes=boxes, scores=scores, feats=x)
140
+
141
+ def forward(
142
+ self, x: list[torch.Tensor]
143
+ ) -> dict[str, torch.Tensor] | torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
144
+ """Concatenates and returns predicted bounding boxes and class probabilities."""
145
+ preds = self.forward_head(x, **self.one2many)
116
146
  if self.end2end:
117
- return self.forward_end2end(x)
118
-
119
- for i in range(self.nl):
120
- x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
121
- if self.training: # Training path
122
- return x
123
- y = self._inference(x)
124
- return y if self.export else (y, x)
125
-
126
- def forward_end2end(self, x: list[torch.Tensor]) -> dict | tuple:
127
- """Perform forward pass of the v10Detect module.
128
-
129
- Args:
130
- x (list[torch.Tensor]): Input feature maps from different levels.
131
-
132
- Returns:
133
- outputs (dict | tuple): Training mode returns dict with one2many and one2one outputs. Inference mode returns
134
- processed detections or tuple with detections and raw outputs.
135
- """
136
- x_detach = [xi.detach() for xi in x]
137
- one2one = [
138
- torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)
139
- ]
140
- for i in range(self.nl):
141
- x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
142
- if self.training: # Training path
143
- return {"one2many": x, "one2one": one2one}
144
-
145
- y = self._inference(one2one)
146
- y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
147
- return y if self.export else (y, {"one2many": x, "one2one": one2one})
147
+ x_detach = [xi.detach() for xi in x]
148
+ one2one = self.forward_head(x_detach, **self.one2one)
149
+ preds = {"one2many": preds, "one2one": one2one}
150
+ if self.training:
151
+ return preds
152
+ y = self._inference(preds["one2one"] if self.end2end else preds)
153
+ if self.end2end:
154
+ y = self.postprocess(y.permute(0, 2, 1))
155
+ return y if self.export else (y, preds)
148
156
 
149
- def _inference(self, x: list[torch.Tensor]) -> torch.Tensor:
157
+ def _inference(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
150
158
  """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.
151
159
 
152
160
  Args:
153
- x (list[torch.Tensor]): List of feature maps from different detection layers.
161
+ x (dict[str, torch.Tensor]): List of feature maps from different detection layers.
154
162
 
155
163
  Returns:
156
164
  (torch.Tensor): Concatenated tensor of decoded bounding boxes and class probabilities.
157
165
  """
158
166
  # Inference path
159
- shape = x[0].shape # BCHW
160
- x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
161
- if self.dynamic or self.shape != shape:
162
- self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
167
+ dbox = self._get_decode_boxes(x)
168
+ return torch.cat((dbox, x["scores"].sigmoid()), 1)
169
+
170
+ def _get_decode_boxes(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
171
+ """Get decoded boxes based on anchors and strides."""
172
+ shape = x["feats"][0].shape # BCHW
173
+ if self.format != "imx" and (self.dynamic or self.shape != shape):
174
+ self.anchors, self.strides = (a.transpose(0, 1) for a in make_anchors(x["feats"], self.stride, 0.5))
163
175
  self.shape = shape
164
176
 
165
- box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
166
- dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
167
- return torch.cat((dbox, cls.sigmoid()), 1)
177
+ boxes = x["boxes"]
178
+ if self.export and self.format in {"tflite", "edgetpu"}:
179
+ # Precompute normalization factor to increase numerical stability
180
+ # See https://github.com/ultralytics/ultralytics/issues/7371
181
+ grid_h = shape[2]
182
+ grid_w = shape[3]
183
+ grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=boxes.device).reshape(1, 4, 1)
184
+ norm = self.strides / (self.stride[0] * grid_size)
185
+ dbox = self.decode_bboxes(self.dfl(boxes) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
186
+ else:
187
+ dbox = self.decode_bboxes(self.dfl(boxes), self.anchors.unsqueeze(0)) * self.strides
188
+ return dbox
168
189
 
169
190
  def bias_init(self):
170
191
  """Initialize Detect() biases, WARNING: requires stride availability."""
171
- m = self # self.model[-1] # Detect() module
172
- # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
173
- # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
174
- for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
175
- a[-1].bias.data[:] = 1.0 # box
176
- b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
192
+ for i, (a, b) in enumerate(zip(self.one2many["box_head"], self.one2many["cls_head"])): # from
193
+ a[-1].bias.data[:] = 2.0 # box
194
+ b[-1].bias.data[: self.nc] = math.log(
195
+ 5 / self.nc / (640 / self.stride[i]) ** 2
196
+ ) # cls (.01 objects, 80 classes, 640 img)
177
197
  if self.end2end:
178
- for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from
179
- a[-1].bias.data[:] = 1.0 # box
180
- b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
198
+ for i, (a, b) in enumerate(zip(self.one2one["box_head"], self.one2one["cls_head"])): # from
199
+ a[-1].bias.data[:] = 2.0 # box
200
+ b[-1].bias.data[: self.nc] = math.log(
201
+ 5 / self.nc / (640 / self.stride[i]) ** 2
202
+ ) # cls (.01 objects, 80 classes, 640 img)
181
203
 
182
204
  def decode_bboxes(self, bboxes: torch.Tensor, anchors: torch.Tensor, xywh: bool = True) -> torch.Tensor:
183
205
  """Decode bounding boxes from predictions."""
@@ -188,28 +210,45 @@ class Detect(nn.Module):
188
210
  dim=1,
189
211
  )
190
212
 
191
- @staticmethod
192
- def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80) -> torch.Tensor:
193
- """Post-process YOLO model predictions.
213
+ def postprocess(self, preds: torch.Tensor) -> torch.Tensor:
214
+ """Post-processes YOLO model predictions.
194
215
 
195
216
  Args:
196
217
  preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
197
218
  format [x, y, w, h, class_probs].
198
- max_det (int): Maximum detections per image.
199
- nc (int, optional): Number of classes.
200
219
 
201
220
  Returns:
202
221
  (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last
203
222
  dimension format [x, y, w, h, max_class_prob, class_index].
204
223
  """
205
- batch_size, anchors, _ = preds.shape # i.e. shape(16,8400,84)
206
- boxes, scores = preds.split([4, nc], dim=-1)
207
- index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
208
- boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
209
- scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
210
- scores, index = scores.flatten(1).topk(min(max_det, anchors))
211
- i = torch.arange(batch_size)[..., None] # batch indices
212
- return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)
224
+ boxes, scores = preds.split([4, self.nc], dim=-1)
225
+ scores, conf, idx = self.get_topk_index(scores, self.max_det)
226
+ boxes = boxes.gather(dim=1, index=idx.repeat(1, 1, 4))
227
+ return torch.cat([boxes, scores, conf], dim=-1)
228
+
229
+ def get_topk_index(self, scores: torch.Tensor, max_det: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
230
+ """Get top-k indices from scores.
231
+
232
+ Args:
233
+ scores (torch.Tensor): Scores tensor with shape (batch_size, num_anchors, num_classes).
234
+ max_det (int): Maximum detections per image.
235
+
236
+ Returns:
237
+ (torch.Tensor, torch.Tensor, torch.Tensor): Top scores, class indices, and filtered indices.
238
+ """
239
+ batch_size, anchors, nc = scores.shape # i.e. shape(16,8400,84)
240
+ # Use max_det directly during export for TensorRT compatibility (requires k to be constant),
241
+ # otherwise use min(max_det, anchors) for safety with small inputs during Python inference
242
+ k = max_det if self.export else min(max_det, anchors)
243
+ ori_index = scores.max(dim=-1)[0].topk(k)[1].unsqueeze(-1)
244
+ scores = scores.gather(dim=1, index=ori_index.repeat(1, 1, nc))
245
+ scores, index = scores.flatten(1).topk(k)
246
+ idx = ori_index[torch.arange(batch_size)[..., None], index // nc] # original index
247
+ return scores[..., None], (index % nc)[..., None].float(), idx
248
+
249
+ def fuse(self) -> None:
250
+ """Remove the one2many head for inference optimization."""
251
+ self.cv2 = self.cv3 = None
213
252
 
214
253
 
215
254
  class Segment(Detect):
@@ -233,33 +272,146 @@ class Segment(Detect):
233
272
  >>> outputs = segment(x)
234
273
  """
235
274
 
236
- def __init__(self, nc: int = 80, nm: int = 32, npr: int = 256, ch: tuple = ()):
275
+ def __init__(self, nc: int = 80, nm: int = 32, npr: int = 256, reg_max=16, end2end=False, ch: tuple = ()):
237
276
  """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.
238
277
 
239
278
  Args:
240
279
  nc (int): Number of classes.
241
280
  nm (int): Number of masks.
242
281
  npr (int): Number of protos.
282
+ reg_max (int): Maximum number of DFL channels.
283
+ end2end (bool): Whether to use end-to-end NMS-free detection.
243
284
  ch (tuple): Tuple of channel sizes from backbone feature maps.
244
285
  """
245
- super().__init__(nc, ch)
286
+ super().__init__(nc, reg_max, end2end, ch)
246
287
  self.nm = nm # number of masks
247
288
  self.npr = npr # number of protos
248
289
  self.proto = Proto(ch[0], self.npr, self.nm) # protos
249
290
 
250
291
  c4 = max(ch[0] // 4, self.nm)
251
292
  self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
293
+ if end2end:
294
+ self.one2one_cv4 = copy.deepcopy(self.cv4)
295
+
296
+ @property
297
+ def one2many(self):
298
+ """Returns the one-to-many head components, here for backward compatibility."""
299
+ return dict(box_head=self.cv2, cls_head=self.cv3, mask_head=self.cv4)
300
+
301
+ @property
302
+ def one2one(self):
303
+ """Returns the one-to-one head components."""
304
+ return dict(box_head=self.one2one_cv2, cls_head=self.one2one_cv3, mask_head=self.one2one_cv4)
252
305
 
253
- def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor]:
306
+ def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor] | dict[str, torch.Tensor]:
254
307
  """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
255
- p = self.proto(x[0]) # mask protos
256
- bs = p.shape[0] # batch size
308
+ outputs = super().forward(x)
309
+ preds = outputs[1] if isinstance(outputs, tuple) else outputs
310
+ proto = self.proto(x[0]) # mask protos
311
+ if isinstance(preds, dict): # training and validating during training
312
+ if self.end2end:
313
+ preds["one2many"]["proto"] = proto
314
+ preds["one2one"]["proto"] = proto.detach()
315
+ else:
316
+ preds["proto"] = proto
317
+ if self.training:
318
+ return preds
319
+ return (outputs, proto) if self.export else ((outputs[0], proto), preds)
320
+
321
+ def _inference(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
322
+ """Decode predicted bounding boxes and class probabilities, concatenated with mask coefficients."""
323
+ preds = super()._inference(x)
324
+ return torch.cat([preds, x["mask_coefficient"]], dim=1)
325
+
326
+ def forward_head(
327
+ self, x: list[torch.Tensor], box_head: torch.nn.Module, cls_head: torch.nn.Module, mask_head: torch.nn.Module
328
+ ) -> torch.Tensor:
329
+ """Concatenates and returns predicted bounding boxes, class probabilities, and mask coefficients."""
330
+ preds = super().forward_head(x, box_head, cls_head)
331
+ if mask_head is not None:
332
+ bs = x[0].shape[0] # batch size
333
+ preds["mask_coefficient"] = torch.cat([mask_head[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2)
334
+ return preds
335
+
336
+ def postprocess(self, preds: torch.Tensor) -> torch.Tensor:
337
+ """Post-process YOLO model predictions.
338
+
339
+ Args:
340
+ preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc + nm) with last dimension
341
+ format [x, y, w, h, class_probs, mask_coefficient].
257
342
 
258
- mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
259
- x = Detect.forward(self, x)
343
+ Returns:
344
+ (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6 + nm) and last
345
+ dimension format [x, y, w, h, max_class_prob, class_index, mask_coefficient].
346
+ """
347
+ boxes, scores, mask_coefficient = preds.split([4, self.nc, self.nm], dim=-1)
348
+ scores, conf, idx = self.get_topk_index(scores, self.max_det)
349
+ boxes = boxes.gather(dim=1, index=idx.repeat(1, 1, 4))
350
+ mask_coefficient = mask_coefficient.gather(dim=1, index=idx.repeat(1, 1, self.nm))
351
+ return torch.cat([boxes, scores, conf, mask_coefficient], dim=-1)
352
+
353
+ def fuse(self) -> None:
354
+ """Remove the one2many head for inference optimization."""
355
+ self.cv2 = self.cv3 = self.cv4 = None
356
+
357
+
358
+ class Segment26(Segment):
359
+ """YOLO26 Segment head for segmentation models.
360
+
361
+ This class extends the Detect head to include mask prediction capabilities for instance segmentation tasks.
362
+
363
+ Attributes:
364
+ nm (int): Number of masks.
365
+ npr (int): Number of protos.
366
+ proto (Proto): Prototype generation module.
367
+ cv4 (nn.ModuleList): Convolution layers for mask coefficients.
368
+
369
+ Methods:
370
+ forward: Return model outputs and mask coefficients.
371
+
372
+ Examples:
373
+ Create a segmentation head
374
+ >>> segment = Segment26(nc=80, nm=32, npr=256, ch=(256, 512, 1024))
375
+ >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
376
+ >>> outputs = segment(x)
377
+ """
378
+
379
+ def __init__(self, nc: int = 80, nm: int = 32, npr: int = 256, reg_max=16, end2end=False, ch: tuple = ()):
380
+ """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.
381
+
382
+ Args:
383
+ nc (int): Number of classes.
384
+ nm (int): Number of masks.
385
+ npr (int): Number of protos.
386
+ reg_max (int): Maximum number of DFL channels.
387
+ end2end (bool): Whether to use end-to-end NMS-free detection.
388
+ ch (tuple): Tuple of channel sizes from backbone feature maps.
389
+ """
390
+ super().__init__(nc, nm, npr, reg_max, end2end, ch)
391
+ self.proto = Proto26(ch, self.npr, self.nm, nc) # protos
392
+
393
+ def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor] | dict[str, torch.Tensor]:
394
+ """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
395
+ outputs = Detect.forward(self, x)
396
+ preds = outputs[1] if isinstance(outputs, tuple) else outputs
397
+ proto = self.proto(x) # mask protos
398
+ if isinstance(preds, dict): # training and validating during training
399
+ if self.end2end:
400
+ preds["one2many"]["proto"] = proto
401
+ preds["one2one"]["proto"] = (
402
+ tuple(p.detach() for p in proto) if isinstance(proto, tuple) else proto.detach()
403
+ )
404
+ else:
405
+ preds["proto"] = proto
260
406
  if self.training:
261
- return x, mc, p
262
- return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
407
+ return preds
408
+ return (outputs, proto) if self.export else ((outputs[0], proto), preds)
409
+
410
+ def fuse(self) -> None:
411
+ """Remove the one2many head and extra part of proto module for inference optimization."""
412
+ super().fuse()
413
+ if hasattr(self.proto, "fuse"):
414
+ self.proto.fuse()
263
415
 
264
416
 
265
417
  class OBB(Detect):
@@ -283,38 +435,114 @@ class OBB(Detect):
283
435
  >>> outputs = obb(x)
284
436
  """
285
437
 
286
- def __init__(self, nc: int = 80, ne: int = 1, ch: tuple = ()):
438
+ def __init__(self, nc: int = 80, ne: int = 1, reg_max=16, end2end=False, ch: tuple = ()):
287
439
  """Initialize OBB with number of classes `nc` and layer channels `ch`.
288
440
 
289
441
  Args:
290
442
  nc (int): Number of classes.
291
443
  ne (int): Number of extra parameters.
444
+ reg_max (int): Maximum number of DFL channels.
445
+ end2end (bool): Whether to use end-to-end NMS-free detection.
292
446
  ch (tuple): Tuple of channel sizes from backbone feature maps.
293
447
  """
294
- super().__init__(nc, ch)
448
+ super().__init__(nc, reg_max, end2end, ch)
295
449
  self.ne = ne # number of extra parameters
296
450
 
297
451
  c4 = max(ch[0] // 4, self.ne)
298
452
  self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)
299
-
300
- def forward(self, x: list[torch.Tensor]) -> torch.Tensor | tuple:
301
- """Concatenate and return predicted bounding boxes and class probabilities."""
302
- bs = x[0].shape[0] # batch size
303
- angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits
304
- # NOTE: set `angle` as an attribute so that `decode_bboxes` could use it.
305
- angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4]
306
- # angle = angle.sigmoid() * math.pi / 2 # [0, pi/2]
307
- if not self.training:
308
- self.angle = angle
309
- x = Detect.forward(self, x)
310
- if self.training:
311
- return x, angle
312
- return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))
453
+ if end2end:
454
+ self.one2one_cv4 = copy.deepcopy(self.cv4)
455
+
456
+ @property
457
+ def one2many(self):
458
+ """Returns the one-to-many head components, here for backward compatibility."""
459
+ return dict(box_head=self.cv2, cls_head=self.cv3, angle_head=self.cv4)
460
+
461
+ @property
462
+ def one2one(self):
463
+ """Returns the one-to-one head components."""
464
+ return dict(box_head=self.one2one_cv2, cls_head=self.one2one_cv3, angle_head=self.one2one_cv4)
465
+
466
+ def _inference(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
467
+ """Decode predicted bounding boxes and class probabilities, concatenated with rotation angles."""
468
+ # For decode_bboxes convenience
469
+ self.angle = x["angle"] # TODO: need to test obb
470
+ preds = super()._inference(x)
471
+ return torch.cat([preds, x["angle"]], dim=1)
472
+
473
+ def forward_head(
474
+ self, x: list[torch.Tensor], box_head: torch.nn.Module, cls_head: torch.nn.Module, angle_head: torch.nn.Module
475
+ ) -> torch.Tensor:
476
+ """Concatenates and returns predicted bounding boxes, class probabilities, and angles."""
477
+ preds = super().forward_head(x, box_head, cls_head)
478
+ if angle_head is not None:
479
+ bs = x[0].shape[0] # batch size
480
+ angle = torch.cat(
481
+ [angle_head[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2
482
+ ) # OBB theta logits
483
+ angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4]
484
+ preds["angle"] = angle
485
+ return preds
313
486
 
314
487
  def decode_bboxes(self, bboxes: torch.Tensor, anchors: torch.Tensor) -> torch.Tensor:
315
488
  """Decode rotated bounding boxes."""
316
489
  return dist2rbox(bboxes, self.angle, anchors, dim=1)
317
490
 
491
+ def postprocess(self, preds: torch.Tensor) -> torch.Tensor:
492
+ """Post-process YOLO model predictions.
493
+
494
+ Args:
495
+ preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc + ne) with last dimension
496
+ format [x, y, w, h, class_probs, angle].
497
+
498
+ Returns:
499
+ (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 7) and last
500
+ dimension format [x, y, w, h, max_class_prob, class_index, angle].
501
+ """
502
+ boxes, scores, angle = preds.split([4, self.nc, self.ne], dim=-1)
503
+ scores, conf, idx = self.get_topk_index(scores, self.max_det)
504
+ boxes = boxes.gather(dim=1, index=idx.repeat(1, 1, 4))
505
+ angle = angle.gather(dim=1, index=idx.repeat(1, 1, self.ne))
506
+ return torch.cat([boxes, scores, conf, angle], dim=-1)
507
+
508
+ def fuse(self) -> None:
509
+ """Remove the one2many head for inference optimization."""
510
+ self.cv2 = self.cv3 = self.cv4 = None
511
+
512
+
513
+ class OBB26(OBB):
514
+ """YOLO26 OBB detection head for detection with rotation models. This class extends the OBB head with modified angle
515
+ processing that outputs raw angle predictions without sigmoid transformation, compared to the original
516
+ OBB class.
517
+
518
+ Attributes:
519
+ ne (int): Number of extra parameters.
520
+ cv4 (nn.ModuleList): Convolution layers for angle prediction.
521
+ angle (torch.Tensor): Predicted rotation angles.
522
+
523
+ Methods:
524
+ forward_head: Concatenate and return predicted bounding boxes, class probabilities, and raw angles.
525
+
526
+ Examples:
527
+ Create an OBB26 detection head
528
+ >>> obb26 = OBB26(nc=80, ne=1, ch=(256, 512, 1024))
529
+ >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
530
+ >>> outputs = obb26(x).
531
+ """
532
+
533
+ def forward_head(
534
+ self, x: list[torch.Tensor], box_head: torch.nn.Module, cls_head: torch.nn.Module, angle_head: torch.nn.Module
535
+ ) -> torch.Tensor:
536
+ """Concatenates and returns predicted bounding boxes, class probabilities, and raw angles."""
537
+ preds = Detect.forward_head(self, x, box_head, cls_head)
538
+ if angle_head is not None:
539
+ bs = x[0].shape[0] # batch size
540
+ angle = torch.cat(
541
+ [angle_head[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2
542
+ ) # OBB theta logits (raw output without sigmoid transformation)
543
+ preds["angle"] = angle
544
+ return preds
545
+
318
546
 
319
547
  class Pose(Detect):
320
548
  """YOLO Pose head for keypoints models.
@@ -337,38 +565,85 @@ class Pose(Detect):
337
565
  >>> outputs = pose(x)
338
566
  """
339
567
 
340
- def __init__(self, nc: int = 80, kpt_shape: tuple = (17, 3), ch: tuple = ()):
568
+ def __init__(self, nc: int = 80, kpt_shape: tuple = (17, 3), reg_max=16, end2end=False, ch: tuple = ()):
341
569
  """Initialize YOLO network with default parameters and Convolutional Layers.
342
570
 
343
571
  Args:
344
572
  nc (int): Number of classes.
345
573
  kpt_shape (tuple): Number of keypoints, number of dims (2 for x,y or 3 for x,y,visible).
574
+ reg_max (int): Maximum number of DFL channels.
575
+ end2end (bool): Whether to use end-to-end NMS-free detection.
346
576
  ch (tuple): Tuple of channel sizes from backbone feature maps.
347
577
  """
348
- super().__init__(nc, ch)
578
+ super().__init__(nc, reg_max, end2end, ch)
349
579
  self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
350
580
  self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total
351
581
 
352
582
  c4 = max(ch[0] // 4, self.nk)
353
583
  self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)
584
+ if end2end:
585
+ self.one2one_cv4 = copy.deepcopy(self.cv4)
586
+
587
+ @property
588
+ def one2many(self):
589
+ """Returns the one-to-many head components, here for backward compatibility."""
590
+ return dict(box_head=self.cv2, cls_head=self.cv3, pose_head=self.cv4)
591
+
592
+ @property
593
+ def one2one(self):
594
+ """Returns the one-to-one head components."""
595
+ return dict(box_head=self.one2one_cv2, cls_head=self.one2one_cv3, pose_head=self.one2one_cv4)
596
+
597
+ def _inference(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
598
+ """Decode predicted bounding boxes and class probabilities, concatenated with keypoints."""
599
+ preds = super()._inference(x)
600
+ return torch.cat([preds, self.kpts_decode(x["kpts"])], dim=1)
601
+
602
+ def forward_head(
603
+ self, x: list[torch.Tensor], box_head: torch.nn.Module, cls_head: torch.nn.Module, pose_head: torch.nn.Module
604
+ ) -> torch.Tensor:
605
+ """Concatenates and returns predicted bounding boxes, class probabilities, and keypoints."""
606
+ preds = super().forward_head(x, box_head, cls_head)
607
+ if pose_head is not None:
608
+ bs = x[0].shape[0] # batch size
609
+ preds["kpts"] = torch.cat([pose_head[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], 2)
610
+ return preds
611
+
612
+ def postprocess(self, preds: torch.Tensor) -> torch.Tensor:
613
+ """Post-process YOLO model predictions.
354
614
 
355
- def forward(self, x: list[torch.Tensor]) -> torch.Tensor | tuple:
356
- """Perform forward pass through YOLO model and return predictions."""
357
- bs = x[0].shape[0] # batch size
358
- kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
359
- x = Detect.forward(self, x)
360
- if self.training:
361
- return x, kpt
362
- pred_kpt = self.kpts_decode(bs, kpt)
363
- return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))
615
+ Args:
616
+ preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc + nk) with last dimension
617
+ format [x, y, w, h, class_probs, keypoints].
364
618
 
365
- def kpts_decode(self, bs: int, kpts: torch.Tensor) -> torch.Tensor:
619
+ Returns:
620
+ (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6 + self.nk) and
621
+ last dimension format [x, y, w, h, max_class_prob, class_index, keypoints].
622
+ """
623
+ boxes, scores, kpts = preds.split([4, self.nc, self.nk], dim=-1)
624
+ scores, conf, idx = self.get_topk_index(scores, self.max_det)
625
+ boxes = boxes.gather(dim=1, index=idx.repeat(1, 1, 4))
626
+ kpts = kpts.gather(dim=1, index=idx.repeat(1, 1, self.nk))
627
+ return torch.cat([boxes, scores, conf, kpts], dim=-1)
628
+
629
+ def fuse(self) -> None:
630
+ """Remove the one2many head for inference optimization."""
631
+ self.cv2 = self.cv3 = self.cv4 = None
632
+
633
+ def kpts_decode(self, kpts: torch.Tensor) -> torch.Tensor:
366
634
  """Decode keypoints from predictions."""
367
635
  ndim = self.kpt_shape[1]
636
+ bs = kpts.shape[0]
368
637
  if self.export:
369
- # NCNN fix
370
638
  y = kpts.view(bs, *self.kpt_shape, -1)
371
- a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
639
+ if self.format in {"tflite", "edgetpu"}:
640
+ # Precompute normalization factor to increase numerical stability
641
+ grid_h, grid_w = self.shape[2], self.shape[3]
642
+ grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)
643
+ norm = self.strides / (self.stride[0] * grid_size)
644
+ a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm
645
+ else:
646
+ a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
372
647
  if ndim == 3:
373
648
  a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
374
649
  return a.view(bs, self.nk, -1)
@@ -384,6 +659,134 @@ class Pose(Detect):
384
659
  return y
385
660
 
386
661
 
662
+ class Pose26(Pose):
663
+ """YOLO26 Pose head for keypoints models.
664
+
665
+ This class extends the Detect head to include keypoint prediction capabilities for pose estimation tasks.
666
+
667
+ Attributes:
668
+ kpt_shape (tuple): Number of keypoints and dimensions (2 for x,y or 3 for x,y,visible).
669
+ nk (int): Total number of keypoint values.
670
+ cv4 (nn.ModuleList): Convolution layers for keypoint prediction.
671
+
672
+ Methods:
673
+ forward: Perform forward pass through YOLO model and return predictions.
674
+ kpts_decode: Decode keypoints from predictions.
675
+
676
+ Examples:
677
+ Create a pose detection head
678
+ >>> pose = Pose(nc=80, kpt_shape=(17, 3), ch=(256, 512, 1024))
679
+ >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
680
+ >>> outputs = pose(x)
681
+ """
682
+
683
+ def __init__(self, nc: int = 80, kpt_shape: tuple = (17, 3), reg_max=16, end2end=False, ch: tuple = ()):
684
+ """Initialize YOLO network with default parameters and Convolutional Layers.
685
+
686
+ Args:
687
+ nc (int): Number of classes.
688
+ kpt_shape (tuple): Number of keypoints, number of dims (2 for x,y or 3 for x,y,visible).
689
+ reg_max (int): Maximum number of DFL channels.
690
+ end2end (bool): Whether to use end-to-end NMS-free detection.
691
+ ch (tuple): Tuple of channel sizes from backbone feature maps.
692
+ """
693
+ super().__init__(nc, kpt_shape, reg_max, end2end, ch)
694
+ self.flow_model = RealNVP()
695
+
696
+ c4 = max(ch[0] // 4, kpt_shape[0] * (kpt_shape[1] + 2))
697
+ self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3)) for x in ch)
698
+
699
+ self.cv4_kpts = nn.ModuleList(nn.Conv2d(c4, self.nk, 1) for _ in ch)
700
+ self.nk_sigma = kpt_shape[0] * 2 # sigma_x, sigma_y for each keypoint
701
+ self.cv4_sigma = nn.ModuleList(nn.Conv2d(c4, self.nk_sigma, 1) for _ in ch)
702
+
703
+ if end2end:
704
+ self.one2one_cv4 = copy.deepcopy(self.cv4)
705
+ self.one2one_cv4_kpts = copy.deepcopy(self.cv4_kpts)
706
+ self.one2one_cv4_sigma = copy.deepcopy(self.cv4_sigma)
707
+
708
+ @property
709
+ def one2many(self):
710
+ """Returns the one-to-many head components, here for backward compatibility."""
711
+ return dict(
712
+ box_head=self.cv2,
713
+ cls_head=self.cv3,
714
+ pose_head=self.cv4,
715
+ kpts_head=self.cv4_kpts,
716
+ kpts_sigma_head=self.cv4_sigma,
717
+ )
718
+
719
+ @property
720
+ def one2one(self):
721
+ """Returns the one-to-one head components."""
722
+ return dict(
723
+ box_head=self.one2one_cv2,
724
+ cls_head=self.one2one_cv3,
725
+ pose_head=self.one2one_cv4,
726
+ kpts_head=self.one2one_cv4_kpts,
727
+ kpts_sigma_head=self.one2one_cv4_sigma,
728
+ )
729
+
730
+ def forward_head(
731
+ self,
732
+ x: list[torch.Tensor],
733
+ box_head: torch.nn.Module,
734
+ cls_head: torch.nn.Module,
735
+ pose_head: torch.nn.Module,
736
+ kpts_head: torch.nn.Module,
737
+ kpts_sigma_head: torch.nn.Module,
738
+ ) -> torch.Tensor:
739
+ """Concatenates and returns predicted bounding boxes, class probabilities, and keypoints."""
740
+ preds = Detect.forward_head(self, x, box_head, cls_head)
741
+ if pose_head is not None:
742
+ bs = x[0].shape[0] # batch size
743
+ features = [pose_head[i](x[i]) for i in range(self.nl)]
744
+ preds["kpts"] = torch.cat([kpts_head[i](features[i]).view(bs, self.nk, -1) for i in range(self.nl)], 2)
745
+ if self.training:
746
+ preds["kpts_sigma"] = torch.cat(
747
+ [kpts_sigma_head[i](features[i]).view(bs, self.nk_sigma, -1) for i in range(self.nl)], 2
748
+ )
749
+ return preds
750
+
751
+ def fuse(self) -> None:
752
+ """Remove the one2many head for inference optimization."""
753
+ super().fuse()
754
+ self.cv4_kpts = self.cv4_sigma = self.flow_model = self.one2one_cv4_sigma = None
755
+
756
+ def kpts_decode(self, kpts: torch.Tensor) -> torch.Tensor:
757
+ """Decode keypoints from predictions."""
758
+ ndim = self.kpt_shape[1]
759
+ bs = kpts.shape[0]
760
+ if self.export:
761
+ if self.format in {
762
+ "tflite",
763
+ "edgetpu",
764
+ }: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
765
+ # Precompute normalization factor to increase numerical stability
766
+ y = kpts.view(bs, *self.kpt_shape, -1)
767
+ grid_h, grid_w = self.shape[2], self.shape[3]
768
+ grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1)
769
+ norm = self.strides / (self.stride[0] * grid_size)
770
+ a = (y[:, :, :2] + self.anchors) * norm
771
+ else:
772
+ # NCNN fix
773
+ y = kpts.view(bs, *self.kpt_shape, -1)
774
+ a = (y[:, :, :2] + self.anchors) * self.strides
775
+ if ndim == 3:
776
+ a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
777
+ return a.view(bs, self.nk, -1)
778
+ else:
779
+ y = kpts.clone()
780
+ if ndim == 3:
781
+ if NOT_MACOS14:
782
+ y[:, 2::ndim].sigmoid_()
783
+ else: # Apple macOS14 MPS bug https://github.com/ultralytics/ultralytics/pull/21878
784
+ y[:, 2::ndim] = y[:, 2::ndim].sigmoid()
785
+ y[:, 0::ndim] = (y[:, 0::ndim] + self.anchors[0]) * self.strides
786
+ y[:, 1::ndim] = (y[:, 1::ndim] + self.anchors[1]) * self.strides
787
+ return y
788
+
789
+
387
790
  class Classify(nn.Module):
388
791
  """YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).
389
792
 
@@ -459,29 +862,44 @@ class WorldDetect(Detect):
459
862
  >>> outputs = world_detect(x, text)
460
863
  """
461
864
 
462
- def __init__(self, nc: int = 80, embed: int = 512, with_bn: bool = False, ch: tuple = ()):
865
+ def __init__(
866
+ self,
867
+ nc: int = 80,
868
+ embed: int = 512,
869
+ with_bn: bool = False,
870
+ reg_max: int = 16,
871
+ end2end: bool = False,
872
+ ch: tuple = (),
873
+ ):
463
874
  """Initialize YOLO detection layer with nc classes and layer channels ch.
464
875
 
465
876
  Args:
466
877
  nc (int): Number of classes.
467
878
  embed (int): Embedding dimension.
468
879
  with_bn (bool): Whether to use batch normalization in contrastive head.
880
+ reg_max (int): Maximum number of DFL channels.
881
+ end2end (bool): Whether to use end-to-end NMS-free detection.
469
882
  ch (tuple): Tuple of channel sizes from backbone feature maps.
470
883
  """
471
- super().__init__(nc, ch)
884
+ super().__init__(nc, reg_max=reg_max, end2end=end2end, ch=ch)
472
885
  c3 = max(ch[0], min(self.nc, 100))
473
886
  self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)
474
887
  self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)
475
888
 
476
- def forward(self, x: list[torch.Tensor], text: torch.Tensor) -> list[torch.Tensor] | tuple:
889
+ def forward(self, x: list[torch.Tensor], text: torch.Tensor) -> dict[str, torch.Tensor] | tuple:
477
890
  """Concatenate and return predicted bounding boxes and class probabilities."""
891
+ feats = [xi.clone() for xi in x] # save original features for anchor generation
478
892
  for i in range(self.nl):
479
893
  x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), text)), 1)
480
- if self.training:
481
- return x
482
894
  self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts
483
- y = self._inference(x)
484
- return y if self.export else (y, x)
895
+ bs = x[0].shape[0]
896
+ x_cat = torch.cat([xi.view(bs, self.no, -1) for xi in x], 2)
897
+ boxes, scores = x_cat.split((self.reg_max * 4, self.nc), 1)
898
+ preds = dict(boxes=boxes, scores=scores, feats=feats)
899
+ if self.training:
900
+ return preds
901
+ y = self._inference(preds)
902
+ return y if self.export else (y, preds)
485
903
 
486
904
  def bias_init(self):
487
905
  """Initialize Detect() biases, WARNING: requires stride availability."""
@@ -548,12 +966,14 @@ class LRPCHead(nn.Module):
548
966
  mask = pf_score.sigmoid() > conf
549
967
  cls_feat = cls_feat.flatten(2).transpose(-1, -2)
550
968
  cls_feat = self.vocab(cls_feat[:, mask] if conf else cls_feat * mask.unsqueeze(-1).int())
551
- return (self.loc(loc_feat), cls_feat.transpose(-1, -2)), mask
969
+ return self.loc(loc_feat), cls_feat.transpose(-1, -2), mask
552
970
  else:
553
971
  cls_feat = self.vocab(cls_feat)
554
972
  loc_feat = self.loc(loc_feat)
555
- return (loc_feat, cls_feat.flatten(2)), torch.ones(
556
- cls_feat.shape[2] * cls_feat.shape[3], device=cls_feat.device, dtype=torch.bool
973
+ return (
974
+ loc_feat,
975
+ cls_feat.flatten(2),
976
+ torch.ones(cls_feat.shape[2] * cls_feat.shape[3], device=cls_feat.device, dtype=torch.bool),
557
977
  )
558
978
 
559
979
 
@@ -589,16 +1009,20 @@ class YOLOEDetect(Detect):
589
1009
 
590
1010
  is_fused = False
591
1011
 
592
- def __init__(self, nc: int = 80, embed: int = 512, with_bn: bool = False, ch: tuple = ()):
1012
+ def __init__(
1013
+ self, nc: int = 80, embed: int = 512, with_bn: bool = False, reg_max=16, end2end=False, ch: tuple = ()
1014
+ ):
593
1015
  """Initialize YOLO detection layer with nc classes and layer channels ch.
594
1016
 
595
1017
  Args:
596
1018
  nc (int): Number of classes.
597
1019
  embed (int): Embedding dimension.
598
1020
  with_bn (bool): Whether to use batch normalization in contrastive head.
1021
+ reg_max (int): Maximum number of DFL channels.
1022
+ end2end (bool): Whether to use end-to-end NMS-free detection.
599
1023
  ch (tuple): Tuple of channel sizes from backbone feature maps.
600
1024
  """
601
- super().__init__(nc, ch)
1025
+ super().__init__(nc, reg_max, end2end, ch)
602
1026
  c3 = max(ch[0], min(self.nc, 100))
603
1027
  assert c3 <= embed
604
1028
  assert with_bn
@@ -614,29 +1038,43 @@ class YOLOEDetect(Detect):
614
1038
  for x in ch
615
1039
  )
616
1040
  )
617
-
618
1041
  self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)
1042
+ if end2end:
1043
+ self.one2one_cv3 = copy.deepcopy(self.cv3) # overwrite with new cv3
1044
+ self.one2one_cv4 = copy.deepcopy(self.cv4)
619
1045
 
620
1046
  self.reprta = Residual(SwiGLUFFN(embed, embed))
621
1047
  self.savpe = SAVPE(ch, c3, embed)
622
1048
  self.embed = embed
623
1049
 
624
1050
  @smart_inference_mode()
625
- def fuse(self, txt_feats: torch.Tensor):
1051
+ def fuse(self, txt_feats: torch.Tensor = None):
626
1052
  """Fuse text features with model weights for efficient inference."""
1053
+ if txt_feats is None: # means eliminate one2many branch
1054
+ self.cv2 = self.cv3 = self.cv4 = None
1055
+ return
627
1056
  if self.is_fused:
628
1057
  return
629
1058
 
630
1059
  assert not self.training
631
1060
  txt_feats = txt_feats.to(torch.float32).squeeze(0)
632
- for cls_head, bn_head in zip(self.cv3, self.cv4):
633
- assert isinstance(cls_head, nn.Sequential)
634
- assert isinstance(bn_head, BNContrastiveHead)
635
- conv = cls_head[-1]
1061
+ self._fuse_tp(txt_feats, self.cv3, self.cv4)
1062
+ if self.end2end:
1063
+ self._fuse_tp(txt_feats, self.one2one_cv3, self.one2one_cv4)
1064
+ del self.reprta
1065
+ self.reprta = nn.Identity()
1066
+ self.is_fused = True
1067
+
1068
+ def _fuse_tp(self, txt_feats: torch.Tensor, cls_head: torch.nn.Module, bn_head: torch.nn.Module) -> None:
1069
+ """Fuse text prompt embeddings with model weights for efficient inference."""
1070
+ for cls_h, bn_h in zip(cls_head, bn_head):
1071
+ assert isinstance(cls_h, nn.Sequential)
1072
+ assert isinstance(bn_h, BNContrastiveHead)
1073
+ conv = cls_h[-1]
636
1074
  assert isinstance(conv, nn.Conv2d)
637
- logit_scale = bn_head.logit_scale
638
- bias = bn_head.bias
639
- norm = bn_head.norm
1075
+ logit_scale = bn_h.logit_scale
1076
+ bias = bn_h.bias
1077
+ norm = bn_h.norm
640
1078
 
641
1079
  t = txt_feats * logit_scale.exp()
642
1080
  conv: nn.Conv2d = fuse_conv_and_bn(conv, norm)
@@ -660,13 +1098,9 @@ class YOLOEDetect(Detect):
660
1098
 
661
1099
  conv.weight.data.copy_(w.unsqueeze(-1).unsqueeze(-1))
662
1100
  conv.bias.data.copy_(b1 + b2)
663
- cls_head[-1] = conv
664
-
665
- bn_head.fuse()
1101
+ cls_h[-1] = conv
666
1102
 
667
- del self.reprta
668
- self.reprta = nn.Identity()
669
- self.is_fused = True
1103
+ bn_h.fuse()
670
1104
 
671
1105
  def get_tpe(self, tpe: torch.Tensor | None) -> torch.Tensor | None:
672
1106
  """Get text prompt embeddings with normalization."""
@@ -681,66 +1115,82 @@ class YOLOEDetect(Detect):
681
1115
  assert vpe.ndim == 3 # (B, N, D)
682
1116
  return vpe
683
1117
 
684
- def forward_lrpc(self, x: list[torch.Tensor], return_mask: bool = False) -> torch.Tensor | tuple:
1118
+ def forward(self, x: list[torch.Tensor]) -> torch.Tensor | tuple:
1119
+ """Process features with class prompt embeddings to generate detections."""
1120
+ if hasattr(self, "lrpc"): # for prompt-free inference
1121
+ return self.forward_lrpc(x[:3])
1122
+ return super().forward(x)
1123
+
1124
+ def forward_lrpc(self, x: list[torch.Tensor]) -> torch.Tensor | tuple:
685
1125
  """Process features with fused text embeddings to generate detections for prompt-free model."""
686
- masks = []
687
- assert self.is_fused, "Prompt-free inference requires model to be fused!"
1126
+ boxes, scores, index = [], [], []
1127
+ bs = x[0].shape[0]
1128
+ cv2 = self.cv2 if not self.end2end else self.one2one_cv2
1129
+ cv3 = self.cv3 if not self.end2end else self.one2one_cv2
688
1130
  for i in range(self.nl):
689
- cls_feat = self.cv3[i](x[i])
690
- loc_feat = self.cv2[i](x[i])
1131
+ cls_feat = cv3[i](x[i])
1132
+ loc_feat = cv2[i](x[i])
691
1133
  assert isinstance(self.lrpc[i], LRPCHead)
692
- x[i], mask = self.lrpc[i](
693
- cls_feat, loc_feat, 0 if self.export and not self.dynamic else getattr(self, "conf", 0.001)
1134
+ box, score, idx = self.lrpc[i](
1135
+ cls_feat,
1136
+ loc_feat,
1137
+ 0 if self.export and not self.dynamic else getattr(self, "conf", 0.001),
694
1138
  )
695
- masks.append(mask)
696
- shape = x[0][0].shape
697
- if self.dynamic or self.shape != shape:
698
- self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors([b[0] for b in x], self.stride, 0.5))
699
- self.shape = shape
700
- box = torch.cat([xi[0].view(shape[0], self.reg_max * 4, -1) for xi in x], 2)
701
- cls = torch.cat([xi[1] for xi in x], 2)
702
-
703
- if self.export and self.format in {"tflite", "edgetpu"}:
704
- # Precompute normalization factor to increase numerical stability
705
- # See https://github.com/ultralytics/ultralytics/issues/7371
706
- grid_h = shape[2]
707
- grid_w = shape[3]
708
- grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
709
- norm = self.strides / (self.stride[0] * grid_size)
710
- dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
711
- else:
712
- dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
713
-
714
- mask = torch.cat(masks)
715
- y = torch.cat((dbox if self.export and not self.dynamic else dbox[..., mask], cls.sigmoid()), 1)
716
-
717
- if return_mask:
718
- return (y, mask) if self.export else ((y, x), mask)
719
- else:
720
- return y if self.export else (y, x)
721
-
722
- def forward(self, x: list[torch.Tensor], cls_pe: torch.Tensor, return_mask: bool = False) -> torch.Tensor | tuple:
723
- """Process features with class prompt embeddings to generate detections."""
724
- if hasattr(self, "lrpc"): # for prompt-free inference
725
- return self.forward_lrpc(x, return_mask)
726
- for i in range(self.nl):
727
- x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), cls_pe)), 1)
728
- if self.training:
729
- return x
1139
+ boxes.append(box.view(bs, self.reg_max * 4, -1))
1140
+ scores.append(score)
1141
+ index.append(idx)
1142
+ preds = dict(boxes=torch.cat(boxes, 2), scores=torch.cat(scores, 2), feats=x, index=torch.cat(index))
1143
+ y = self._inference(preds)
1144
+ if self.end2end:
1145
+ y = self.postprocess(y.permute(0, 2, 1))
1146
+ return y if self.export else (y, preds)
1147
+
1148
+ def _get_decode_boxes(self, x):
1149
+ """Decode predicted bounding boxes for inference."""
1150
+ dbox = super()._get_decode_boxes(x)
1151
+ if hasattr(self, "lrpc"):
1152
+ dbox = dbox if self.export and not self.dynamic else dbox[..., x["index"]]
1153
+ return dbox
1154
+
1155
+ @property
1156
+ def one2many(self):
1157
+ """Returns the one-to-many head components, here for v5/v5/v8/v9/11 backward compatibility."""
1158
+ return dict(box_head=self.cv2, cls_head=self.cv3, contrastive_head=self.cv4)
1159
+
1160
+ @property
1161
+ def one2one(self):
1162
+ """Returns the one-to-one head components."""
1163
+ return dict(box_head=self.one2one_cv2, cls_head=self.one2one_cv3, contrastive_head=self.one2one_cv4)
1164
+
1165
+ def forward_head(self, x, box_head, cls_head, contrastive_head):
1166
+ """Concatenates and returns predicted bounding boxes, class probabilities, and text embeddings."""
1167
+ assert len(x) == 4, f"Expected 4 features including 3 feature maps and 1 text embeddings, but got {len(x)}."
1168
+ if box_head is None or cls_head is None: # for fused inference
1169
+ return dict()
1170
+ bs = x[0].shape[0] # batch size
1171
+ boxes = torch.cat([box_head[i](x[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)], dim=-1)
1172
+ self.nc = x[-1].shape[1]
1173
+ scores = torch.cat(
1174
+ [contrastive_head[i](cls_head[i](x[i]), x[-1]).reshape(bs, self.nc, -1) for i in range(self.nl)], dim=-1
1175
+ )
730
1176
  self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts
731
- y = self._inference(x)
732
- return y if self.export else (y, x)
1177
+ return dict(boxes=boxes, scores=scores, feats=x[:3])
733
1178
 
734
1179
  def bias_init(self):
735
- """Initialize biases for detection heads."""
736
- m = self # self.model[-1] # Detect() module
737
- # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
738
- # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
739
- for a, b, c, s in zip(m.cv2, m.cv3, m.cv4, m.stride): # from
740
- a[-1].bias.data[:] = 1.0 # box
741
- # b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
1180
+ """Initialize Detect() biases, WARNING: requires stride availability."""
1181
+ for i, (a, b, c) in enumerate(
1182
+ zip(self.one2many["box_head"], self.one2many["cls_head"], self.one2many["contrastive_head"])
1183
+ ):
1184
+ a[-1].bias.data[:] = 2.0 # box
742
1185
  b[-1].bias.data[:] = 0.0
743
- c.bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2)
1186
+ c.bias.data[:] = math.log(5 / self.nc / (640 / self.stride[i]) ** 2)
1187
+ if self.end2end:
1188
+ for i, (a, b, c) in enumerate(
1189
+ zip(self.one2one["box_head"], self.one2one["cls_head"], self.one2one["contrastive_head"])
1190
+ ):
1191
+ a[-1].bias.data[:] = 2.0 # box
1192
+ b[-1].bias.data[:] = 0.0
1193
+ c.bias.data[:] = math.log(5 / self.nc / (640 / self.stride[i]) ** 2)
744
1194
 
745
1195
 
746
1196
  class YOLOESegment(YOLOEDetect):
@@ -767,7 +1217,15 @@ class YOLOESegment(YOLOEDetect):
767
1217
  """
768
1218
 
769
1219
  def __init__(
770
- self, nc: int = 80, nm: int = 32, npr: int = 256, embed: int = 512, with_bn: bool = False, ch: tuple = ()
1220
+ self,
1221
+ nc: int = 80,
1222
+ nm: int = 32,
1223
+ npr: int = 256,
1224
+ embed: int = 512,
1225
+ with_bn: bool = False,
1226
+ reg_max=16,
1227
+ end2end=False,
1228
+ ch: tuple = (),
771
1229
  ):
772
1230
  """Initialize YOLOESegment with class count, mask parameters, and embedding dimensions.
773
1231
 
@@ -777,36 +1235,191 @@ class YOLOESegment(YOLOEDetect):
777
1235
  npr (int): Number of protos.
778
1236
  embed (int): Embedding dimension.
779
1237
  with_bn (bool): Whether to use batch normalization in contrastive head.
1238
+ reg_max (int): Maximum number of DFL channels.
1239
+ end2end (bool): Whether to use end-to-end NMS-free detection.
780
1240
  ch (tuple): Tuple of channel sizes from backbone feature maps.
781
1241
  """
782
- super().__init__(nc, embed, with_bn, ch)
1242
+ super().__init__(nc, embed, with_bn, reg_max, end2end, ch)
783
1243
  self.nm = nm
784
1244
  self.npr = npr
785
1245
  self.proto = Proto(ch[0], self.npr, self.nm)
786
1246
 
787
1247
  c5 = max(ch[0] // 4, self.nm)
788
1248
  self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch)
1249
+ if end2end:
1250
+ self.one2one_cv5 = copy.deepcopy(self.cv5)
1251
+
1252
+ @property
1253
+ def one2many(self):
1254
+ """Returns the one-to-many head components, here for v5/v5/v8/v9/11 backward compatibility."""
1255
+ return dict(box_head=self.cv2, cls_head=self.cv3, mask_head=self.cv5, contrastive_head=self.cv4)
1256
+
1257
+ @property
1258
+ def one2one(self):
1259
+ """Returns the one-to-one head components."""
1260
+ return dict(
1261
+ box_head=self.one2one_cv2,
1262
+ cls_head=self.one2one_cv3,
1263
+ mask_head=self.one2one_cv5,
1264
+ contrastive_head=self.one2one_cv4,
1265
+ )
1266
+
1267
+ def forward_lrpc(self, x: list[torch.Tensor]) -> torch.Tensor | tuple:
1268
+ """Process features with fused text embeddings to generate detections for prompt-free model."""
1269
+ boxes, scores, index = [], [], []
1270
+ bs = x[0].shape[0]
1271
+ cv2 = self.cv2 if not self.end2end else self.one2one_cv2
1272
+ cv3 = self.cv3 if not self.end2end else self.one2one_cv3
1273
+ cv5 = self.cv5 if not self.end2end else self.one2one_cv5
1274
+ for i in range(self.nl):
1275
+ cls_feat = cv3[i](x[i])
1276
+ loc_feat = cv2[i](x[i])
1277
+ assert isinstance(self.lrpc[i], LRPCHead)
1278
+ box, score, idx = self.lrpc[i](
1279
+ cls_feat,
1280
+ loc_feat,
1281
+ 0 if self.export and not self.dynamic else getattr(self, "conf", 0.001),
1282
+ )
1283
+ boxes.append(box.view(bs, self.reg_max * 4, -1))
1284
+ scores.append(score)
1285
+ index.append(idx)
1286
+ mc = torch.cat([cv5[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2)
1287
+ index = torch.cat(index)
1288
+ preds = dict(
1289
+ boxes=torch.cat(boxes, 2),
1290
+ scores=torch.cat(scores, 2),
1291
+ feats=x,
1292
+ index=index,
1293
+ mask_coefficient=mc * index.int() if self.export and not self.dynamic else mc[..., index],
1294
+ )
1295
+ y = self._inference(preds)
1296
+ if self.end2end:
1297
+ y = self.postprocess(y.permute(0, 2, 1))
1298
+ return y if self.export else (y, preds)
789
1299
 
790
- def forward(self, x: list[torch.Tensor], text: torch.Tensor) -> tuple | torch.Tensor:
1300
+ def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor] | dict[str, torch.Tensor]:
791
1301
  """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
792
- p = self.proto(x[0]) # mask protos
793
- bs = p.shape[0] # batch size
1302
+ outputs = super().forward(x)
1303
+ preds = outputs[1] if isinstance(outputs, tuple) else outputs
1304
+ proto = self.proto(x[0]) # mask protos
1305
+ if isinstance(preds, dict): # training and validating during training
1306
+ if self.end2end:
1307
+ preds["one2many"]["proto"] = proto
1308
+ preds["one2one"]["proto"] = proto.detach()
1309
+ else:
1310
+ preds["proto"] = proto
1311
+ if self.training:
1312
+ return preds
1313
+ return (outputs, proto) if self.export else ((outputs[0], proto), preds)
794
1314
 
795
- mc = torch.cat([self.cv5[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
796
- has_lrpc = hasattr(self, "lrpc")
1315
+ def _inference(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
1316
+ """Decode predicted bounding boxes and class probabilities, concatenated with mask coefficients."""
1317
+ preds = super()._inference(x)
1318
+ return torch.cat([preds, x["mask_coefficient"]], dim=1)
797
1319
 
798
- if not has_lrpc:
799
- x = YOLOEDetect.forward(self, x, text)
800
- else:
801
- x, mask = YOLOEDetect.forward(self, x, text, return_mask=True)
1320
+ def forward_head(
1321
+ self,
1322
+ x: list[torch.Tensor],
1323
+ box_head: torch.nn.Module,
1324
+ cls_head: torch.nn.Module,
1325
+ mask_head: torch.nn.Module,
1326
+ contrastive_head: torch.nn.Module,
1327
+ ) -> torch.Tensor:
1328
+ """Concatenates and returns predicted bounding boxes, class probabilities, and mask coefficients."""
1329
+ preds = super().forward_head(x, box_head, cls_head, contrastive_head)
1330
+ if mask_head is not None:
1331
+ bs = x[0].shape[0] # batch size
1332
+ preds["mask_coefficient"] = torch.cat([mask_head[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2)
1333
+ return preds
1334
+
1335
+ def postprocess(self, preds: torch.Tensor) -> torch.Tensor:
1336
+ """Post-process YOLO model predictions.
802
1337
 
803
- if self.training:
804
- return x, mc, p
1338
+ Args:
1339
+ preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc + nm) with last dimension
1340
+ format [x, y, w, h, class_probs, mask_coefficient].
1341
+
1342
+ Returns:
1343
+ (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6 + nm) and last
1344
+ dimension format [x, y, w, h, max_class_prob, class_index, mask_coefficient].
1345
+ """
1346
+ boxes, scores, mask_coefficient = preds.split([4, self.nc, self.nm], dim=-1)
1347
+ scores, conf, idx = self.get_topk_index(scores, self.max_det)
1348
+ boxes = boxes.gather(dim=1, index=idx.repeat(1, 1, 4))
1349
+ mask_coefficient = mask_coefficient.gather(dim=1, index=idx.repeat(1, 1, self.nm))
1350
+ return torch.cat([boxes, scores, conf, mask_coefficient], dim=-1)
1351
+
1352
+ def fuse(self, txt_feats: torch.Tensor = None):
1353
+ """Fuse text features with model weights for efficient inference."""
1354
+ super().fuse(txt_feats)
1355
+ if txt_feats is None: # means eliminate one2many branch
1356
+ self.cv5 = None
1357
+ if hasattr(self.proto, "fuse"):
1358
+ self.proto.fuse()
1359
+ return
1360
+
1361
+
1362
+ class YOLOESegment26(YOLOESegment):
1363
+ """YOLOE-style segmentation head module using Proto26 for mask generation.
805
1364
 
806
- if has_lrpc:
807
- mc = (mc * mask.int()) if self.export and not self.dynamic else mc[..., mask]
1365
+ This class extends the YOLOEDetect functionality to include segmentation capabilities by integrating a prototype
1366
+ generation module and convolutional layers to predict mask coefficients.
808
1367
 
809
- return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
1368
+ Args:
1369
+ nc (int): Number of classes. Defaults to 80.
1370
+ nm (int): Number of masks. Defaults to 32.
1371
+ npr (int): Number of prototype channels. Defaults to 256.
1372
+ embed (int): Embedding dimensionality. Defaults to 512.
1373
+ with_bn (bool): Whether to use Batch Normalization. Defaults to False.
1374
+ reg_max (int): Maximum regression value for bounding boxes. Defaults to 16.
1375
+ end2end (bool): Whether to use end-to-end detection mode. Defaults to False.
1376
+ ch (tuple[int, ...]): Input channels for each scale.
1377
+
1378
+ Attributes:
1379
+ nm (int): Number of segmentation masks.
1380
+ npr (int): Number of prototype channels.
1381
+ proto (Proto26): Prototype generation module for segmentation.
1382
+ cv5 (nn.ModuleList): Convolutional layers for generating mask coefficients from features.
1383
+ one2one_cv5 (nn.ModuleList, optional): Deep copy of cv5 for end-to-end detection branches.
1384
+ """
1385
+
1386
+ def __init__(
1387
+ self,
1388
+ nc: int = 80,
1389
+ nm: int = 32,
1390
+ npr: int = 256,
1391
+ embed: int = 512,
1392
+ with_bn: bool = False,
1393
+ reg_max=16,
1394
+ end2end=False,
1395
+ ch: tuple = (),
1396
+ ):
1397
+ """Initialize YOLOESegment26 with class count, mask parameters, and embedding dimensions."""
1398
+ YOLOEDetect.__init__(self, nc, embed, with_bn, reg_max, end2end, ch)
1399
+ self.nm = nm
1400
+ self.npr = npr
1401
+ self.proto = Proto26(ch, self.npr, self.nm, nc) # protos
1402
+
1403
+ c5 = max(ch[0] // 4, self.nm)
1404
+ self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch)
1405
+ if end2end:
1406
+ self.one2one_cv5 = copy.deepcopy(self.cv5)
1407
+
1408
+ def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor] | dict[str, torch.Tensor]:
1409
+ """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
1410
+ outputs = YOLOEDetect.forward(self, x)
1411
+ preds = outputs[1] if isinstance(outputs, tuple) else outputs
1412
+ proto = self.proto([xi.detach() for xi in x], return_semseg=False) # mask protos
1413
+
1414
+ if isinstance(preds, dict): # training and validating during training
1415
+ if self.end2end and not hasattr(self, "lrpc"): # not prompt-free
1416
+ preds["one2many"]["proto"] = proto
1417
+ preds["one2one"]["proto"] = proto.detach()
1418
+ else:
1419
+ preds["proto"] = proto
1420
+ if self.training:
1421
+ return preds
1422
+ return (outputs, proto) if self.export else ((outputs[0], proto), preds)
810
1423
 
811
1424
 
812
1425
  class RTDETRDecoder(nn.Module):
@@ -1165,7 +1778,7 @@ class v10Detect(Detect):
1165
1778
  nc (int): Number of classes.
1166
1779
  ch (tuple): Tuple of channel sizes from backbone feature maps.
1167
1780
  """
1168
- super().__init__(nc, ch)
1781
+ super().__init__(nc, end2end=True, ch=ch)
1169
1782
  c3 = max(ch[0], min(self.nc, 100)) # channels
1170
1783
  # Light cls head
1171
1784
  self.cv3 = nn.ModuleList(
@@ -1180,4 +1793,4 @@ class v10Detect(Detect):
1180
1793
 
1181
1794
  def fuse(self):
1182
1795
  """Remove the one2many head for inference optimization."""
1183
- self.cv2 = self.cv3 = nn.ModuleList([nn.Identity()] * self.nl)
1796
+ self.cv2 = self.cv3 = None