rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,504 @@
|
|
|
1
|
+
"""DETR DEtection TRansformer decoder for object detection tasks.
|
|
2
|
+
|
|
3
|
+
Most of the modules here are adapted from here:
|
|
4
|
+
https://github.com/facebookresearch/detr/blob/29901c51d7fe8712168b8d0d64351170bc0f83e0/models/detr.py#L258
|
|
5
|
+
The original code is:
|
|
6
|
+
Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
from torch import nn
|
|
14
|
+
|
|
15
|
+
import rslearn.models.detr.box_ops as box_ops
|
|
16
|
+
from rslearn.models.component import FeatureMaps, Predictor
|
|
17
|
+
from rslearn.train.model_context import ModelContext, ModelOutput
|
|
18
|
+
|
|
19
|
+
from .matcher import HungarianMatcher
|
|
20
|
+
from .position_encoding import PositionEmbeddingSine
|
|
21
|
+
from .transformer import Transformer
|
|
22
|
+
from .util import accuracy
|
|
23
|
+
|
|
24
|
+
DEFAULT_WEIGHT_DICT: dict[str, float] = {
|
|
25
|
+
"loss_ce": 1,
|
|
26
|
+
"loss_bbox": 5,
|
|
27
|
+
"loss_giou": 2,
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MLP(nn.Module):
|
|
32
|
+
"""Very simple multi-layer perceptron (also called FFN)."""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
|
|
36
|
+
):
|
|
37
|
+
"""Create a new MLP.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
input_dim: input dimension.
|
|
41
|
+
hidden_dim: hidden dimension.
|
|
42
|
+
output_dim: output dimension.
|
|
43
|
+
num_layers: number of layers in this MLP.
|
|
44
|
+
"""
|
|
45
|
+
super().__init__()
|
|
46
|
+
self.num_layers = num_layers
|
|
47
|
+
h = [hidden_dim] * (num_layers - 1)
|
|
48
|
+
self.layers = nn.ModuleList(
|
|
49
|
+
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
53
|
+
"""Forward pass through the MLP."""
|
|
54
|
+
for i, layer in enumerate(self.layers):
|
|
55
|
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
|
56
|
+
return x
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class DetrPredictor(nn.Module):
|
|
60
|
+
"""DETR prediction module.
|
|
61
|
+
|
|
62
|
+
This is DETR up to and excluding computing the loss.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
in_channels: int,
|
|
68
|
+
num_classes: int,
|
|
69
|
+
num_queries: int = 100,
|
|
70
|
+
transformer: Transformer = Transformer(),
|
|
71
|
+
aux_loss: bool = False,
|
|
72
|
+
):
|
|
73
|
+
"""Initializes the model.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
in_channels: number of channels in features computed by the backbone.
|
|
77
|
+
num_classes: number of object classes
|
|
78
|
+
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
|
79
|
+
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
|
80
|
+
transformer: torch module of the transformer architecture. See transformer.py
|
|
81
|
+
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
|
82
|
+
"""
|
|
83
|
+
super().__init__()
|
|
84
|
+
self.num_queries = num_queries
|
|
85
|
+
self.transformer = transformer
|
|
86
|
+
hidden_dim = transformer.d_model
|
|
87
|
+
self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
|
|
88
|
+
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
|
|
89
|
+
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
|
90
|
+
self.input_proj = nn.Conv2d(in_channels, hidden_dim, kernel_size=1)
|
|
91
|
+
self.aux_loss = aux_loss
|
|
92
|
+
|
|
93
|
+
def forward(
|
|
94
|
+
self, feat_map: torch.Tensor, pos_embedding: torch.Tensor
|
|
95
|
+
) -> dict[str, torch.Tensor]:
|
|
96
|
+
"""Compute the detection outputs.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
feat_map: the input feature map.
|
|
100
|
+
pos_embedding: positional embedding.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
output dict containing predicted boxes, classification logits, and
|
|
104
|
+
aux_outputs (if aux_loss is enabled).
|
|
105
|
+
"""
|
|
106
|
+
hs = self.transformer(
|
|
107
|
+
src=self.input_proj(feat_map),
|
|
108
|
+
query_embed=self.query_embed.weight,
|
|
109
|
+
pos_embed=pos_embedding,
|
|
110
|
+
)[0]
|
|
111
|
+
|
|
112
|
+
outputs_class = self.class_embed(hs)
|
|
113
|
+
outputs_coord = self.bbox_embed(hs).sigmoid()
|
|
114
|
+
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
|
|
115
|
+
if self.aux_loss:
|
|
116
|
+
out["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord)
|
|
117
|
+
return out
|
|
118
|
+
|
|
119
|
+
@torch.jit.unused
|
|
120
|
+
def _set_aux_loss(
|
|
121
|
+
self, outputs_class: torch.Tensor, outputs_coord: torch.Tensor
|
|
122
|
+
) -> list[dict[str, torch.Tensor]]:
|
|
123
|
+
# this is a workaround to make torchscript happy, as torchscript
|
|
124
|
+
# doesn't support dictionary with non-homogeneous values, such
|
|
125
|
+
# as a dict having both a Tensor and a list.
|
|
126
|
+
return [
|
|
127
|
+
{"pred_logits": a, "pred_boxes": b}
|
|
128
|
+
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
|
|
129
|
+
]
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class SetCriterion(nn.Module):
|
|
133
|
+
"""SetCriterion computes the loss for DETR.
|
|
134
|
+
|
|
135
|
+
The process happens in two steps:
|
|
136
|
+
(1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
|
137
|
+
(2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
def __init__(
|
|
141
|
+
self,
|
|
142
|
+
num_classes: int,
|
|
143
|
+
matcher: HungarianMatcher = HungarianMatcher(),
|
|
144
|
+
weight_dict: dict[str, float] = DEFAULT_WEIGHT_DICT,
|
|
145
|
+
eos_coef: float = 0.1,
|
|
146
|
+
losses: list[str] = ["labels", "boxes", "cardinality"],
|
|
147
|
+
):
|
|
148
|
+
"""Create a SetCriterion.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
num_classes: number of object categories, omitting the special no-object category
|
|
152
|
+
matcher: module able to compute a matching between targets and proposals
|
|
153
|
+
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
|
154
|
+
eos_coef: relative classification weight applied to the no-object category
|
|
155
|
+
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
|
156
|
+
"""
|
|
157
|
+
super().__init__()
|
|
158
|
+
self.num_classes = num_classes
|
|
159
|
+
self.matcher = matcher
|
|
160
|
+
self.weight_dict = weight_dict
|
|
161
|
+
self.eos_coef = eos_coef
|
|
162
|
+
self.losses = losses
|
|
163
|
+
empty_weight = torch.ones(self.num_classes + 1)
|
|
164
|
+
empty_weight[-1] = self.eos_coef
|
|
165
|
+
self.register_buffer("empty_weight", empty_weight)
|
|
166
|
+
|
|
167
|
+
def loss_labels(
|
|
168
|
+
self,
|
|
169
|
+
outputs: dict[str, torch.Tensor],
|
|
170
|
+
targets: list[dict[str, torch.Tensor]],
|
|
171
|
+
indices: list[tuple[torch.Tensor, torch.Tensor]],
|
|
172
|
+
num_boxes: int,
|
|
173
|
+
log: bool = True,
|
|
174
|
+
) -> dict[str, torch.Tensor]:
|
|
175
|
+
"""Compute classification loss (NLL).
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
outputs: the outputs from the model.
|
|
179
|
+
targets: target dicts, which must contain the key "labels" containing a tensor of dim [nb_target_boxes].
|
|
180
|
+
indices: the matching indices between outputs and targets.
|
|
181
|
+
num_boxes: number of boxes, ignored.
|
|
182
|
+
log: whether to add additional metrics to the loss dict for logging.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
loss dict, mapping from loss name to value. The actual loss is stored under
|
|
186
|
+
loss_ce.
|
|
187
|
+
"""
|
|
188
|
+
assert "pred_logits" in outputs
|
|
189
|
+
src_logits = outputs["pred_logits"]
|
|
190
|
+
|
|
191
|
+
idx = self._get_src_permutation_idx(indices)
|
|
192
|
+
target_classes_o = torch.cat(
|
|
193
|
+
[t["labels"][J] for t, (_, J) in zip(targets, indices)]
|
|
194
|
+
)
|
|
195
|
+
target_classes = torch.full(
|
|
196
|
+
src_logits.shape[:2],
|
|
197
|
+
self.num_classes,
|
|
198
|
+
dtype=torch.int64,
|
|
199
|
+
device=src_logits.device,
|
|
200
|
+
)
|
|
201
|
+
target_classes[idx] = target_classes_o
|
|
202
|
+
|
|
203
|
+
loss_ce = F.cross_entropy(
|
|
204
|
+
src_logits.transpose(1, 2), target_classes, self.empty_weight
|
|
205
|
+
)
|
|
206
|
+
losses = {"loss_ce": loss_ce}
|
|
207
|
+
|
|
208
|
+
if log:
|
|
209
|
+
# TODO this should probably be a separate loss, not hacked in this one here
|
|
210
|
+
losses["class_error"] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
|
|
211
|
+
return losses
|
|
212
|
+
|
|
213
|
+
@torch.no_grad()
|
|
214
|
+
def loss_cardinality(
|
|
215
|
+
self,
|
|
216
|
+
outputs: dict[str, torch.Tensor],
|
|
217
|
+
targets: list[dict[str, torch.Tensor]],
|
|
218
|
+
indices: list[tuple[torch.Tensor, torch.Tensor]],
|
|
219
|
+
num_boxes: int,
|
|
220
|
+
) -> dict[str, torch.Tensor]:
|
|
221
|
+
"""Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes.
|
|
222
|
+
|
|
223
|
+
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
|
|
224
|
+
"""
|
|
225
|
+
pred_logits = outputs["pred_logits"]
|
|
226
|
+
device = pred_logits.device
|
|
227
|
+
tgt_lengths = torch.as_tensor(
|
|
228
|
+
[len(v["labels"]) for v in targets], device=device
|
|
229
|
+
)
|
|
230
|
+
# Count the number of predictions that are NOT "no-object" (which is the last class)
|
|
231
|
+
card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
|
|
232
|
+
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
|
|
233
|
+
losses = {"cardinality_error": card_err}
|
|
234
|
+
return losses
|
|
235
|
+
|
|
236
|
+
def loss_boxes(
|
|
237
|
+
self,
|
|
238
|
+
outputs: dict[str, torch.Tensor],
|
|
239
|
+
targets: list[dict[str, torch.Tensor]],
|
|
240
|
+
indices: list[tuple[torch.Tensor, torch.Tensor]],
|
|
241
|
+
num_boxes: int,
|
|
242
|
+
) -> dict[str, torch.Tensor]:
|
|
243
|
+
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
|
|
244
|
+
|
|
245
|
+
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
|
|
246
|
+
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
|
|
247
|
+
"""
|
|
248
|
+
assert "pred_boxes" in outputs
|
|
249
|
+
idx = self._get_src_permutation_idx(indices)
|
|
250
|
+
src_boxes = outputs["pred_boxes"][idx]
|
|
251
|
+
target_boxes = torch.cat(
|
|
252
|
+
[t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
|
|
256
|
+
|
|
257
|
+
losses = {}
|
|
258
|
+
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
|
259
|
+
|
|
260
|
+
loss_giou = 1 - torch.diag(
|
|
261
|
+
box_ops.generalized_box_iou(
|
|
262
|
+
box_ops.box_cxcywh_to_xyxy(src_boxes),
|
|
263
|
+
box_ops.box_cxcywh_to_xyxy(target_boxes),
|
|
264
|
+
)
|
|
265
|
+
)
|
|
266
|
+
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
|
267
|
+
return losses
|
|
268
|
+
|
|
269
|
+
def _get_src_permutation_idx(
|
|
270
|
+
self, indices: list[tuple[torch.Tensor, torch.Tensor]]
|
|
271
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
272
|
+
# permute predictions following indices
|
|
273
|
+
batch_idx = torch.cat(
|
|
274
|
+
[torch.full_like(src, i) for i, (src, _) in enumerate(indices)]
|
|
275
|
+
)
|
|
276
|
+
src_idx = torch.cat([src for (src, _) in indices])
|
|
277
|
+
return batch_idx, src_idx
|
|
278
|
+
|
|
279
|
+
def _get_tgt_permutation_idx(
|
|
280
|
+
self, indices: list[tuple[torch.Tensor, torch.Tensor]]
|
|
281
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
282
|
+
# permute targets following indices
|
|
283
|
+
batch_idx = torch.cat(
|
|
284
|
+
[torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]
|
|
285
|
+
)
|
|
286
|
+
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
|
287
|
+
return batch_idx, tgt_idx
|
|
288
|
+
|
|
289
|
+
def get_loss(
|
|
290
|
+
self,
|
|
291
|
+
loss: str,
|
|
292
|
+
outputs: dict[str, torch.Tensor],
|
|
293
|
+
targets: list[dict[str, torch.Tensor]],
|
|
294
|
+
indices: list[tuple[torch.Tensor, torch.Tensor]],
|
|
295
|
+
num_boxes: int,
|
|
296
|
+
**kwargs: Any,
|
|
297
|
+
) -> dict[str, torch.Tensor]:
|
|
298
|
+
"""Compute the specified loss.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
loss: the name of the loss to compute.
|
|
302
|
+
outputs: the outputs from the model.
|
|
303
|
+
targets: the targets.
|
|
304
|
+
indices: the corresponding output/target indices from the matcher.
|
|
305
|
+
num_boxes: the number of target boxes.
|
|
306
|
+
kwargs: additional arguments to pass to the loss function.
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
the loss dict.
|
|
310
|
+
"""
|
|
311
|
+
loss_map = {
|
|
312
|
+
"labels": self.loss_labels,
|
|
313
|
+
"cardinality": self.loss_cardinality,
|
|
314
|
+
"boxes": self.loss_boxes,
|
|
315
|
+
}
|
|
316
|
+
assert loss in loss_map, f"do you really want to compute {loss} loss?"
|
|
317
|
+
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
|
|
318
|
+
|
|
319
|
+
def forward(
|
|
320
|
+
self, outputs: dict[str, Any], targets: list[dict[str, torch.Tensor]]
|
|
321
|
+
) -> dict[str, torch.Tensor]:
|
|
322
|
+
"""This performs the loss computation.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
outputs: dict of tensors, see the output specification of the model for the format
|
|
326
|
+
targets: list of dicts, such that len(targets) == batch_size.
|
|
327
|
+
The expected keys in each dict depends on the losses applied, see each loss' doc
|
|
328
|
+
"""
|
|
329
|
+
outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
|
|
330
|
+
|
|
331
|
+
# Retrieve the matching between the outputs of the last layer and the targets
|
|
332
|
+
indices = self.matcher(outputs_without_aux, targets)
|
|
333
|
+
|
|
334
|
+
num_boxes = sum(len(t["labels"]) for t in targets)
|
|
335
|
+
num_boxes = torch.as_tensor([num_boxes])
|
|
336
|
+
num_boxes = torch.clamp(num_boxes, min=1).item()
|
|
337
|
+
|
|
338
|
+
# Compute all the requested losses
|
|
339
|
+
losses = {}
|
|
340
|
+
for loss in self.losses:
|
|
341
|
+
losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
|
|
342
|
+
|
|
343
|
+
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
|
344
|
+
if "aux_outputs" in outputs:
|
|
345
|
+
for i, aux_outputs in enumerate(outputs["aux_outputs"]):
|
|
346
|
+
indices = self.matcher(aux_outputs, targets)
|
|
347
|
+
for loss in self.losses:
|
|
348
|
+
if loss == "masks":
|
|
349
|
+
# Intermediate masks losses are too costly to compute, we ignore them.
|
|
350
|
+
continue
|
|
351
|
+
kwargs = {}
|
|
352
|
+
if loss == "labels":
|
|
353
|
+
# Logging is enabled only for the last layer
|
|
354
|
+
kwargs = {"log": False}
|
|
355
|
+
l_dict = self.get_loss(
|
|
356
|
+
loss, aux_outputs, targets, indices, num_boxes, **kwargs
|
|
357
|
+
)
|
|
358
|
+
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
|
|
359
|
+
losses.update(l_dict)
|
|
360
|
+
|
|
361
|
+
# Apply weights.
|
|
362
|
+
# We only keep the ones present in weight dict, since there may be others that
|
|
363
|
+
# are only produced for logging purposes (not that we're logging them).
|
|
364
|
+
final_losses = {
|
|
365
|
+
k: loss * self.weight_dict[k]
|
|
366
|
+
for k, loss in losses.items()
|
|
367
|
+
if k in self.weight_dict
|
|
368
|
+
}
|
|
369
|
+
return final_losses
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
class PostProcess(nn.Module):
|
|
373
|
+
"""PostProcess converts the model output into the COCO format used by rslearn."""
|
|
374
|
+
|
|
375
|
+
@torch.no_grad()
|
|
376
|
+
def forward(
|
|
377
|
+
self, outputs: dict[str, torch.Tensor], target_sizes: torch.Tensor
|
|
378
|
+
) -> list[dict[str, torch.Tensor]]:
|
|
379
|
+
"""Forward pass for PostProcess to perform the output format conversion.
|
|
380
|
+
|
|
381
|
+
Args:
|
|
382
|
+
outputs: raw outputs of the model
|
|
383
|
+
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch.
|
|
384
|
+
For evaluation, this must be the original image size (before any data augmentation).
|
|
385
|
+
For visualization, this should be the image size after data augment, but before padding.
|
|
386
|
+
"""
|
|
387
|
+
out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"]
|
|
388
|
+
|
|
389
|
+
assert len(out_logits) == len(target_sizes)
|
|
390
|
+
assert target_sizes.shape[1] == 2
|
|
391
|
+
|
|
392
|
+
prob = F.softmax(out_logits, -1)
|
|
393
|
+
scores, labels = prob[..., :-1].max(-1)
|
|
394
|
+
|
|
395
|
+
# convert to [x0, y0, x1, y1] format
|
|
396
|
+
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
|
|
397
|
+
# and from relative [0, 1] to absolute [0, height] coordinates
|
|
398
|
+
img_h, img_w = target_sizes.unbind(1)
|
|
399
|
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
|
|
400
|
+
boxes = boxes * scale_fct[:, None, :]
|
|
401
|
+
|
|
402
|
+
results = [
|
|
403
|
+
{"scores": cur_scores, "labels": cur_labels, "boxes": cur_boxes}
|
|
404
|
+
for cur_scores, cur_labels, cur_boxes in zip(scores, labels, boxes)
|
|
405
|
+
]
|
|
406
|
+
|
|
407
|
+
return results
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
class Detr(Predictor):
|
|
411
|
+
"""DETR prediction module.
|
|
412
|
+
|
|
413
|
+
This combines PositionEmbeddingSine, DetrPredictor, SetCriterion, and PostProcess.
|
|
414
|
+
|
|
415
|
+
This is the module that should be used as a decoder component in rslearn.
|
|
416
|
+
"""
|
|
417
|
+
|
|
418
|
+
def __init__(self, predictor: DetrPredictor, criterion: SetCriterion):
|
|
419
|
+
"""Create a Detr.
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
predictor: the DetrPredictor.
|
|
423
|
+
criterion: the SetCriterion.
|
|
424
|
+
"""
|
|
425
|
+
super().__init__()
|
|
426
|
+
self.predictor = predictor
|
|
427
|
+
self.criterion = criterion
|
|
428
|
+
self.pos_embedding = PositionEmbeddingSine(
|
|
429
|
+
num_pos_feats=predictor.transformer.d_model // 2, normalize=True
|
|
430
|
+
)
|
|
431
|
+
self.postprocess = PostProcess()
|
|
432
|
+
|
|
433
|
+
if predictor.aux_loss:
|
|
434
|
+
# Hack to make sure it's included in the weight dict for the criterion.
|
|
435
|
+
aux_weight_dict = {}
|
|
436
|
+
num_dec_layers = len(predictor.transformer.decoder.layers)
|
|
437
|
+
for i in range(num_dec_layers - 1):
|
|
438
|
+
aux_weight_dict.update(
|
|
439
|
+
{f"{k}_{i}": v for k, v in self.criterion.weight_dict.items()}
|
|
440
|
+
)
|
|
441
|
+
self.criterion.weight_dict.update(aux_weight_dict)
|
|
442
|
+
|
|
443
|
+
def forward(
|
|
444
|
+
self,
|
|
445
|
+
intermediates: Any,
|
|
446
|
+
context: ModelContext,
|
|
447
|
+
targets: list[dict[str, Any]] | None = None,
|
|
448
|
+
) -> ModelOutput:
|
|
449
|
+
"""Compute the detection outputs and loss from features.
|
|
450
|
+
|
|
451
|
+
DETR will use only the last feature map, which should correspond to the lowest
|
|
452
|
+
resolution one.
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
intermediates: the output from the previous component. It must be a FeatureMaps.
|
|
456
|
+
context: the model context. Input dicts must contain an "image" key which we will
|
|
457
|
+
be used to establish the original image size.
|
|
458
|
+
targets: must contain class key that stores the class label.
|
|
459
|
+
|
|
460
|
+
Returns:
|
|
461
|
+
the model output.
|
|
462
|
+
"""
|
|
463
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
464
|
+
raise ValueError("input to Detr must be a FeatureMaps")
|
|
465
|
+
|
|
466
|
+
# We only use the last feature map (most fine-grained).
|
|
467
|
+
features = intermediates.feature_maps[-1]
|
|
468
|
+
|
|
469
|
+
# Get image sizes.
|
|
470
|
+
image_sizes = torch.tensor(
|
|
471
|
+
[[inp["image"].shape[2], inp["image"].shape[1]] for inp in context.inputs],
|
|
472
|
+
dtype=torch.int32,
|
|
473
|
+
device=features.device,
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
pos_embedding = self.pos_embedding(features)
|
|
477
|
+
outputs = self.predictor(features, pos_embedding)
|
|
478
|
+
|
|
479
|
+
if targets is not None:
|
|
480
|
+
# Convert boxes from [x0, y0, x1, y1] to [cx, cy, w, h].
|
|
481
|
+
converted_targets = []
|
|
482
|
+
for target, image_size in zip(targets, image_sizes):
|
|
483
|
+
boxes = target["boxes"]
|
|
484
|
+
img_w, img_h = image_size
|
|
485
|
+
scale_fct = torch.stack([img_w, img_h, img_w, img_h])
|
|
486
|
+
boxes = boxes / scale_fct
|
|
487
|
+
boxes = box_ops.box_xyxy_to_cxcywh(boxes)
|
|
488
|
+
converted_targets.append(
|
|
489
|
+
{
|
|
490
|
+
"boxes": boxes,
|
|
491
|
+
"labels": target["labels"],
|
|
492
|
+
}
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
losses = self.criterion(outputs, converted_targets)
|
|
496
|
+
else:
|
|
497
|
+
losses = {}
|
|
498
|
+
|
|
499
|
+
results = self.postprocess(outputs, image_sizes)
|
|
500
|
+
|
|
501
|
+
return ModelOutput(
|
|
502
|
+
outputs=results,
|
|
503
|
+
loss_dict=losses,
|
|
504
|
+
)
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""Modules to compute the matching cost and solve the corresponding LSAP.
|
|
2
|
+
|
|
3
|
+
This is copied from https://github.com/facebookresearch/detr/.
|
|
4
|
+
The original code is:
|
|
5
|
+
Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from scipy.optimize import linear_sum_assignment
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
from .box_ops import box_cxcywh_to_xyxy, generalized_box_iou
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class HungarianMatcher(nn.Module):
|
|
16
|
+
"""This class computes an assignment between the targets and the predictions of the network.
|
|
17
|
+
|
|
18
|
+
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
|
|
19
|
+
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
|
|
20
|
+
while the others are un-matched (and thus treated as non-objects).
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self, cost_class: float = 1, cost_bbox: float = 5, cost_giou: float = 2
|
|
25
|
+
):
|
|
26
|
+
"""Creates the matcher.
|
|
27
|
+
|
|
28
|
+
Params:
|
|
29
|
+
cost_class: This is the relative weight of the classification error in the matching cost
|
|
30
|
+
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
|
|
31
|
+
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
|
|
32
|
+
"""
|
|
33
|
+
super().__init__()
|
|
34
|
+
self.cost_class = cost_class
|
|
35
|
+
self.cost_bbox = cost_bbox
|
|
36
|
+
self.cost_giou = cost_giou
|
|
37
|
+
assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, (
|
|
38
|
+
"all costs cant be 0"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
@torch.no_grad()
|
|
42
|
+
def forward(
|
|
43
|
+
self, outputs: dict[str, torch.Tensor], targets: list[dict[str, torch.Tensor]]
|
|
44
|
+
) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
|
45
|
+
"""Performs the matching.
|
|
46
|
+
|
|
47
|
+
Params:
|
|
48
|
+
outputs: This is a dict that contains at least these entries:
|
|
49
|
+
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
|
50
|
+
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
|
|
51
|
+
|
|
52
|
+
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
|
|
53
|
+
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
|
|
54
|
+
objects in the target) containing the class labels
|
|
55
|
+
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
|
59
|
+
- index_i is the indices of the selected predictions (in order)
|
|
60
|
+
- index_j is the indices of the corresponding selected targets (in order)
|
|
61
|
+
For each batch element, it holds:
|
|
62
|
+
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
|
63
|
+
"""
|
|
64
|
+
bs, num_queries = outputs["pred_logits"].shape[:2]
|
|
65
|
+
|
|
66
|
+
# We flatten to compute the cost matrices in a batch
|
|
67
|
+
out_prob = (
|
|
68
|
+
outputs["pred_logits"].flatten(0, 1).softmax(-1)
|
|
69
|
+
) # [batch_size * num_queries, num_classes]
|
|
70
|
+
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
|
71
|
+
|
|
72
|
+
# Also concat the target labels and boxes
|
|
73
|
+
tgt_ids = torch.cat([v["labels"] for v in targets])
|
|
74
|
+
tgt_bbox = torch.cat([v["boxes"] for v in targets])
|
|
75
|
+
|
|
76
|
+
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
|
77
|
+
# but approximate it in 1 - proba[target class].
|
|
78
|
+
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
|
79
|
+
cost_class = -out_prob[:, tgt_ids]
|
|
80
|
+
|
|
81
|
+
# Compute the L1 cost between boxes
|
|
82
|
+
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
|
|
83
|
+
|
|
84
|
+
# Compute the giou cost betwen boxes
|
|
85
|
+
cost_giou = -generalized_box_iou(
|
|
86
|
+
box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Final cost matrix
|
|
90
|
+
C = (
|
|
91
|
+
self.cost_bbox * cost_bbox
|
|
92
|
+
+ self.cost_class * cost_class
|
|
93
|
+
+ self.cost_giou * cost_giou
|
|
94
|
+
)
|
|
95
|
+
C = C.view(bs, num_queries, -1).cpu()
|
|
96
|
+
|
|
97
|
+
sizes = [len(v["boxes"]) for v in targets]
|
|
98
|
+
indices = [
|
|
99
|
+
linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))
|
|
100
|
+
]
|
|
101
|
+
return [
|
|
102
|
+
(
|
|
103
|
+
torch.as_tensor(i, dtype=torch.int64),
|
|
104
|
+
torch.as_tensor(j, dtype=torch.int64),
|
|
105
|
+
)
|
|
106
|
+
for i, j in indices
|
|
107
|
+
]
|