rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,114 @@
1
+ """Various positional encodings for the transformer.
2
+
3
+ This is copied from https://github.com/facebookresearch/detr/.
4
+ The original code is:
5
+ Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
6
+ """
7
+
8
+ import math
9
+
10
+ import torch
11
+ from torch import nn
12
+
13
+
14
+ class PositionEmbeddingSine(nn.Module):
15
+ """Sinusoidal position embedding.
16
+
17
+ This is similar to the one used by the Attention is all you need paper, but
18
+ generalized to work on images.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ num_pos_feats: int = 64,
24
+ temperature: int = 10000,
25
+ normalize: bool = False,
26
+ scale: float | None = None,
27
+ ):
28
+ """Create a new PositionEmbeddingSine.
29
+
30
+ Args:
31
+ num_pos_feats: the number of features to use. Note that the output will
32
+ have 2x this many, one for x dimension and one for y dimension.
33
+ temperature: temperature parameter.
34
+ normalize: whether to normalize the resulting embeddings.
35
+ scale: how much to scale the embeddings, if normalizing. Defaults to 2*pi.
36
+ """
37
+ super().__init__()
38
+ self.num_pos_feats = num_pos_feats
39
+ self.temperature = temperature
40
+ self.normalize = normalize
41
+ if scale is not None and normalize is False:
42
+ raise ValueError("normalize should be True if scale is passed")
43
+ if scale is None:
44
+ scale = 2 * math.pi
45
+ self.scale = scale
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ """Compute position embeddings.
49
+
50
+ Args:
51
+ x: the feature map, NCHW. The embeddings will have the same height and
52
+ width.
53
+
54
+ Returns:
55
+ the position embeddings, as an NCHW tensor.
56
+ """
57
+ ones = torch.ones_like(x[:, 0, :, :])
58
+ y_embed = ones.cumsum(1, dtype=torch.float32)
59
+ x_embed = ones.cumsum(2, dtype=torch.float32)
60
+ if self.normalize:
61
+ eps = 1e-6
62
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
63
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
64
+
65
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
66
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
67
+
68
+ pos_x = x_embed[:, :, :, None] / dim_t
69
+ pos_y = y_embed[:, :, :, None] / dim_t
70
+ pos_x = torch.stack(
71
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
72
+ ).flatten(3)
73
+ pos_y = torch.stack(
74
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
75
+ ).flatten(3)
76
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
77
+ return pos
78
+
79
+
80
+ class PositionEmbeddingLearned(nn.Module):
81
+ """Absolute pos embedding, learned."""
82
+
83
+ def __init__(self, num_pos_feats: int = 256):
84
+ """Create a new PositionEmbeddingLearned."""
85
+ super().__init__()
86
+ self.row_embed = nn.Embedding(50, num_pos_feats)
87
+ self.col_embed = nn.Embedding(50, num_pos_feats)
88
+ self.reset_parameters()
89
+
90
+ def reset_parameters(self) -> None:
91
+ """Reset the parameters."""
92
+ nn.init.uniform_(self.row_embed.weight)
93
+ nn.init.uniform_(self.col_embed.weight)
94
+
95
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
96
+ """Compute the position embedding."""
97
+ h, w = x.shape[-2:]
98
+ i = torch.arange(w, device=x.device)
99
+ j = torch.arange(h, device=x.device)
100
+ x_emb = self.col_embed(i)
101
+ y_emb = self.row_embed(j)
102
+ pos = (
103
+ torch.cat(
104
+ [
105
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
106
+ y_emb.unsqueeze(1).repeat(1, w, 1),
107
+ ],
108
+ dim=-1,
109
+ )
110
+ .permute(2, 0, 1)
111
+ .unsqueeze(0)
112
+ .repeat(x.shape[0], 1, 1, 1)
113
+ )
114
+ return pos
@@ -0,0 +1,429 @@
1
+ """DETR Transformer class.
2
+
3
+ This is copied from https://github.com/facebookresearch/detr/.
4
+ The original code is:
5
+ Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
6
+ """
7
+
8
+ import copy
9
+ from collections.abc import Callable
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Transformer(nn.Module):
17
+ """Transformer implementation."""
18
+
19
+ def __init__(
20
+ self,
21
+ d_model: int = 512,
22
+ nhead: int = 8,
23
+ num_encoder_layers: int = 6,
24
+ num_decoder_layers: int = 6,
25
+ dim_feedforward: int = 2048,
26
+ dropout: float = 0.1,
27
+ activation: str = "relu",
28
+ normalize_before: bool = False,
29
+ return_intermediate_dec: bool = True,
30
+ ):
31
+ """Create a new Transformer."""
32
+ super().__init__()
33
+
34
+ encoder_layer = TransformerEncoderLayer(
35
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
36
+ )
37
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
38
+ self.encoder = TransformerEncoder(
39
+ encoder_layer, num_encoder_layers, encoder_norm
40
+ )
41
+
42
+ decoder_layer = TransformerDecoderLayer(
43
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
44
+ )
45
+ decoder_norm = nn.LayerNorm(d_model)
46
+ self.decoder = TransformerDecoder(
47
+ decoder_layer,
48
+ num_decoder_layers,
49
+ decoder_norm,
50
+ return_intermediate=return_intermediate_dec,
51
+ )
52
+
53
+ self._reset_parameters()
54
+
55
+ self.d_model = d_model
56
+ self.nhead = nhead
57
+
58
+ def _reset_parameters(self) -> None:
59
+ for p in self.parameters():
60
+ if p.dim() > 1:
61
+ nn.init.xavier_uniform_(p)
62
+
63
+ def forward(
64
+ self,
65
+ src: Tensor,
66
+ query_embed: Tensor,
67
+ mask: Tensor | None = None,
68
+ pos_embed: Tensor | None = None,
69
+ ) -> tuple[Tensor, Tensor]:
70
+ """Run forward pass through the transformer model.
71
+
72
+ Args:
73
+ src: the source features, NCHW.
74
+ query_embed: the query embedding to use for decoding.
75
+ mask: optional token mask.
76
+ pos_embed: NCHW positional embedding corresponding to src.
77
+ """
78
+ # flatten NxCxHxW to HWxNxC
79
+ bs, c, h, w = src.shape
80
+ src = src.flatten(2).permute(2, 0, 1)
81
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
82
+
83
+ if pos_embed is not None:
84
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
85
+ if mask is not None:
86
+ mask = mask.flatten(1)
87
+
88
+ tgt = torch.zeros_like(query_embed)
89
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
90
+ hs = self.decoder(
91
+ tgt,
92
+ memory,
93
+ memory_key_padding_mask=mask,
94
+ pos=pos_embed,
95
+ query_pos=query_embed,
96
+ )
97
+ return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
98
+
99
+
100
+ class TransformerEncoder(nn.Module):
101
+ """Transformer encoder implementation."""
102
+
103
+ def __init__(
104
+ self,
105
+ encoder_layer: "TransformerEncoderLayer",
106
+ num_layers: int,
107
+ norm: nn.Module | None = None,
108
+ ):
109
+ """Create a new TransformerEncoder."""
110
+ super().__init__()
111
+ self.layers = _get_clones(encoder_layer, num_layers)
112
+ self.num_layers = num_layers
113
+ self.norm = norm
114
+
115
+ def forward(
116
+ self,
117
+ src: Tensor,
118
+ mask: Tensor | None = None,
119
+ src_key_padding_mask: Tensor | None = None,
120
+ pos: Tensor | None = None,
121
+ ) -> Tensor:
122
+ """Forward pass through the TransformerEncoder."""
123
+ output = src
124
+
125
+ for layer in self.layers:
126
+ output = layer(
127
+ output,
128
+ src_mask=mask,
129
+ src_key_padding_mask=src_key_padding_mask,
130
+ pos=pos,
131
+ )
132
+
133
+ if self.norm is not None:
134
+ output = self.norm(output)
135
+
136
+ return output
137
+
138
+
139
+ class TransformerDecoder(nn.Module):
140
+ """Transformer decoder implementation."""
141
+
142
+ def __init__(
143
+ self,
144
+ decoder_layer: "TransformerDecoderLayer",
145
+ num_layers: int,
146
+ norm: nn.Module | None = None,
147
+ return_intermediate: bool = False,
148
+ ):
149
+ """Create a new TransformerDecoder."""
150
+ super().__init__()
151
+ self.layers = _get_clones(decoder_layer, num_layers)
152
+ self.num_layers = num_layers
153
+ if norm is None:
154
+ self.norm = nn.Identity()
155
+ else:
156
+ self.norm = norm
157
+ self.return_intermediate = return_intermediate
158
+
159
+ def forward(
160
+ self,
161
+ tgt: Tensor,
162
+ memory: Tensor,
163
+ tgt_mask: Tensor | None = None,
164
+ memory_mask: Tensor | None = None,
165
+ tgt_key_padding_mask: Tensor | None = None,
166
+ memory_key_padding_mask: Tensor | None = None,
167
+ pos: Tensor | None = None,
168
+ query_pos: Tensor | None = None,
169
+ ) -> Tensor:
170
+ """Forward pass through the TransformerDecoder."""
171
+ output = tgt
172
+
173
+ intermediate = []
174
+
175
+ for layer in self.layers:
176
+ output = layer(
177
+ output,
178
+ memory,
179
+ tgt_mask=tgt_mask,
180
+ memory_mask=memory_mask,
181
+ tgt_key_padding_mask=tgt_key_padding_mask,
182
+ memory_key_padding_mask=memory_key_padding_mask,
183
+ pos=pos,
184
+ query_pos=query_pos,
185
+ )
186
+ if self.return_intermediate:
187
+ intermediate.append(self.norm(output))
188
+
189
+ output = self.norm(output)
190
+ if self.return_intermediate:
191
+ intermediate.pop()
192
+ intermediate.append(output)
193
+
194
+ if self.return_intermediate:
195
+ return torch.stack(intermediate)
196
+
197
+ return output.unsqueeze(0)
198
+
199
+
200
+ class TransformerEncoderLayer(nn.Module):
201
+ """One layer in a TransformerEncoder."""
202
+
203
+ def __init__(
204
+ self,
205
+ d_model: int,
206
+ nhead: int,
207
+ dim_feedforward: int = 2048,
208
+ dropout: float = 0.1,
209
+ activation: str = "relu",
210
+ normalize_before: bool = False,
211
+ ):
212
+ """Create a new TransformerEncoderLayer."""
213
+ super().__init__()
214
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
215
+ # Implementation of Feedforward model
216
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
217
+ self.dropout = nn.Dropout(dropout)
218
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
219
+
220
+ self.norm1 = nn.LayerNorm(d_model)
221
+ self.norm2 = nn.LayerNorm(d_model)
222
+ self.dropout1 = nn.Dropout(dropout)
223
+ self.dropout2 = nn.Dropout(dropout)
224
+
225
+ self.activation = _get_activation_fn(activation)
226
+ self.normalize_before = normalize_before
227
+
228
+ def with_pos_embed(self, tensor: Tensor, pos: Tensor | None) -> Tensor:
229
+ """Add optional positional embedding to the tensor, if provided."""
230
+ return tensor if pos is None else tensor + pos
231
+
232
+ def forward_post(
233
+ self,
234
+ src: Tensor,
235
+ src_mask: Tensor | None = None,
236
+ src_key_padding_mask: Tensor | None = None,
237
+ pos: Tensor | None = None,
238
+ ) -> Tensor:
239
+ """Forward pass with normalization after layers."""
240
+ q = k = self.with_pos_embed(src, pos)
241
+ src2 = self.self_attn(
242
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
243
+ )[0]
244
+ src = src + self.dropout1(src2)
245
+ src = self.norm1(src)
246
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
247
+ src = src + self.dropout2(src2)
248
+ src = self.norm2(src)
249
+ return src
250
+
251
+ def forward_pre(
252
+ self,
253
+ src: Tensor,
254
+ src_mask: Tensor | None = None,
255
+ src_key_padding_mask: Tensor | None = None,
256
+ pos: Tensor | None = None,
257
+ ) -> Tensor:
258
+ """Forward pass with normalization before layers."""
259
+ src2 = self.norm1(src)
260
+ q = k = self.with_pos_embed(src2, pos)
261
+ src2 = self.self_attn(
262
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
263
+ )[0]
264
+ src = src + self.dropout1(src2)
265
+ src2 = self.norm2(src)
266
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
267
+ src = src + self.dropout2(src2)
268
+ return src
269
+
270
+ def forward(
271
+ self,
272
+ src: Tensor,
273
+ src_mask: Tensor | None = None,
274
+ src_key_padding_mask: Tensor | None = None,
275
+ pos: Tensor | None = None,
276
+ ) -> Tensor:
277
+ """Forward pass through the TransformerEncoderLayer."""
278
+ if self.normalize_before:
279
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
280
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
281
+
282
+
283
+ class TransformerDecoderLayer(nn.Module):
284
+ """One layer in a TransformerDecoder."""
285
+
286
+ def __init__(
287
+ self,
288
+ d_model: int,
289
+ nhead: int,
290
+ dim_feedforward: int = 2048,
291
+ dropout: float = 0.1,
292
+ activation: str = "relu",
293
+ normalize_before: bool = False,
294
+ ):
295
+ """Create a new TransformerDecoderLayer."""
296
+ super().__init__()
297
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
298
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
299
+ # Implementation of Feedforward model
300
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
301
+ self.dropout = nn.Dropout(dropout)
302
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
303
+
304
+ self.norm1 = nn.LayerNorm(d_model)
305
+ self.norm2 = nn.LayerNorm(d_model)
306
+ self.norm3 = nn.LayerNorm(d_model)
307
+ self.dropout1 = nn.Dropout(dropout)
308
+ self.dropout2 = nn.Dropout(dropout)
309
+ self.dropout3 = nn.Dropout(dropout)
310
+
311
+ self.activation = _get_activation_fn(activation)
312
+ self.normalize_before = normalize_before
313
+
314
+ def with_pos_embed(self, tensor: Tensor, pos: Tensor | None) -> Tensor:
315
+ """Add optional positional embedding to the tensor, if provided."""
316
+ return tensor if pos is None else tensor + pos
317
+
318
+ def forward_post(
319
+ self,
320
+ tgt: Tensor,
321
+ memory: Tensor,
322
+ tgt_mask: Tensor | None = None,
323
+ memory_mask: Tensor | None = None,
324
+ tgt_key_padding_mask: Tensor | None = None,
325
+ memory_key_padding_mask: Tensor | None = None,
326
+ pos: Tensor | None = None,
327
+ query_pos: Tensor | None = None,
328
+ ) -> Tensor:
329
+ """Forward pass with normalization after layers."""
330
+ q = k = self.with_pos_embed(tgt, query_pos)
331
+ tgt2 = self.self_attn(
332
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
333
+ )[0]
334
+ tgt = tgt + self.dropout1(tgt2)
335
+ tgt = self.norm1(tgt)
336
+ tgt2 = self.multihead_attn(
337
+ query=self.with_pos_embed(tgt, query_pos),
338
+ key=self.with_pos_embed(memory, pos),
339
+ value=memory,
340
+ attn_mask=memory_mask,
341
+ key_padding_mask=memory_key_padding_mask,
342
+ )[0]
343
+ tgt = tgt + self.dropout2(tgt2)
344
+ tgt = self.norm2(tgt)
345
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
346
+ tgt = tgt + self.dropout3(tgt2)
347
+ tgt = self.norm3(tgt)
348
+ return tgt
349
+
350
+ def forward_pre(
351
+ self,
352
+ tgt: Tensor,
353
+ memory: Tensor,
354
+ tgt_mask: Tensor | None = None,
355
+ memory_mask: Tensor | None = None,
356
+ tgt_key_padding_mask: Tensor | None = None,
357
+ memory_key_padding_mask: Tensor | None = None,
358
+ pos: Tensor | None = None,
359
+ query_pos: Tensor | None = None,
360
+ ) -> Tensor:
361
+ """Forward pass with normalization before layers."""
362
+ tgt2 = self.norm1(tgt)
363
+ q = k = self.with_pos_embed(tgt2, query_pos)
364
+ tgt2 = self.self_attn(
365
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
366
+ )[0]
367
+ tgt = tgt + self.dropout1(tgt2)
368
+ tgt2 = self.norm2(tgt)
369
+ tgt2 = self.multihead_attn(
370
+ query=self.with_pos_embed(tgt2, query_pos),
371
+ key=self.with_pos_embed(memory, pos),
372
+ value=memory,
373
+ attn_mask=memory_mask,
374
+ key_padding_mask=memory_key_padding_mask,
375
+ )[0]
376
+ tgt = tgt + self.dropout2(tgt2)
377
+ tgt2 = self.norm3(tgt)
378
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
379
+ tgt = tgt + self.dropout3(tgt2)
380
+ return tgt
381
+
382
+ def forward(
383
+ self,
384
+ tgt: Tensor,
385
+ memory: Tensor,
386
+ tgt_mask: Tensor | None = None,
387
+ memory_mask: Tensor | None = None,
388
+ tgt_key_padding_mask: Tensor | None = None,
389
+ memory_key_padding_mask: Tensor | None = None,
390
+ pos: Tensor | None = None,
391
+ query_pos: Tensor | None = None,
392
+ ) -> Tensor:
393
+ """Forward pass through the TransformerDecoderLayer."""
394
+ if self.normalize_before:
395
+ return self.forward_pre(
396
+ tgt,
397
+ memory,
398
+ tgt_mask,
399
+ memory_mask,
400
+ tgt_key_padding_mask,
401
+ memory_key_padding_mask,
402
+ pos,
403
+ query_pos,
404
+ )
405
+ return self.forward_post(
406
+ tgt,
407
+ memory,
408
+ tgt_mask,
409
+ memory_mask,
410
+ tgt_key_padding_mask,
411
+ memory_key_padding_mask,
412
+ pos,
413
+ query_pos,
414
+ )
415
+
416
+
417
+ def _get_clones(module: nn.Module, N: int) -> nn.ModuleList:
418
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
419
+
420
+
421
+ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
422
+ """Return an activation function given a string."""
423
+ if activation == "relu":
424
+ return F.relu
425
+ if activation == "gelu":
426
+ return F.gelu
427
+ if activation == "glu":
428
+ return F.glu
429
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
@@ -0,0 +1,24 @@
1
+ """Miscellaneous utilities for DETR."""
2
+
3
+ import torch
4
+
5
+
6
+ @torch.no_grad()
7
+ def accuracy(
8
+ output: torch.Tensor, target: torch.Tensor, topk: tuple[int, ...] = (1,)
9
+ ) -> list[torch.Tensor]:
10
+ """Computes the precision@k for the specified values of k."""
11
+ if target.numel() == 0:
12
+ return [torch.zeros([], device=output.device)]
13
+ maxk = max(topk)
14
+ batch_size = target.size(0)
15
+
16
+ _, pred = output.topk(maxk, 1, True, True)
17
+ pred = pred.t()
18
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
19
+
20
+ res = []
21
+ for k in topk:
22
+ correct_k = correct[:k].view(-1).float().sum(0)
23
+ res.append(correct_k.mul_(100.0 / batch_size))
24
+ return res