dgenerate-ultralytics-headless 8.3.253__py3-none-any.whl → 8.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. {dgenerate_ultralytics_headless-8.3.253.dist-info → dgenerate_ultralytics_headless-8.4.3.dist-info}/METADATA +41 -49
  2. {dgenerate_ultralytics_headless-8.3.253.dist-info → dgenerate_ultralytics_headless-8.4.3.dist-info}/RECORD +85 -74
  3. tests/__init__.py +2 -2
  4. tests/conftest.py +1 -1
  5. tests/test_cuda.py +8 -2
  6. tests/test_engine.py +8 -8
  7. tests/test_exports.py +11 -4
  8. tests/test_integrations.py +9 -9
  9. tests/test_python.py +14 -14
  10. tests/test_solutions.py +3 -3
  11. ultralytics/__init__.py +1 -1
  12. ultralytics/cfg/__init__.py +25 -27
  13. ultralytics/cfg/default.yaml +3 -1
  14. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  15. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  16. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  17. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  18. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  19. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  20. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  21. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  22. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  23. ultralytics/data/annotator.py +2 -2
  24. ultralytics/data/augment.py +7 -0
  25. ultralytics/data/converter.py +57 -38
  26. ultralytics/data/dataset.py +1 -1
  27. ultralytics/engine/exporter.py +31 -26
  28. ultralytics/engine/model.py +34 -34
  29. ultralytics/engine/predictor.py +17 -17
  30. ultralytics/engine/results.py +14 -12
  31. ultralytics/engine/trainer.py +59 -29
  32. ultralytics/engine/tuner.py +19 -11
  33. ultralytics/engine/validator.py +16 -16
  34. ultralytics/models/fastsam/predict.py +1 -1
  35. ultralytics/models/yolo/classify/predict.py +1 -1
  36. ultralytics/models/yolo/classify/train.py +1 -1
  37. ultralytics/models/yolo/classify/val.py +1 -1
  38. ultralytics/models/yolo/detect/predict.py +2 -2
  39. ultralytics/models/yolo/detect/train.py +4 -3
  40. ultralytics/models/yolo/detect/val.py +7 -1
  41. ultralytics/models/yolo/model.py +8 -8
  42. ultralytics/models/yolo/obb/predict.py +2 -2
  43. ultralytics/models/yolo/obb/train.py +3 -3
  44. ultralytics/models/yolo/obb/val.py +1 -1
  45. ultralytics/models/yolo/pose/predict.py +1 -1
  46. ultralytics/models/yolo/pose/train.py +3 -1
  47. ultralytics/models/yolo/pose/val.py +1 -1
  48. ultralytics/models/yolo/segment/predict.py +3 -3
  49. ultralytics/models/yolo/segment/train.py +4 -4
  50. ultralytics/models/yolo/segment/val.py +4 -2
  51. ultralytics/models/yolo/yoloe/train.py +6 -1
  52. ultralytics/models/yolo/yoloe/train_seg.py +6 -1
  53. ultralytics/nn/autobackend.py +5 -5
  54. ultralytics/nn/modules/__init__.py +8 -0
  55. ultralytics/nn/modules/block.py +128 -8
  56. ultralytics/nn/modules/head.py +788 -203
  57. ultralytics/nn/tasks.py +86 -41
  58. ultralytics/nn/text_model.py +5 -2
  59. ultralytics/optim/__init__.py +5 -0
  60. ultralytics/optim/muon.py +338 -0
  61. ultralytics/solutions/ai_gym.py +3 -3
  62. ultralytics/solutions/config.py +1 -1
  63. ultralytics/solutions/heatmap.py +1 -1
  64. ultralytics/solutions/instance_segmentation.py +2 -2
  65. ultralytics/solutions/parking_management.py +1 -1
  66. ultralytics/solutions/solutions.py +2 -2
  67. ultralytics/trackers/track.py +1 -1
  68. ultralytics/utils/__init__.py +8 -8
  69. ultralytics/utils/benchmarks.py +23 -23
  70. ultralytics/utils/callbacks/platform.py +11 -7
  71. ultralytics/utils/checks.py +6 -6
  72. ultralytics/utils/downloads.py +5 -3
  73. ultralytics/utils/export/engine.py +19 -10
  74. ultralytics/utils/export/imx.py +19 -13
  75. ultralytics/utils/export/tensorflow.py +21 -21
  76. ultralytics/utils/files.py +2 -2
  77. ultralytics/utils/loss.py +587 -203
  78. ultralytics/utils/metrics.py +1 -0
  79. ultralytics/utils/ops.py +11 -2
  80. ultralytics/utils/tal.py +98 -19
  81. ultralytics/utils/tuner.py +2 -2
  82. {dgenerate_ultralytics_headless-8.3.253.dist-info → dgenerate_ultralytics_headless-8.4.3.dist-info}/WHEEL +0 -0
  83. {dgenerate_ultralytics_headless-8.3.253.dist-info → dgenerate_ultralytics_headless-8.4.3.dist-info}/entry_points.txt +0 -0
  84. {dgenerate_ultralytics_headless-8.3.253.dist-info → dgenerate_ultralytics_headless-8.4.3.dist-info}/licenses/LICENSE +0 -0
  85. {dgenerate_ultralytics_headless-8.3.253.dist-info → dgenerate_ultralytics_headless-8.4.3.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,7 @@ from ultralytics.utils import NOT_MACOS14
15
15
  from ultralytics.utils.tal import dist2bbox, dist2rbox, make_anchors
16
16
  from ultralytics.utils.torch_utils import TORCH_1_11, fuse_conv_and_bn, smart_inference_mode
17
17
 
18
- from .block import DFL, SAVPE, BNContrastiveHead, ContrastiveHead, Proto, Residual, SwiGLUFFN
18
+ from .block import DFL, SAVPE, BNContrastiveHead, ContrastiveHead, Proto, Proto26, RealNVP, Residual, SwiGLUFFN
19
19
  from .conv import Conv, DWConv
20
20
  from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
21
21
  from .utils import bias_init_with_prob, linear_init
@@ -68,7 +68,6 @@ class Detect(nn.Module):
68
68
  dynamic = False # force grid reconstruction
69
69
  export = False # export mode
70
70
  format = None # export format
71
- end2end = False # end2end
72
71
  max_det = 300 # max_det
73
72
  shape = None
74
73
  anchors = torch.empty(0) # init
@@ -76,17 +75,19 @@ class Detect(nn.Module):
76
75
  legacy = False # backward compatibility for v3/v5/v8/v9 models
77
76
  xyxy = False # xyxy or xywh output
78
77
 
79
- def __init__(self, nc: int = 80, ch: tuple = ()):
78
+ def __init__(self, nc: int = 80, reg_max=16, end2end=False, ch: tuple = ()):
80
79
  """Initialize the YOLO detection layer with specified number of classes and channels.
81
80
 
82
81
  Args:
83
82
  nc (int): Number of classes.
83
+ reg_max (int): Maximum number of DFL channels.
84
+ end2end (bool): Whether to use end-to-end NMS-free detection.
84
85
  ch (tuple): Tuple of channel sizes from backbone feature maps.
85
86
  """
86
87
  super().__init__()
87
88
  self.nc = nc # number of classes
88
89
  self.nl = len(ch) # number of detection layers
89
- self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
90
+ self.reg_max = reg_max # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
90
91
  self.no = nc + self.reg_max * 4 # number of outputs per anchor
91
92
  self.stride = torch.zeros(self.nl) # strides computed during build
92
93
  c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
@@ -107,77 +108,88 @@ class Detect(nn.Module):
107
108
  )
108
109
  self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
109
110
 
110
- if self.end2end:
111
+ if end2end:
111
112
  self.one2one_cv2 = copy.deepcopy(self.cv2)
112
113
  self.one2one_cv3 = copy.deepcopy(self.cv3)
113
114
 
