ultralytics 8.3.97__py3-none-any.whl → 8.3.99__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. tests/test_python.py +56 -0
  2. ultralytics/__init__.py +3 -2
  3. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  4. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  5. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
  6. ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
  7. ultralytics/data/augment.py +101 -5
  8. ultralytics/data/dataset.py +165 -12
  9. ultralytics/engine/exporter.py +13 -13
  10. ultralytics/engine/trainer.py +16 -7
  11. ultralytics/models/__init__.py +2 -2
  12. ultralytics/models/nas/model.py +1 -0
  13. ultralytics/models/nas/predict.py +4 -24
  14. ultralytics/models/nas/val.py +1 -4
  15. ultralytics/models/yolo/__init__.py +3 -3
  16. ultralytics/models/yolo/detect/val.py +6 -1
  17. ultralytics/models/yolo/model.py +182 -3
  18. ultralytics/models/yolo/segment/val.py +43 -16
  19. ultralytics/models/yolo/yoloe/__init__.py +21 -0
  20. ultralytics/models/yolo/yoloe/predict.py +170 -0
  21. ultralytics/models/yolo/yoloe/train.py +355 -0
  22. ultralytics/models/yolo/yoloe/train_seg.py +141 -0
  23. ultralytics/models/yolo/yoloe/val.py +187 -0
  24. ultralytics/nn/autobackend.py +3 -2
  25. ultralytics/nn/modules/__init__.py +18 -1
  26. ultralytics/nn/modules/block.py +17 -1
  27. ultralytics/nn/modules/head.py +359 -22
  28. ultralytics/nn/tasks.py +276 -10
  29. ultralytics/nn/text_model.py +193 -0
  30. ultralytics/utils/callbacks/comet.py +3 -6
  31. ultralytics/utils/downloads.py +6 -2
  32. ultralytics/utils/instance.py +7 -2
  33. ultralytics/utils/loss.py +67 -6
  34. ultralytics/utils/plotting.py +1 -1
  35. ultralytics/utils/tal.py +1 -1
  36. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/METADATA +69 -67
  37. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/RECORD +41 -31
  38. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/WHEEL +0 -0
  39. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/entry_points.txt +0 -0
  40. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/licenses/LICENSE +0 -0
  41. {ultralytics-8.3.97.dist-info → ultralytics-8.3.99.dist-info}/top_level.txt +0 -0
@@ -546,7 +546,7 @@ class AutoBackend(nn.Module):
546
546
 
547
547
  self.__dict__.update(locals()) # assign all variables to self
548
548
 
