rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,504 @@
1
+ """DETR DEtection TRansformer decoder for object detection tasks.
2
+
3
+ Most of the modules here are adapted from here:
4
+ https://github.com/facebookresearch/detr/blob/29901c51d7fe8712168b8d0d64351170bc0f83e0/models/detr.py#L258
5
+ The original code is:
6
+ Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7
+ """
8
+
9
+ from typing import Any
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+
15
+ import rslearn.models.detr.box_ops as box_ops
16
+ from rslearn.models.component import FeatureMaps, Predictor
17
+ from rslearn.train.model_context import ModelContext, ModelOutput
18
+
19
+ from .matcher import HungarianMatcher
20
+ from .position_encoding import PositionEmbeddingSine
21
+ from .transformer import Transformer
22
+ from .util import accuracy
23
+
24
+ DEFAULT_WEIGHT_DICT: dict[str, float] = {
25
+ "loss_ce": 1,
26
+ "loss_bbox": 5,
27
+ "loss_giou": 2,
28
+ }
29
+
30
+
31
+ class MLP(nn.Module):
32
+ """Very simple multi-layer perceptron (also called FFN)."""
33
+
34
+ def __init__(
35
+ self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
36
+ ):
37
+ """Create a new MLP.
38
+
39
+ Args:
40
+ input_dim: input dimension.
41
+ hidden_dim: hidden dimension.
42
+ output_dim: output dimension.
43
+ num_layers: number of layers in this MLP.
44
+ """
45
+ super().__init__()
46
+ self.num_layers = num_layers
47
+ h = [hidden_dim] * (num_layers - 1)
48
+ self.layers = nn.ModuleList(
49
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
50
+ )
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ """Forward pass through the MLP."""
54
+ for i, layer in enumerate(self.layers):
55
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
56
+ return x
57
+
58
+
59
+ class DetrPredictor(nn.Module):
60
+ """DETR prediction module.
61
+
62
+ This is DETR up to and excluding computing the loss.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ in_channels: int,
68
+ num_classes: int,
69
+ num_queries: int = 100,
70
+ transformer: Transformer = Transformer(),
71
+ aux_loss: bool = False,
72
+ ):
73
+ """Initializes the model.
74
+
75
+ Args:
76
+ in_channels: number of channels in features computed by the backbone.
77
+ num_classes: number of object classes
78
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
79
+ DETR can detect in a single image. For COCO, we recommend 100 queries.
80
+ transformer: torch module of the transformer architecture. See transformer.py
81
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
82
+ """
83
+ super().__init__()
84
+ self.num_queries = num_queries
85
+ self.transformer = transformer
86
+ hidden_dim = transformer.d_model
87
+ self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
88
+ self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
89
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
90
+ self.input_proj = nn.Conv2d(in_channels, hidden_dim, kernel_size=1)
91
+ self.aux_loss = aux_loss
92
+
93
+ def forward(
94
+ self, feat_map: torch.Tensor, pos_embedding: torch.Tensor
95
+ ) -> dict[str, torch.Tensor]:
96
+ """Compute the detection outputs.
97
+
98
+ Args:
99
+ feat_map: the input feature map.
100
+ pos_embedding: positional embedding.
101
+
102
+ Returns:
103
+ output dict containing predicted boxes, classification logits, and
104
+ aux_outputs (if aux_loss is enabled).
105
+ """
106
+ hs = self.transformer(
107
+ src=self.input_proj(feat_map),
108
+ query_embed=self.query_embed.weight,
109
+ pos_embed=pos_embedding,
110
+ )[0]
111
+
112
+ outputs_class = self.class_embed(hs)
113
+ outputs_coord = self.bbox_embed(hs).sigmoid()
114
+ out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
115
+ if self.aux_loss:
116
+ out["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord)
117
+ return out
118
+
119
+ @torch.jit.unused
120
+ def _set_aux_loss(
121
+ self, outputs_class: torch.Tensor, outputs_coord: torch.Tensor
122
+ ) -> list[dict[str, torch.Tensor]]:
123
+ # this is a workaround to make torchscript happy, as torchscript
124
+ # doesn't support dictionary with non-homogeneous values, such
125
+ # as a dict having both a Tensor and a list.
126
+ return [
127
+ {"pred_logits": a, "pred_boxes": b}
128
+ for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
129
+ ]
130
+
131
+
132
+ class SetCriterion(nn.Module):
133
+ """SetCriterion computes the loss for DETR.
134
+
135
+ The process happens in two steps:
136
+ (1) we compute hungarian assignment between ground truth boxes and the outputs of the model
137
+ (2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
138
+ """
139
+
140
+ def __init__(
141
+ self,
142
+ num_classes: int,
143
+ matcher: HungarianMatcher = HungarianMatcher(),
144
+ weight_dict: dict[str, float] = DEFAULT_WEIGHT_DICT,
145
+ eos_coef: float = 0.1,
146
+ losses: list[str] = ["labels", "boxes", "cardinality"],
147
+ ):
148
+ """Create a SetCriterion.
149
+
150
+ Args:
151
+ num_classes: number of object categories, omitting the special no-object category
152
+ matcher: module able to compute a matching between targets and proposals
153
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
154
+ eos_coef: relative classification weight applied to the no-object category
155
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
156
+ """
157
+ super().__init__()
158
+ self.num_classes = num_classes
159
+ self.matcher = matcher
160
+ self.weight_dict = weight_dict
161
+ self.eos_coef = eos_coef
162
+ self.losses = losses
163
+ empty_weight = torch.ones(self.num_classes + 1)
164
+ empty_weight[-1] = self.eos_coef
165
+ self.register_buffer("empty_weight", empty_weight)
166
+
167
+ def loss_labels(
168
+ self,
169
+ outputs: dict[str, torch.Tensor],
170
+ targets: list[dict[str, torch.Tensor]],
171
+ indices: list[tuple[torch.Tensor, torch.Tensor]],
172
+ num_boxes: int,
173
+ log: bool = True,
174
+ ) -> dict[str, torch.Tensor]:
175
+ """Compute classification loss (NLL).
176
+
177
+ Args:
178
+ outputs: the outputs from the model.
179
+ targets: target dicts, which must contain the key "labels" containing a tensor of dim [nb_target_boxes].
180
+ indices: the matching indices between outputs and targets.
181
+ num_boxes: number of boxes, ignored.
182
+ log: whether to add additional metrics to the loss dict for logging.
183
+
184
+ Returns:
185
+ loss dict, mapping from loss name to value. The actual loss is stored under
186
+ loss_ce.
187
+ """
188
+ assert "pred_logits" in outputs
189
+ src_logits = outputs["pred_logits"]
190
+
191
+ idx = self._get_src_permutation_idx(indices)
192
+ target_classes_o = torch.cat(
193
+ [t["labels"][J] for t, (_, J) in zip(targets, indices)]
194
+ )
195
+ target_classes = torch.full(
196
+ src_logits.shape[:2],
197
+ self.num_classes,
198
+ dtype=torch.int64,
199
+ device=src_logits.device,
200
+ )
201
+ target_classes[idx] = target_classes_o
202
+
203
+ loss_ce = F.cross_entropy(
204
+ src_logits.transpose(1, 2), target_classes, self.empty_weight
205
+ )
206
+ losses = {"loss_ce": loss_ce}
207
+
208
+ if log:
209
+ # TODO this should probably be a separate loss, not hacked in this one here
210
+ losses["class_error"] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
211
+ return losses
212
+
213
+ @torch.no_grad()
214
+ def loss_cardinality(
215
+ self,
216
+ outputs: dict[str, torch.Tensor],
217
+ targets: list[dict[str, torch.Tensor]],
218
+ indices: list[tuple[torch.Tensor, torch.Tensor]],
219
+ num_boxes: int,
220
+ ) -> dict[str, torch.Tensor]:
221
+ """Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes.
222
+
223
+ This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
224
+ """
225
+ pred_logits = outputs["pred_logits"]
226
+ device = pred_logits.device
227
+ tgt_lengths = torch.as_tensor(
228
+ [len(v["labels"]) for v in targets], device=device
229
+ )
230
+ # Count the number of predictions that are NOT "no-object" (which is the last class)
231
+ card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
232
+ card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
233
+ losses = {"cardinality_error": card_err}
234
+ return losses
235
+
236
+ def loss_boxes(
237
+ self,
238
+ outputs: dict[str, torch.Tensor],
239
+ targets: list[dict[str, torch.Tensor]],
240
+ indices: list[tuple[torch.Tensor, torch.Tensor]],
241
+ num_boxes: int,
242
+ ) -> dict[str, torch.Tensor]:
243
+ """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
244
+
245
+ targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
246
+ The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
247
+ """
248
+ assert "pred_boxes" in outputs
249
+ idx = self._get_src_permutation_idx(indices)
250
+ src_boxes = outputs["pred_boxes"][idx]
251
+ target_boxes = torch.cat(
252
+ [t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0
253
+ )
254
+
255
+ loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
256
+
257
+ losses = {}
258
+ losses["loss_bbox"] = loss_bbox.sum() / num_boxes
259
+
260
+ loss_giou = 1 - torch.diag(
261
+ box_ops.generalized_box_iou(
262
+ box_ops.box_cxcywh_to_xyxy(src_boxes),
263
+ box_ops.box_cxcywh_to_xyxy(target_boxes),
264
+ )
265
+ )
266
+ losses["loss_giou"] = loss_giou.sum() / num_boxes
267
+ return losses
268
+
269
+ def _get_src_permutation_idx(
270
+ self, indices: list[tuple[torch.Tensor, torch.Tensor]]
271
+ ) -> tuple[torch.Tensor, torch.Tensor]:
272
+ # permute predictions following indices
273
+ batch_idx = torch.cat(
274
+ [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]
275
+ )
276
+ src_idx = torch.cat([src for (src, _) in indices])
277
+ return batch_idx, src_idx
278
+
279
+ def _get_tgt_permutation_idx(
280
+ self, indices: list[tuple[torch.Tensor, torch.Tensor]]
281
+ ) -> tuple[torch.Tensor, torch.Tensor]:
282
+ # permute targets following indices
283
+ batch_idx = torch.cat(
284
+ [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]
285
+ )
286
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
287
+ return batch_idx, tgt_idx
288
+
289
+ def get_loss(
290
+ self,
291
+ loss: str,
292
+ outputs: dict[str, torch.Tensor],
293
+ targets: list[dict[str, torch.Tensor]],
294
+ indices: list[tuple[torch.Tensor, torch.Tensor]],
295
+ num_boxes: int,
296
+ **kwargs: Any,
297
+ ) -> dict[str, torch.Tensor]:
298
+ """Compute the specified loss.
299
+
300
+ Args:
301
+ loss: the name of the loss to compute.
302
+ outputs: the outputs from the model.
303
+ targets: the targets.
304
+ indices: the corresponding output/target indices from the matcher.
305
+ num_boxes: the number of target boxes.
306
+ kwargs: additional arguments to pass to the loss function.
307
+
308
+ Returns:
309
+ the loss dict.
310
+ """
311
+ loss_map = {
312
+ "labels": self.loss_labels,
313
+ "cardinality": self.loss_cardinality,
314
+ "boxes": self.loss_boxes,
315
+ }
316
+ assert loss in loss_map, f"do you really want to compute {loss} loss?"
317
+ return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
318
+
319
+ def forward(
320
+ self, outputs: dict[str, Any], targets: list[dict[str, torch.Tensor]]
321
+ ) -> dict[str, torch.Tensor]:
322
+ """This performs the loss computation.
323
+
324
+ Args:
325
+ outputs: dict of tensors, see the output specification of the model for the format
326
+ targets: list of dicts, such that len(targets) == batch_size.
327
+ The expected keys in each dict depends on the losses applied, see each loss' doc
328
+ """
329
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
330
+
331
+ # Retrieve the matching between the outputs of the last layer and the targets
332
+ indices = self.matcher(outputs_without_aux, targets)
333
+
334
+ num_boxes = sum(len(t["labels"]) for t in targets)
335
+ num_boxes = torch.as_tensor([num_boxes])
336
+ num_boxes = torch.clamp(num_boxes, min=1).item()
337
+
338
+ # Compute all the requested losses
339
+ losses = {}
340
+ for loss in self.losses:
341
+ losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
342
+
343
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
344
+ if "aux_outputs" in outputs:
345
+ for i, aux_outputs in enumerate(outputs["aux_outputs"]):
346
+ indices = self.matcher(aux_outputs, targets)
347
+ for loss in self.losses:
348
+ if loss == "masks":
349
+ # Intermediate masks losses are too costly to compute, we ignore them.
350
+ continue
351
+ kwargs = {}
352
+ if loss == "labels":
353
+ # Logging is enabled only for the last layer
354
+ kwargs = {"log": False}
355
+ l_dict = self.get_loss(
356
+ loss, aux_outputs, targets, indices, num_boxes, **kwargs
357
+ )
358
+ l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
359
+ losses.update(l_dict)
360
+
361
+ # Apply weights.
362
+ # We only keep the ones present in weight dict, since there may be others that
363
+ # are only produced for logging purposes (not that we're logging them).
364
+ final_losses = {
365
+ k: loss * self.weight_dict[k]
366
+ for k, loss in losses.items()
367
+ if k in self.weight_dict
368
+ }
369
+ return final_losses
370
+
371
+
372
+ class PostProcess(nn.Module):
373
+ """PostProcess converts the model output into the COCO format used by rslearn."""
374
+
375
+ @torch.no_grad()
376
+ def forward(
377
+ self, outputs: dict[str, torch.Tensor], target_sizes: torch.Tensor
378
+ ) -> list[dict[str, torch.Tensor]]:
379
+ """Forward pass for PostProcess to perform the output format conversion.
380
+
381
+ Args:
382
+ outputs: raw outputs of the model
383
+ target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch.
384
+ For evaluation, this must be the original image size (before any data augmentation).
385
+ For visualization, this should be the image size after data augment, but before padding.
386
+ """
387
+ out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"]
388
+
389
+ assert len(out_logits) == len(target_sizes)
390
+ assert target_sizes.shape[1] == 2
391
+
392
+ prob = F.softmax(out_logits, -1)
393
+ scores, labels = prob[..., :-1].max(-1)
394
+
395
+ # convert to [x0, y0, x1, y1] format
396
+ boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
397
+ # and from relative [0, 1] to absolute [0, height] coordinates
398
+ img_h, img_w = target_sizes.unbind(1)
399
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
400
+ boxes = boxes * scale_fct[:, None, :]
401
+
402
+ results = [
403
+ {"scores": cur_scores, "labels": cur_labels, "boxes": cur_boxes}
404
+ for cur_scores, cur_labels, cur_boxes in zip(scores, labels, boxes)
405
+ ]
406
+
407
+ return results
408
+
409
+
410
+ class Detr(Predictor):
411
+ """DETR prediction module.
412
+
413
+ This combines PositionEmbeddingSine, DetrPredictor, SetCriterion, and PostProcess.
414
+
415
+ This is the module that should be used as a decoder component in rslearn.
416
+ """
417
+
418
+ def __init__(self, predictor: DetrPredictor, criterion: SetCriterion):
419
+ """Create a Detr.
420
+
421
+ Args:
422
+ predictor: the DetrPredictor.
423
+ criterion: the SetCriterion.
424
+ """
425
+ super().__init__()
426
+ self.predictor = predictor
427
+ self.criterion = criterion
428
+ self.pos_embedding = PositionEmbeddingSine(
429
+ num_pos_feats=predictor.transformer.d_model // 2, normalize=True
430
+ )
431
+ self.postprocess = PostProcess()
432
+
433
+ if predictor.aux_loss:
434
+ # Hack to make sure it's included in the weight dict for the criterion.
435
+ aux_weight_dict = {}
436
+ num_dec_layers = len(predictor.transformer.decoder.layers)
437
+ for i in range(num_dec_layers - 1):
438
+ aux_weight_dict.update(
439
+ {f"{k}_{i}": v for k, v in self.criterion.weight_dict.items()}
440
+ )
441
+ self.criterion.weight_dict.update(aux_weight_dict)
442
+
443
+ def forward(
444
+ self,
445
+ intermediates: Any,
446
+ context: ModelContext,
447
+ targets: list[dict[str, Any]] | None = None,
448
+ ) -> ModelOutput:
449
+ """Compute the detection outputs and loss from features.
450
+
451
+ DETR will use only the last feature map, which should correspond to the lowest
452
+ resolution one.
453
+
454
+ Args:
455
+ intermediates: the output from the previous component. It must be a FeatureMaps.
456
+ context: the model context. Input dicts must contain an "image" key which we will
457
+ be used to establish the original image size.
458
+ targets: must contain class key that stores the class label.
459
+
460
+ Returns:
461
+ the model output.
462
+ """
463
+ if not isinstance(intermediates, FeatureMaps):
464
+ raise ValueError("input to Detr must be a FeatureMaps")
465
+
466
+ # We only use the last feature map (most fine-grained).
467
+ features = intermediates.feature_maps[-1]
468
+
469
+ # Get image sizes.
470
+ image_sizes = torch.tensor(
471
+ [[inp["image"].shape[2], inp["image"].shape[1]] for inp in context.inputs],
472
+ dtype=torch.int32,
473
+ device=features.device,
474
+ )
475
+
476
+ pos_embedding = self.pos_embedding(features)
477
+ outputs = self.predictor(features, pos_embedding)
478
+
479
+ if targets is not None:
480
+ # Convert boxes from [x0, y0, x1, y1] to [cx, cy, w, h].
481
+ converted_targets = []
482
+ for target, image_size in zip(targets, image_sizes):
483
+ boxes = target["boxes"]
484
+ img_w, img_h = image_size
485
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h])
486
+ boxes = boxes / scale_fct
487
+ boxes = box_ops.box_xyxy_to_cxcywh(boxes)
488
+ converted_targets.append(
489
+ {
490
+ "boxes": boxes,
491
+ "labels": target["labels"],
492
+ }
493
+ )
494
+
495
+ losses = self.criterion(outputs, converted_targets)
496
+ else:
497
+ losses = {}
498
+
499
+ results = self.postprocess(outputs, image_sizes)
500
+
501
+ return ModelOutput(
502
+ outputs=results,
503
+ loss_dict=losses,
504
+ )
@@ -0,0 +1,107 @@
1
+ """Modules to compute the matching cost and solve the corresponding LSAP.
2
+
3
+ This is copied from https://github.com/facebookresearch/detr/.
4
+ The original code is:
5
+ Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
6
+ """
7
+
8
+ import torch
9
+ from scipy.optimize import linear_sum_assignment
10
+ from torch import nn
11
+
12
+ from .box_ops import box_cxcywh_to_xyxy, generalized_box_iou
13
+
14
+
15
+ class HungarianMatcher(nn.Module):
16
+ """This class computes an assignment between the targets and the predictions of the network.
17
+
18
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
19
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
20
+ while the others are un-matched (and thus treated as non-objects).
21
+ """
22
+
23
+ def __init__(
24
+ self, cost_class: float = 1, cost_bbox: float = 5, cost_giou: float = 2
25
+ ):
26
+ """Creates the matcher.
27
+
28
+ Params:
29
+ cost_class: This is the relative weight of the classification error in the matching cost
30
+ cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
31
+ cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
32
+ """
33
+ super().__init__()
34
+ self.cost_class = cost_class
35
+ self.cost_bbox = cost_bbox
36
+ self.cost_giou = cost_giou
37
+ assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
38
+ "all costs cant be 0"
39
+ )
40
+
41
+ @torch.no_grad()
42
+ def forward(
43
+ self, outputs: dict[str, torch.Tensor], targets: list[dict[str, torch.Tensor]]
44
+ ) -> list[tuple[torch.Tensor, torch.Tensor]]:
45
+ """Performs the matching.
46
+
47
+ Params:
48
+ outputs: This is a dict that contains at least these entries:
49
+ "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
50
+ "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
51
+
52
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
53
+ "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
54
+ objects in the target) containing the class labels
55
+ "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
56
+
57
+ Returns:
58
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
59
+ - index_i is the indices of the selected predictions (in order)
60
+ - index_j is the indices of the corresponding selected targets (in order)
61
+ For each batch element, it holds:
62
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
63
+ """
64
+ bs, num_queries = outputs["pred_logits"].shape[:2]
65
+
66
+ # We flatten to compute the cost matrices in a batch
67
+ out_prob = (
68
+ outputs["pred_logits"].flatten(0, 1).softmax(-1)
69
+ ) # [batch_size * num_queries, num_classes]
70
+ out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
71
+
72
+ # Also concat the target labels and boxes
73
+ tgt_ids = torch.cat([v["labels"] for v in targets])
74
+ tgt_bbox = torch.cat([v["boxes"] for v in targets])
75
+
76
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
77
+ # but approximate it in 1 - proba[target class].
78
+ # The 1 is a constant that doesn't change the matching, it can be ommitted.
79
+ cost_class = -out_prob[:, tgt_ids]
80
+
81
+ # Compute the L1 cost between boxes
82
+ cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
83
+
84
+ # Compute the giou cost betwen boxes
85
+ cost_giou = -generalized_box_iou(
86
+ box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)
87
+ )
88
+
89
+ # Final cost matrix
90
+ C = (
91
+ self.cost_bbox * cost_bbox
92
+ + self.cost_class * cost_class
93
+ + self.cost_giou * cost_giou
94
+ )
95
+ C = C.view(bs, num_queries, -1).cpu()
96
+
97
+ sizes = [len(v["boxes"]) for v in targets]
98
+ indices = [
99
+ linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))
100
+ ]
101
+ return [
102
+ (
103
+ torch.as_tensor(i, dtype=torch.int64),
104
+ torch.as_tensor(j, dtype=torch.int64),
105
+ )
106
+ for i, j in indices
107
+ ]