114
- def forward(self, x: list[torch.Tensor]) -> list[torch.Tensor] | tuple:
115
- """Concatenate and return predicted bounding boxes and class probabilities."""
115
+ @property
116
+ def one2many(self):
117
+ """Returns the one-to-many head components, here for v5/v5/v8/v9/11 backward compatibility."""
118
+ return dict(box_head=self.cv2, cls_head=self.cv3)
119
+
120
+ @property
121
+ def one2one(self):
122
+ """Returns the one-to-one head components."""
123
+ return dict(box_head=self.one2one_cv2, cls_head=self.one2one_cv3)
124
+
125
+ @property
126
+ def end2end(self):
127
+ """Checks if the model has one2one for v5/v5/v8/v9/11 backward compatibility."""
128
+ return hasattr(self, "one2one")
129
+
130
+ def forward_head(
131
+ self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None
132
+ ) -> dict[str, torch.Tensor]:
133
+ """Concatenates and returns predicted bounding boxes and class probabilities."""
134
+ if box_head is None or cls_head is None: # for fused inference
135
+ return dict()
136
+ bs = x[0].shape[0] # batch size
137
+ boxes = torch.cat([box_head[i](x[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)], dim=-1)
138
+ scores = torch.cat([cls_head[i](x[i]).view(bs, self.nc, -1) for i in range(self.nl)], dim=-1)
139
+ return dict(boxes=boxes, scores=scores, feats=x)
140
+
141
+ def forward(
142
+ self, x: list[torch.Tensor]
143
+ ) -> dict[str, torch.Tensor] | torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
144
+ """Concatenates and returns predicted bounding boxes and class probabilities."""
145
+ preds = self.forward_head(x, **self.one2many)
116
146
  if self.end2end:
117
- return self.forward_end2end(x)
118
-
119
- for i in range(self.nl):
120
- x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
121
- if self.training: # Training path
122
- return x
123
- y = self._inference(x)
124
- return y if self.export else (y, x)
125
-
126
- def forward_end2end(self, x: list[torch.Tensor]) -> dict | tuple:
127
- """Perform forward pass of the v10Detect module.
128
-
129
- Args:
130
- x (list[torch.Tensor]): Input feature maps from different levels.
131
-
132
- Returns:
133
- outputs (dict | tuple): Training mode returns dict with one2many and one2one outputs. Inference mode returns
134
- processed detections or tuple with detections and raw outputs.
135
- """
136
- x_detach = [xi.detach() for xi in x]
137
- one2one = [
138
- torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)
139
- ]
140
- for i in range(self.nl):
141
- x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
142
- if self.training: # Training path
143
- return {"one2many": x, "one2one": one2one}
144
-
145
- y = self._inference(one2one)
146
- y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
147
- return y if self.export else (y, {"one2many": x, "one2one": one2one})
147
+ x_detach = [xi.detach() for xi in x]
148
+ one2one = self.forward_head(x_detach, **self.one2one)
149
+ preds = {"one2many": preds, "one2one": one2one}
150
+ if self.training:
151
+ return preds
152
+ y = self._inference(preds["one2one"] if self.end2end else preds)
153
+ if self.end2end:
154
+ y = self.postprocess(y.permute(0, 2, 1))
155
+ return y if self.export else (y, preds)
148
156
 
149
- def _inference(self, x: list[torch.Tensor]) -> torch.Tensor:
157
+ def _inference(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
150
158
  """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.
151
159
 
152
160
  Args:
153
- x (list[torch.Tensor]): List of feature maps from different detection layers.
161
+ x (dict[str, torch.Tensor]): List of feature maps from different detection layers.
154
162
 
155
163
  Returns:
156
164
  (torch.Tensor): Concatenated tensor of decoded bounding boxes and class probabilities.
157
165
  """
158
166
  # Inference path
159
- shape = x[0].shape # BCHW
160
- x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
167
+ dbox = self._get_decode_boxes(x)
168
+ return torch.cat((dbox, x["scores"].sigmoid()), 1)
169
+
170
+ def _get_decode_boxes(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
171
+ """Get decoded boxes based on anchors and strides."""
172
+ shape = x["feats"][0].shape # BCHW
161
173
  if self.dynamic or self.shape != shape:
162
- self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
174
+ self.anchors, self.strides = (a.transpose(0, 1) for a in make_anchors(x["feats"], self.stride, 0.5))
163
175
  self.shape = shape
164
176
 
165
- box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
166
- dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
167
- return torch.cat((dbox, cls.sigmoid()), 1)
177
+ dbox = self.decode_bboxes(self.dfl(x["boxes"]), self.anchors.unsqueeze(0)) * self.strides
178
+ return dbox
168
179
 
169
180
  def bias_init(self):
170
181
  """Initialize Detect() biases, WARNING: requires stride availability."""
171
- m = self # self.model[-1] # Detect() module
172
- # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
173
- # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
174
- for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
175
- a[-1].bias.data[:] = 1.0 # box
176
- b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
182
+ for i, (a, b) in enumerate(zip(self.one2many["box_head"], self.one2many["cls_head"])): # from
183
+ a[-1].bias.data[:] = 2.0 # box
184
+ b[-1].bias.data[: self.nc] = math.log(
185
+ 5 / self.nc / (640 / self.stride[i]) ** 2
186
+ ) # cls (.01 objects, 80 classes, 640 img)
177
187
  if self.end2end:
178
- for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from
179
- a[-1].bias.data[:] = 1.0 # box
180
- b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
188
+ for i, (a, b) in enumerate(zip(self.one2one["box_head"], self.one2one["cls_head"])): # from
189
+ a[-1].bias.data[:] = 2.0 # box
190
+ b[-1].bias.data[: self.nc] = math.log(
191
+ 5 / self.nc / (640 / self.stride[i]) ** 2
192
+ ) # cls (.01 objects, 80 classes, 640 img)
181
193
 
182
194
  def decode_bboxes(self, bboxes: torch.Tensor, anchors: torch.Tensor, xywh: bool = True) -> torch.Tensor:
183
195
  """Decode bounding boxes from predictions."""
@@ -188,28 +200,45 @@ class Detect(nn.Module):
188
200
  dim=1,
189
201
  )
190
202
 
191
- @staticmethod
192
- def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80) -> torch.Tensor:
193
- """Post-process YOLO model predictions.
203
+ def postprocess(self, preds: torch.Tensor) -> torch.Tensor:
204
+ """Post-processes YOLO model predictions.
194
205
 
195
206
  Args:
196
207
  preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension
197
208
  format [x, y, w, h, class_probs].
198
- max_det (int): Maximum detections per image.
199
- nc (int, optional): Number of classes.
200
209
 
201
210
  Returns:
