ultralytics 8.3.98__py3-none-any.whl → 8.3.100__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/test_python.py +56 -0
- ultralytics/__init__.py +3 -2
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/data/augment.py +101 -5
- ultralytics/data/dataset.py +165 -12
- ultralytics/engine/exporter.py +5 -4
- ultralytics/engine/trainer.py +16 -7
- ultralytics/models/__init__.py +2 -2
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/detect/val.py +6 -1
- ultralytics/models/yolo/model.py +183 -3
- ultralytics/models/yolo/segment/val.py +43 -16
- ultralytics/models/yolo/yoloe/__init__.py +21 -0
- ultralytics/models/yolo/yoloe/predict.py +170 -0
- ultralytics/models/yolo/yoloe/train.py +355 -0
- ultralytics/models/yolo/yoloe/train_seg.py +141 -0
- ultralytics/models/yolo/yoloe/val.py +187 -0
- ultralytics/nn/autobackend.py +17 -7
- ultralytics/nn/modules/__init__.py +18 -1
- ultralytics/nn/modules/block.py +17 -1
- ultralytics/nn/modules/head.py +359 -22
- ultralytics/nn/tasks.py +276 -10
- ultralytics/nn/text_model.py +193 -0
- ultralytics/utils/benchmarks.py +1 -0
- ultralytics/utils/callbacks/comet.py +3 -6
- ultralytics/utils/downloads.py +6 -2
- ultralytics/utils/loss.py +67 -6
- ultralytics/utils/plotting.py +1 -1
- ultralytics/utils/tal.py +1 -1
- {ultralytics-8.3.98.dist-info → ultralytics-8.3.100.dist-info}/METADATA +10 -10
- {ultralytics-8.3.98.dist-info → ultralytics-8.3.100.dist-info}/RECORD +38 -28
- {ultralytics-8.3.98.dist-info → ultralytics-8.3.100.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.98.dist-info → ultralytics-8.3.100.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.98.dist-info → ultralytics-8.3.100.dist-info}/licenses/LICENSE +0 -0
- {ultralytics-8.3.98.dist-info → ultralytics-8.3.100.dist-info}/top_level.txt +0 -0
ultralytics/nn/modules/head.py
CHANGED
@@ -6,16 +6,18 @@ import math
|
|
6
6
|
|
7
7
|
import torch
|
8
8
|
import torch.nn as nn
|
9
|
+
import torch.nn.functional as F
|
9
10
|
from torch.nn.init import constant_, xavier_uniform_
|
10
11
|
|
11
12
|
from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
|
13
|
+
from ultralytics.utils.torch_utils import fuse_conv_and_bn, smart_inference_mode
|
12
14
|
|
13
15
|
from .block import DFL, BNContrastiveHead, ContrastiveHead, Proto
|
14
16
|
from .conv import Conv, DWConv
|
15
17
|
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
|
16
18
|
from .utils import bias_init_with_prob, linear_init
|
17
19
|
|
18
|
-
__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect"
|
20
|
+
__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect", "YOLOEDetect", "YOLOESegment"
|
19
21
|
|
20
22
|
|
21
23
|
class Detect(nn.Module):
|
@@ -78,11 +80,12 @@ class Detect(nn.Module):
|
|
78
80
|
Performs forward pass of the v10Detect module.
|
79
81
|
|
80
82
|
Args:
|
81
|
-
x (
|
83
|
+
x (List[torch.Tensor]): Input feature maps from different levels.
|
82
84
|
|
83
85
|
Returns:
|
84
|
-
(dict
|
85
|
-
|
86
|
+
(dict | tuple): If in training mode, returns a dictionary containing the outputs of both one2many and
|
87
|
+
one2one detections. If not in training mode, returns processed detections or a tuple with
|
88
|
+
processed detections and raw outputs.
|
86
89
|
"""
|
87
90
|
x_detach = [xi.detach() for xi in x]
|
88
91
|
one2one = [
|
@@ -98,7 +101,15 @@ class Detect(nn.Module):
|
|
98
101
|
return y if self.export else (y, {"one2many": x, "one2one": one2one})
|
99
102
|
|
100
103
|
def _inference(self, x):
|
101
|
-
"""
|
104
|
+
"""
|
105
|
+
Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
x (List[torch.Tensor]): List of feature maps from different detection layers.
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
(torch.Tensor): Concatenated tensor of decoded bounding boxes and class probabilities.
|
112
|
+
"""
|
102
113
|
# Inference path
|
103
114
|
shape = x[0].shape # BCHW
|
104
115
|
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
|
@@ -320,19 +331,216 @@ class WorldDetect(Detect):
|
|
320
331
|
x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), text)), 1)
|
321
332
|
if self.training:
|
322
333
|
return x
|
334
|
+
self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts
|
335
|
+
y = self._inference(x)
|
336
|
+
return y if self.export else (y, x)
|
323
337
|
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
338
|
+
def bias_init(self):
|
339
|
+
"""Initialize Detect() biases, WARNING: requires stride availability."""
|
340
|
+
m = self # self.model[-1] # Detect() module
|
341
|
+
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
|
342
|
+
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
|
343
|
+
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
|
344
|
+
a[-1].bias.data[:] = 1.0 # box
|
345
|
+
# b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
|
330
346
|
|
331
|
-
|
332
|
-
|
333
|
-
|
347
|
+
|
348
|
+
class SAVPE(nn.Module):
|
349
|
+
"""Spatial-Aware Visual Prompt Embedding module for feature enhancement."""
|
350
|
+
|
351
|
+
def __init__(self, ch, c3, embed):
|
352
|
+
"""Initialize SAVPE module with channels, intermediate channels, and embedding dimension."""
|
353
|
+
super().__init__()
|
354
|
+
self.cv1 = nn.ModuleList(
|
355
|
+
nn.Sequential(
|
356
|
+
Conv(x, c3, 3), Conv(c3, c3, 3), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity()
|
357
|
+
)
|
358
|
+
for i, x in enumerate(ch)
|
359
|
+
)
|
360
|
+
|
361
|
+
self.cv2 = nn.ModuleList(
|
362
|
+
nn.Sequential(Conv(x, c3, 1), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity())
|
363
|
+
for i, x in enumerate(ch)
|
364
|
+
)
|
365
|
+
|
366
|
+
self.c = 16
|
367
|
+
self.cv3 = nn.Conv2d(3 * c3, embed, 1)
|
368
|
+
self.cv4 = nn.Conv2d(3 * c3, self.c, 3, padding=1)
|
369
|
+
self.cv5 = nn.Conv2d(1, self.c, 3, padding=1)
|
370
|
+
self.cv6 = nn.Sequential(Conv(2 * self.c, self.c, 3), nn.Conv2d(self.c, self.c, 3, padding=1))
|
371
|
+
|
372
|
+
def forward(self, x, vp):
|
373
|
+
"""Process input features and visual prompts to generate enhanced embeddings."""
|
374
|
+
y = [self.cv2[i](xi) for i, xi in enumerate(x)]
|
375
|
+
y = self.cv4(torch.cat(y, dim=1))
|
376
|
+
|
377
|
+
x = [self.cv1[i](xi) for i, xi in enumerate(x)]
|
378
|
+
x = self.cv3(torch.cat(x, dim=1))
|
379
|
+
|
380
|
+
B, C, H, W = x.shape
|
381
|
+
|
382
|
+
Q = vp.shape[1]
|
383
|
+
|
384
|
+
x = x.view(B, C, -1)
|
385
|
+
|
386
|
+
y = y.reshape(B, 1, self.c, H, W).expand(-1, Q, -1, -1, -1).reshape(B * Q, self.c, H, W)
|
387
|
+
vp = vp.reshape(B, Q, 1, H, W).reshape(B * Q, 1, H, W)
|
388
|
+
|
389
|
+
y = self.cv6(torch.cat((y, self.cv5(vp)), dim=1))
|
390
|
+
|
391
|
+
y = y.reshape(B, Q, self.c, -1)
|
392
|
+
vp = vp.reshape(B, Q, 1, -1)
|
393
|
+
|
394
|
+
score = y * vp + torch.logical_not(vp) * torch.finfo(y.dtype).min
|
395
|
+
|
396
|
+
score = F.softmax(score, dim=-1, dtype=torch.float).to(score.dtype)
|
397
|
+
|
398
|
+
aggregated = score.transpose(-2, -3) @ x.reshape(B, self.c, C // self.c, -1).transpose(-1, -2)
|
399
|
+
|
400
|
+
return F.normalize(aggregated.transpose(-2, -3).reshape(B, Q, -1), dim=-1, p=2)
|
401
|
+
|
402
|
+
|
403
|
+
class LRPCHead(nn.Module):
|
404
|
+
"""Lightweight Region Proposal and Classification Head for efficient object detection."""
|
405
|
+
|
406
|
+
def __init__(self, vocab, pf, loc, enabled=True):
|
407
|
+
"""Initialize LRPCHead with vocabulary, proposal filter, and localization components."""
|
408
|
+
super().__init__()
|
409
|
+
self.vocab = self.conv2linear(vocab) if enabled else vocab
|
410
|
+
self.pf = pf
|
411
|
+
self.loc = loc
|
412
|
+
self.enabled = enabled
|
413
|
+
|
414
|
+
def conv2linear(self, conv):
|
415
|
+
"""Convert a 1x1 convolutional layer to a linear layer."""
|
416
|
+
assert isinstance(conv, nn.Conv2d) and conv.kernel_size == (1, 1)
|
417
|
+
linear = nn.Linear(conv.in_channels, conv.out_channels)
|
418
|
+
linear.weight.data = conv.weight.view(conv.out_channels, -1).data
|
419
|
+
linear.bias.data = conv.bias.data
|
420
|
+
return linear
|
421
|
+
|
422
|
+
def forward(self, cls_feat, loc_feat, conf, max_det):
|
423
|
+
"""Process classification and localization features to generate detection proposals."""
|
424
|
+
if self.enabled:
|
425
|
+
pf_score = self.pf(cls_feat)[0, 0].flatten(0)
|
426
|
+
mask = pf_score.sigmoid() > conf
|
427
|
+
|
428
|
+
cls_feat = self.vocab(cls_feat.flatten(2).transpose(-1, -2)[:, mask])
|
429
|
+
return (self.loc(loc_feat), cls_feat.transpose(-1, -2)), mask
|
334
430
|
else:
|
335
|
-
|
431
|
+
cls_feat = self.vocab(cls_feat)
|
432
|
+
loc_feat = self.loc(loc_feat)
|
433
|
+
return (loc_feat, cls_feat.flatten(2)), torch.ones(
|
434
|
+
cls_feat.shape[2] * cls_feat.shape[3], device=cls_feat.device, dtype=torch.bool
|
435
|
+
)
|
436
|
+
|
437
|
+
|
438
|
+
class YOLOEDetect(Detect):
|
439
|
+
"""Head for integrating YOLO detection models with semantic understanding from text embeddings."""
|
440
|
+
|
441
|
+
is_fused = False
|
442
|
+
|
443
|
+
def __init__(self, nc=80, embed=512, with_bn=False, ch=()):
|
444
|
+
"""Initialize YOLO detection layer with nc classes and layer channels ch."""
|
445
|
+
super().__init__(nc, ch)
|
446
|
+
c3 = max(ch[0], min(self.nc, 100))
|
447
|
+
assert c3 <= embed
|
448
|
+
assert with_bn is True
|
449
|
+
self.cv3 = (
|
450
|
+
nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)
|
451
|
+
if self.legacy
|
452
|
+
else nn.ModuleList(
|
453
|
+
nn.Sequential(
|
454
|
+
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
|
455
|
+
nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
|
456
|
+
nn.Conv2d(c3, embed, 1),
|
457
|
+
)
|
458
|
+
for x in ch
|
459
|
+
)
|
460
|
+
)
|
461
|
+
|
462
|
+
self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)
|
463
|
+
|
464
|
+
self.reprta = Residual(SwiGLUFFN(embed, embed))
|
465
|
+
self.savpe = SAVPE(ch, c3, embed)
|
466
|
+
self.embed = embed
|
467
|
+
|
468
|
+
@smart_inference_mode()
|
469
|
+
def fuse(self, txt_feats):
|
470
|
+
"""Fuse text features with model weights for efficient inference."""
|
471
|
+
if self.is_fused:
|
472
|
+
return
|
473
|
+
|
474
|
+
assert not self.training
|
475
|
+
txt_feats = txt_feats.to(torch.float32).squeeze(0)
|
476
|
+
for cls_head, bn_head in zip(self.cv3, self.cv4):
|
477
|
+
assert isinstance(cls_head, nn.Sequential)
|
478
|
+
assert isinstance(bn_head, BNContrastiveHead)
|
479
|
+
conv = cls_head[-1]
|
480
|
+
assert isinstance(conv, nn.Conv2d)
|
481
|
+
logit_scale = bn_head.logit_scale
|
482
|
+
bias = bn_head.bias
|
483
|
+
norm = bn_head.norm
|
484
|
+
|
485
|
+
t = txt_feats * logit_scale.exp()
|
486
|
+
conv: nn.Conv2d = fuse_conv_and_bn(conv, norm)
|
487
|
+
|
488
|
+
w = conv.weight.data.squeeze(-1).squeeze(-1)
|
489
|
+
b = conv.bias.data
|
490
|
+
|
491
|
+
w = t @ w
|
492
|
+
b1 = (t @ b.reshape(-1).unsqueeze(-1)).squeeze(-1)
|
493
|
+
b2 = torch.ones_like(b1) * bias
|
494
|
+
|
495
|
+
conv = (
|
496
|
+
nn.Conv2d(
|
497
|
+
conv.in_channels,
|
498
|
+
w.shape[0],
|
499
|
+
kernel_size=1,
|
500
|
+
)
|
501
|
+
.requires_grad_(False)
|
502
|
+
.to(conv.weight.device)
|
503
|
+
)
|
504
|
+
|
505
|
+
conv.weight.data.copy_(w.unsqueeze(-1).unsqueeze(-1))
|
506
|
+
conv.bias.data.copy_(b1 + b2)
|
507
|
+
cls_head[-1] = conv
|
508
|
+
|
509
|
+
bn_head.fuse()
|
510
|
+
|
511
|
+
del self.reprta
|
512
|
+
self.reprta = nn.Identity()
|
513
|
+
self.is_fused = True
|
514
|
+
|
515
|
+
def get_tpe(self, tpe):
|
516
|
+
"""Get text prompt embeddings with normalization."""
|
517
|
+
return None if tpe is None else F.normalize(self.reprta(tpe), dim=-1, p=2)
|
518
|
+
|
519
|
+
def get_vpe(self, x, vpe):
|
520
|
+
"""Get visual prompt embeddings with spatial awareness."""
|
521
|
+
if vpe.shape[1] == 0: # no visual prompt embeddings
|
522
|
+
return torch.zeros(x[0].shape[0], 0, self.embed, device=x[0].device)
|
523
|
+
if vpe.ndim == 4: # (B, N, H, W)
|
524
|
+
vpe = self.savpe(x, vpe)
|
525
|
+
assert vpe.ndim == 3 # (B, N, D)
|
526
|
+
return vpe
|
527
|
+
|
528
|
+
def forward_lrpc(self, x, return_mask=False):
|
529
|
+
"""Process features with fused text embeddings to generate detections for prompt-free model."""
|
530
|
+
masks = []
|
531
|
+
assert self.is_fused, "Prompt-free inference requires model to be fused!"
|
532
|
+
for i in range(self.nl):
|
533
|
+
cls_feat = self.cv3[i](x[i])
|
534
|
+
loc_feat = self.cv2[i](x[i])
|
535
|
+
assert isinstance(self.lrpc[i], LRPCHead)
|
536
|
+
x[i], mask = self.lrpc[i](cls_feat, loc_feat, self.conf, self.max_det)
|
537
|
+
masks.append(mask)
|
538
|
+
shape = x[0][0].shape
|
539
|
+
if self.dynamic or self.shape != shape:
|
540
|
+
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors([b[0] for b in x], self.stride, 0.5))
|
541
|
+
self.shape = shape
|
542
|
+
box = torch.cat([xi[0].view(shape[0], self.reg_max * 4, -1) for xi in x], 2)
|
543
|
+
cls = torch.cat([xi[1] for xi in x], 2)
|
336
544
|
|
337
545
|
if self.export and self.format in {"tflite", "edgetpu"}:
|
338
546
|
# Precompute normalization factor to increase numerical stability
|
@@ -345,17 +553,105 @@ class WorldDetect(Detect):
|
|
345
553
|
else:
|
346
554
|
dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
|
347
555
|
|
348
|
-
|
556
|
+
mask = torch.cat(masks)
|
557
|
+
y = torch.cat((dbox[:, :, mask], cls.sigmoid()), 1)
|
558
|
+
|
559
|
+
if return_mask:
|
560
|
+
return (y, mask) if self.export else ((y, x), mask)
|
561
|
+
else:
|
562
|
+
return y if self.export else (y, x)
|
563
|
+
|
564
|
+
def forward(self, x, cls_pe, return_mask=False):
|
565
|
+
"""Process features with class prompt embeddings to generate detections."""
|
566
|
+
if hasattr(self, "lrpc"): # for prompt-free inference
|
567
|
+
return self.forward_lrpc(x, return_mask)
|
568
|
+
for i in range(self.nl):
|
569
|
+
x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), cls_pe)), 1)
|
570
|
+
if self.training:
|
571
|
+
return x
|
572
|
+
self.no = self.nc + self.reg_max * 4 # self.nc could be changed when inference with different texts
|
573
|
+
y = self._inference(x)
|
349
574
|
return y if self.export else (y, x)
|
350
575
|
|
351
576
|
def bias_init(self):
|
352
|
-
"""Initialize
|
577
|
+
"""Initialize biases for detection heads."""
|
353
578
|
m = self # self.model[-1] # Detect() module
|
354
579
|
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
|
355
580
|
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
|
356
|
-
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
|
581
|
+
for a, b, c, s in zip(m.cv2, m.cv3, m.cv4, m.stride): # from
|
357
582
|
a[-1].bias.data[:] = 1.0 # box
|
358
583
|
# b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
|
584
|
+
b[-1].bias.data[:] = 0.0
|
585
|
+
c.bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2)
|
586
|
+
|
587
|
+
|
588
|
+
class SwiGLUFFN(nn.Module):
|
589
|
+
"""SwiGLU Feed-Forward Network for transformer-based architectures."""
|
590
|
+
|
591
|
+
def __init__(self, gc, ec, e=4) -> None:
|
592
|
+
"""Initialize SwiGLU FFN with input dimension, output dimension, and expansion factor."""
|
593
|
+
super().__init__()
|
594
|
+
self.w12 = nn.Linear(gc, e * ec)
|
595
|
+
self.w3 = nn.Linear(e * ec // 2, ec)
|
596
|
+
|
597
|
+
def forward(self, x):
|
598
|
+
"""Apply SwiGLU transformation to input features."""
|
599
|
+
x12 = self.w12(x)
|
600
|
+
x1, x2 = x12.chunk(2, dim=-1)
|
601
|
+
hidden = F.silu(x1) * x2
|
602
|
+
return self.w3(hidden)
|
603
|
+
|
604
|
+
|
605
|
+
class Residual(nn.Module):
|
606
|
+
"""Residual connection wrapper for neural network modules."""
|
607
|
+
|
608
|
+
def __init__(self, m) -> None:
|
609
|
+
"""Initialize residual module with the wrapped module."""
|
610
|
+
super().__init__()
|
611
|
+
self.m = m
|
612
|
+
nn.init.zeros_(self.m.w3.bias)
|
613
|
+
# For models with l scale, please change the initialization to
|
614
|
+
# nn.init.constant_(self.m.w3.weight, 1e-6)
|
615
|
+
nn.init.zeros_(self.m.w3.weight)
|
616
|
+
|
617
|
+
def forward(self, x):
|
618
|
+
"""Apply residual connection to input features."""
|
619
|
+
return x + self.m(x)
|
620
|
+
|
621
|
+
|
622
|
+
class YOLOESegment(YOLOEDetect):
|
623
|
+
"""YOLO segmentation head with text embedding capabilities."""
|
624
|
+
|
625
|
+
def __init__(self, nc=80, nm=32, npr=256, embed=512, with_bn=False, ch=()):
|
626
|
+
"""Initialize YOLOESegment with class count, mask parameters, and embedding dimensions."""
|
627
|
+
super().__init__(nc, embed, with_bn, ch)
|
628
|
+
self.nm = nm
|
629
|
+
self.npr = npr
|
630
|
+
self.proto = Proto(ch[0], self.npr, self.nm)
|
631
|
+
|
632
|
+
c5 = max(ch[0] // 4, self.nm)
|
633
|
+
self.cv5 = nn.ModuleList(nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nm, 1)) for x in ch)
|
634
|
+
|
635
|
+
def forward(self, x, text):
|
636
|
+
"""Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
|
637
|
+
p = self.proto(x[0]) # mask protos
|
638
|
+
bs = p.shape[0] # batch size
|
639
|
+
|
640
|
+
mc = torch.cat([self.cv5[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
|
641
|
+
has_lrpc = hasattr(self, "lrpc")
|
642
|
+
|
643
|
+
if not has_lrpc:
|
644
|
+
x = YOLOEDetect.forward(self, x, text)
|
645
|
+
else:
|
646
|
+
x, mask = YOLOEDetect.forward(self, x, text, return_mask=True)
|
647
|
+
|
648
|
+
if self.training:
|
649
|
+
return x, mc, p
|
650
|
+
|
651
|
+
if has_lrpc:
|
652
|
+
mc = mc[:, :, mask]
|
653
|
+
|
654
|
+
return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
|
359
655
|
|
360
656
|
|
361
657
|
class RTDETRDecoder(nn.Module):
|
@@ -449,7 +745,17 @@ class RTDETRDecoder(nn.Module):
|
|
449
745
|
self._reset_parameters()
|
450
746
|
|
451
747
|
def forward(self, x, batch=None):
|
452
|
-
"""
|
748
|
+
"""
|
749
|
+
Runs the forward pass of the module, returning bounding box and classification scores for the input.
|
750
|
+
|
751
|
+
Args:
|
752
|
+
x (List[torch.Tensor]): List of feature maps from the backbone.
|
753
|
+
batch (dict, optional): Batch information for training.
|
754
|
+
|
755
|
+
Returns:
|
756
|
+
(tuple | torch.Tensor): During training, returns a tuple of bounding boxes, scores, and other metadata.
|
757
|
+
During inference, returns a tensor of shape (bs, 300, 4+nc) containing bounding boxes and class scores.
|
758
|
+
"""
|
453
759
|
from ultralytics.models.utils.ops import get_cdn_group
|
454
760
|
|
455
761
|
# Input projection and embedding
|
@@ -488,7 +794,19 @@ class RTDETRDecoder(nn.Module):
|
|
488
794
|
return y if self.export else (y, x)
|
489
795
|
|
490
796
|
def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device="cpu", eps=1e-2):
|
491
|
-
"""
|
797
|
+
"""
|
798
|
+
Generates anchor bounding boxes for given shapes with specific grid size and validates them.
|
799
|
+
|
800
|
+
Args:
|
801
|
+
shapes (list): List of feature map shapes.
|
802
|
+
grid_size (float, optional): Base size of grid cells. Default is 0.05.
|
803
|
+
dtype (torch.dtype, optional): Data type for tensors. Default is torch.float32.
|
804
|
+
device (str, optional): Device to create tensors on. Default is "cpu".
|
805
|
+
eps (float, optional): Small value for numerical stability. Default is 1e-2.
|
806
|
+
|
807
|
+
Returns:
|
808
|
+
(tuple): Tuple containing anchors and valid mask tensors.
|
809
|
+
"""
|
492
810
|
anchors = []
|
493
811
|
for i, (h, w) in enumerate(shapes):
|
494
812
|
sy = torch.arange(end=h, dtype=dtype, device=device)
|
@@ -508,7 +826,15 @@ class RTDETRDecoder(nn.Module):
|
|
508
826
|
return anchors, valid_mask
|
509
827
|
|
510
828
|
def _get_encoder_input(self, x):
|
511
|
-
"""
|
829
|
+
"""
|
830
|
+
Processes and returns encoder inputs by getting projection features from input and concatenating them.
|
831
|
+
|
832
|
+
Args:
|
833
|
+
x (List[torch.Tensor]): List of feature maps from the backbone.
|
834
|
+
|
835
|
+
Returns:
|
836
|
+
(tuple): Tuple containing processed features and their shapes.
|
837
|
+
"""
|
512
838
|
# Get projection features
|
513
839
|
x = [self.input_proj[i](feat) for i, feat in enumerate(x)]
|
514
840
|
# Get encoder inputs
|
@@ -526,7 +852,18 @@ class RTDETRDecoder(nn.Module):
|
|
526
852
|
return feats, shapes
|
527
853
|
|
528
854
|
def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):
|
529
|
-
"""
|
855
|
+
"""
|
856
|
+
Generates and prepares the input required for the decoder from the provided features and shapes.
|
857
|
+
|
858
|
+
Args:
|
859
|
+
feats (torch.Tensor): Processed features from encoder.
|
860
|
+
shapes (list): List of feature map shapes.
|
861
|
+
dn_embed (torch.Tensor, optional): Denoising embeddings. Default is None.
|
862
|
+
dn_bbox (torch.Tensor, optional): Denoising bounding boxes. Default is None.
|
863
|
+
|
864
|
+
Returns:
|
865
|
+
(tuple): Tuple containing embeddings, reference bounding boxes, encoded bounding boxes, and scores.
|
866
|
+
"""
|
530
867
|
bs = feats.shape[0]
|
531
868
|
# Prepare input for decoder
|
532
869
|
anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
|