dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +8 -10
- tests/test_cuda.py +9 -10
- tests/test_engine.py +29 -2
- tests/test_exports.py +69 -21
- tests/test_integrations.py +8 -11
- tests/test_python.py +109 -71
- tests/test_solutions.py +170 -159
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +57 -64
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/Objects365.yaml +19 -15
- ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +19 -21
- ultralytics/cfg/datasets/VisDrone.yaml +5 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +24 -2
- ultralytics/cfg/datasets/coco.yaml +2 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +7 -7
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +96 -94
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +286 -476
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +151 -26
- ultralytics/data/converter.py +38 -50
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +41 -45
- ultralytics/engine/exporter.py +462 -462
- ultralytics/engine/model.py +150 -191
- ultralytics/engine/predictor.py +30 -40
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +193 -120
- ultralytics/engine/tuner.py +77 -63
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +19 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +7 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +22 -40
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +206 -79
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2268 -366
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +15 -41
- ultralytics/models/yolo/classify/val.py +34 -32
- ultralytics/models/yolo/detect/predict.py +8 -11
- ultralytics/models/yolo/detect/train.py +13 -32
- ultralytics/models/yolo/detect/val.py +75 -63
- ultralytics/models/yolo/model.py +37 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +42 -39
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +10 -22
- ultralytics/models/yolo/pose/val.py +40 -59
- ultralytics/models/yolo/segment/predict.py +16 -20
- ultralytics/models/yolo/segment/train.py +3 -12
- ultralytics/models/yolo/segment/val.py +106 -56
- ultralytics/models/yolo/world/train.py +12 -16
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +31 -56
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +16 -21
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +152 -80
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +133 -217
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +64 -116
- ultralytics/nn/modules/transformer.py +79 -89
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +111 -156
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +13 -17
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +4 -7
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +70 -70
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +151 -87
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +19 -15
- ultralytics/utils/downloads.py +29 -41
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +16 -16
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +15 -24
- ultralytics/utils/metrics.py +131 -160
- ultralytics/utils/nms.py +21 -30
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +122 -119
- ultralytics/utils/tal.py +28 -44
- ultralytics/utils/torch_utils.py +70 -187
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,547 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
ViTDet backbone adapted from Detectron2.
|
|
7
|
+
This module implements Vision Transformer (ViT) backbone for object detection.
|
|
8
|
+
|
|
9
|
+
Rope embedding code adopted from:
|
|
10
|
+
1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
|
|
11
|
+
2. https://github.com/naver-ai/rope-vit
|
|
12
|
+
3. https://github.com/lucidrains/rotary-embedding-torch
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import math
|
|
18
|
+
from functools import partial
|
|
19
|
+
from typing import Callable
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
import torch.nn as nn
|
|
23
|
+
import torch.nn.functional as F
|
|
24
|
+
import torch.utils.checkpoint as checkpoint
|
|
25
|
+
from torch import Tensor
|
|
26
|
+
|
|
27
|
+
from ultralytics.models.sam.modules.blocks import PatchEmbed
|
|
28
|
+
from ultralytics.models.sam.modules.utils import (
|
|
29
|
+
apply_rotary_enc,
|
|
30
|
+
compute_axial_cis,
|
|
31
|
+
concat_rel_pos,
|
|
32
|
+
get_abs_pos,
|
|
33
|
+
window_partition,
|
|
34
|
+
window_unpartition,
|
|
35
|
+
)
|
|
36
|
+
from ultralytics.utils.checks import check_requirements
|
|
37
|
+
|
|
38
|
+
from .model_misc import LayerScale
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class Attention(nn.Module):
|
|
42
|
+
"""Multi-head Attention block with relative position embeddings and 2d-rope."""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
dim: int,
|
|
47
|
+
num_heads: int = 8,
|
|
48
|
+
qkv_bias: bool = True,
|
|
49
|
+
use_rel_pos: bool = False,
|
|
50
|
+
rel_pos_zero_init: bool = True,
|
|
51
|
+
input_size: tuple[int, int] | None = None,
|
|
52
|
+
cls_token: bool = False,
|
|
53
|
+
use_rope: bool = False,
|
|
54
|
+
rope_theta: float = 10000.0,
|
|
55
|
+
rope_pt_size: tuple[int, int] | None = None,
|
|
56
|
+
rope_interp: bool = False,
|
|
57
|
+
):
|
|
58
|
+
"""
|
|
59
|
+
Args:
|
|
60
|
+
dim (int): Number of input channels.
|
|
61
|
+
num_heads (int): Number of attention heads.
|
|
62
|
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
|
63
|
+
rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
|
64
|
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
|
65
|
+
input_size (int or None): Input resolution for calculating the relative positional parameter size or rope
|
|
66
|
+
size.
|
|
67
|
+
attn_type: Type of attention operation, e.g. "vanilla", "vanilla-xformer".
|
|
68
|
+
cls_token: whether a cls_token is present.
|
|
69
|
+
use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together)
|
|
70
|
+
use_rel_pos: whether to use relative positional embeddings
|
|
71
|
+
rope_theta: control frequencies of rope
|
|
72
|
+
rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling
|
|
73
|
+
rope_interp: whether to interpolate (or extrapolate) rope to match input size.
|
|
74
|
+
"""
|
|
75
|
+
super().__init__()
|
|
76
|
+
self.num_heads = num_heads
|
|
77
|
+
self.head_dim = dim // num_heads
|
|
78
|
+
self.scale = self.head_dim**-0.5
|
|
79
|
+
self.cls_token = cls_token
|
|
80
|
+
|
|
81
|
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
82
|
+
self.proj = nn.Linear(dim, dim)
|
|
83
|
+
|
|
84
|
+
# rel_pos embeddings and rope
|
|
85
|
+
self.use_rel_pos = use_rel_pos
|
|
86
|
+
self.input_size = input_size
|
|
87
|
+
|
|
88
|
+
self.use_rope = use_rope
|
|
89
|
+
self.rope_theta = rope_theta
|
|
90
|
+
self.rope_pt_size = rope_pt_size
|
|
91
|
+
self.rope_interp = rope_interp
|
|
92
|
+
|
|
93
|
+
# init rel_pos embeddings and rope
|
|
94
|
+
self._setup_rel_pos(rel_pos_zero_init, input_size)
|
|
95
|
+
self._setup_rope_freqs(input_size)
|
|
96
|
+
|
|
97
|
+
def _setup_rel_pos(self, rel_pos_zero_init: bool = True, input_size: tuple[int, int] | None = None) -> None:
|
|
98
|
+
"""Setup relative positional embeddings."""
|
|
99
|
+
if not self.use_rel_pos:
|
|
100
|
+
self.rel_pos_h = None
|
|
101
|
+
self.rel_pos_w = None
|
|
102
|
+
return
|
|
103
|
+
|
|
104
|
+
assert input_size is not None
|
|
105
|
+
assert self.cls_token is False, "not supported"
|
|
106
|
+
# initialize relative positional embeddings
|
|
107
|
+
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim))
|
|
108
|
+
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim))
|
|
109
|
+
|
|
110
|
+
if not rel_pos_zero_init:
|
|
111
|
+
nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
|
|
112
|
+
nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
|
|
113
|
+
|
|
114
|
+
# Precompute the relative coords
|
|
115
|
+
H, W = input_size
|
|
116
|
+
q_coords = torch.arange(H)[:, None]
|
|
117
|
+
k_coords = torch.arange(W)[None, :]
|
|
118
|
+
relative_coords = (q_coords - k_coords) + (H - 1)
|
|
119
|
+
self.relative_coords = relative_coords.long()
|
|
120
|
+
|
|
121
|
+
def _setup_rope_freqs(self, input_size: tuple[int, int] | None = None) -> None:
|
|
122
|
+
"""Setup 2d-rope frequencies."""
|
|
123
|
+
if not self.use_rope:
|
|
124
|
+
self.freqs_cis = None
|
|
125
|
+
return
|
|
126
|
+
|
|
127
|
+
assert input_size is not None
|
|
128
|
+
# determine rope input size
|
|
129
|
+
if self.rope_pt_size is None:
|
|
130
|
+
self.rope_pt_size = input_size
|
|
131
|
+
|
|
132
|
+
# initialize 2d rope freqs
|
|
133
|
+
self.compute_cis = partial(
|
|
134
|
+
compute_axial_cis,
|
|
135
|
+
dim=self.head_dim,
|
|
136
|
+
theta=self.rope_theta,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# interpolate rope
|
|
140
|
+
scale_pos = 1.0
|
|
141
|
+
if self.rope_interp:
|
|
142
|
+
scale_pos = self.rope_pt_size[0] / input_size[0]
|
|
143
|
+
# get scaled freqs_cis
|
|
144
|
+
freqs_cis = self.compute_cis(
|
|
145
|
+
end_x=input_size[0],
|
|
146
|
+
end_y=input_size[1],
|
|
147
|
+
scale_pos=scale_pos,
|
|
148
|
+
)
|
|
149
|
+
if self.cls_token:
|
|
150
|
+
t = torch.zeros(
|
|
151
|
+
self.head_dim // 2,
|
|
152
|
+
dtype=torch.float32,
|
|
153
|
+
device=freqs_cis.device,
|
|
154
|
+
)
|
|
155
|
+
cls_freqs_cis = torch.polar(torch.ones_like(t), t)[None, :]
|
|
156
|
+
freqs_cis = torch.cat([cls_freqs_cis, freqs_cis], dim=0)
|
|
157
|
+
|
|
158
|
+
self.freqs_cis = freqs_cis
|
|
159
|
+
|
|
160
|
+
def _apply_rope(self, q, k) -> tuple[Tensor, Tensor]:
|
|
161
|
+
"""Apply 2d-rope to q and k."""
|
|
162
|
+
if not self.use_rope:
|
|
163
|
+
return q, k
|
|
164
|
+
|
|
165
|
+
assert self.freqs_cis is not None
|
|
166
|
+
return apply_rotary_enc(q, k, freqs_cis=self.freqs_cis.to(q.device))
|
|
167
|
+
|
|
168
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
169
|
+
"""Forward pass of attention block."""
|
|
170
|
+
s = 1 if self.cls_token else 0 # used to exclude cls_token
|
|
171
|
+
if x.ndim == 4:
|
|
172
|
+
B, H, W, _ = x.shape
|
|
173
|
+
assert s == 0 # no cls_token
|
|
174
|
+
L = H * W
|
|
175
|
+
ndim = 4
|
|
176
|
+
else:
|
|
177
|
+
assert x.ndim == 3
|
|
178
|
+
B, L, _ = x.shape
|
|
179
|
+
ndim = 3
|
|
180
|
+
H = W = math.sqrt(L - s)
|
|
181
|
+
|
|
182
|
+
# qkv with shape (3, B, nHead, L, C)
|
|
183
|
+
qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, -1)
|
|
184
|
+
# q, k, v with shape (B, nHead, L, C)
|
|
185
|
+
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
|
|
186
|
+
|
|
187
|
+
# handle rope and rel pos embeddings
|
|
188
|
+
q, k = self._apply_rope(q, k)
|
|
189
|
+
if self.use_rel_pos:
|
|
190
|
+
q, k = concat_rel_pos(
|
|
191
|
+
q.flatten(0, 1),
|
|
192
|
+
k.flatten(0, 1),
|
|
193
|
+
(H, W),
|
|
194
|
+
x.shape[1:3],
|
|
195
|
+
self.rel_pos_h,
|
|
196
|
+
self.rel_pos_w,
|
|
197
|
+
rescale=True,
|
|
198
|
+
relative_coords=self.relative_coords,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# sdpa expects [B, nheads, H*W, C] so we transpose back
|
|
202
|
+
q = q.reshape(B, self.num_heads, H * W, -1)
|
|
203
|
+
k = k.reshape(B, self.num_heads, H * W, -1)
|
|
204
|
+
|
|
205
|
+
x = F.scaled_dot_product_attention(q, k, v)
|
|
206
|
+
|
|
207
|
+
if ndim == 4:
|
|
208
|
+
x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
|
209
|
+
else:
|
|
210
|
+
x = x.view(B, self.num_heads, L, -1).permute(0, 2, 1, 3).reshape(B, L, -1)
|
|
211
|
+
|
|
212
|
+
x = self.proj(x)
|
|
213
|
+
|
|
214
|
+
return x
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class Block(nn.Module):
|
|
218
|
+
"""Transformer blocks with support of window attention."""
|
|
219
|
+
|
|
220
|
+
def __init__(
|
|
221
|
+
self,
|
|
222
|
+
dim: int,
|
|
223
|
+
num_heads: int,
|
|
224
|
+
mlp_ratio: float = 4.0,
|
|
225
|
+
qkv_bias: bool = True,
|
|
226
|
+
drop_path: float = 0.0,
|
|
227
|
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
|
228
|
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
|
229
|
+
use_rel_pos: bool = False,
|
|
230
|
+
rel_pos_zero_init: bool = True,
|
|
231
|
+
window_size: int = 0,
|
|
232
|
+
input_size: tuple[int, int] | None = None,
|
|
233
|
+
use_rope: bool = False,
|
|
234
|
+
rope_pt_size: tuple[int, int] | None = None,
|
|
235
|
+
rope_interp: bool = False,
|
|
236
|
+
cls_token: bool = False,
|
|
237
|
+
dropout: float = 0.0,
|
|
238
|
+
init_values: float | None = None,
|
|
239
|
+
):
|
|
240
|
+
"""
|
|
241
|
+
Args:
|
|
242
|
+
dim (int): Number of input channels.
|
|
243
|
+
num_heads (int): Number of attention heads in each ViT block.
|
|
244
|
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
|
245
|
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
|
246
|
+
drop_path (float): Stochastic depth rate.
|
|
247
|
+
norm_layer (nn.Module): Normalization layer.
|
|
248
|
+
act_layer (nn.Module): Activation layer.
|
|
249
|
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
|
250
|
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
|
251
|
+
window_size (int): Window size for window attention blocks. If it equals 0, then not use window attention.
|
|
252
|
+
input_size (int or None): Input resolution for calculating the relative positional parameter size.
|
|
253
|
+
dropout (float): Dropout rate.
|
|
254
|
+
cls_token: whether a cls_token is present.
|
|
255
|
+
use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together)
|
|
256
|
+
rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling
|
|
257
|
+
rope_interp: whether to interpolate (or extrapolate) rope to match target input size, expected to specify
|
|
258
|
+
source size as rope_pt_size.
|
|
259
|
+
init_values: layer scale init, None for no layer scale.
|
|
260
|
+
"""
|
|
261
|
+
super().__init__()
|
|
262
|
+
|
|
263
|
+
check_requirements("timm")
|
|
264
|
+
from timm.layers import DropPath, Mlp
|
|
265
|
+
|
|
266
|
+
self.norm1 = norm_layer(dim)
|
|
267
|
+
self.attn = Attention(
|
|
268
|
+
dim,
|
|
269
|
+
num_heads=num_heads,
|
|
270
|
+
qkv_bias=qkv_bias,
|
|
271
|
+
use_rel_pos=use_rel_pos,
|
|
272
|
+
rel_pos_zero_init=rel_pos_zero_init,
|
|
273
|
+
input_size=input_size if window_size == 0 else (window_size, window_size),
|
|
274
|
+
use_rope=use_rope,
|
|
275
|
+
rope_pt_size=rope_pt_size,
|
|
276
|
+
rope_interp=rope_interp,
|
|
277
|
+
cls_token=cls_token,
|
|
278
|
+
)
|
|
279
|
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
280
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
281
|
+
|
|
282
|
+
self.norm2 = norm_layer(dim)
|
|
283
|
+
self.mlp = Mlp(
|
|
284
|
+
in_features=dim,
|
|
285
|
+
hidden_features=int(dim * mlp_ratio),
|
|
286
|
+
act_layer=act_layer,
|
|
287
|
+
drop=(dropout, 0.0),
|
|
288
|
+
)
|
|
289
|
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
290
|
+
self.dropout = nn.Dropout(dropout)
|
|
291
|
+
self.window_size = window_size
|
|
292
|
+
|
|
293
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
294
|
+
"""Forward pass of the transformer block."""
|
|
295
|
+
shortcut = x
|
|
296
|
+
x = self.norm1(x)
|
|
297
|
+
# Window partition
|
|
298
|
+
if self.window_size > 0:
|
|
299
|
+
H, W = x.shape[1], x.shape[2]
|
|
300
|
+
x, pad_hw = window_partition(x, self.window_size)
|
|
301
|
+
|
|
302
|
+
x = self.ls1(self.attn(x))
|
|
303
|
+
# Reverse window partition
|
|
304
|
+
if self.window_size > 0:
|
|
305
|
+
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
|
306
|
+
|
|
307
|
+
x = shortcut + self.dropout(self.drop_path(x))
|
|
308
|
+
x = x + self.dropout(self.drop_path(self.ls2(self.mlp(self.norm2(x)))))
|
|
309
|
+
|
|
310
|
+
return x
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
class ViT(nn.Module):
|
|
314
|
+
"""This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`. "Exploring Plain Vision Transformer
|
|
315
|
+
Backbones for Object Detection", https://arxiv.org/abs/2203.16527.
|
|
316
|
+
"""
|
|
317
|
+
|
|
318
|
+
def __init__(
|
|
319
|
+
self,
|
|
320
|
+
img_size: int = 1024,
|
|
321
|
+
patch_size: int = 16,
|
|
322
|
+
in_chans: int = 3,
|
|
323
|
+
embed_dim: int = 768,
|
|
324
|
+
depth: int = 12,
|
|
325
|
+
num_heads: int = 12,
|
|
326
|
+
mlp_ratio: float = 4.0,
|
|
327
|
+
qkv_bias: bool = True,
|
|
328
|
+
drop_path_rate: float = 0.0,
|
|
329
|
+
norm_layer: Callable[..., nn.Module] | str = "LayerNorm",
|
|
330
|
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
|
331
|
+
use_abs_pos: bool = True,
|
|
332
|
+
tile_abs_pos: bool = True,
|
|
333
|
+
rel_pos_blocks: tuple[int, ...] | bool = (2, 5, 8, 11),
|
|
334
|
+
rel_pos_zero_init: bool = True,
|
|
335
|
+
window_size: int = 14,
|
|
336
|
+
global_att_blocks: tuple[int, ...] = (2, 5, 8, 11),
|
|
337
|
+
use_rope: bool = False,
|
|
338
|
+
rope_pt_size: int | None = None,
|
|
339
|
+
use_interp_rope: bool = False,
|
|
340
|
+
pretrain_img_size: int = 224,
|
|
341
|
+
pretrain_use_cls_token: bool = True,
|
|
342
|
+
retain_cls_token: bool = True,
|
|
343
|
+
dropout: float = 0.0,
|
|
344
|
+
return_interm_layers: bool = False,
|
|
345
|
+
init_values: float | None = None, # for layerscale
|
|
346
|
+
ln_pre: bool = False,
|
|
347
|
+
ln_post: bool = False,
|
|
348
|
+
bias_patch_embed: bool = True,
|
|
349
|
+
compile_mode: str | None = None,
|
|
350
|
+
use_act_checkpoint: bool = True,
|
|
351
|
+
):
|
|
352
|
+
"""
|
|
353
|
+
Args:
|
|
354
|
+
img_size (int): Input image size. Only relevant for rel pos or rope.
|
|
355
|
+
patch_size (int): Patch size.
|
|
356
|
+
in_chans (int): Number of input image channels.
|
|
357
|
+
embed_dim (int): Patch embedding dimension.
|
|
358
|
+
depth (int): Depth of ViT.
|
|
359
|
+
num_heads (int): Number of attention heads in each ViT block.
|
|
360
|
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
|
361
|
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
|
362
|
+
drop_path_rate (float): Stochastic depth rate.
|
|
363
|
+
norm_layer (nn.Module): Normalization layer.
|
|
364
|
+
act_layer (nn.Module): Activation layer.
|
|
365
|
+
use_abs_pos (bool): If True, use absolute positional embeddings.
|
|
366
|
+
tile_abs_pos (bool): If True, tile absolute positional embeddings instead of interpolation.
|
|
367
|
+
rel_pos_blocks (list): Blocks which have rel pos embeddings.
|
|
368
|
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
|
369
|
+
window_size (int): Window size for window attention blocks.
|
|
370
|
+
global_att_blocks (list): Indexes for blocks using global attention (other blocks use window attention).
|
|
371
|
+
use_rope (bool): whether to use rope 2d (indep of rel_pos_blocks, as it can be used together).
|
|
372
|
+
rope_pt_size (int): size of rope in previous stage of training, needed for interpolation or tiling.
|
|
373
|
+
use_interp_rope: whether to interpolate (or extrapolate) rope to match target input size, expected to
|
|
374
|
+
specify source size as rope_pt_size.
|
|
375
|
+
use_act_checkpoint (bool): If True, use activation checkpointing.
|
|
376
|
+
pretrain_img_size (int): input image size for pretraining models.
|
|
377
|
+
pretrain_use_cls_token (bool): If True, pretraining models use class token.
|
|
378
|
+
retain_cls_token: whether cls_token should be retained.
|
|
379
|
+
dropout (float): Dropout rate. Applied in residual blocks of attn, mlp and inside the mlp.
|
|
380
|
+
return_interm_layers (bool): Whether to return intermediate layers (all global attention blocks).
|
|
381
|
+
init_values: layer scale init, None for no layer scale.
|
|
382
|
+
ln_pre (bool): If True, apply layer norm before transformer blocks.
|
|
383
|
+
ln_post (bool): If True, apply layer norm after transformer blocks.
|
|
384
|
+
bias_patch_embed (bool): bias in conv for patch embed?
|
|
385
|
+
compile_mode (str): mode to compile the forward.
|
|
386
|
+
"""
|
|
387
|
+
super().__init__()
|
|
388
|
+
self.pretrain_use_cls_token = pretrain_use_cls_token
|
|
389
|
+
|
|
390
|
+
window_block_indexes = [i for i in range(depth) if i not in global_att_blocks]
|
|
391
|
+
self.full_attn_ids = list(global_att_blocks)
|
|
392
|
+
self.rel_pos_blocks = [False] * depth
|
|
393
|
+
if isinstance(rel_pos_blocks, bool) and rel_pos_blocks:
|
|
394
|
+
self.rel_pos_blocks = [True] * depth
|
|
395
|
+
else:
|
|
396
|
+
for i in rel_pos_blocks:
|
|
397
|
+
self.rel_pos_blocks[i] = True
|
|
398
|
+
|
|
399
|
+
self.retain_cls_token = retain_cls_token
|
|
400
|
+
if self.retain_cls_token:
|
|
401
|
+
assert pretrain_use_cls_token
|
|
402
|
+
assert len(window_block_indexes) == 0, "windowing not supported with cls token"
|
|
403
|
+
|
|
404
|
+
assert sum(self.rel_pos_blocks) == 0, "rel pos not supported with cls token"
|
|
405
|
+
|
|
406
|
+
scale = embed_dim**-0.5
|
|
407
|
+
self.class_embedding = nn.Parameter(scale * torch.randn(1, 1, embed_dim))
|
|
408
|
+
|
|
409
|
+
if isinstance(norm_layer, str):
|
|
410
|
+
norm_layer = partial(getattr(nn, norm_layer), eps=1e-5)
|
|
411
|
+
|
|
412
|
+
self.patch_embed = PatchEmbed(
|
|
413
|
+
kernel_size=(patch_size, patch_size),
|
|
414
|
+
stride=(patch_size, patch_size),
|
|
415
|
+
in_chans=in_chans,
|
|
416
|
+
embed_dim=embed_dim,
|
|
417
|
+
bias=bias_patch_embed,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
# Handle absolute positional embedding
|
|
421
|
+
self.tile_abs_pos = tile_abs_pos
|
|
422
|
+
self.use_abs_pos = use_abs_pos
|
|
423
|
+
if self.tile_abs_pos:
|
|
424
|
+
assert self.use_abs_pos
|
|
425
|
+
|
|
426
|
+
if self.use_abs_pos:
|
|
427
|
+
# Initialize absolute positional embedding with pretrain image size.
|
|
428
|
+
num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
|
|
429
|
+
num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
|
|
430
|
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
|
|
431
|
+
else:
|
|
432
|
+
self.pos_embed = None
|
|
433
|
+
|
|
434
|
+
# stochastic depth decay rule
|
|
435
|
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
|
436
|
+
|
|
437
|
+
self.patch_size = patch_size
|
|
438
|
+
self.window_size = window_size
|
|
439
|
+
self.blocks = nn.ModuleList()
|
|
440
|
+
cur_stage = 1
|
|
441
|
+
for i in range(depth):
|
|
442
|
+
block = Block(
|
|
443
|
+
dim=embed_dim,
|
|
444
|
+
num_heads=num_heads,
|
|
445
|
+
mlp_ratio=mlp_ratio,
|
|
446
|
+
qkv_bias=qkv_bias,
|
|
447
|
+
drop_path=dpr[i],
|
|
448
|
+
norm_layer=norm_layer,
|
|
449
|
+
act_layer=act_layer,
|
|
450
|
+
use_rel_pos=self.rel_pos_blocks[i],
|
|
451
|
+
rel_pos_zero_init=rel_pos_zero_init,
|
|
452
|
+
window_size=window_size if i in window_block_indexes else 0,
|
|
453
|
+
input_size=(img_size // patch_size, img_size // patch_size),
|
|
454
|
+
use_rope=use_rope,
|
|
455
|
+
rope_pt_size=((window_size, window_size) if rope_pt_size is None else (rope_pt_size, rope_pt_size)),
|
|
456
|
+
rope_interp=use_interp_rope,
|
|
457
|
+
cls_token=self.retain_cls_token,
|
|
458
|
+
dropout=dropout,
|
|
459
|
+
init_values=init_values,
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
if i not in window_block_indexes:
|
|
463
|
+
cur_stage += 1
|
|
464
|
+
|
|
465
|
+
self.use_act_checkpoint = use_act_checkpoint
|
|
466
|
+
|
|
467
|
+
self.blocks.append(block)
|
|
468
|
+
|
|
469
|
+
self.return_interm_layers = return_interm_layers
|
|
470
|
+
self.channel_list = [embed_dim] * len(self.full_attn_ids) if return_interm_layers else [embed_dim]
|
|
471
|
+
|
|
472
|
+
if self.pos_embed is not None:
|
|
473
|
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
|
474
|
+
|
|
475
|
+
self.ln_pre = norm_layer(embed_dim) if ln_pre else nn.Identity()
|
|
476
|
+
self.ln_post = norm_layer(embed_dim) if ln_post else nn.Identity()
|
|
477
|
+
|
|
478
|
+
self.apply(self._init_weights)
|
|
479
|
+
|
|
480
|
+
if compile_mode is not None:
|
|
481
|
+
self.forward = torch.compile(self.forward, mode=compile_mode, fullgraph=True)
|
|
482
|
+
if self.use_act_checkpoint and self.training:
|
|
483
|
+
torch._dynamo.config.optimize_ddp = False
|
|
484
|
+
|
|
485
|
+
@staticmethod
|
|
486
|
+
def _init_weights(m: nn.Module) -> None:
|
|
487
|
+
"""Initialize the weights."""
|
|
488
|
+
if isinstance(m, nn.Linear):
|
|
489
|
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
490
|
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
491
|
+
nn.init.constant_(m.bias, 0)
|
|
492
|
+
elif isinstance(m, nn.LayerNorm):
|
|
493
|
+
nn.init.constant_(m.bias, 0)
|
|
494
|
+
nn.init.constant_(m.weight, 1.0)
|
|
495
|
+
|
|
496
|
+
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
|
497
|
+
"""Vit forward path and get feature maps."""
|
|
498
|
+
x = self.patch_embed(x)
|
|
499
|
+
h, w = x.shape[1], x.shape[2]
|
|
500
|
+
|
|
501
|
+
s = 0
|
|
502
|
+
if self.retain_cls_token:
|
|
503
|
+
# If cls_token is retained, we don't
|
|
504
|
+
# maintain spatial shape
|
|
505
|
+
x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1)
|
|
506
|
+
s = 1
|
|
507
|
+
|
|
508
|
+
if self.pos_embed is not None:
|
|
509
|
+
x = x + get_abs_pos(
|
|
510
|
+
self.pos_embed,
|
|
511
|
+
self.pretrain_use_cls_token,
|
|
512
|
+
(h, w),
|
|
513
|
+
self.retain_cls_token,
|
|
514
|
+
tiling=self.tile_abs_pos,
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
x = self.ln_pre(x)
|
|
518
|
+
|
|
519
|
+
outputs = []
|
|
520
|
+
for i, blk in enumerate(self.blocks):
|
|
521
|
+
if self.use_act_checkpoint and self.training:
|
|
522
|
+
x = checkpoint.checkpoint(blk, x, use_reentrant=False)
|
|
523
|
+
else:
|
|
524
|
+
x = blk(x)
|
|
525
|
+
if (i == self.full_attn_ids[-1]) or (self.return_interm_layers and i in self.full_attn_ids):
|
|
526
|
+
if i == self.full_attn_ids[-1]:
|
|
527
|
+
x = self.ln_post(x)
|
|
528
|
+
|
|
529
|
+
feats = x[:, s:]
|
|
530
|
+
if feats.ndim == 4:
|
|
531
|
+
feats = feats.permute(0, 3, 1, 2)
|
|
532
|
+
else:
|
|
533
|
+
assert feats.ndim == 3
|
|
534
|
+
h = w = math.sqrt(feats.shape[1])
|
|
535
|
+
feats = feats.reshape(feats.shape[0], h, w, feats.shape[-1]).permute(0, 3, 1, 2)
|
|
536
|
+
|
|
537
|
+
outputs.append(feats)
|
|
538
|
+
|
|
539
|
+
return outputs
|
|
540
|
+
|
|
541
|
+
def set_imgsz(self, imgsz: list[int] = [1008, 1008]):
|
|
542
|
+
"""Setup rel pos embeddings and rope freqs for a new input image size."""
|
|
543
|
+
for block in self.blocks:
|
|
544
|
+
if block.window_size != 0:
|
|
545
|
+
continue
|
|
546
|
+
block.attn._setup_rel_pos(input_size=(imgsz[0] // self.patch_size, imgsz[1] // self.patch_size))
|
|
547
|
+
block.attn._setup_rope_freqs(input_size=(imgsz[0] // self.patch_size, imgsz[1] // self.patch_size))
|