202
211
  (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last
203
212
  dimension format [x, y, w, h, max_class_prob, class_index].
204
213
  """
205
- batch_size, anchors, _ = preds.shape # i.e. shape(16,8400,84)
206
- boxes, scores = preds.split([4, nc], dim=-1)
207
- index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1)
208
- boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4))
209
- scores = scores.gather(dim=1, index=index.repeat(1, 1, nc))
210
- scores, index = scores.flatten(1).topk(min(max_det, anchors))
211
- i = torch.arange(batch_size)[..., None] # batch indices
212
- return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)
214
+ boxes, scores = preds.split([4, self.nc], dim=-1)
215
+ scores, conf, idx = self.get_topk_index(scores, self.max_det)
216
+ boxes = boxes.gather(dim=1, index=idx.repeat(1, 1, 4))
217
+ return torch.cat([boxes, scores, conf], dim=-1)
218
+
219
+ def get_topk_index(self, scores: torch.Tensor, max_det: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
220
+ """Get top-k indices from scores.
221
+
222
+ Args:
223
+ scores (torch.Tensor): Scores tensor with shape (batch_size, num_anchors, num_classes).
224
+ max_det (int): Maximum detections per image.
225
+
226
+ Returns:
227
+ (torch.Tensor, torch.Tensor, torch.Tensor): Top scores, class indices, and filtered indices.
228
+ """
229
+ batch_size, anchors, nc = scores.shape # i.e. shape(16,8400,84)
230
+ # Use max_det directly during export for TensorRT compatibility (requires k to be constant),
231
+ # otherwise use min(max_det, anchors) for safety with small inputs during Python inference
232
+ k = max_det if self.export else min(max_det, anchors)
233
+ ori_index = scores.max(dim=-1)[0].topk(k)[1].unsqueeze(-1)
234
+ scores = scores.gather(dim=1, index=ori_index.repeat(1, 1, nc))
235
+ scores, index = scores.flatten(1).topk(k)
236
+ idx = ori_index[torch.arange(batch_size)[..., None], index // nc] # original index
237
+ return scores[..., None], (index % nc)[..., None].float(), idx
238
+
239
+ def fuse(self) -> None:
240
+ """Remove the one2many head for inference optimization."""
241
+ self.cv2 = self.cv3 = None
213
242
 
214
243
 
215
244
  class Segment(Detect):
@@ -233,33 +262,146 @@ class Segment(Detect):
233
262
  >>> outputs = segment(x)
234
263
  """
235
264
 
236
- def __init__(self, nc: int = 80, nm: int = 32, npr: int = 256, ch: tuple = ()):
265
+ def __init__(self, nc: int = 80, nm: int = 32, npr: int = 256, reg_max=16, end2end=False, ch: tuple = ()):
237
266
  """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.
238
267
 
239
268
  Args:
240
269
  nc (int): Number of classes.
241
270
  nm (int): Number of masks.
242
271
  npr (int): Number of protos.
272
+ reg_max (int): Maximum number of DFL channels.
273
+ end2end (bool): Whether to use end-to-end NMS-free detection.
243
274
  ch (tuple): Tuple of channel sizes from backbone feature maps.
244
275
  """
245
- super().__init__(nc, ch)
276
+ super().__init__(nc, reg_max, end2end, ch)
246
277
  self.nm = nm # number of masks
247
278
  self.npr = npr # number of protos
248
279
  self.proto = Proto(ch[0], self.npr, self.nm) # protos
249
280
 
250
281
  c4 = max(ch[0] // 4, self.nm)
251
282
  self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
283
+ if end2end:
284
+ self.one2one_cv4 = copy.deepcopy(self.cv4)
285
+
286
+ @property
287
+ def one2many(self):
288
+ """Returns the one-to-many head components, here for backward compatibility."""
289
+ return dict(box_head=self.cv2, cls_head=self.cv3, mask_head=self.cv4)
290
+
291
+ @property
292
+ def one2one(self):
293
+ """Returns the one-to-one head components."""
294
+ return dict(box_head=self.one2one_cv2, cls_head=self.one2one_cv3, mask_head=self.one2one_cv4)
252
295
 
253
- def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor]:
296
+ def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor] | dict[str, torch.Tensor]:
254
297
  """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
255
- p = self.proto(x[0]) # mask protos
256
- bs = p.shape[0] # batch size
298
+ outputs = super().forward(x)
299
+ preds = outputs[1] if isinstance(outputs, tuple) else outputs
300
+ proto = self.proto(x[0]) # mask protos
301
+ if isinstance(preds, dict): # training and validating during training
302
+ if self.end2end:
303
+ preds["one2many"]["proto"] = proto
304
+ preds["one2one"]["proto"] = proto.detach()
305
+ else:
306
+ preds["proto"] = proto
307
+ if self.training:
308
+ return preds
309
+ return (outputs, proto) if self.export else ((outputs[0], proto), preds)
310
+
311
+ def _inference(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
312
+ """Decode predicted bounding boxes and class probabilities, concatenated with mask coefficients."""
313
+ preds = super()._inference(x)
314
+ return torch.cat([preds, x["mask_coefficient"]], dim=1)
315
+
316
+ def forward_head(
317
+ self, x: list[torch.Tensor], box_head: torch.nn.Module, cls_head: torch.nn.Module, mask_head: torch.nn.Module
318
+ ) -> torch.Tensor:
319
+ """Concatenates and returns predicted bounding boxes, class probabilities, and mask coefficients."""
320
+ preds = super().forward_head(x, box_head, cls_head)
321
+ if mask_head is not None:
322
+ bs = x[0].shape[0] # batch size
323
+ preds["mask_coefficient"] = torch.cat([mask_head[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2)
324
+ return preds
325
+
326
+ def postprocess(self, preds: torch.Tensor) -> torch.Tensor:
327
+ """Post-process YOLO model predictions.
328
+
329
+ Args:
330
+ preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc + nm) with last dimension
331
+ format [x, y, w, h, class_probs, mask_coefficient].
332
+
333
+ Returns:
334
+ (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6 + nm) and last
335
+ dimension format [x, y, w, h, max_class_prob, class_index, mask_coefficient].
336
+ """
337
+ boxes, scores, mask_coefficient = preds.split([4, self.nc, self.nm], dim=-1)
338
+ scores, conf, idx = self.get_topk_index(scores, self.max_det)
339
+ boxes = boxes.gather(dim=1, index=idx.repeat(1, 1, 4))
340
+ mask_coefficient = mask_coefficient.gather(dim=1, index=idx.repeat(1, 1, self.nm))
341
+ return torch.cat([boxes, scores, conf, mask_coefficient], dim=-1)
342
+
343
+ def fuse(self) -> None:
344
+ """Remove the one2many head for inference optimization."""
345
+ self.cv2 = self.cv3 = self.cv4 = None
346
+
347
+
348
+ class Segment26(Segment):
349
+ """YOLO26 Segment head for segmentation models.
350
+
351
+ This class extends the Detect head to include mask prediction capabilities for instance segmentation tasks.
352
+
353
+ Attributes:
354
+ nm (int): Number of masks.
355
+ npr (int): Number of protos.
356
+ proto (Proto): Prototype generation module.
357
+ cv4 (nn.ModuleList): Convolution layers for mask coefficients.
358
+
359
+ Methods:
360
+ forward: Return model outputs and mask coefficients.
361
+
362
+ Examples:
363
+ Create a segmentation head
364
+ >>> segment = Segment26(nc=80, nm=32, npr=256, ch=(256, 512, 1024))
365
+ >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
366
+ >>> outputs = segment(x)
367
+ """
257
368
 
258
- mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
259
- x = Detect.forward(self, x)
369
+ def __init__(self, nc: int = 80, nm: int = 32, npr: int = 256, reg_max=16, end2end=False, ch: tuple = ()):
370
+ """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.
371
+
372
+ Args:
373
+ nc (int): Number of classes.
374
+ nm (int): Number of masks.
375
+ npr (int): Number of protos.
376
+ reg_max (int): Maximum number of DFL channels.
377
+ end2end (bool): Whether to use end-to-end NMS-free detection.
378
+ ch (tuple): Tuple of channel sizes from backbone feature maps.
379
+ """
380
+ super().__init__(nc, nm, npr, reg_max, end2end, ch)
381
+ self.proto = Proto26(ch, self.npr, self.nm, nc) # protos
382
+
383
+ def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor] | dict[str, torch.Tensor]:
384
+ """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
385
+ outputs = Detect.forward(self, x)
386
+ preds = outputs[1] if isinstance(outputs, tuple) else outputs
387
+ proto = self.proto(x) # mask protos
388
+ if isinstance(preds, dict): # training and validating during training
389
+ if self.end2end:
390
+ preds["one2many"]["proto"] = proto
391
+ preds["one2one"]["proto"] = (
392
+ tuple(p.detach() for p in proto) if isinstance(proto, tuple) else proto.detach()
393
+ )
394
+ else:
395
+ preds["proto"] = proto
260
396
  if self.training:
261
- return x, mc, p
262
- return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
397
+ return preds
398
+ return (outputs, proto) if self.export else ((outputs[0], proto), preds)
399
+
400
+ def fuse(self) -> None:
401
+ """Remove the one2many head and extra part of proto module for inference optimization."""
402
+ super().fuse()
403
+ if hasattr(self.proto, "fuse"):
404
+ self.proto.fuse()
263
405
 
264
406
 
265
407
  class OBB(Detect):
@@ -283,38 +425,114 @@ class OBB(Detect):
283
425
  >>> outputs = obb(x)
284
426
  """
285
427
 
286
- def __init__(self, nc: int = 80, ne: int = 1, ch: tuple = ()):
428
+ def __init__(self, nc: int = 80, ne: int = 1, reg_max=16, end2end=False, ch: tuple = ()):
287
429
  """Initialize OBB with number of classes `nc` and layer channels `ch`.
288
430
 
289
431
  Args:
290
432
  nc (int): Number of classes.
291
433
  ne (int): Number of extra parameters.
434
+ reg_max (int): Maximum number of DFL channels.
435
+ end2end (bool): Whether to use end-to-end NMS-free detection.
292
436
  ch (tuple): Tuple of channel sizes from backbone feature maps.