549
- def forward(self, im, augment=False, visualize=False, embed=None):
549
+ def forward(self, im, augment=False, visualize=False, embed=None, **kwargs):
550
550
  """
551
551
  Runs inference on the YOLOv8 MultiBackend model.
552
552
 
@@ -555,6 +555,7 @@ class AutoBackend(nn.Module):
555
555
  augment (bool): Whether to perform data augmentation during inference. Defaults to False.
556
556
  visualize (bool): Whether to visualize the output predictions. Defaults to False.
557
557
  embed (list, optional): A list of feature vectors/embeddings to return.
558
+ **kwargs (Any): Additional keyword arguments for model configuration.
558
559
 
559
560
  Returns:
560
561
  (torch.Tensor | List[torch.Tensor]): The raw output tensor(s) from the model.
@@ -567,7 +568,7 @@ class AutoBackend(nn.Module):
567
568
 
568
569
  # PyTorch
569
570
  if self.pt or self.nn_module:
570
- y = self.model(im, augment=augment, visualize=visualize, embed=embed)
571
+ y = self.model(im, augment=augment, visualize=visualize, embed=embed, **kwargs)
571
572
 
572
573
  # TorchScript
573
574
  elif self.jit:
@@ -51,6 +51,7 @@ from .block import (
51
51
  HGBlock,
52
52
  HGStem,
53
53
  ImagePoolingAttn,
54
+ MaxSigmoidAttnBlock,
54
55
  Proto,
55
56
  RepC3,
56
57
  RepNCSPELAN4,
@@ -75,7 +76,19 @@ from .conv import (
75
76
  RepConv,
76
77
  SpatialAttention,
77
78
  )
78
- from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment, WorldDetect, v10Detect
79
+ from .head import (
80
+ OBB,
81
+ Classify,
82
+ Detect,
83
+ LRPCHead,
84
+ Pose,
85
+ RTDETRDecoder,
86
+ Segment,
87
+ WorldDetect,
88
+ YOLOEDetect,
89
+ YOLOESegment,
90
+ v10Detect,
91
+ )
79
92
  from .transformer import (
80
93
  AIFI,
81
94
  MLP,
@@ -143,8 +156,12 @@ __all__ = (
143
156
  "ResNetLayer",
144
157
  "OBB",
145
158
  "WorldDetect",
159
+ "YOLOEDetect",
160
+ "YOLOESegment",
146
161
  "v10Detect",
162
+ "LRPCHead",
147
163
  "ImagePoolingAttn",
164
+ "MaxSigmoidAttnBlock",
148
165
  "ContrastiveHead",
149
166
  "BNContrastiveHead",
150
167
  "RepNCSPELAN4",
@@ -771,7 +771,7 @@ class ContrastiveHead(nn.Module):
771
771
 
772
772
  class BNContrastiveHead(nn.Module):
773
773
  """
774
- Batch Norm Contrastive Head for YOLO-World using batch norm instead of l2-normalization.
774
+ Batch Norm Contrastive Head using batch norm instead of l2-normalization.
775
775
 
776
776
  Args:
777
777
  embed_dims (int): Embed dimensions of text and image features.
@@ -791,6 +791,21 @@ class BNContrastiveHead(nn.Module):
791
791
  # use -1.0 is more stable
792
792
  self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
793
793
 
794
+ def fuse(self):
795
+ """Fuse the batch normalization layer in the BNContrastiveHead module."""
796
+ del self.norm
797
+ del self.bias
798
+ del self.logit_scale
799
+ self.forward = self.forward_fuse
800
+
801
+ def forward_fuse(self, x, w):
802
+ """
803
+ Passes input out unchanged.
804
+
805
+ TODO: Update or remove?
806
+ """
807
+ return x
808
+
794
809
  def forward(self, x, w):
795
810
  """
796
811
  Forward function of contrastive learning with batch normalization.
@@ -804,6 +819,7 @@ class BNContrastiveHead(nn.Module):
804
819
  """
805
820
  x = self.norm(x)
806
821
  w = F.normalize(w, dim=-1, p=2)
822
+
807
823
  x = torch.einsum("bchw,bkc->bkhw", x, w)
808
824
  return x * self.logit_scale.exp() + self.bias
809
825
 
@@ -6,16 +6,18 @@ import math
6
6
 
7
7
  import torch
8
8
  import torch.nn as nn
9
+ import torch.nn.functional as F
9
10
  from torch.nn.init import constant_, xavier_uniform_
10
11
 
11
12
  from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
13
+ from ultralytics.utils.torch_utils import fuse_conv_and_bn, smart_inference_mode
12
14
 
13
15
  from .block import DFL, BNContrastiveHead, ContrastiveHead, Proto
14
16
  from .conv import Conv, DWConv
15
17
  from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
16
18
  from .utils import bias_init_with_prob, linear_init
17
19
 
18
- __all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect"
20
+ __all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect", "YOLOEDetect", "YOLOESegment"
19
21
 
20
22
 
21
23
  class Detect(nn.Module):
@@ -78,11 +80,12 @@ class Detect(nn.Module):
78
80
  Performs forward pass of the v10Detect module.
79
81
 
80
82
  Args:
81
- x (tensor): Input tensor.
83
+ x (List[torch.Tensor]): Input feature maps from different levels.
82
84
 
