dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +138 -0
- tests/test_cuda.py +215 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +236 -0
- tests/test_integrations.py +154 -0
- tests/test_python.py +694 -0
- tests/test_solutions.py +187 -0
- ultralytics/__init__.py +30 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1023 -0
- ultralytics/cfg/datasets/Argoverse.yaml +77 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +443 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +106 -0
- ultralytics/cfg/datasets/VisDrone.yaml +77 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +42 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +127 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +22 -0
- ultralytics/cfg/trackers/bytetrack.yaml +14 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2945 -0
- ultralytics/data/base.py +438 -0
- ultralytics/data/build.py +258 -0
- ultralytics/data/converter.py +754 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +676 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +125 -0
- ultralytics/data/split_dota.py +325 -0
- ultralytics/data/utils.py +777 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1519 -0
- ultralytics/engine/model.py +1156 -0
- ultralytics/engine/predictor.py +502 -0
- ultralytics/engine/results.py +1840 -0
- ultralytics/engine/trainer.py +853 -0
- ultralytics/engine/tuner.py +243 -0
- ultralytics/engine/validator.py +377 -0
- ultralytics/hub/__init__.py +168 -0
- ultralytics/hub/auth.py +137 -0
- ultralytics/hub/google/__init__.py +176 -0
- ultralytics/hub/session.py +446 -0
- ultralytics/hub/utils.py +248 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +61 -0
- ultralytics/models/fastsam/predict.py +181 -0
- ultralytics/models/fastsam/utils.py +24 -0
- ultralytics/models/fastsam/val.py +40 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +102 -0
- ultralytics/models/nas/predict.py +58 -0
- ultralytics/models/nas/val.py +39 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +84 -0
- ultralytics/models/rtdetr/train.py +85 -0
- ultralytics/models/rtdetr/val.py +191 -0
- ultralytics/models/sam/__init__.py +6 -0
- ultralytics/models/sam/amg.py +260 -0
- ultralytics/models/sam/build.py +358 -0
- ultralytics/models/sam/model.py +170 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +515 -0
- ultralytics/models/sam/modules/encoders.py +854 -0
- ultralytics/models/sam/modules/memory_attention.py +299 -0
- ultralytics/models/sam/modules/sam.py +1006 -0
- ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
- ultralytics/models/sam/modules/transformer.py +351 -0
- ultralytics/models/sam/modules/utils.py +394 -0
- ultralytics/models/sam/predict.py +1605 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +455 -0
- ultralytics/models/utils/ops.py +268 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +88 -0
- ultralytics/models/yolo/classify/train.py +233 -0
- ultralytics/models/yolo/classify/val.py +215 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +124 -0
- ultralytics/models/yolo/detect/train.py +217 -0
- ultralytics/models/yolo/detect/val.py +451 -0
- ultralytics/models/yolo/model.py +354 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +66 -0
- ultralytics/models/yolo/obb/train.py +81 -0
- ultralytics/models/yolo/obb/val.py +283 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +79 -0
- ultralytics/models/yolo/pose/train.py +154 -0
- ultralytics/models/yolo/pose/val.py +394 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +113 -0
- ultralytics/models/yolo/segment/train.py +123 -0
- ultralytics/models/yolo/segment/val.py +428 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +119 -0
- ultralytics/models/yolo/world/train_world.py +176 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +169 -0
- ultralytics/models/yolo/yoloe/train.py +298 -0
- ultralytics/models/yolo/yoloe/train_seg.py +124 -0
- ultralytics/models/yolo/yoloe/val.py +191 -0
- ultralytics/nn/__init__.py +29 -0
- ultralytics/nn/autobackend.py +842 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +53 -0
- ultralytics/nn/modules/block.py +1966 -0
- ultralytics/nn/modules/conv.py +712 -0
- ultralytics/nn/modules/head.py +880 -0
- ultralytics/nn/modules/transformer.py +713 -0
- ultralytics/nn/modules/utils.py +164 -0
- ultralytics/nn/tasks.py +1627 -0
- ultralytics/nn/text_model.py +351 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +116 -0
- ultralytics/solutions/analytics.py +252 -0
- ultralytics/solutions/config.py +106 -0
- ultralytics/solutions/distance_calculation.py +124 -0
- ultralytics/solutions/heatmap.py +127 -0
- ultralytics/solutions/instance_segmentation.py +84 -0
- ultralytics/solutions/object_blurrer.py +90 -0
- ultralytics/solutions/object_counter.py +195 -0
- ultralytics/solutions/object_cropper.py +84 -0
- ultralytics/solutions/parking_management.py +273 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +120 -0
- ultralytics/solutions/security_alarm.py +154 -0
- ultralytics/solutions/similarity_search.py +172 -0
- ultralytics/solutions/solutions.py +724 -0
- ultralytics/solutions/speed_estimation.py +110 -0
- ultralytics/solutions/streamlit_inference.py +196 -0
- ultralytics/solutions/templates/similarity-search.html +160 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +68 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +124 -0
- ultralytics/trackers/bot_sort.py +260 -0
- ultralytics/trackers/byte_tracker.py +480 -0
- ultralytics/trackers/track.py +125 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +376 -0
- ultralytics/trackers/utils/kalman_filter.py +493 -0
- ultralytics/trackers/utils/matching.py +157 -0
- ultralytics/utils/__init__.py +1435 -0
- ultralytics/utils/autobatch.py +106 -0
- ultralytics/utils/autodevice.py +174 -0
- ultralytics/utils/benchmarks.py +695 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +234 -0
- ultralytics/utils/callbacks/clearml.py +153 -0
- ultralytics/utils/callbacks/comet.py +552 -0
- ultralytics/utils/callbacks/dvc.py +205 -0
- ultralytics/utils/callbacks/hub.py +108 -0
- ultralytics/utils/callbacks/mlflow.py +138 -0
- ultralytics/utils/callbacks/neptune.py +140 -0
- ultralytics/utils/callbacks/raytune.py +43 -0
- ultralytics/utils/callbacks/tensorboard.py +132 -0
- ultralytics/utils/callbacks/wb.py +185 -0
- ultralytics/utils/checks.py +897 -0
- ultralytics/utils/dist.py +119 -0
- ultralytics/utils/downloads.py +499 -0
- ultralytics/utils/errors.py +43 -0
- ultralytics/utils/export.py +219 -0
- ultralytics/utils/files.py +221 -0
- ultralytics/utils/instance.py +499 -0
- ultralytics/utils/loss.py +813 -0
- ultralytics/utils/metrics.py +1356 -0
- ultralytics/utils/ops.py +885 -0
- ultralytics/utils/patches.py +143 -0
- ultralytics/utils/plotting.py +1011 -0
- ultralytics/utils/tal.py +416 -0
- ultralytics/utils/torch_utils.py +990 -0
- ultralytics/utils/triton.py +116 -0
- ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,299 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
import copy
|
4
|
+
from typing import Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch import Tensor, nn
|
8
|
+
|
9
|
+
from .blocks import RoPEAttention
|
10
|
+
|
11
|
+
|
12
|
+
class MemoryAttentionLayer(nn.Module):
|
13
|
+
"""
|
14
|
+
Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
|
15
|
+
|
16
|
+
This class combines self-attention, cross-attention, and feedforward components to process input tensors and
|
17
|
+
generate memory-based attention outputs.
|
18
|
+
|
19
|
+
Attributes:
|
20
|
+
d_model (int): Dimensionality of the model.
|
21
|
+
dim_feedforward (int): Dimensionality of the feedforward network.
|
22
|
+
dropout_value (float): Dropout rate for regularization.
|
23
|
+
self_attn (RoPEAttention): Self-attention mechanism using RoPE (Rotary Position Embedding).
|
24
|
+
cross_attn_image (RoPEAttention): Cross-attention mechanism for image processing.
|
25
|
+
linear1 (nn.Linear): First linear layer of the feedforward network.
|
26
|
+
linear2 (nn.Linear): Second linear layer of the feedforward network.
|
27
|
+
norm1 (nn.LayerNorm): Layer normalization for self-attention output.
|
28
|
+
norm2 (nn.LayerNorm): Layer normalization for cross-attention output.
|
29
|
+
norm3 (nn.LayerNorm): Layer normalization for feedforward network output.
|
30
|
+
dropout1 (nn.Dropout): Dropout layer after self-attention.
|
31
|
+
dropout2 (nn.Dropout): Dropout layer after cross-attention.
|
32
|
+
dropout3 (nn.Dropout): Dropout layer after feedforward network.
|
33
|
+
activation (nn.ReLU): Activation function for the feedforward network.
|
34
|
+
pos_enc_at_attn (bool): Flag to add positional encoding at attention.
|
35
|
+
pos_enc_at_cross_attn_queries (bool): Flag to add positional encoding to cross-attention queries.
|
36
|
+
pos_enc_at_cross_attn_keys (bool): Flag to add positional encoding to cross-attention keys.
|
37
|
+
|
38
|
+
Methods:
|
39
|
+
forward: Performs the full memory attention operation on input tensors.
|
40
|
+
_forward_sa: Performs self-attention on input tensor.
|
41
|
+
_forward_ca: Performs cross-attention between target and memory tensors.
|
42
|
+
|
43
|
+
Examples:
|
44
|
+
>>> layer = MemoryAttentionLayer(d_model=256, dim_feedforward=2048, dropout=0.1)
|
45
|
+
>>> tgt = torch.randn(1, 100, 256)
|
46
|
+
>>> memory = torch.randn(1, 100, 64)
|
47
|
+
>>> pos = torch.randn(1, 100, 256)
|
48
|
+
>>> query_pos = torch.randn(1, 100, 256)
|
49
|
+
>>> output = layer(tgt, memory, pos, query_pos)
|
50
|
+
>>> print(output.shape)
|
51
|
+
torch.Size([1, 100, 256])
|
52
|
+
"""
|
53
|
+
|
54
|
+
def __init__(
|
55
|
+
self,
|
56
|
+
d_model: int = 256,
|
57
|
+
dim_feedforward: int = 2048,
|
58
|
+
dropout: float = 0.1,
|
59
|
+
pos_enc_at_attn: bool = False,
|
60
|
+
pos_enc_at_cross_attn_keys: bool = True,
|
61
|
+
pos_enc_at_cross_attn_queries: bool = False,
|
62
|
+
):
|
63
|
+
"""
|
64
|
+
Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
d_model (int): Dimensionality of the model.
|
68
|
+
dim_feedforward (int): Dimensionality of the feedforward network.
|
69
|
+
dropout (float): Dropout rate for regularization.
|
70
|
+
pos_enc_at_attn (bool): Whether to add positional encoding at attention.
|
71
|
+
pos_enc_at_cross_attn_keys (bool): Whether to add positional encoding to cross-attention keys.
|
72
|
+
pos_enc_at_cross_attn_queries (bool): Whether to add positional encoding to cross-attention queries.
|
73
|
+
"""
|
74
|
+
super().__init__()
|
75
|
+
self.d_model = d_model
|
76
|
+
self.dim_feedforward = dim_feedforward
|
77
|
+
self.dropout_value = dropout
|
78
|
+
self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
|
79
|
+
self.cross_attn_image = RoPEAttention(
|
80
|
+
rope_k_repeat=True,
|
81
|
+
embedding_dim=256,
|
82
|
+
num_heads=1,
|
83
|
+
downsample_rate=1,
|
84
|
+
kv_in_dim=64,
|
85
|
+
)
|
86
|
+
|
87
|
+
# Implementation of Feedforward model
|
88
|
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
89
|
+
self.dropout = nn.Dropout(dropout)
|
90
|
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
91
|
+
|
92
|
+
self.norm1 = nn.LayerNorm(d_model)
|
93
|
+
self.norm2 = nn.LayerNorm(d_model)
|
94
|
+
self.norm3 = nn.LayerNorm(d_model)
|
95
|
+
self.dropout1 = nn.Dropout(dropout)
|
96
|
+
self.dropout2 = nn.Dropout(dropout)
|
97
|
+
self.dropout3 = nn.Dropout(dropout)
|
98
|
+
|
99
|
+
self.activation = nn.ReLU()
|
100
|
+
|
101
|
+
# Where to add pos enc
|
102
|
+
self.pos_enc_at_attn = pos_enc_at_attn
|
103
|
+
self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
|
104
|
+
self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
|
105
|
+
|
106
|
+
def _forward_sa(self, tgt: Tensor, query_pos: Optional[Tensor]) -> Tensor:
|
107
|
+
"""Perform self-attention on input tensor using positional encoding and RoPE attention mechanism."""
|
108
|
+
tgt2 = self.norm1(tgt)
|
109
|
+
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
|
110
|
+
tgt2 = self.self_attn(q, k, v=tgt2)
|
111
|
+
tgt = tgt + self.dropout1(tgt2)
|
112
|
+
return tgt
|
113
|
+
|
114
|
+
def _forward_ca(
|
115
|
+
self,
|
116
|
+
tgt: Tensor,
|
117
|
+
memory: Tensor,
|
118
|
+
query_pos: Optional[Tensor],
|
119
|
+
pos: Optional[Tensor],
|
120
|
+
num_k_exclude_rope: int = 0,
|
121
|
+
) -> Tensor:
|
122
|
+
"""Perform cross-attention between target and memory tensors using RoPEAttention mechanism."""
|
123
|
+
kwds = {}
|
124
|
+
if num_k_exclude_rope > 0:
|
125
|
+
assert isinstance(self.cross_attn_image, RoPEAttention)
|
126
|
+
kwds = {"num_k_exclude_rope": num_k_exclude_rope}
|
127
|
+
|
128
|
+
# Cross-Attention
|
129
|
+
tgt2 = self.norm2(tgt)
|
130
|
+
tgt2 = self.cross_attn_image(
|
131
|
+
q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
|
132
|
+
k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
|
133
|
+
v=memory,
|
134
|
+
**kwds,
|
135
|
+
)
|
136
|
+
tgt = tgt + self.dropout2(tgt2)
|
137
|
+
return tgt
|
138
|
+
|
139
|
+
def forward(
|
140
|
+
self,
|
141
|
+
tgt: Tensor,
|
142
|
+
memory: Tensor,
|
143
|
+
pos: Optional[Tensor] = None,
|
144
|
+
query_pos: Optional[Tensor] = None,
|
145
|
+
num_k_exclude_rope: int = 0,
|
146
|
+
) -> torch.Tensor:
|
147
|
+
"""Process input tensors through self-attention, cross-attention, and feedforward network layers."""
|
148
|
+
tgt = self._forward_sa(tgt, query_pos)
|
149
|
+
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
|
150
|
+
# MLP
|
151
|
+
tgt2 = self.norm3(tgt)
|
152
|
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
153
|
+
tgt = tgt + self.dropout3(tgt2)
|
154
|
+
return tgt
|
155
|
+
|
156
|
+
|
157
|
+
class MemoryAttention(nn.Module):
|
158
|
+
"""
|
159
|
+
Memory attention module for processing sequential data with self and cross-attention mechanisms.
|
160
|
+
|
161
|
+
This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
|
162
|
+
for processing sequential data, particularly useful in transformer-like architectures.
|
163
|
+
|
164
|
+
Attributes:
|
165
|
+
d_model (int): The dimension of the model's hidden state.
|
166
|
+
layers (nn.ModuleList): A list of MemoryAttentionLayer modules.
|
167
|
+
num_layers (int): The number of attention layers.
|
168
|
+
norm (nn.LayerNorm): Layer normalization applied to the output.
|
169
|
+
pos_enc_at_input (bool): Whether to apply positional encoding at the input.
|
170
|
+
batch_first (bool): Whether the input tensors are in batch-first format.
|
171
|
+
|
172
|
+
Methods:
|
173
|
+
forward: Processes input tensors through the attention layers.
|
174
|
+
|
175
|
+
Examples:
|
176
|
+
>>> d_model = 256
|
177
|
+
>>> layer = MemoryAttentionLayer(d_model)
|
178
|
+
>>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
|
179
|
+
>>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)
|
180
|
+
>>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)
|
181
|
+
>>> curr_pos = torch.randn(10, 32, d_model)
|
182
|
+
>>> memory_pos = torch.randn(20, 32, d_model)
|
183
|
+
>>> output = attention(curr, memory, curr_pos, memory_pos)
|
184
|
+
>>> print(output.shape)
|
185
|
+
torch.Size([10, 32, 256])
|
186
|
+
"""
|
187
|
+
|
188
|
+
def __init__(
|
189
|
+
self,
|
190
|
+
d_model: int,
|
191
|
+
pos_enc_at_input: bool,
|
192
|
+
layer: nn.Module,
|
193
|
+
num_layers: int,
|
194
|
+
batch_first: bool = True, # Do layers expect batch first input?
|
195
|
+
):
|
196
|
+
"""
|
197
|
+
Initialize MemoryAttention with specified layers and normalization for sequential data processing.
|
198
|
+
|
199
|
+
This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
|
200
|
+
for processing sequential data, particularly useful in transformer-like architectures.
|
201
|
+
|
202
|
+
Args:
|
203
|
+
d_model (int): The dimension of the model's hidden state.
|
204
|
+
pos_enc_at_input (bool): Whether to apply positional encoding at the input.
|
205
|
+
layer (nn.Module): The attention layer to be used in the module.
|
206
|
+
num_layers (int): The number of attention layers.
|
207
|
+
batch_first (bool): Whether the input tensors are in batch-first format.
|
208
|
+
|
209
|
+
Examples:
|
210
|
+
>>> d_model = 256
|
211
|
+
>>> layer = MemoryAttentionLayer(d_model)
|
212
|
+
>>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
|
213
|
+
>>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)
|
214
|
+
>>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)
|
215
|
+
>>> curr_pos = torch.randn(10, 32, d_model)
|
216
|
+
>>> memory_pos = torch.randn(20, 32, d_model)
|
217
|
+
>>> output = attention(curr, memory, curr_pos, memory_pos)
|
218
|
+
>>> print(output.shape)
|
219
|
+
torch.Size([10, 32, 256])
|
220
|
+
"""
|
221
|
+
super().__init__()
|
222
|
+
self.d_model = d_model
|
223
|
+
self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
|
224
|
+
self.num_layers = num_layers
|
225
|
+
self.norm = nn.LayerNorm(d_model)
|
226
|
+
self.pos_enc_at_input = pos_enc_at_input
|
227
|
+
self.batch_first = batch_first
|
228
|
+
|
229
|
+
def forward(
|
230
|
+
self,
|
231
|
+
curr: torch.Tensor, # self-attention inputs
|
232
|
+
memory: torch.Tensor, # cross-attention inputs
|
233
|
+
curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
|
234
|
+
memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
|
235
|
+
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
|
236
|
+
) -> torch.Tensor:
|
237
|
+
"""
|
238
|
+
Process inputs through attention layers, applying self and cross-attention with positional encoding.
|
239
|
+
|
240
|
+
Args:
|
241
|
+
curr (torch.Tensor): Self-attention input tensor, representing the current state.
|
242
|
+
memory (torch.Tensor): Cross-attention input tensor, representing memory information.
|
243
|
+
curr_pos (Optional[Tensor]): Positional encoding for self-attention inputs.
|
244
|
+
memory_pos (Optional[Tensor]): Positional encoding for cross-attention inputs.
|
245
|
+
num_obj_ptr_tokens (int): Number of object pointer tokens to exclude from rotary position embedding.
|
246
|
+
|
247
|
+
Returns:
|
248
|
+
(torch.Tensor): Processed output tensor after applying attention layers and normalization.
|
249
|
+
|
250
|
+
Examples:
|
251
|
+
>>> d_model = 256
|
252
|
+
>>> layer = MemoryAttentionLayer(d_model)
|
253
|
+
>>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
|
254
|
+
>>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)
|
255
|
+
>>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)
|
256
|
+
>>> curr_pos = torch.randn(10, 32, d_model)
|
257
|
+
>>> memory_pos = torch.randn(20, 32, d_model)
|
258
|
+
>>> output = attention(curr, memory, curr_pos, memory_pos)
|
259
|
+
>>> print(output.shape)
|
260
|
+
torch.Size([10, 32, 256])
|
261
|
+
"""
|
262
|
+
if isinstance(curr, list):
|
263
|
+
assert isinstance(curr_pos, list)
|
264
|
+
assert len(curr) == len(curr_pos) == 1
|
265
|
+
curr, curr_pos = curr[0], curr_pos[0]
|
266
|
+
|
267
|
+
assert curr.shape[1] == memory.shape[1], "Batch size must be the same for curr and memory"
|
268
|
+
|
269
|
+
output = curr
|
270
|
+
if self.pos_enc_at_input and curr_pos is not None:
|
271
|
+
output = output + 0.1 * curr_pos
|
272
|
+
|
273
|
+
if self.batch_first:
|
274
|
+
# Convert to batch first
|
275
|
+
output = output.transpose(0, 1)
|
276
|
+
curr_pos = curr_pos.transpose(0, 1)
|
277
|
+
memory = memory.transpose(0, 1)
|
278
|
+
memory_pos = memory_pos.transpose(0, 1)
|
279
|
+
|
280
|
+
for layer in self.layers:
|
281
|
+
kwds = {}
|
282
|
+
if isinstance(layer.cross_attn_image, RoPEAttention):
|
283
|
+
kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
|
284
|
+
|
285
|
+
output = layer(
|
286
|
+
tgt=output,
|
287
|
+
memory=memory,
|
288
|
+
pos=memory_pos,
|
289
|
+
query_pos=curr_pos,
|
290
|
+
**kwds,
|
291
|
+
)
|
292
|
+
normed_output = self.norm(output)
|
293
|
+
|
294
|
+
if self.batch_first:
|
295
|
+
# Convert back to seq first
|
296
|
+
normed_output = normed_output.transpose(0, 1)
|
297
|
+
curr_pos = curr_pos.transpose(0, 1)
|
298
|
+
|
299
|
+
return normed_output
|