293
437
  """
294
- super().__init__(nc, ch)
438
+ super().__init__(nc, reg_max, end2end, ch)
295
439
  self.ne = ne # number of extra parameters
296
440
 
297
441
  c4 = max(ch[0] // 4, self.ne)
298
442
  self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)
299
-
300
- def forward(self, x: list[torch.Tensor]) -> torch.Tensor | tuple:
301
- """Concatenate and return predicted bounding boxes and class probabilities."""
302
- bs = x[0].shape[0] # batch size
303
- angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits
304
- # NOTE: set `angle` as an attribute so that `decode_bboxes` could use it.
305
- angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4]
306
- # angle = angle.sigmoid() * math.pi / 2 # [0, pi/2]
307
- if not self.training:
308
- self.angle = angle
309
- x = Detect.forward(self, x)
310
- if self.training:
311
- return x, angle
312
- return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))
443
+ if end2end:
444
+ self.one2one_cv4 = copy.deepcopy(self.cv4)
445
+
446
+ @property
447
+ def one2many(self):
448
+ """Returns the one-to-many head components, here for backward compatibility."""
449
+ return dict(box_head=self.cv2, cls_head=self.cv3, angle_head=self.cv4)
450
+
451
+ @property
452
+ def one2one(self):
453
+ """Returns the one-to-one head components."""
454
+ return dict(box_head=self.one2one_cv2, cls_head=self.one2one_cv3, angle_head=self.one2one_cv4)
455
+
456
+ def _inference(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
457
+ """Decode predicted bounding boxes and class probabilities, concatenated with rotation angles."""
458
+ # For decode_bboxes convenience
459
+ self.angle = x["angle"] # TODO: need to test obb
460
+ preds = super()._inference(x)
461
+ return torch.cat([preds, x["angle"]], dim=1)
462
+
463
+ def forward_head(
464
+ self, x: list[torch.Tensor], box_head: torch.nn.Module, cls_head: torch.nn.Module, angle_head: torch.nn.Module
465
+ ) -> torch.Tensor:
466
+ """Concatenates and returns predicted bounding boxes, class probabilities, and angles."""
467
+ preds = super().forward_head(x, box_head, cls_head)
468
+ if angle_head is not None:
469
+ bs = x[0].shape[0] # batch size
470
+ angle = torch.cat(
471
+ [angle_head[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2
472
+ ) # OBB theta logits
473
+ angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4]
474
+ preds["angle"] = angle
475
+ return preds
313
476
 
314
477
  def decode_bboxes(self, bboxes: torch.Tensor, anchors: torch.Tensor) -> torch.Tensor:
315
478
  """Decode rotated bounding boxes."""
316
479
  return dist2rbox(bboxes, self.angle, anchors, dim=1)
317
480
 
481
+ def postprocess(self, preds: torch.Tensor) -> torch.Tensor:
482
+ """Post-process YOLO model predictions.
483
+
484
+ Args:
485
+ preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc + ne) with last dimension
486
+ format [x, y, w, h, class_probs, angle].
487
+
488
+ Returns:
489
+ (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 7) and last
490
+ dimension format [x, y, w, h, max_class_prob, class_index, angle].
491
+ """
492
+ boxes, scores, angle = preds.split([4, self.nc, self.ne], dim=-1)
493
+ scores, conf, idx = self.get_topk_index(scores, self.max_det)
494
+ boxes = boxes.gather(dim=1, index=idx.repeat(1, 1, 4))
495
+ angle = angle.gather(dim=1, index=idx.repeat(1, 1, self.ne))
496
+ return torch.cat([boxes, scores, conf, angle], dim=-1)
497
+
498
+ def fuse(self) -> None:
499
+ """Remove the one2many head for inference optimization."""
500
+ self.cv2 = self.cv3 = self.cv4 = None
501
+
502
+
503
+ class OBB26(OBB):
504
+ """YOLO26 OBB detection head for detection with rotation models. This class extends the OBB head with modified angle
505
+ processing that outputs raw angle predictions without sigmoid transformation, compared to the original
506
+ OBB class.
507
+
508
+ Attributes:
509
+ ne (int): Number of extra parameters.
510
+ cv4 (nn.ModuleList): Convolution layers for angle prediction.
511
+ angle (torch.Tensor): Predicted rotation angles.
512
+
513
+ Methods:
514
+ forward_head: Concatenate and return predicted bounding boxes, class probabilities, and raw angles.
515
+
516
+ Examples:
517
+ Create an OBB26 detection head
518
+ >>> obb26 = OBB26(nc=80, ne=1, ch=(256, 512, 1024))
519
+ >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
520
+ >>> outputs = obb26(x).
521
+ """
522
+
523
+ def forward_head(
524
+ self, x: list[torch.Tensor], box_head: torch.nn.Module, cls_head: torch.nn.Module, angle_head: torch.nn.Module
525
+ ) -> torch.Tensor:
526
+ """Concatenates and returns predicted bounding boxes, class probabilities, and raw angles."""
527
+ preds = Detect.forward_head(self, x, box_head, cls_head)
528
+ if angle_head is not None:
529
+ bs = x[0].shape[0] # batch size
530
+ angle = torch.cat(
531
+ [angle_head[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2
532
+ ) # OBB theta logits (raw output without sigmoid transformation)
533
+ preds["angle"] = angle
534
+ return preds
535
+
318
536
 
319
537
  class Pose(Detect):
320
538
  """YOLO Pose head for keypoints models.
@@ -337,36 +555,76 @@ class Pose(Detect):
337
555
  >>> outputs = pose(x)
338
556
  """
339
557
 
340
- def __init__(self, nc: int = 80, kpt_shape: tuple = (17, 3), ch: tuple = ()):
558
+ def __init__(self, nc: int = 80, kpt_shape: tuple = (17, 3), reg_max=16, end2end=False, ch: tuple = ()):
341
559
  """Initialize YOLO network with default parameters and Convolutional Layers.
342
560
 
343
561
  Args:
344
562
  nc (int): Number of classes.
345
563
  kpt_shape (tuple): Number of keypoints, number of dims (2 for x,y or 3 for x,y,visible).
564
+ reg_max (int): Maximum number of DFL channels.
565
+ end2end (bool): Whether to use end-to-end NMS-free detection.
346
566
  ch (tuple): Tuple of channel sizes from backbone feature maps.
347
567
  """
348
- super().__init__(nc, ch)
568
+ super().__init__(nc, reg_max, end2end, ch)
349
569
  self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
350
570
  self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total
351
571
 
352
572
  c4 = max(ch[0] // 4, self.nk)
353
573
  self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)
574
+ if end2end:
575
+ self.one2one_cv4 = copy.deepcopy(self.cv4)
576
+
577
+ @property
578
+ def one2many(self):
579
+ """Returns the one-to-many head components, here for backward compatibility."""
580
+ return dict(box_head=self.cv2, cls_head=self.cv3, pose_head=self.cv4)
581
+
582
+ @property
583
+ def one2one(self):
584
+ """Returns the one-to-one head components."""
585
+ return dict(box_head=self.one2one_cv2, cls_head=self.one2one_cv3, pose_head=self.one2one_cv4)
586
+
587
+ def _inference(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
588
+ """Decode predicted bounding boxes and class probabilities, concatenated with keypoints."""
589
+ preds = super()._inference(x)
590
+ return torch.cat([preds, self.kpts_decode(x["kpts"])], dim=1)
591
+
592
+ def forward_head(
593
+ self, x: list[torch.Tensor], box_head: torch.nn.Module, cls_head: torch.nn.Module, pose_head: torch.nn.Module
594
+ ) -> torch.Tensor:
595
+ """Concatenates and returns predicted bounding boxes, class probabilities, and keypoints."""
596
+ preds = super().forward_head(x, box_head, cls_head)
597
+ if pose_head is not None:
598
+ bs = x[0].shape[0] # batch size
599
+ preds["kpts"] = torch.cat([pose_head[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], 2)
600
+ return preds
601
+
602
+ def postprocess(self, preds: torch.Tensor) -> torch.Tensor:
603
+ """Post-process YOLO model predictions.
354
604
 