83
85
  Returns:
84
- (dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections.
85
- If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately.
86
+ (dict | tuple): If in training mode, returns a dictionary containing the outputs of both one2many and
87
+ one2one detections. If not in training mode, returns processed detections or a tuple with
88
+ processed detections and raw outputs.
86
89
  """
87
90
  x_detach = [xi.detach() for xi in x]
88
91
  one2one = [
@@ -98,7 +101,15 @@ class Detect(nn.Module):
98
101
  return y if self.export else (y, {"one2many": x, "one2one": one2one})
99
102
 
100
103
  def _inference(self, x):
101
- """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps."""
104
+ """
105
+ Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.
106
+
107
+ Args:
108
+ x (List[torch.Tensor]): List of feature maps from different detection layers.
109
+
110
+ Returns:
111
+ (torch.Tensor): Concatenated tensor of decoded bounding boxes and class probabilities.
112
+ """
102
113
  # Inference path
103
114
  shape = x[0].shape # BCHW
104
115
  x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
@@ -320,19 +331,216 @@ class WorldDetect(Detect):
320
331
  x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), text)), 1)
321
332
  if self.training:
322
333
  return x
334
+ self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts
335
+ y = self._inference(x)
336
+ return y if self.export else (y, x)
323
337
 
324
- # Inference path
325
- shape = x[0].shape # BCHW
326
- x_cat = torch.cat([xi.view(shape[0], self.nc + self.reg_max * 4, -1) for xi in x], 2)
327
- if self.dynamic or self.shape != shape:
328
- self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
329
- self.shape = shape
338
+ def bias_init(self):
339
+ """Initialize Detect() biases, WARNING: requires stride availability."""
340
+ m = self # self.model[-1] # Detect() module
341
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
342
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
343
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
344
+ a[-1].bias.data[:] = 1.0 # box
345
+ # b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
330
346
 
331
- if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
332
- box = x_cat[:, : self.reg_max * 4]
333
- cls = x_cat[:, self.reg_max * 4 :]
347
+
348
+ class SAVPE(nn.Module):
349
+ """Spatial-Aware Visual Prompt Embedding module for feature enhancement."""
350
+
351
+ def __init__(self, ch, c3, embed):
352
+ """Initialize SAVPE module with channels, intermediate channels, and embedding dimension."""
353
+ super().__init__()
354
+ self.cv1 = nn.ModuleList(
355
+ nn.Sequential(
356
+ Conv(x, c3, 3), Conv(c3, c3, 3), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity()
357
+ )
358
+ for i, x in enumerate(ch)
359
+ )
360
+
361
+ self.cv2 = nn.ModuleList(
362
+ nn.Sequential(Conv(x, c3, 1), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity())
363
+ for i, x in enumerate(ch)
364
+ )
365
+
366
+ self.c = 16
367
+ self.cv3 = nn.Conv2d(3 * c3, embed, 1)
368
+ self.cv4 = nn.Conv2d(3 * c3, self.c, 3, padding=1)
369
+ self.cv5 = nn.Conv2d(1, self.c, 3, padding=1)
370
+ self.cv6 = nn.Sequential(Conv(2 * self.c, self.c, 3), nn.Conv2d(self.c, self.c, 3, padding=1))
371
+
372
+ def forward(self, x, vp):
373
+ """Process input features and visual prompts to generate enhanced embeddings."""
374
+ y = [self.cv2[i](xi) for i, xi in enumerate(x)]
375
+ y = self.cv4(torch.cat(y, dim=1))
376
+
377
+ x = [self.cv1[i](xi) for i, xi in enumerate(x)]
378
+ x = self.cv3(torch.cat(x, dim=1))
379
+
380
+ B, C, H, W = x.shape
381
+
382
+ Q = vp.shape[1]
383
+
384
+ x = x.view(B, C, -1)
385
+
386
+ y = y.reshape(B, 1, self.c, H, W).expand(-1, Q, -1, -1, -1).reshape(B * Q, self.c, H, W)
387
+ vp = vp.reshape(B, Q, 1, H, W).reshape(B * Q, 1, H, W)
388
+
389
+ y = self.cv6(torch.cat((y, self.cv5(vp)), dim=1))
390
+
391
+ y = y.reshape(B, Q, self.c, -1)
392
+ vp = vp.reshape(B, Q, 1, -1)
393
+
394
+ score = y * vp + torch.logical_not(vp) * torch.finfo(y.dtype).min
395
+
396
+ score = F.softmax(score, dim=-1, dtype=torch.float).to(score.dtype)
397
+
398
+ aggregated = score.transpose(-2, -3) @ x.reshape(B, self.c, C // self.c, -1).transpose(-1, -2)
399
+
400
+ return F.normalize(aggregated.transpose(-2, -3).reshape(B, Q, -1), dim=-1, p=2)
401
+
402
+
403
+ class LRPCHead(nn.Module):
404
+ """Lightweight Region Proposal and Classification Head for efficient object detection."""
405
+
406
+ def __init__(self, vocab, pf, loc, enabled=True):
407
+ """Initialize LRPCHead with vocabulary, proposal filter, and localization components."""
408
+ super().__init__()
409
+ self.vocab = self.conv2linear(vocab) if enabled else vocab
410
+ self.pf = pf
411
+ self.loc = loc
412
+ self.enabled = enabled
413
+
414
+ def conv2linear(self, conv):
415
+ """Convert a 1x1 convolutional layer to a linear layer."""
416
+ assert isinstance(conv, nn.Conv2d) and conv.kernel_size == (1, 1)
417
+ linear = nn.Linear(conv.in_channels, conv.out_channels)
418
+ linear.weight.data = conv.weight.view(conv.out_channels, -1).data
419
+ linear.bias.data = conv.bias.data
420
+ return linear
421
+
422
+ def forward(self, cls_feat, loc_feat, conf, max_det):
423
+ """Process classification and localization features to generate detection proposals."""
424
+ if self.enabled:
425
+ pf_score = self.pf(cls_feat)[0, 0].flatten(0)
426
+ mask = pf_score.sigmoid() > conf
427
+
428
+ cls_feat = self.vocab(cls_feat.flatten(2).transpose(-1, -2)[:, mask])
429
+ return (self.loc(loc_feat), cls_feat.transpose(-1, -2)), mask
334
430
  else:
335
- box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
431
+ cls_feat = self.vocab(cls_feat)
432
+ loc_feat = self.loc(loc_feat)
433
+ return (loc_feat, cls_feat.flatten(2)), torch.ones(
434
+ cls_feat.shape[2] * cls_feat.shape[3], device=cls_feat.device, dtype=torch.bool
435
+ )
436
+
437
+
438
+ class YOLOEDetect(Detect):
439
+ """Head for integrating YOLO detection models with semantic understanding from text embeddings."""
440
+
441
+ is_fused = False
442
+
443
+ def __init__(self, nc=80, embed=512, with_bn=False, ch=()):
444
+ """Initialize YOLO detection layer with nc classes and layer channels ch."""
445
+ super().__init__(nc, ch)
446
+ c3 = max(ch[0], min(self.nc, 100))
447
+ assert c3 <= embed
448
+ assert with_bn is True
449
+ self.cv3 = (
450
+ nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)
451
+ if self.legacy
452
+ else nn.ModuleList(
453
+ nn.Sequential(
454
+ nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
455
+ nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
456
+ nn.Conv2d(c3, embed, 1),
457
+ )
458
+ for x in ch
459
+ )
460
+ )
461
+
462
+ self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)
463
+
464
+ self.reprta = Residual(SwiGLUFFN(embed, embed))
465
+ self.savpe = SAVPE(ch, c3, embed)
466
+ self.embed = embed
467
+
468
+ @smart_inference_mode()
469
+ def fuse(self, txt_feats):
470
+ """Fuse text features with model weights for efficient inference."""
471
+ if self.is_fused:
472
+ return
473
+
474
+ assert not self.training
475
+ txt_feats = txt_feats.to(torch.float32).squeeze(0)
476
+ for cls_head, bn_head in zip(self.cv3, self.cv4):
477
+ assert isinstance(cls_head, nn.Sequential)
478
+ assert isinstance(bn_head, BNContrastiveHead)
479
+ conv = cls_head[-1]
480
+ assert isinstance(conv, nn.Conv2d)
481
+ logit_scale = bn_head.logit_scale
482
+ bias = bn_head.bias
483
+ norm = bn_head.norm
484
+
485
+ t = txt_feats * logit_scale.exp()
486
+ conv: nn.Conv2d = fuse_conv_and_bn(conv, norm)
487
+
488
+ w = conv.weight.data.squeeze(-1).squeeze(-1)
489
+ b = conv.bias.data
490
+
491
+ w = t @ w
492
+ b1 = (t @ b.reshape(-1).unsqueeze(-1)).squeeze(-1)
493
+ b2 = torch.ones_like(b1) * bias
494
+
495
+ conv = (
496
+ nn.Conv2d(
497
+ conv.in_channels,
498
+ w.shape[0],
499
+ kernel_size=1,
500
+ )
501
+ .requires_grad_(False)
502
+ .to(conv.weight.device)
503
+ )
504
+
505
+ conv.weight.data.copy_(w.unsqueeze(-1).unsqueeze(-1))
506
+ conv.bias.data.copy_(b1 + b2)
507
+ cls_head[-1] = conv
508
+
509
+ bn_head.fuse()
510
+
511
+ del self.reprta
512
+ self.reprta = nn.Identity()
513
+ self.is_fused = True
514
+
515
+ def get_tpe(self, tpe):
516
+ """Get text prompt embeddings with normalization."""
517
+ return None if tpe is None else F.normalize(self.reprta(tpe), dim=-1, p=2)
518
+
519
+ def get_vpe(self, x, vpe):
520
+ """Get visual prompt embeddings with spatial awareness."""
521
+ if vpe.shape[1] == 0: # no visual prompt embeddings
522
+ return torch.zeros(x[0].shape[0], 0, self.embed, device=x[0].device)
523
+ if vpe.ndim == 4: # (B, N, H, W)
524
+ vpe = self.savpe(x, vpe)
525
+ assert vpe.ndim == 3 # (B, N, D)
526
+ return vpe
527
+
528
+ def forward_lrpc(self, x, return_mask=False):
529
+ """Process features with fused text embeddings to generate detections for prompt-free model."""
530
+ masks = []
531
+ assert self.is_fused, "Prompt-free inference requires model to be fused!"
532
+ for i in range(self.nl):
533
+ cls_feat = self.cv3[i](x[i])
534
+ loc_feat = self.cv2[i](x[i])
535
+ assert isinstance(self.lrpc[i], LRPCHead)
536
+ x[i], mask = self.lrpc[i](cls_feat, loc_feat, self.conf, self.max_det)
537
+ masks.append(mask)
538
+ shape = x[0][0].shape
539
+ if self.dynamic or self.shape != shape:
540
+ self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors([b[0] for b in x], self.stride, 0.5))
541
+ self.shape = shape
542
+ box = torch.cat([xi[0].view(shape[0], self.reg_max * 4, -1) for xi in x], 2)
543
+ cls = torch.cat([xi[1] for xi in x], 2)
336
544
 
337
545
  if self.export and self.format in {"tflite", "edgetpu"}:
338
546
  # Precompute normalization factor to increase numerical stability
@@ -345,17 +553,105 @@ class WorldDetect(Detect):
345
553
  else:
346
554
  dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
347
555
 
348
- y = torch.cat((dbox, cls.sigmoid()), 1)
556
+ mask = torch.cat(masks)
557
+ y = torch.cat((dbox[:, :, mask], cls.sigmoid()), 1)
558
+
559
+ if return_mask:
560
+ return (y, mask) if self.export else ((y, x), mask)
561
+ else:
562
+ return y if self.export else (y, x)
563
+
564
+ def forward(self, x, cls_pe, return_mask=False):
565
+ """Process features with class prompt embeddings to generate detections."""
566
+ if hasattr(self, "lrpc"): # for prompt-free inference
567
+ return self.forward_lrpc(x, return_mask)
568
+ for i in range(self.nl):
569
+ x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), cls_pe)), 1)
570
+ if self.training:
571
+ return x
572
+ self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts
573
+ y = self._inference(x)
349
574
  return y if self.export else (y, x)
350
575
 
351
576
  def bias_init(self):
352
- """Initialize Detect() biases, WARNING: requires stride availability."""
577
+ """Initialize biases for detection heads."""
353
578
  m = self # self.model[-1] # Detect() module
354
579
  # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
355
580
  # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
356
- for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
581
+ for a, b, c, s in zip(m.cv2, m.cv3, m.cv4, m.stride): # from
357
582
  a[-1].bias.data[:] = 1.0 # box
358
583
  # b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
584
+ b[-1].bias.data[:] = 0.0
585
+ c.bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2)
586
+
587
+
588
+ class SwiGLUFFN(nn.Module):
589
+ """SwiGLU Feed-Forward Network for transformer-based architectures."""
590
+
591
+ def __init__(self, gc, ec, e=4) -> None:
592
+ """Initialize SwiGLU FFN with input dimension, output dimension, and expansion factor."""
593
+ super().__init__()
594
+ self.w12 = nn.Linear(gc, e * ec)
595
+ self.w3 = nn.Linear(e * ec // 2, ec)
596
+
597
+ def forward(self, x):
598
+ """Apply SwiGLU transformation to input features."""
599
+ x12 = self.w12(x)
600
+ x1, x2 = x12.chunk(2, dim=-1)
601
+ hidden = F.silu(x1) * x2
602
+ return self.w3(hidden)
603
+
604
+
605
+ class Residual(nn.Module):
606
+ """Residual connection wrapper for neural network modules."""
607
+
608
+ def __init__(self, m) -> None:
609
+ """Initialize residual module with the wrapped module."""
610
+ super().__init__()
611
+ self.m = m
612
+ nn.init.zeros_(self.m.w3.bias)
613
+ # For models with l scale, please change the initialization to
614
+ # nn.init.constant_(self.m.w3.weight, 1e-6)
615
+ nn.init.zeros_(self.m.w3.weight)
616
+
617
+ def forward(self, x):
618
+ """Apply residual connection to input features."""
619
+ return x + self.m(x)
620
+
621
+
622
+ class YOLOESegment(YOLOEDetect):
623
+ """YOLO segmentation head with text embedding capabilities."""
624
+
625
+ def __init__(self, nc=80, nm=32, npr=256, embed=512, with_bn=False, ch=()):
626
+ """Initialize YOLOESegment with class count, mask parameters, and embedding dimensions."""
627
+ super().__init__(nc, embed, with_bn, ch)
628
+ self.nm = nm
629
+ self.npr = npr
630
+ self.proto = Proto(ch[0], self.npr, self.nm)
631
+
632
+ c5 = max(ch[0] // 4, self.nm)
633
+ self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch)
634
+
635
+ def forward(self, x, text):
636
+ """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
637
+ p = self.proto(x[0]) # mask protos
638
+ bs = p.shape[0] # batch size
639
+
640
+ mc = torch.cat([self.cv5[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
641
+ has_lrpc = hasattr(self, "lrpc")
642
+
643
+ if not has_lrpc:
644
+ x = YOLOEDetect.forward(self, x, text)
645
+ else:
646
+ x, mask = YOLOEDetect.forward(self, x, text, return_mask=True)
647
+
648
+ if self.training:
649
+ return x, mc, p
650
+
651
+ if has_lrpc:
652
+ mc = mc[:, :, mask]
653
+
654
+ return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
359
655
 
360
656
 
361
657
  class RTDETRDecoder(nn.Module):
@@ -449,7 +745,17 @@ class RTDETRDecoder(nn.Module):
449
745
  self._reset_parameters()
450
746
 
451
747
  def forward(self, x, batch=None):
452
- """Runs the forward pass of the module, returning bounding box and classification scores for the input."""
748
+ """
749
+ Runs the forward pass of the module, returning bounding box and classification scores for the input.
750
+
751
+ Args:
752
+ x (List[torch.Tensor]): List of feature maps from the backbone.
753
+ batch (dict, optional): Batch information for training.
754
+
755
+ Returns:
756
+ (tuple | torch.Tensor): During training, returns a tuple of bounding boxes, scores, and other metadata.
757
+ During inference, returns a tensor of shape (bs, 300, 4+nc) containing bounding boxes and class scores.
758
+ """
453
759
  from ultralytics.models.utils.ops import get_cdn_group
454
760
 
455
761
  # Input projection and embedding
@@ -488,7 +794,19 @@ class RTDETRDecoder(nn.Module):
488
794
  return y if self.export else (y, x)
489
795
 
490
796
  def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device="cpu", eps=1e-2):
491
- """Generates anchor bounding boxes for given shapes with specific grid size and validates them."""
797
+ """
798
+ Generates anchor bounding boxes for given shapes with specific grid size and validates them.
799
+
800
+ Args:
801
+ shapes (list): List of feature map shapes.
802
+ grid_size (float, optional): Base size of grid cells. Default is 0.05.
803
+ dtype (torch.dtype, optional): Data type for tensors. Default is torch.float32.
804
+ device (str, optional): Device to create tensors on. Default is "cpu".
805
+ eps (float, optional): Small value for numerical stability. Default is 1e-2.
806
+
807
+ Returns:
808
+ (tuple): Tuple containing anchors and valid mask tensors.
809
+ """
492
810
  anchors = []
493
811
  for i, (h, w) in enumerate(shapes):
494
812
  sy = torch.arange(end=h, dtype=dtype, device=device)
@@ -508,7 +826,15 @@ class RTDETRDecoder(nn.Module):
508
826
  return anchors, valid_mask
509
827
 
510
828
  def _get_encoder_input(self, x):
511
- """Processes and returns encoder inputs by getting projection features from input and concatenating them."""
829
+ """
830
+ Processes and returns encoder inputs by getting projection features from input and concatenating them.
831
+
832
+ Args:
833
+ x (List[torch.Tensor]): List of feature maps from the backbone.
834
+
835
+ Returns:
836
+ (tuple): Tuple containing processed features and their shapes.
837
+ """
512
838
  # Get projection features
513
839
  x = [self.input_proj[i](feat) for i, feat in enumerate(x)]
514
840
  # Get encoder inputs
@@ -526,7 +852,18 @@ class RTDETRDecoder(nn.Module):
526
852
  return feats, shapes
527
853
 
528
854
  def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):
529
- """Generates and prepares the input required for the decoder from the provided features and shapes."""
855
+ """
856
+ Generates and prepares the input required for the decoder from the provided features and shapes.
857
+
858
+ Args:
859
+ feats (torch.Tensor): Processed features from encoder.
860
+ shapes (list): List of feature map shapes.
861
+ dn_embed (torch.Tensor, optional): Denoising embeddings. Default is None.
862
+ dn_bbox (torch.Tensor, optional): Denoising bounding boxes. Default is None.
863
+
864
+ Returns:
865
+ (tuple): Tuple containing embeddings, reference bounding boxes, encoded bounding boxes, and scores.
866
+ """
530
867
  bs = feats.shape[0]
531
868
  # Prepare input for decoder
532
869
  anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)