rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
"""Various positional encodings for the transformer.
|
|
2
|
+
|
|
3
|
+
This is copied from https://github.com/facebookresearch/detr/.
|
|
4
|
+
The original code is:
|
|
5
|
+
Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import math
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from torch import nn
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class PositionEmbeddingSine(nn.Module):
|
|
15
|
+
"""Sinusoidal position embedding.
|
|
16
|
+
|
|
17
|
+
This is similar to the one used by the Attention is all you need paper, but
|
|
18
|
+
generalized to work on images.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
num_pos_feats: int = 64,
|
|
24
|
+
temperature: int = 10000,
|
|
25
|
+
normalize: bool = False,
|
|
26
|
+
scale: float | None = None,
|
|
27
|
+
):
|
|
28
|
+
"""Create a new PositionEmbeddingSine.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
num_pos_feats: the number of features to use. Note that the output will
|
|
32
|
+
have 2x this many, one for x dimension and one for y dimension.
|
|
33
|
+
temperature: temperature parameter.
|
|
34
|
+
normalize: whether to normalize the resulting embeddings.
|
|
35
|
+
scale: how much to scale the embeddings, if normalizing. Defaults to 2*pi.
|
|
36
|
+
"""
|
|
37
|
+
super().__init__()
|
|
38
|
+
self.num_pos_feats = num_pos_feats
|
|
39
|
+
self.temperature = temperature
|
|
40
|
+
self.normalize = normalize
|
|
41
|
+
if scale is not None and normalize is False:
|
|
42
|
+
raise ValueError("normalize should be True if scale is passed")
|
|
43
|
+
if scale is None:
|
|
44
|
+
scale = 2 * math.pi
|
|
45
|
+
self.scale = scale
|
|
46
|
+
|
|
47
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
48
|
+
"""Compute position embeddings.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
x: the feature map, NCHW. The embeddings will have the same height and
|
|
52
|
+
width.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
the position embeddings, as an NCHW tensor.
|
|
56
|
+
"""
|
|
57
|
+
ones = torch.ones_like(x[:, 0, :, :])
|
|
58
|
+
y_embed = ones.cumsum(1, dtype=torch.float32)
|
|
59
|
+
x_embed = ones.cumsum(2, dtype=torch.float32)
|
|
60
|
+
if self.normalize:
|
|
61
|
+
eps = 1e-6
|
|
62
|
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
|
63
|
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
|
64
|
+
|
|
65
|
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
|
66
|
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
|
67
|
+
|
|
68
|
+
pos_x = x_embed[:, :, :, None] / dim_t
|
|
69
|
+
pos_y = y_embed[:, :, :, None] / dim_t
|
|
70
|
+
pos_x = torch.stack(
|
|
71
|
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
|
72
|
+
).flatten(3)
|
|
73
|
+
pos_y = torch.stack(
|
|
74
|
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
|
75
|
+
).flatten(3)
|
|
76
|
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
|
77
|
+
return pos
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class PositionEmbeddingLearned(nn.Module):
|
|
81
|
+
"""Absolute pos embedding, learned."""
|
|
82
|
+
|
|
83
|
+
def __init__(self, num_pos_feats: int = 256):
|
|
84
|
+
"""Create a new PositionEmbeddingLearned."""
|
|
85
|
+
super().__init__()
|
|
86
|
+
self.row_embed = nn.Embedding(50, num_pos_feats)
|
|
87
|
+
self.col_embed = nn.Embedding(50, num_pos_feats)
|
|
88
|
+
self.reset_parameters()
|
|
89
|
+
|
|
90
|
+
def reset_parameters(self) -> None:
|
|
91
|
+
"""Reset the parameters."""
|
|
92
|
+
nn.init.uniform_(self.row_embed.weight)
|
|
93
|
+
nn.init.uniform_(self.col_embed.weight)
|
|
94
|
+
|
|
95
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
96
|
+
"""Compute the position embedding."""
|
|
97
|
+
h, w = x.shape[-2:]
|
|
98
|
+
i = torch.arange(w, device=x.device)
|
|
99
|
+
j = torch.arange(h, device=x.device)
|
|
100
|
+
x_emb = self.col_embed(i)
|
|
101
|
+
y_emb = self.row_embed(j)
|
|
102
|
+
pos = (
|
|
103
|
+
torch.cat(
|
|
104
|
+
[
|
|
105
|
+
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
|
106
|
+
y_emb.unsqueeze(1).repeat(1, w, 1),
|
|
107
|
+
],
|
|
108
|
+
dim=-1,
|
|
109
|
+
)
|
|
110
|
+
.permute(2, 0, 1)
|
|
111
|
+
.unsqueeze(0)
|
|
112
|
+
.repeat(x.shape[0], 1, 1, 1)
|
|
113
|
+
)
|
|
114
|
+
return pos
|
|
@@ -0,0 +1,429 @@
|
|
|
1
|
+
"""DETR Transformer class.
|
|
2
|
+
|
|
3
|
+
This is copied from https://github.com/facebookresearch/detr/.
|
|
4
|
+
The original code is:
|
|
5
|
+
Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import copy
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
from torch import Tensor, nn
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Transformer(nn.Module):
|
|
17
|
+
"""Transformer implementation."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
d_model: int = 512,
|
|
22
|
+
nhead: int = 8,
|
|
23
|
+
num_encoder_layers: int = 6,
|
|
24
|
+
num_decoder_layers: int = 6,
|
|
25
|
+
dim_feedforward: int = 2048,
|
|
26
|
+
dropout: float = 0.1,
|
|
27
|
+
activation: str = "relu",
|
|
28
|
+
normalize_before: bool = False,
|
|
29
|
+
return_intermediate_dec: bool = True,
|
|
30
|
+
):
|
|
31
|
+
"""Create a new Transformer."""
|
|
32
|
+
super().__init__()
|
|
33
|
+
|
|
34
|
+
encoder_layer = TransformerEncoderLayer(
|
|
35
|
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
|
36
|
+
)
|
|
37
|
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
|
38
|
+
self.encoder = TransformerEncoder(
|
|
39
|
+
encoder_layer, num_encoder_layers, encoder_norm
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
decoder_layer = TransformerDecoderLayer(
|
|
43
|
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
|
44
|
+
)
|
|
45
|
+
decoder_norm = nn.LayerNorm(d_model)
|
|
46
|
+
self.decoder = TransformerDecoder(
|
|
47
|
+
decoder_layer,
|
|
48
|
+
num_decoder_layers,
|
|
49
|
+
decoder_norm,
|
|
50
|
+
return_intermediate=return_intermediate_dec,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
self._reset_parameters()
|
|
54
|
+
|
|
55
|
+
self.d_model = d_model
|
|
56
|
+
self.nhead = nhead
|
|
57
|
+
|
|
58
|
+
def _reset_parameters(self) -> None:
|
|
59
|
+
for p in self.parameters():
|
|
60
|
+
if p.dim() > 1:
|
|
61
|
+
nn.init.xavier_uniform_(p)
|
|
62
|
+
|
|
63
|
+
def forward(
|
|
64
|
+
self,
|
|
65
|
+
src: Tensor,
|
|
66
|
+
query_embed: Tensor,
|
|
67
|
+
mask: Tensor | None = None,
|
|
68
|
+
pos_embed: Tensor | None = None,
|
|
69
|
+
) -> tuple[Tensor, Tensor]:
|
|
70
|
+
"""Run forward pass through the transformer model.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
src: the source features, NCHW.
|
|
74
|
+
query_embed: the query embedding to use for decoding.
|
|
75
|
+
mask: optional token mask.
|
|
76
|
+
pos_embed: NCHW positional embedding corresponding to src.
|
|
77
|
+
"""
|
|
78
|
+
# flatten NxCxHxW to HWxNxC
|
|
79
|
+
bs, c, h, w = src.shape
|
|
80
|
+
src = src.flatten(2).permute(2, 0, 1)
|
|
81
|
+
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
|
82
|
+
|
|
83
|
+
if pos_embed is not None:
|
|
84
|
+
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
|
|
85
|
+
if mask is not None:
|
|
86
|
+
mask = mask.flatten(1)
|
|
87
|
+
|
|
88
|
+
tgt = torch.zeros_like(query_embed)
|
|
89
|
+
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
|
90
|
+
hs = self.decoder(
|
|
91
|
+
tgt,
|
|
92
|
+
memory,
|
|
93
|
+
memory_key_padding_mask=mask,
|
|
94
|
+
pos=pos_embed,
|
|
95
|
+
query_pos=query_embed,
|
|
96
|
+
)
|
|
97
|
+
return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class TransformerEncoder(nn.Module):
|
|
101
|
+
"""Transformer encoder implementation."""
|
|
102
|
+
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
encoder_layer: "TransformerEncoderLayer",
|
|
106
|
+
num_layers: int,
|
|
107
|
+
norm: nn.Module | None = None,
|
|
108
|
+
):
|
|
109
|
+
"""Create a new TransformerEncoder."""
|
|
110
|
+
super().__init__()
|
|
111
|
+
self.layers = _get_clones(encoder_layer, num_layers)
|
|
112
|
+
self.num_layers = num_layers
|
|
113
|
+
self.norm = norm
|
|
114
|
+
|
|
115
|
+
def forward(
|
|
116
|
+
self,
|
|
117
|
+
src: Tensor,
|
|
118
|
+
mask: Tensor | None = None,
|
|
119
|
+
src_key_padding_mask: Tensor | None = None,
|
|
120
|
+
pos: Tensor | None = None,
|
|
121
|
+
) -> Tensor:
|
|
122
|
+
"""Forward pass through the TransformerEncoder."""
|
|
123
|
+
output = src
|
|
124
|
+
|
|
125
|
+
for layer in self.layers:
|
|
126
|
+
output = layer(
|
|
127
|
+
output,
|
|
128
|
+
src_mask=mask,
|
|
129
|
+
src_key_padding_mask=src_key_padding_mask,
|
|
130
|
+
pos=pos,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
if self.norm is not None:
|
|
134
|
+
output = self.norm(output)
|
|
135
|
+
|
|
136
|
+
return output
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class TransformerDecoder(nn.Module):
|
|
140
|
+
"""Transformer decoder implementation."""
|
|
141
|
+
|
|
142
|
+
def __init__(
|
|
143
|
+
self,
|
|
144
|
+
decoder_layer: "TransformerDecoderLayer",
|
|
145
|
+
num_layers: int,
|
|
146
|
+
norm: nn.Module | None = None,
|
|
147
|
+
return_intermediate: bool = False,
|
|
148
|
+
):
|
|
149
|
+
"""Create a new TransformerDecoder."""
|
|
150
|
+
super().__init__()
|
|
151
|
+
self.layers = _get_clones(decoder_layer, num_layers)
|
|
152
|
+
self.num_layers = num_layers
|
|
153
|
+
if norm is None:
|
|
154
|
+
self.norm = nn.Identity()
|
|
155
|
+
else:
|
|
156
|
+
self.norm = norm
|
|
157
|
+
self.return_intermediate = return_intermediate
|
|
158
|
+
|
|
159
|
+
def forward(
|
|
160
|
+
self,
|
|
161
|
+
tgt: Tensor,
|
|
162
|
+
memory: Tensor,
|
|
163
|
+
tgt_mask: Tensor | None = None,
|
|
164
|
+
memory_mask: Tensor | None = None,
|
|
165
|
+
tgt_key_padding_mask: Tensor | None = None,
|
|
166
|
+
memory_key_padding_mask: Tensor | None = None,
|
|
167
|
+
pos: Tensor | None = None,
|
|
168
|
+
query_pos: Tensor | None = None,
|
|
169
|
+
) -> Tensor:
|
|
170
|
+
"""Forward pass through the TransformerDecoder."""
|
|
171
|
+
output = tgt
|
|
172
|
+
|
|
173
|
+
intermediate = []
|
|
174
|
+
|
|
175
|
+
for layer in self.layers:
|
|
176
|
+
output = layer(
|
|
177
|
+
output,
|
|
178
|
+
memory,
|
|
179
|
+
tgt_mask=tgt_mask,
|
|
180
|
+
memory_mask=memory_mask,
|
|
181
|
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
182
|
+
memory_key_padding_mask=memory_key_padding_mask,
|
|
183
|
+
pos=pos,
|
|
184
|
+
query_pos=query_pos,
|
|
185
|
+
)
|
|
186
|
+
if self.return_intermediate:
|
|
187
|
+
intermediate.append(self.norm(output))
|
|
188
|
+
|
|
189
|
+
output = self.norm(output)
|
|
190
|
+
if self.return_intermediate:
|
|
191
|
+
intermediate.pop()
|
|
192
|
+
intermediate.append(output)
|
|
193
|
+
|
|
194
|
+
if self.return_intermediate:
|
|
195
|
+
return torch.stack(intermediate)
|
|
196
|
+
|
|
197
|
+
return output.unsqueeze(0)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class TransformerEncoderLayer(nn.Module):
|
|
201
|
+
"""One layer in a TransformerEncoder."""
|
|
202
|
+
|
|
203
|
+
def __init__(
|
|
204
|
+
self,
|
|
205
|
+
d_model: int,
|
|
206
|
+
nhead: int,
|
|
207
|
+
dim_feedforward: int = 2048,
|
|
208
|
+
dropout: float = 0.1,
|
|
209
|
+
activation: str = "relu",
|
|
210
|
+
normalize_before: bool = False,
|
|
211
|
+
):
|
|
212
|
+
"""Create a new TransformerEncoderLayer."""
|
|
213
|
+
super().__init__()
|
|
214
|
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
215
|
+
# Implementation of Feedforward model
|
|
216
|
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
217
|
+
self.dropout = nn.Dropout(dropout)
|
|
218
|
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
219
|
+
|
|
220
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
221
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
222
|
+
self.dropout1 = nn.Dropout(dropout)
|
|
223
|
+
self.dropout2 = nn.Dropout(dropout)
|
|
224
|
+
|
|
225
|
+
self.activation = _get_activation_fn(activation)
|
|
226
|
+
self.normalize_before = normalize_before
|
|
227
|
+
|
|
228
|
+
def with_pos_embed(self, tensor: Tensor, pos: Tensor | None) -> Tensor:
|
|
229
|
+
"""Add optional positional embedding to the tensor, if provided."""
|
|
230
|
+
return tensor if pos is None else tensor + pos
|
|
231
|
+
|
|
232
|
+
def forward_post(
|
|
233
|
+
self,
|
|
234
|
+
src: Tensor,
|
|
235
|
+
src_mask: Tensor | None = None,
|
|
236
|
+
src_key_padding_mask: Tensor | None = None,
|
|
237
|
+
pos: Tensor | None = None,
|
|
238
|
+
) -> Tensor:
|
|
239
|
+
"""Forward pass with normalization after layers."""
|
|
240
|
+
q = k = self.with_pos_embed(src, pos)
|
|
241
|
+
src2 = self.self_attn(
|
|
242
|
+
q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
|
243
|
+
)[0]
|
|
244
|
+
src = src + self.dropout1(src2)
|
|
245
|
+
src = self.norm1(src)
|
|
246
|
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
|
247
|
+
src = src + self.dropout2(src2)
|
|
248
|
+
src = self.norm2(src)
|
|
249
|
+
return src
|
|
250
|
+
|
|
251
|
+
def forward_pre(
|
|
252
|
+
self,
|
|
253
|
+
src: Tensor,
|
|
254
|
+
src_mask: Tensor | None = None,
|
|
255
|
+
src_key_padding_mask: Tensor | None = None,
|
|
256
|
+
pos: Tensor | None = None,
|
|
257
|
+
) -> Tensor:
|
|
258
|
+
"""Forward pass with normalization before layers."""
|
|
259
|
+
src2 = self.norm1(src)
|
|
260
|
+
q = k = self.with_pos_embed(src2, pos)
|
|
261
|
+
src2 = self.self_attn(
|
|
262
|
+
q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
|
263
|
+
)[0]
|
|
264
|
+
src = src + self.dropout1(src2)
|
|
265
|
+
src2 = self.norm2(src)
|
|
266
|
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
|
267
|
+
src = src + self.dropout2(src2)
|
|
268
|
+
return src
|
|
269
|
+
|
|
270
|
+
def forward(
|
|
271
|
+
self,
|
|
272
|
+
src: Tensor,
|
|
273
|
+
src_mask: Tensor | None = None,
|
|
274
|
+
src_key_padding_mask: Tensor | None = None,
|
|
275
|
+
pos: Tensor | None = None,
|
|
276
|
+
) -> Tensor:
|
|
277
|
+
"""Forward pass through the TransformerEncoderLayer."""
|
|
278
|
+
if self.normalize_before:
|
|
279
|
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
|
280
|
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class TransformerDecoderLayer(nn.Module):
|
|
284
|
+
"""One layer in a TransformerDecoder."""
|
|
285
|
+
|
|
286
|
+
def __init__(
|
|
287
|
+
self,
|
|
288
|
+
d_model: int,
|
|
289
|
+
nhead: int,
|
|
290
|
+
dim_feedforward: int = 2048,
|
|
291
|
+
dropout: float = 0.1,
|
|
292
|
+
activation: str = "relu",
|
|
293
|
+
normalize_before: bool = False,
|
|
294
|
+
):
|
|
295
|
+
"""Create a new TransformerDecoderLayer."""
|
|
296
|
+
super().__init__()
|
|
297
|
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
298
|
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
299
|
+
# Implementation of Feedforward model
|
|
300
|
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
301
|
+
self.dropout = nn.Dropout(dropout)
|
|
302
|
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
303
|
+
|
|
304
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
305
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
306
|
+
self.norm3 = nn.LayerNorm(d_model)
|
|
307
|
+
self.dropout1 = nn.Dropout(dropout)
|
|
308
|
+
self.dropout2 = nn.Dropout(dropout)
|
|
309
|
+
self.dropout3 = nn.Dropout(dropout)
|
|
310
|
+
|
|
311
|
+
self.activation = _get_activation_fn(activation)
|
|
312
|
+
self.normalize_before = normalize_before
|
|
313
|
+
|
|
314
|
+
def with_pos_embed(self, tensor: Tensor, pos: Tensor | None) -> Tensor:
|
|
315
|
+
"""Add optional positional embedding to the tensor, if provided."""
|
|
316
|
+
return tensor if pos is None else tensor + pos
|
|
317
|
+
|
|
318
|
+
def forward_post(
|
|
319
|
+
self,
|
|
320
|
+
tgt: Tensor,
|
|
321
|
+
memory: Tensor,
|
|
322
|
+
tgt_mask: Tensor | None = None,
|
|
323
|
+
memory_mask: Tensor | None = None,
|
|
324
|
+
tgt_key_padding_mask: Tensor | None = None,
|
|
325
|
+
memory_key_padding_mask: Tensor | None = None,
|
|
326
|
+
pos: Tensor | None = None,
|
|
327
|
+
query_pos: Tensor | None = None,
|
|
328
|
+
) -> Tensor:
|
|
329
|
+
"""Forward pass with normalization after layers."""
|
|
330
|
+
q = k = self.with_pos_embed(tgt, query_pos)
|
|
331
|
+
tgt2 = self.self_attn(
|
|
332
|
+
q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
|
333
|
+
)[0]
|
|
334
|
+
tgt = tgt + self.dropout1(tgt2)
|
|
335
|
+
tgt = self.norm1(tgt)
|
|
336
|
+
tgt2 = self.multihead_attn(
|
|
337
|
+
query=self.with_pos_embed(tgt, query_pos),
|
|
338
|
+
key=self.with_pos_embed(memory, pos),
|
|
339
|
+
value=memory,
|
|
340
|
+
attn_mask=memory_mask,
|
|
341
|
+
key_padding_mask=memory_key_padding_mask,
|
|
342
|
+
)[0]
|
|
343
|
+
tgt = tgt + self.dropout2(tgt2)
|
|
344
|
+
tgt = self.norm2(tgt)
|
|
345
|
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
|
346
|
+
tgt = tgt + self.dropout3(tgt2)
|
|
347
|
+
tgt = self.norm3(tgt)
|
|
348
|
+
return tgt
|
|
349
|
+
|
|
350
|
+
def forward_pre(
|
|
351
|
+
self,
|
|
352
|
+
tgt: Tensor,
|
|
353
|
+
memory: Tensor,
|
|
354
|
+
tgt_mask: Tensor | None = None,
|
|
355
|
+
memory_mask: Tensor | None = None,
|
|
356
|
+
tgt_key_padding_mask: Tensor | None = None,
|
|
357
|
+
memory_key_padding_mask: Tensor | None = None,
|
|
358
|
+
pos: Tensor | None = None,
|
|
359
|
+
query_pos: Tensor | None = None,
|
|
360
|
+
) -> Tensor:
|
|
361
|
+
"""Forward pass with normalization before layers."""
|
|
362
|
+
tgt2 = self.norm1(tgt)
|
|
363
|
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
|
364
|
+
tgt2 = self.self_attn(
|
|
365
|
+
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
|
366
|
+
)[0]
|
|
367
|
+
tgt = tgt + self.dropout1(tgt2)
|
|
368
|
+
tgt2 = self.norm2(tgt)
|
|
369
|
+
tgt2 = self.multihead_attn(
|
|
370
|
+
query=self.with_pos_embed(tgt2, query_pos),
|
|
371
|
+
key=self.with_pos_embed(memory, pos),
|
|
372
|
+
value=memory,
|
|
373
|
+
attn_mask=memory_mask,
|
|
374
|
+
key_padding_mask=memory_key_padding_mask,
|
|
375
|
+
)[0]
|
|
376
|
+
tgt = tgt + self.dropout2(tgt2)
|
|
377
|
+
tgt2 = self.norm3(tgt)
|
|
378
|
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
|
379
|
+
tgt = tgt + self.dropout3(tgt2)
|
|
380
|
+
return tgt
|
|
381
|
+
|
|
382
|
+
def forward(
|
|
383
|
+
self,
|
|
384
|
+
tgt: Tensor,
|
|
385
|
+
memory: Tensor,
|
|
386
|
+
tgt_mask: Tensor | None = None,
|
|
387
|
+
memory_mask: Tensor | None = None,
|
|
388
|
+
tgt_key_padding_mask: Tensor | None = None,
|
|
389
|
+
memory_key_padding_mask: Tensor | None = None,
|
|
390
|
+
pos: Tensor | None = None,
|
|
391
|
+
query_pos: Tensor | None = None,
|
|
392
|
+
) -> Tensor:
|
|
393
|
+
"""Forward pass through the TransformerDecoderLayer."""
|
|
394
|
+
if self.normalize_before:
|
|
395
|
+
return self.forward_pre(
|
|
396
|
+
tgt,
|
|
397
|
+
memory,
|
|
398
|
+
tgt_mask,
|
|
399
|
+
memory_mask,
|
|
400
|
+
tgt_key_padding_mask,
|
|
401
|
+
memory_key_padding_mask,
|
|
402
|
+
pos,
|
|
403
|
+
query_pos,
|
|
404
|
+
)
|
|
405
|
+
return self.forward_post(
|
|
406
|
+
tgt,
|
|
407
|
+
memory,
|
|
408
|
+
tgt_mask,
|
|
409
|
+
memory_mask,
|
|
410
|
+
tgt_key_padding_mask,
|
|
411
|
+
memory_key_padding_mask,
|
|
412
|
+
pos,
|
|
413
|
+
query_pos,
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def _get_clones(module: nn.Module, N: int) -> nn.ModuleList:
|
|
418
|
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
|
|
422
|
+
"""Return an activation function given a string."""
|
|
423
|
+
if activation == "relu":
|
|
424
|
+
return F.relu
|
|
425
|
+
if activation == "gelu":
|
|
426
|
+
return F.gelu
|
|
427
|
+
if activation == "glu":
|
|
428
|
+
return F.glu
|
|
429
|
+
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Miscellaneous utilities for DETR."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@torch.no_grad()
|
|
7
|
+
def accuracy(
|
|
8
|
+
output: torch.Tensor, target: torch.Tensor, topk: tuple[int, ...] = (1,)
|
|
9
|
+
) -> list[torch.Tensor]:
|
|
10
|
+
"""Computes the precision@k for the specified values of k."""
|
|
11
|
+
if target.numel() == 0:
|
|
12
|
+
return [torch.zeros([], device=output.device)]
|
|
13
|
+
maxk = max(topk)
|
|
14
|
+
batch_size = target.size(0)
|
|
15
|
+
|
|
16
|
+
_, pred = output.topk(maxk, 1, True, True)
|
|
17
|
+
pred = pred.t()
|
|
18
|
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
19
|
+
|
|
20
|
+
res = []
|
|
21
|
+
for k in topk:
|
|
22
|
+
correct_k = correct[:k].view(-1).float().sum(0)
|
|
23
|
+
res.append(correct_k.mul_(100.0 / batch_size))
|
|
24
|
+
return res
|