355
- def forward(self, x: list[torch.Tensor]) -> torch.Tensor | tuple:
356
- """Perform forward pass through YOLO model and return predictions."""
357
- bs = x[0].shape[0] # batch size
358
- kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
359
- x = Detect.forward(self, x)
360
- if self.training:
361
- return x, kpt
362
- pred_kpt = self.kpts_decode(bs, kpt)
363
- return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))
605
+ Args:
606
+ preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc + nk) with last dimension
607
+ format [x, y, w, h, class_probs, keypoints].
608
+
609
+ Returns:
610
+ (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6 + self.nk) and
611
+ last dimension format [x, y, w, h, max_class_prob, class_index, keypoints].
612
+ """
613
+ boxes, scores, kpts = preds.split([4, self.nc, self.nk], dim=-1)
614
+ scores, conf, idx = self.get_topk_index(scores, self.max_det)
615
+ boxes = boxes.gather(dim=1, index=idx.repeat(1, 1, 4))
616
+ kpts = kpts.gather(dim=1, index=idx.repeat(1, 1, self.nk))
617
+ return torch.cat([boxes, scores, conf, kpts], dim=-1)
364
618
 
365
- def kpts_decode(self, bs: int, kpts: torch.Tensor) -> torch.Tensor:
619
+ def fuse(self) -> None:
620
+ """Remove the one2many head for inference optimization."""
621
+ self.cv2 = self.cv3 = self.cv4 = None
622
+
623
+ def kpts_decode(self, kpts: torch.Tensor) -> torch.Tensor:
366
624
  """Decode keypoints from predictions."""
367
625
  ndim = self.kpt_shape[1]
626
+ bs = kpts.shape[0]
368
627
  if self.export:
369
- # NCNN fix
370
628
  y = kpts.view(bs, *self.kpt_shape, -1)
371
629
  a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
372
630
  if ndim == 3:
@@ -384,6 +642,123 @@ class Pose(Detect):
384
642
  return y
385
643
 
386
644
 
645
+ class Pose26(Pose):
646
+ """YOLO26 Pose head for keypoints models.
647
+
648
+ This class extends the Detect head to include keypoint prediction capabilities for pose estimation tasks.
649
+
650
+ Attributes:
651
+ kpt_shape (tuple): Number of keypoints and dimensions (2 for x,y or 3 for x,y,visible).
652
+ nk (int): Total number of keypoint values.
653
+ cv4 (nn.ModuleList): Convolution layers for keypoint prediction.
654
+
655
+ Methods:
656
+ forward: Perform forward pass through YOLO model and return predictions.
657
+ kpts_decode: Decode keypoints from predictions.
658
+
659
+ Examples:
660
+ Create a pose detection head
661
+ >>> pose = Pose(nc=80, kpt_shape=(17, 3), ch=(256, 512, 1024))
662
+ >>> x = [torch.randn(1, 256, 80, 80), torch.randn(1, 512, 40, 40), torch.randn(1, 1024, 20, 20)]
663
+ >>> outputs = pose(x)
664
+ """
665
+
666
+ def __init__(self, nc: int = 80, kpt_shape: tuple = (17, 3), reg_max=16, end2end=False, ch: tuple = ()):
667
+ """Initialize YOLO network with default parameters and Convolutional Layers.
668
+
669
+ Args:
670
+ nc (int): Number of classes.
671
+ kpt_shape (tuple): Number of keypoints, number of dims (2 for x,y or 3 for x,y,visible).
672
+ reg_max (int): Maximum number of DFL channels.
673
+ end2end (bool): Whether to use end-to-end NMS-free detection.
674
+ ch (tuple): Tuple of channel sizes from backbone feature maps.
675
+ """
676
+ super().__init__(nc, kpt_shape, reg_max, end2end, ch)
677
+ self.flow_model = RealNVP()
678
+
679
+ c4 = max(ch[0] // 4, kpt_shape[0] * (kpt_shape[1] + 2))
680
+ self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3)) for x in ch)
681
+
682
+ self.cv4_kpts = nn.ModuleList(nn.Conv2d(c4, self.nk, 1) for _ in ch)
683
+ self.nk_sigma = kpt_shape[0] * 2 # sigma_x, sigma_y for each keypoint
684
+ self.cv4_sigma = nn.ModuleList(nn.Conv2d(c4, self.nk_sigma, 1) for _ in ch)
685
+
686
+ if end2end:
687
+ self.one2one_cv4 = copy.deepcopy(self.cv4)
688
+ self.one2one_cv4_kpts = copy.deepcopy(self.cv4_kpts)
689
+ self.one2one_cv4_sigma = copy.deepcopy(self.cv4_sigma)
690
+
691
+ @property
692
+ def one2many(self):
693
+ """Returns the one-to-many head components, here for backward compatibility."""
694
+ return dict(
695
+ box_head=self.cv2,
696
+ cls_head=self.cv3,
697
+ pose_head=self.cv4,
698
+ kpts_head=self.cv4_kpts,
699
+ kpts_sigma_head=self.cv4_sigma,
700
+ )
701
+
702
+ @property
703
+ def one2one(self):
704
+ """Returns the one-to-one head components."""
705
+ return dict(
706
+ box_head=self.one2one_cv2,
707
+ cls_head=self.one2one_cv3,
708
+ pose_head=self.one2one_cv4,
709
+ kpts_head=self.one2one_cv4_kpts,
710
+ kpts_sigma_head=self.one2one_cv4_sigma,
711
+ )
712
+
713
+ def forward_head(
714
+ self,
715
+ x: list[torch.Tensor],
716
+ box_head: torch.nn.Module,
717
+ cls_head: torch.nn.Module,
718
+ pose_head: torch.nn.Module,
719
+ kpts_head: torch.nn.Module,
720
+ kpts_sigma_head: torch.nn.Module,
721
+ ) -> torch.Tensor:
722
+ """Concatenates and returns predicted bounding boxes, class probabilities, and keypoints."""
723
+ preds = Detect.forward_head(self, x, box_head, cls_head)
724
+ if pose_head is not None:
725
+ bs = x[0].shape[0] # batch size
726
+ features = [pose_head[i](x[i]) for i in range(self.nl)]
727
+ preds["kpts"] = torch.cat([kpts_head[i](features[i]).view(bs, self.nk, -1) for i in range(self.nl)], 2)
728
+ if self.training:
729
+ preds["kpts_sigma"] = torch.cat(
730
+ [kpts_sigma_head[i](features[i]).view(bs, self.nk_sigma, -1) for i in range(self.nl)], 2
731
+ )
732
+ return preds
733
+
734
+ def fuse(self) -> None:
735
+ """Remove the one2many head for inference optimization."""
736
+ super().fuse()
737
+ self.cv4_kpts = self.cv4_sigma = self.flow_model = self.one2one_cv4_sigma = None
738
+
739
+ def kpts_decode(self, kpts: torch.Tensor) -> torch.Tensor:
740
+ """Decode keypoints from predictions."""
741
+ ndim = self.kpt_shape[1]
742
+ bs = kpts.shape[0]
743
+ if self.export:
744
+ y = kpts.view(bs, *self.kpt_shape, -1)
745
+ # NCNN fix
746
+ a = (y[:, :, :2] + self.anchors) * self.strides
747
+ if ndim == 3:
748
+ a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
749
+ return a.view(bs, self.nk, -1)
750
+ else:
751
+ y = kpts.clone()
752
+ if ndim == 3:
753
+ if NOT_MACOS14:
754
+ y[:, 2::ndim].sigmoid_()
755
+ else: # Apple macOS14 MPS bug https://github.com/ultralytics/ultralytics/pull/21878
756
+ y[:, 2::ndim] = y[:, 2::ndim].sigmoid()
757
+ y[:, 0::ndim] = (y[:, 0::ndim] + self.anchors[0]) * self.strides
758
+ y[:, 1::ndim] = (y[:, 1::ndim] + self.anchors[1]) * self.strides
759
+ return y
760
+
761
+
387
762
  class Classify(nn.Module):
388
763
  """YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).
389
764
 
@@ -459,29 +834,44 @@ class WorldDetect(Detect):
459
834
  >>> outputs = world_detect(x, text)
460
835
  """
461
836
 
462
- def __init__(self, nc: int = 80, embed: int = 512, with_bn: bool = False, ch: tuple = ()):
837
+ def __init__(
838
+ self,
839
+ nc: int = 80,
840
+ embed: int = 512,
841
+ with_bn: bool = False,
842
+ reg_max: int = 16,
843
+ end2end: bool = False,
844
+ ch: tuple = (),
845
+ ):
463
846
  """Initialize YOLO detection layer with nc classes and layer channels ch.
464
847
 
465
848
  Args:
466
849
  nc (int): Number of classes.
467
850
  embed (int): Embedding dimension.
468
851
  with_bn (bool): Whether to use batch normalization in contrastive head.
852
+ reg_max (int): Maximum number of DFL channels.
853
+ end2end (bool): Whether to use end-to-end NMS-free detection.
469
854
  ch (tuple): Tuple of channel sizes from backbone feature maps.
