dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +138 -0
- tests/test_cuda.py +215 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +236 -0
- tests/test_integrations.py +154 -0
- tests/test_python.py +694 -0
- tests/test_solutions.py +187 -0
- ultralytics/__init__.py +30 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1023 -0
- ultralytics/cfg/datasets/Argoverse.yaml +77 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +443 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +106 -0
- ultralytics/cfg/datasets/VisDrone.yaml +77 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +42 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +127 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +22 -0
- ultralytics/cfg/trackers/bytetrack.yaml +14 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2945 -0
- ultralytics/data/base.py +438 -0
- ultralytics/data/build.py +258 -0
- ultralytics/data/converter.py +754 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +676 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +125 -0
- ultralytics/data/split_dota.py +325 -0
- ultralytics/data/utils.py +777 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1519 -0
- ultralytics/engine/model.py +1156 -0
- ultralytics/engine/predictor.py +502 -0
- ultralytics/engine/results.py +1840 -0
- ultralytics/engine/trainer.py +853 -0
- ultralytics/engine/tuner.py +243 -0
- ultralytics/engine/validator.py +377 -0
- ultralytics/hub/__init__.py +168 -0
- ultralytics/hub/auth.py +137 -0
- ultralytics/hub/google/__init__.py +176 -0
- ultralytics/hub/session.py +446 -0
- ultralytics/hub/utils.py +248 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +61 -0
- ultralytics/models/fastsam/predict.py +181 -0
- ultralytics/models/fastsam/utils.py +24 -0
- ultralytics/models/fastsam/val.py +40 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +102 -0
- ultralytics/models/nas/predict.py +58 -0
- ultralytics/models/nas/val.py +39 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +84 -0
- ultralytics/models/rtdetr/train.py +85 -0
- ultralytics/models/rtdetr/val.py +191 -0
- ultralytics/models/sam/__init__.py +6 -0
- ultralytics/models/sam/amg.py +260 -0
- ultralytics/models/sam/build.py +358 -0
- ultralytics/models/sam/model.py +170 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +515 -0
- ultralytics/models/sam/modules/encoders.py +854 -0
- ultralytics/models/sam/modules/memory_attention.py +299 -0
- ultralytics/models/sam/modules/sam.py +1006 -0
- ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
- ultralytics/models/sam/modules/transformer.py +351 -0
- ultralytics/models/sam/modules/utils.py +394 -0
- ultralytics/models/sam/predict.py +1605 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +455 -0
- ultralytics/models/utils/ops.py +268 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +88 -0
- ultralytics/models/yolo/classify/train.py +233 -0
- ultralytics/models/yolo/classify/val.py +215 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +124 -0
- ultralytics/models/yolo/detect/train.py +217 -0
- ultralytics/models/yolo/detect/val.py +451 -0
- ultralytics/models/yolo/model.py +354 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +66 -0
- ultralytics/models/yolo/obb/train.py +81 -0
- ultralytics/models/yolo/obb/val.py +283 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +79 -0
- ultralytics/models/yolo/pose/train.py +154 -0
- ultralytics/models/yolo/pose/val.py +394 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +113 -0
- ultralytics/models/yolo/segment/train.py +123 -0
- ultralytics/models/yolo/segment/val.py +428 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +119 -0
- ultralytics/models/yolo/world/train_world.py +176 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +169 -0
- ultralytics/models/yolo/yoloe/train.py +298 -0
- ultralytics/models/yolo/yoloe/train_seg.py +124 -0
- ultralytics/models/yolo/yoloe/val.py +191 -0
- ultralytics/nn/__init__.py +29 -0
- ultralytics/nn/autobackend.py +842 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +53 -0
- ultralytics/nn/modules/block.py +1966 -0
- ultralytics/nn/modules/conv.py +712 -0
- ultralytics/nn/modules/head.py +880 -0
- ultralytics/nn/modules/transformer.py +713 -0
- ultralytics/nn/modules/utils.py +164 -0
- ultralytics/nn/tasks.py +1627 -0
- ultralytics/nn/text_model.py +351 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +116 -0
- ultralytics/solutions/analytics.py +252 -0
- ultralytics/solutions/config.py +106 -0
- ultralytics/solutions/distance_calculation.py +124 -0
- ultralytics/solutions/heatmap.py +127 -0
- ultralytics/solutions/instance_segmentation.py +84 -0
- ultralytics/solutions/object_blurrer.py +90 -0
- ultralytics/solutions/object_counter.py +195 -0
- ultralytics/solutions/object_cropper.py +84 -0
- ultralytics/solutions/parking_management.py +273 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +120 -0
- ultralytics/solutions/security_alarm.py +154 -0
- ultralytics/solutions/similarity_search.py +172 -0
- ultralytics/solutions/solutions.py +724 -0
- ultralytics/solutions/speed_estimation.py +110 -0
- ultralytics/solutions/streamlit_inference.py +196 -0
- ultralytics/solutions/templates/similarity-search.html +160 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +68 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +124 -0
- ultralytics/trackers/bot_sort.py +260 -0
- ultralytics/trackers/byte_tracker.py +480 -0
- ultralytics/trackers/track.py +125 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +376 -0
- ultralytics/trackers/utils/kalman_filter.py +493 -0
- ultralytics/trackers/utils/matching.py +157 -0
- ultralytics/utils/__init__.py +1435 -0
- ultralytics/utils/autobatch.py +106 -0
- ultralytics/utils/autodevice.py +174 -0
- ultralytics/utils/benchmarks.py +695 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +234 -0
- ultralytics/utils/callbacks/clearml.py +153 -0
- ultralytics/utils/callbacks/comet.py +552 -0
- ultralytics/utils/callbacks/dvc.py +205 -0
- ultralytics/utils/callbacks/hub.py +108 -0
- ultralytics/utils/callbacks/mlflow.py +138 -0
- ultralytics/utils/callbacks/neptune.py +140 -0
- ultralytics/utils/callbacks/raytune.py +43 -0
- ultralytics/utils/callbacks/tensorboard.py +132 -0
- ultralytics/utils/callbacks/wb.py +185 -0
- ultralytics/utils/checks.py +897 -0
- ultralytics/utils/dist.py +119 -0
- ultralytics/utils/downloads.py +499 -0
- ultralytics/utils/errors.py +43 -0
- ultralytics/utils/export.py +219 -0
- ultralytics/utils/files.py +221 -0
- ultralytics/utils/instance.py +499 -0
- ultralytics/utils/loss.py +813 -0
- ultralytics/utils/metrics.py +1356 -0
- ultralytics/utils/ops.py +885 -0
- ultralytics/utils/patches.py +143 -0
- ultralytics/utils/plotting.py +1011 -0
- ultralytics/utils/tal.py +416 -0
- ultralytics/utils/torch_utils.py +990 -0
- ultralytics/utils/triton.py +116 -0
- ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,164 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
import copy
|
4
|
+
import math
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
import torch.nn as nn
|
9
|
+
import torch.nn.functional as F
|
10
|
+
from torch.nn.init import uniform_
|
11
|
+
|
12
|
+
__all__ = "multi_scale_deformable_attn_pytorch", "inverse_sigmoid"
|
13
|
+
|
14
|
+
|
15
|
+
def _get_clones(module, n):
|
16
|
+
"""
|
17
|
+
Create a list of cloned modules from the given module.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
module (nn.Module): The module to be cloned.
|
21
|
+
n (int): Number of clones to create.
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
(nn.ModuleList): A ModuleList containing n clones of the input module.
|
25
|
+
|
26
|
+
Examples:
|
27
|
+
>>> import torch.nn as nn
|
28
|
+
>>> layer = nn.Linear(10, 10)
|
29
|
+
>>> clones = _get_clones(layer, 3)
|
30
|
+
>>> len(clones)
|
31
|
+
3
|
32
|
+
"""
|
33
|
+
return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
|
34
|
+
|
35
|
+
|
36
|
+
def bias_init_with_prob(prior_prob=0.01):
|
37
|
+
"""
|
38
|
+
Initialize conv/fc bias value according to a given probability value.
|
39
|
+
|
40
|
+
This function calculates the bias initialization value based on a prior probability using the inverse error function.
|
41
|
+
It's commonly used in object detection models to initialize classification layers with a specific positive prediction
|
42
|
+
probability.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
prior_prob (float, optional): Prior probability for bias initialization.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
(float): Bias initialization value calculated from the prior probability.
|
49
|
+
|
50
|
+
Examples:
|
51
|
+
>>> bias = bias_init_with_prob(0.01)
|
52
|
+
>>> print(f"Bias initialization value: {bias:.4f}")
|
53
|
+
Bias initialization value: -4.5951
|
54
|
+
"""
|
55
|
+
return float(-np.log((1 - prior_prob) / prior_prob)) # return bias_init
|
56
|
+
|
57
|
+
|
58
|
+
def linear_init(module):
|
59
|
+
"""
|
60
|
+
Initialize the weights and biases of a linear module.
|
61
|
+
|
62
|
+
This function initializes the weights of a linear module using a uniform distribution within bounds calculated
|
63
|
+
from the input dimension. If the module has a bias, it is also initialized.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
module (nn.Module): Linear module to initialize.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
(nn.Module): The initialized module.
|
70
|
+
|
71
|
+
Examples:
|
72
|
+
>>> import torch.nn as nn
|
73
|
+
>>> linear = nn.Linear(10, 5)
|
74
|
+
>>> initialized_linear = linear_init(linear)
|
75
|
+
"""
|
76
|
+
bound = 1 / math.sqrt(module.weight.shape[0])
|
77
|
+
uniform_(module.weight, -bound, bound)
|
78
|
+
if hasattr(module, "bias") and module.bias is not None:
|
79
|
+
uniform_(module.bias, -bound, bound)
|
80
|
+
|
81
|
+
|
82
|
+
def inverse_sigmoid(x, eps=1e-5):
|
83
|
+
"""
|
84
|
+
Calculate the inverse sigmoid function for a tensor.
|
85
|
+
|
86
|
+
This function applies the inverse of the sigmoid function to a tensor, which is useful in various neural network
|
87
|
+
operations, particularly in attention mechanisms and coordinate transformations.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
x (torch.Tensor): Input tensor with values in range [0, 1].
|
91
|
+
eps (float, optional): Small epsilon value to prevent numerical instability.
|
92
|
+
|
93
|
+
Returns:
|
94
|
+
(torch.Tensor): Tensor after applying the inverse sigmoid function.
|
95
|
+
|
96
|
+
Examples:
|
97
|
+
>>> x = torch.tensor([0.2, 0.5, 0.8])
|
98
|
+
>>> inverse_sigmoid(x)
|
99
|
+
tensor([-1.3863, 0.0000, 1.3863])
|
100
|
+
"""
|
101
|
+
x = x.clamp(min=0, max=1)
|
102
|
+
x1 = x.clamp(min=eps)
|
103
|
+
x2 = (1 - x).clamp(min=eps)
|
104
|
+
return torch.log(x1 / x2)
|
105
|
+
|
106
|
+
|
107
|
+
def multi_scale_deformable_attn_pytorch(
|
108
|
+
value: torch.Tensor,
|
109
|
+
value_spatial_shapes: torch.Tensor,
|
110
|
+
sampling_locations: torch.Tensor,
|
111
|
+
attention_weights: torch.Tensor,
|
112
|
+
) -> torch.Tensor:
|
113
|
+
"""
|
114
|
+
Implement multi-scale deformable attention in PyTorch.
|
115
|
+
|
116
|
+
This function performs deformable attention across multiple feature map scales, allowing the model to attend to
|
117
|
+
different spatial locations with learned offsets.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
value (torch.Tensor): The value tensor with shape (bs, num_keys, num_heads, embed_dims).
|
121
|
+
value_spatial_shapes (torch.Tensor): Spatial shapes of the value tensor with shape (num_levels, 2).
|
122
|
+
sampling_locations (torch.Tensor): The sampling locations with shape
|
123
|
+
(bs, num_queries, num_heads, num_levels, num_points, 2).
|
124
|
+
attention_weights (torch.Tensor): The attention weights with shape
|
125
|
+
(bs, num_queries, num_heads, num_levels, num_points).
|
126
|
+
|
127
|
+
Returns:
|
128
|
+
(torch.Tensor): The output tensor with shape (bs, num_queries, embed_dims).
|
129
|
+
|
130
|
+
References:
|
131
|
+
https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py
|
132
|
+
"""
|
133
|
+
bs, _, num_heads, embed_dims = value.shape
|
134
|
+
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
135
|
+
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
|
136
|
+
sampling_grids = 2 * sampling_locations - 1
|
137
|
+
sampling_value_list = []
|
138
|
+
for level, (H_, W_) in enumerate(value_spatial_shapes):
|
139
|
+
# bs, H_*W_, num_heads, embed_dims ->
|
140
|
+
# bs, H_*W_, num_heads*embed_dims ->
|
141
|
+
# bs, num_heads*embed_dims, H_*W_ ->
|
142
|
+
# bs*num_heads, embed_dims, H_, W_
|
143
|
+
value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
|
144
|
+
# bs, num_queries, num_heads, num_points, 2 ->
|
145
|
+
# bs, num_heads, num_queries, num_points, 2 ->
|
146
|
+
# bs*num_heads, num_queries, num_points, 2
|
147
|
+
sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
|
148
|
+
# bs*num_heads, embed_dims, num_queries, num_points
|
149
|
+
sampling_value_l_ = F.grid_sample(
|
150
|
+
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
|
151
|
+
)
|
152
|
+
sampling_value_list.append(sampling_value_l_)
|
153
|
+
# (bs, num_queries, num_heads, num_levels, num_points) ->
|
154
|
+
# (bs, num_heads, num_queries, num_levels, num_points) ->
|
155
|
+
# (bs, num_heads, 1, num_queries, num_levels*num_points)
|
156
|
+
attention_weights = attention_weights.transpose(1, 2).reshape(
|
157
|
+
bs * num_heads, 1, num_queries, num_levels * num_points
|
158
|
+
)
|
159
|
+
output = (
|
160
|
+
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
161
|
+
.sum(-1)
|
162
|
+
.view(bs, num_heads * embed_dims, num_queries)
|
163
|
+
)
|
164
|
+
return output.transpose(1, 2).contiguous()
|