470
855
  """
471
- super().__init__(nc, ch)
856
+ super().__init__(nc, reg_max=reg_max, end2end=end2end, ch=ch)
472
857
  c3 = max(ch[0], min(self.nc, 100))
473
858
  self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)
474
859
  self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)
475
860
 
476
- def forward(self, x: list[torch.Tensor], text: torch.Tensor) -> list[torch.Tensor] | tuple:
861
+ def forward(self, x: list[torch.Tensor], text: torch.Tensor) -> dict[str, torch.Tensor] | tuple:
477
862
  """Concatenate and return predicted bounding boxes and class probabilities."""
863
+ feats = [xi.clone() for xi in x] # save original features for anchor generation
478
864
  for i in range(self.nl):
479
865
  x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), text)), 1)
480
- if self.training:
481
- return x
482
866
  self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts
483
- y = self._inference(x)
484
- return y if self.export else (y, x)
867
+ bs = x[0].shape[0]
868
+ x_cat = torch.cat([xi.view(bs, self.no, -1) for xi in x], 2)
869
+ boxes, scores = x_cat.split((self.reg_max * 4, self.nc), 1)
870
+ preds = dict(boxes=boxes, scores=scores, feats=feats)
871
+ if self.training:
872
+ return preds
873
+ y = self._inference(preds)
874
+ return y if self.export else (y, preds)
485
875
 
486
876
  def bias_init(self):
487
877
  """Initialize Detect() biases, WARNING: requires stride availability."""
@@ -548,12 +938,14 @@ class LRPCHead(nn.Module):
548
938
  mask = pf_score.sigmoid() > conf
549
939
  cls_feat = cls_feat.flatten(2).transpose(-1, -2)
550
940
  cls_feat = self.vocab(cls_feat[:, mask] if conf else cls_feat * mask.unsqueeze(-1).int())
551
- return (self.loc(loc_feat), cls_feat.transpose(-1, -2)), mask
941
+ return self.loc(loc_feat), cls_feat.transpose(-1, -2), mask
552
942
  else:
553
943
  cls_feat = self.vocab(cls_feat)
554
944
  loc_feat = self.loc(loc_feat)
555
- return (loc_feat, cls_feat.flatten(2)), torch.ones(
556
- cls_feat.shape[2] * cls_feat.shape[3], device=cls_feat.device, dtype=torch.bool
945
+ return (
946
+ loc_feat,
947
+ cls_feat.flatten(2),
948
+ torch.ones(cls_feat.shape[2] * cls_feat.shape[3], device=cls_feat.device, dtype=torch.bool),
557
949
  )
558
950
 
559
951
 
@@ -589,16 +981,20 @@ class YOLOEDetect(Detect):
589
981
 
590
982
  is_fused = False
591
983
 
592
- def __init__(self, nc: int = 80, embed: int = 512, with_bn: bool = False, ch: tuple = ()):
984
+ def __init__(
985
+ self, nc: int = 80, embed: int = 512, with_bn: bool = False, reg_max=16, end2end=False, ch: tuple = ()
986
+ ):
593
987
  """Initialize YOLO detection layer with nc classes and layer channels ch.
594
988
 
595
989
  Args:
596
990
  nc (int): Number of classes.
597
991
  embed (int): Embedding dimension.
598
992
  with_bn (bool): Whether to use batch normalization in contrastive head.
993
+ reg_max (int): Maximum number of DFL channels.
994
+ end2end (bool): Whether to use end-to-end NMS-free detection.
599
995
  ch (tuple): Tuple of channel sizes from backbone feature maps.
600
996
  """
601
- super().__init__(nc, ch)
997
+ super().__init__(nc, reg_max, end2end, ch)
602
998
  c3 = max(ch[0], min(self.nc, 100))
603
999
  assert c3 <= embed
604
1000
  assert with_bn
@@ -614,29 +1010,43 @@ class YOLOEDetect(Detect):
614
1010
  for x in ch
615
1011
  )
616
1012
  )
617
-
618
1013
  self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)
1014
+ if end2end:
1015
+ self.one2one_cv3 = copy.deepcopy(self.cv3) # overwrite with new cv3
1016
+ self.one2one_cv4 = copy.deepcopy(self.cv4)
619
1017
 
620
1018
  self.reprta = Residual(SwiGLUFFN(embed, embed))
621
1019
  self.savpe = SAVPE(ch, c3, embed)
622
1020
  self.embed = embed
623
1021
 
624
1022
  @smart_inference_mode()
625
- def fuse(self, txt_feats: torch.Tensor):
1023
+ def fuse(self, txt_feats: torch.Tensor = None):
626
1024
  """Fuse text features with model weights for efficient inference."""
1025
+ if txt_feats is None: # means eliminate one2many branch
1026
+ self.cv2 = self.cv3 = self.cv4 = None
1027
+ return
627
1028
  if self.is_fused:
628
1029
  return
629
1030
 
630
1031
  assert not self.training
631
1032
  txt_feats = txt_feats.to(torch.float32).squeeze(0)
632
- for cls_head, bn_head in zip(self.cv3, self.cv4):
633
- assert isinstance(cls_head, nn.Sequential)
634
- assert isinstance(bn_head, BNContrastiveHead)
635
- conv = cls_head[-1]
1033
+ self._fuse_tp(txt_feats, self.cv3, self.cv4)
1034
+ if self.end2end:
1035
+ self._fuse_tp(txt_feats, self.one2one_cv3, self.one2one_cv4)
1036
+ del self.reprta
1037
+ self.reprta = nn.Identity()
1038
+ self.is_fused = True
1039
+
1040
+ def _fuse_tp(self, txt_feats: torch.Tensor, cls_head: torch.nn.Module, bn_head: torch.nn.Module) -> None:
1041
+ """Fuse text prompt embeddings with model weights for efficient inference."""
1042
+ for cls_h, bn_h in zip(cls_head, bn_head):
1043
+ assert isinstance(cls_h, nn.Sequential)
1044
+ assert isinstance(bn_h, BNContrastiveHead)
1045
+ conv = cls_h[-1]
636
1046
  assert isinstance(conv, nn.Conv2d)
637
- logit_scale = bn_head.logit_scale
638
- bias = bn_head.bias
639
- norm = bn_head.norm
1047
+ logit_scale = bn_h.logit_scale
1048
+ bias = bn_h.bias
1049
+ norm = bn_h.norm
640
1050
 
641
1051
  t = txt_feats * logit_scale.exp()
642
1052
  conv: nn.Conv2d = fuse_conv_and_bn(conv, norm)
@@ -660,13 +1070,9 @@ class YOLOEDetect(Detect):
660
1070
 
661
1071
  conv.weight.data.copy_(w.unsqueeze(-1).unsqueeze(-1))
662
1072
  conv.bias.data.copy_(b1 + b2)
663
- cls_head[-1] = conv
664
-
665
- bn_head.fuse()
1073
+ cls_h[-1] = conv
666
1074
 
667
- del self.reprta
668
- self.reprta = nn.Identity()
669
- self.is_fused = True
1075
+ bn_h.fuse()
670
1076
 
671
1077
  def get_tpe(self, tpe: torch.Tensor | None) -> torch.Tensor | None:
672
1078
  """Get text prompt embeddings with normalization."""
@@ -681,66 +1087,82 @@ class YOLOEDetect(Detect):
681
1087
  assert vpe.ndim == 3 # (B, N, D)
682
1088
  return vpe
683
1089
 
684
- def forward_lrpc(self, x: list[torch.Tensor], return_mask: bool = False) -> torch.Tensor | tuple:
1090
+ def forward(self, x: list[torch.Tensor]) -> torch.Tensor | tuple:
1091
+ """Process features with class prompt embeddings to generate detections."""
1092
+ if hasattr(self, "lrpc"): # for prompt-free inference
1093
+ return self.forward_lrpc(x[:3])
1094
+ return super().forward(x)
1095
+
1096
+ def forward_lrpc(self, x: list[torch.Tensor]) -> torch.Tensor | tuple:
685
1097
  """Process features with fused text embeddings to generate detections for prompt-free model."""
686
- masks = []
687
- assert self.is_fused, "Prompt-free inference requires model to be fused!"
1098
+ boxes, scores, index = [], [], []
1099
+ bs = x[0].shape[0]
1100
+ cv2 = self.cv2 if not self.end2end else self.one2one_cv2
1101
+ cv3 = self.cv3 if not self.end2end else self.one2one_cv2
688
1102
  for i in range(self.nl):
689
- cls_feat = self.cv3[i](x[i])
690
- loc_feat = self.cv2[i](x[i])
1103
+ cls_feat = cv3[i](x[i])
1104
+ loc_feat = cv2[i](x[i])
691
1105
  assert isinstance(self.lrpc[i], LRPCHead)
692
- x[i], mask = self.lrpc[i](
693
- cls_feat, loc_feat, 0 if self.export and not self.dynamic else getattr(self, "conf", 0.001)
1106
+ box, score, idx = self.lrpc[i](
1107
+ cls_feat,
1108
+ loc_feat,
1109
+ 0 if self.export and not self.dynamic else getattr(self, "conf", 0.001),
694
1110
  )
695
- masks.append(mask)
696
- shape = x[0][0].shape
697
- if self.dynamic or self.shape != shape:
698
- self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors([b[0] for b in x], self.stride, 0.5))
699
- self.shape = shape
700
- box = torch.cat([xi[0].view(shape[0], self.reg_max * 4, -1) for xi in x], 2)
701
- cls = torch.cat([xi[1] for xi in x], 2)
702
-
703
- if self.export and self.format in {"tflite", "edgetpu"}:
704
- # Precompute normalization factor to increase numerical stability
705
- # See https://github.com/ultralytics/ultralytics/issues/7371
706
- grid_h = shape[2]
707
- grid_w = shape[3]
708
- grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
709
- norm = self.strides / (self.stride[0] * grid_size)
710
- dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
711
- else:
712
- dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
713
-
714
- mask = torch.cat(masks)
715
- y = torch.cat((dbox if self.export and not self.dynamic else dbox[..., mask], cls.sigmoid()), 1)
716
-
717
- if return_mask:
718
- return (y, mask) if self.export else ((y, x), mask)
719
- else:
720
- return y if self.export else (y, x)
721
-
722
- def forward(self, x: list[torch.Tensor], cls_pe: torch.Tensor, return_mask: bool = False) -> torch.Tensor | tuple:
723
- """Process features with class prompt embeddings to generate detections."""
724
- if hasattr(self, "lrpc"): # for prompt-free inference
725
- return self.forward_lrpc(x, return_mask)
726
- for i in range(self.nl):
727
- x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), cls_pe)), 1)
728
- if self.training:
729
- return x
1111
+ boxes.append(box.view(bs, self.reg_max * 4, -1))
1112
+ scores.append(score)
1113
+ index.append(idx)
1114
+ preds = dict(boxes=torch.cat(boxes, 2), scores=torch.cat(scores, 2), feats=x, index=torch.cat(index))
1115
+ y = self._inference(preds)
1116
+ if self.end2end:
1117
+ y = self.postprocess(y.permute(0, 2, 1))
1118
+ return y if self.export else (y, preds)
1119
+
1120
+ def _get_decode_boxes(self, x):
1121
+ """Decode predicted bounding boxes for inference."""
1122
+ dbox = super()._get_decode_boxes(x)
1123
+ if hasattr(self, "lrpc"):
1124
+ dbox = dbox if self.export and not self.dynamic else dbox[..., x["index"]]
1125
+ return dbox
1126
+
1127
+ @property
1128
+ def one2many(self):
1129
+ """Returns the one-to-many head components, here for v5/v5/v8/v9/11 backward compatibility."""
1130
+ return dict(box_head=self.cv2, cls_head=self.cv3, contrastive_head=self.cv4)
1131
+
1132
+ @property
1133
+ def one2one(self):
1134
+ """Returns the one-to-one head components."""
1135
+ return dict(box_head=self.one2one_cv2, cls_head=self.one2one_cv3, contrastive_head=self.one2one_cv4)
1136
+
1137
+ def forward_head(self, x, box_head, cls_head, contrastive_head):
1138
+ """Concatenates and returns predicted bounding boxes, class probabilities, and text embeddings."""
1139
+ assert len(x) == 4, f"Expected 4 features including 3 feature maps and 1 text embeddings, but got {len(x)}."
1140
+ if box_head is None or cls_head is None: # for fused inference
1141
+ return dict()
1142
+ bs = x[0].shape[0] # batch size
1143
+ boxes = torch.cat([box_head[i](x[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)], dim=-1)
1144
+ self.nc = x[-1].shape[1]
1145
+ scores = torch.cat(
1146
+ [contrastive_head[i](cls_head[i](x[i]), x[-1]).reshape(bs, self.nc, -1) for i in range(self.nl)], dim=-1
1147
+ )
730
1148
  self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts
731
- y = self._inference(x)
732
- return y if self.export else (y, x)
1149
+ return dict(boxes=boxes, scores=scores, feats=x[:3])
733
1150
 
734
1151
  def bias_init(self):
735
- """Initialize biases for detection heads."""
736
- m = self # self.model[-1] # Detect() module
737
- # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
738
- # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
739
- for a, b, c, s in zip(m.cv2, m.cv3, m.cv4, m.stride): # from
740
- a[-1].bias.data[:] = 1.0 # box
741
- # b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
1152
+ """Initialize Detect() biases, WARNING: requires stride availability."""
1153
+ for i, (a, b, c) in enumerate(
1154
+ zip(self.one2many["box_head"], self.one2many["cls_head"], self.one2many["contrastive_head"])
1155
+ ):
1156
+ a[-1].bias.data[:] = 2.0 # box
742
1157
  b[-1].bias.data[:] = 0.0
743
- c.bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2)
1158
+ c.bias.data[:] = math.log(5 / self.nc / (640 / self.stride[i]) ** 2)
1159
+ if self.end2end:
1160
+ for i, (a, b, c) in enumerate(
1161
+ zip(self.one2one["box_head"], self.one2one["cls_head"], self.one2one["contrastive_head"])
1162
+ ):
1163
+ a[-1].bias.data[:] = 2.0 # box
1164
+ b[-1].bias.data[:] = 0.0
1165
+ c.bias.data[:] = math.log(5 / self.nc / (640 / self.stride[i]) ** 2)
744
1166
 
745
1167
 
746
1168
  class YOLOESegment(YOLOEDetect):
@@ -767,7 +1189,15 @@ class YOLOESegment(YOLOEDetect):
767
1189
  """
768
1190
 
769
1191
  def __init__(
770
- self, nc: int = 80, nm: int = 32, npr: int = 256, embed: int = 512, with_bn: bool = False, ch: tuple = ()
1192
+ self,
1193
+ nc: int = 80,
1194
+ nm: int = 32,
1195
+ npr: int = 256,
1196
+ embed: int = 512,
1197
+ with_bn: bool = False,
1198
+ reg_max=16,
1199
+ end2end=False,
1200
+ ch: tuple = (),
771
1201
  ):
772
1202
  """Initialize YOLOESegment with class count, mask parameters, and embedding dimensions.
773
1203
 
@@ -777,36 +1207,191 @@ class YOLOESegment(YOLOEDetect):
777
1207
  npr (int): Number of protos.
778
1208
  embed (int): Embedding dimension.
779
1209
  with_bn (bool): Whether to use batch normalization in contrastive head.
1210
+ reg_max (int): Maximum number of DFL channels.
1211
+ end2end (bool): Whether to use end-to-end NMS-free detection.
780
1212
  ch (tuple): Tuple of channel sizes from backbone feature maps.
781
1213
  """
782
- super().__init__(nc, embed, with_bn, ch)
1214
+ super().__init__(nc, embed, with_bn, reg_max, end2end, ch)
783
1215
  self.nm = nm
784
1216
  self.npr = npr
785
1217
  self.proto = Proto(ch[0], self.npr, self.nm)
786
1218
 
787
1219
  c5 = max(ch[0] // 4, self.nm)
788
1220
  self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch)
1221
+ if end2end:
1222
+ self.one2one_cv5 = copy.deepcopy(self.cv5)
1223
+
1224
+ @property
1225
+ def one2many(self):
1226
+ """Returns the one-to-many head components, here for v5/v5/v8/v9/11 backward compatibility."""
1227
+ return dict(box_head=self.cv2, cls_head=self.cv3, mask_head=self.cv5, contrastive_head=self.cv4)
1228
+
1229
+ @property
1230
+ def one2one(self):
1231
+ """Returns the one-to-one head components."""
1232
+ return dict(
1233
+ box_head=self.one2one_cv2,
1234
+ cls_head=self.one2one_cv3,
1235
+ mask_head=self.one2one_cv5,
1236
+ contrastive_head=self.one2one_cv4,
1237
+ )
1238
+
1239
+ def forward_lrpc(self, x: list[torch.Tensor]) -> torch.Tensor | tuple:
1240
+ """Process features with fused text embeddings to generate detections for prompt-free model."""
1241
+ boxes, scores, index = [], [], []
1242
+ bs = x[0].shape[0]
1243
+ cv2 = self.cv2 if not self.end2end else self.one2one_cv2
1244
+ cv3 = self.cv3 if not self.end2end else self.one2one_cv3
1245
+ cv5 = self.cv5 if not self.end2end else self.one2one_cv5
1246
+ for i in range(self.nl):
1247
+ cls_feat = cv3[i](x[i])
1248
+ loc_feat = cv2[i](x[i])
1249
+ assert isinstance(self.lrpc[i], LRPCHead)
1250
+ box, score, idx = self.lrpc[i](
1251
+ cls_feat,
1252
+ loc_feat,
1253
+ 0 if self.export and not self.dynamic else getattr(self, "conf", 0.001),
1254
+ )
1255
+ boxes.append(box.view(bs, self.reg_max * 4, -1))
1256
+ scores.append(score)
1257
+ index.append(idx)
1258
+ mc = torch.cat([cv5[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2)
1259
+ index = torch.cat(index)
1260
+ preds = dict(
1261
+ boxes=torch.cat(boxes, 2),
1262
+ scores=torch.cat(scores, 2),
1263
+ feats=x,
1264
+ index=index,
1265
+ mask_coefficient=mc * index.int() if self.export and not self.dynamic else mc[..., index],
1266
+ )
1267
+ y = self._inference(preds)
1268
+ if self.end2end:
1269
+ y = self.postprocess(y.permute(0, 2, 1))
1270
+ return y if self.export else (y, preds)
789
1271
 
790
- def forward(self, x: list[torch.Tensor], text: torch.Tensor) -> tuple | torch.Tensor:
1272
+ def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor] | dict[str, torch.Tensor]:
791
1273
  """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
792
- p = self.proto(x[0]) # mask protos
793
- bs = p.shape[0] # batch size
1274
+ outputs = super().forward(x)
1275
+ preds = outputs[1] if isinstance(outputs, tuple) else outputs
1276
+ proto = self.proto(x[0]) # mask protos
1277
+ if isinstance(preds, dict): # training and validating during training
1278
+ if self.end2end:
1279
+ preds["one2many"]["proto"] = proto
1280
+ preds["one2one"]["proto"] = proto.detach()
1281
+ else:
1282
+ preds["proto"] = proto
1283
+ if self.training:
1284
+ return preds
1285
+ return (outputs, proto) if self.export else ((outputs[0], proto), preds)
794
1286
 
795
- mc = torch.cat([self.cv5[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
796
- has_lrpc = hasattr(self, "lrpc")
1287
+ def _inference(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
1288
+ """Decode predicted bounding boxes and class probabilities, concatenated with mask coefficients."""
1289
+ preds = super()._inference(x)
1290
+ return torch.cat([preds, x["mask_coefficient"]], dim=1)
797
1291
 
798
- if not has_lrpc:
799
- x = YOLOEDetect.forward(self, x, text)
800
- else:
801
- x, mask = YOLOEDetect.forward(self, x, text, return_mask=True)
1292
+ def forward_head(
1293
+ self,
1294
+ x: list[torch.Tensor],
1295
+ box_head: torch.nn.Module,
1296
+ cls_head: torch.nn.Module,
1297
+ mask_head: torch.nn.Module,
1298
+ contrastive_head: torch.nn.Module,
1299
+ ) -> torch.Tensor:
1300
+ """Concatenates and returns predicted bounding boxes, class probabilities, and mask coefficients."""
1301
+ preds = super().forward_head(x, box_head, cls_head, contrastive_head)
1302
+ if mask_head is not None:
1303
+ bs = x[0].shape[0] # batch size
1304
+ preds["mask_coefficient"] = torch.cat([mask_head[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2)
1305
+ return preds
1306
+
1307
+ def postprocess(self, preds: torch.Tensor) -> torch.Tensor:
1308
+ """Post-process YOLO model predictions.
802
1309
 
803
- if self.training:
804
- return x, mc, p
1310
+ Args:
1311
+ preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc + nm) with last dimension
1312
+ format [x, y, w, h, class_probs, mask_coefficient].
1313
+
1314
+ Returns:
1315
+ (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6 + nm) and last
1316
+ dimension format [x, y, w, h, max_class_prob, class_index, mask_coefficient].
1317
+ """
1318
+ boxes, scores, mask_coefficient = preds.split([4, self.nc, self.nm], dim=-1)
1319
+ scores, conf, idx = self.get_topk_index(scores, self.max_det)
1320
+ boxes = boxes.gather(dim=1, index=idx.repeat(1, 1, 4))
1321
+ mask_coefficient = mask_coefficient.gather(dim=1, index=idx.repeat(1, 1, self.nm))
1322
+ return torch.cat([boxes, scores, conf, mask_coefficient], dim=-1)
1323
+
1324
+ def fuse(self, txt_feats: torch.Tensor = None):
1325
+ """Fuse text features with model weights for efficient inference."""
1326
+ super().fuse(txt_feats)
1327
+ if txt_feats is None: # means eliminate one2many branch
1328
+ self.cv5 = None
1329
+ if hasattr(self.proto, "fuse"):
1330
+ self.proto.fuse()
1331
+ return
1332
+
1333
+
1334
+ class YOLOESegment26(YOLOESegment):
1335
+ """YOLOE-style segmentation head module using Proto26 for mask generation.
805
1336
 
806
- if has_lrpc:
807
- mc = (mc * mask.int()) if self.export and not self.dynamic else mc[..., mask]
1337
+ This class extends the YOLOEDetect functionality to include segmentation capabilities by integrating a prototype
1338
+ generation module and convolutional layers to predict mask coefficients.
808
1339
 
809
- return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
1340
+ Args:
1341
+ nc (int): Number of classes. Defaults to 80.
1342
+ nm (int): Number of masks. Defaults to 32.
1343
+ npr (int): Number of prototype channels. Defaults to 256.
1344
+ embed (int): Embedding dimensionality. Defaults to 512.
1345
+ with_bn (bool): Whether to use Batch Normalization. Defaults to False.
1346
+ reg_max (int): Maximum regression value for bounding boxes. Defaults to 16.
1347
+ end2end (bool): Whether to use end-to-end detection mode. Defaults to False.
1348
+ ch (tuple[int, ...]): Input channels for each scale.
1349
+
1350
+ Attributes:
1351
+ nm (int): Number of segmentation masks.
1352
+ npr (int): Number of prototype channels.
1353
+ proto (Proto26): Prototype generation module for segmentation.
1354
+ cv5 (nn.ModuleList): Convolutional layers for generating mask coefficients from features.
1355
+ one2one_cv5 (nn.ModuleList, optional): Deep copy of cv5 for end-to-end detection branches.
1356
+ """
1357
+
1358
+ def __init__(
1359
+ self,
1360
+ nc: int = 80,
1361
+ nm: int = 32,
1362
+ npr: int = 256,
1363
+ embed: int = 512,
1364
+ with_bn: bool = False,
1365
+ reg_max=16,
1366
+ end2end=False,
1367
+ ch: tuple = (),
1368
+ ):
1369
+ """Initialize YOLOESegment26 with class count, mask parameters, and embedding dimensions."""
1370
+ YOLOEDetect.__init__(self, nc, embed, with_bn, reg_max, end2end, ch)
1371
+ self.nm = nm
1372
+ self.npr = npr
1373
+ self.proto = Proto26(ch, self.npr, self.nm, nc) # protos
1374
+
1375
+ c5 = max(ch[0] // 4, self.nm)
1376
+ self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch)
1377
+ if end2end:
1378
+ self.one2one_cv5 = copy.deepcopy(self.cv5)
1379
+
1380
+ def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor] | dict[str, torch.Tensor]:
1381
+ """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
1382
+ outputs = YOLOEDetect.forward(self, x)
1383
+ preds = outputs[1] if isinstance(outputs, tuple) else outputs
1384
+ proto = self.proto([xi.detach() for xi in x], return_semseg=False) # mask protos
1385
+
1386
+ if isinstance(preds, dict): # training and validating during training
1387
+ if self.end2end and not hasattr(self, "lrpc"): # not prompt-free
1388
+ preds["one2many"]["proto"] = proto
1389
+ preds["one2one"]["proto"] = proto.detach()
1390
+ else:
1391
+ preds["proto"] = proto
1392
+ if self.training:
1393
+ return preds
1394
+ return (outputs, proto) if self.export else ((outputs[0], proto), preds)
810
1395
 
811
1396
 
812
1397
  class RTDETRDecoder(nn.Module):
@@ -1165,7 +1750,7 @@ class v10Detect(Detect):
1165
1750
  nc (int): Number of classes.
1166
1751
  ch (tuple): Tuple of channel sizes from backbone feature maps.
1167
1752
  """
1168
- super().__init__(nc, ch)
1753
+ super().__init__(nc, end2end=True, ch=ch)
1169
1754
  c3 = max(ch[0], min(self.nc, 100)) # channels
1170
1755
  # Light cls head
1171
1756
  self.cv3 = nn.ModuleList(
@@ -1180,4 +1765,4 @@ class v10Detect(Detect):
1180
1765
 
1181
1766
  def fuse(self):
1182
1767
  """Remove the one2many head for inference optimization."""
1183
- self.cv2 = self.cv3 = nn.ModuleList([nn.Identity()] * self.nl)
1768
+ self.cv2 = self.cv3 = None