rslearn 0.0.16__py3-none-any.whl → 0.0.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/__init__.py +2 -0
- rslearn/config/dataset.py +55 -4
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +9 -65
- rslearn/dataset/materialize.py +5 -5
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +26 -80
- rslearn/lightning_cli.py +10 -3
- rslearn/main.py +11 -36
- rslearn/models/anysat.py +11 -9
- rslearn/models/clay/clay.py +8 -9
- rslearn/models/clip.py +18 -15
- rslearn/models/component.py +99 -0
- rslearn/models/concatenate_features.py +21 -11
- rslearn/models/conv.py +15 -8
- rslearn/models/croma.py +13 -8
- rslearn/models/detr/detr.py +25 -14
- rslearn/models/dinov3.py +11 -6
- rslearn/models/faster_rcnn.py +19 -9
- rslearn/models/feature_center_crop.py +12 -9
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/galileo.py +23 -18
- rslearn/models/module_wrapper.py +26 -57
- rslearn/models/molmo.py +16 -14
- rslearn/models/multitask.py +102 -73
- rslearn/models/olmoearth_pretrain/model.py +20 -17
- rslearn/models/panopticon.py +8 -7
- rslearn/models/pick_features.py +18 -24
- rslearn/models/pooling_decoder.py +22 -14
- rslearn/models/presto/presto.py +16 -10
- rslearn/models/presto/single_file_presto.py +4 -10
- rslearn/models/prithvi.py +12 -8
- rslearn/models/resize_features.py +21 -7
- rslearn/models/sam2_enc.py +11 -9
- rslearn/models/satlaspretrain.py +15 -9
- rslearn/models/simple_time_series.py +31 -17
- rslearn/models/singletask.py +24 -17
- rslearn/models/ssl4eo_s12.py +15 -10
- rslearn/models/swin.py +22 -13
- rslearn/models/terramind.py +24 -7
- rslearn/models/trunk.py +6 -3
- rslearn/models/unet.py +18 -9
- rslearn/models/upsample.py +22 -9
- rslearn/train/all_patches_dataset.py +22 -18
- rslearn/train/dataset.py +69 -54
- rslearn/train/lightning_module.py +51 -32
- rslearn/train/model_context.py +54 -0
- rslearn/train/prediction_writer.py +111 -41
- rslearn/train/tasks/classification.py +34 -15
- rslearn/train/tasks/detection.py +24 -31
- rslearn/train/tasks/embedding.py +33 -29
- rslearn/train/tasks/multi_task.py +7 -7
- rslearn/train/tasks/per_pixel_regression.py +41 -19
- rslearn/train/tasks/regression.py +38 -21
- rslearn/train/tasks/segmentation.py +33 -15
- rslearn/train/tasks/task.py +3 -2
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/METADATA +58 -25
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/RECORD +65 -62
- rslearn/dataset/index.py +0 -173
- rslearn/models/registry.py +0 -22
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/WHEEL +0 -0
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/top_level.txt +0 -0
rslearn/models/conv.py
CHANGED
|
@@ -4,8 +4,12 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
7
8
|
|
|
8
|
-
|
|
9
|
+
from .component import FeatureMaps, IntermediateComponent
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Conv(IntermediateComponent):
|
|
9
13
|
"""A single convolutional layer.
|
|
10
14
|
|
|
11
15
|
It inputs a set of feature maps; the conv layer is applied to each feature map
|
|
@@ -38,19 +42,22 @@ class Conv(torch.nn.Module):
|
|
|
38
42
|
)
|
|
39
43
|
self.activation = activation
|
|
40
44
|
|
|
41
|
-
def forward(self,
|
|
42
|
-
"""
|
|
45
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
46
|
+
"""Apply conv layer on each feature map.
|
|
43
47
|
|
|
44
48
|
Args:
|
|
45
|
-
|
|
46
|
-
|
|
49
|
+
intermediates: the previous output, which must be a FeatureMaps.
|
|
50
|
+
context: the model context.
|
|
47
51
|
|
|
48
52
|
Returns:
|
|
49
|
-
|
|
53
|
+
the resulting feature maps after applying the same Conv2d on each one.
|
|
50
54
|
"""
|
|
55
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
56
|
+
raise ValueError("input to Conv must be FeatureMaps")
|
|
57
|
+
|
|
51
58
|
new_features = []
|
|
52
|
-
for feat_map in
|
|
59
|
+
for feat_map in intermediates.feature_maps:
|
|
53
60
|
feat_map = self.layer(feat_map)
|
|
54
61
|
feat_map = self.activation(feat_map)
|
|
55
62
|
new_features.append(feat_map)
|
|
56
|
-
return new_features
|
|
63
|
+
return FeatureMaps(new_features)
|
rslearn/models/croma.py
CHANGED
|
@@ -12,9 +12,11 @@ from einops import rearrange
|
|
|
12
12
|
from upath import UPath
|
|
13
13
|
|
|
14
14
|
from rslearn.log_utils import get_logger
|
|
15
|
+
from rslearn.train.model_context import ModelContext
|
|
15
16
|
from rslearn.train.transforms.transform import Transform
|
|
16
17
|
from rslearn.utils.fsspec import open_atomic
|
|
17
18
|
|
|
19
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
18
20
|
from .use_croma import PretrainedCROMA
|
|
19
21
|
|
|
20
22
|
logger = get_logger(__name__)
|
|
@@ -76,7 +78,7 @@ MODALITY_BANDS = {
|
|
|
76
78
|
}
|
|
77
79
|
|
|
78
80
|
|
|
79
|
-
class Croma(
|
|
81
|
+
class Croma(FeatureExtractor):
|
|
80
82
|
"""CROMA backbones.
|
|
81
83
|
|
|
82
84
|
There are two model sizes, base and large.
|
|
@@ -160,20 +162,23 @@ class Croma(torch.nn.Module):
|
|
|
160
162
|
align_corners=False,
|
|
161
163
|
)
|
|
162
164
|
|
|
163
|
-
def forward(self,
|
|
165
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
164
166
|
"""Compute feature maps from the Croma backbone.
|
|
165
167
|
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
"sentinel1" keys depending on the configured modality.
|
|
168
|
+
Args:
|
|
169
|
+
context: the model context. Input dicts must include either/both of
|
|
170
|
+
"sentinel2" or "sentinel1" keys depending on the configured modality.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
a FeatureMaps with one feature map at 1/8 the input resolution.
|
|
169
174
|
"""
|
|
170
175
|
sentinel1: torch.Tensor | None = None
|
|
171
176
|
sentinel2: torch.Tensor | None = None
|
|
172
177
|
if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL1]:
|
|
173
|
-
sentinel1 = torch.stack([inp["sentinel1"] for inp in inputs], dim=0)
|
|
178
|
+
sentinel1 = torch.stack([inp["sentinel1"] for inp in context.inputs], dim=0)
|
|
174
179
|
sentinel1 = self._resize_image(sentinel1) if self.do_resizing else sentinel1
|
|
175
180
|
if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL2]:
|
|
176
|
-
sentinel2 = torch.stack([inp["sentinel2"] for inp in inputs], dim=0)
|
|
181
|
+
sentinel2 = torch.stack([inp["sentinel2"] for inp in context.inputs], dim=0)
|
|
177
182
|
sentinel2 = self._resize_image(sentinel2) if self.do_resizing else sentinel2
|
|
178
183
|
|
|
179
184
|
outputs = self.model(
|
|
@@ -200,7 +205,7 @@ class Croma(torch.nn.Module):
|
|
|
200
205
|
w=num_patches_per_dim,
|
|
201
206
|
)
|
|
202
207
|
|
|
203
|
-
return [features]
|
|
208
|
+
return FeatureMaps([features])
|
|
204
209
|
|
|
205
210
|
def get_backbone_channels(self) -> list:
|
|
206
211
|
"""Returns the output channels of this model when used as a backbone.
|
rslearn/models/detr/detr.py
CHANGED
|
@@ -13,6 +13,8 @@ import torch.nn.functional as F
|
|
|
13
13
|
from torch import nn
|
|
14
14
|
|
|
15
15
|
import rslearn.models.detr.box_ops as box_ops
|
|
16
|
+
from rslearn.models.component import FeatureMaps, Predictor
|
|
17
|
+
from rslearn.train.model_context import ModelContext, ModelOutput
|
|
16
18
|
|
|
17
19
|
from .matcher import HungarianMatcher
|
|
18
20
|
from .position_encoding import PositionEmbeddingSine
|
|
@@ -405,7 +407,7 @@ class PostProcess(nn.Module):
|
|
|
405
407
|
return results
|
|
406
408
|
|
|
407
409
|
|
|
408
|
-
class Detr(
|
|
410
|
+
class Detr(Predictor):
|
|
409
411
|
"""DETR prediction module.
|
|
410
412
|
|
|
411
413
|
This combines PositionEmbeddingSine, DetrPredictor, SetCriterion, and PostProcess.
|
|
@@ -440,33 +442,39 @@ class Detr(nn.Module):
|
|
|
440
442
|
|
|
441
443
|
def forward(
|
|
442
444
|
self,
|
|
443
|
-
|
|
444
|
-
|
|
445
|
+
intermediates: Any,
|
|
446
|
+
context: ModelContext,
|
|
445
447
|
targets: list[dict[str, Any]] | None = None,
|
|
446
|
-
) ->
|
|
448
|
+
) -> ModelOutput:
|
|
447
449
|
"""Compute the detection outputs and loss from features.
|
|
448
450
|
|
|
449
451
|
DETR will use only the last feature map, which should correspond to the lowest
|
|
450
452
|
resolution one.
|
|
451
453
|
|
|
452
454
|
Args:
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
455
|
+
intermediates: the output from the previous component. It must be a FeatureMaps.
|
|
456
|
+
context: the model context. Input dicts must contain an "image" key which we will
|
|
457
|
+
be used to establish the original image size.
|
|
458
|
+
targets: must contain class key that stores the class label.
|
|
456
459
|
|
|
457
460
|
Returns:
|
|
458
|
-
|
|
461
|
+
the model output.
|
|
459
462
|
"""
|
|
463
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
464
|
+
raise ValueError("input to Detr must be a FeatureMaps")
|
|
465
|
+
|
|
466
|
+
# We only use the last feature map (most fine-grained).
|
|
467
|
+
features = intermediates.feature_maps[-1]
|
|
468
|
+
|
|
460
469
|
# Get image sizes.
|
|
461
470
|
image_sizes = torch.tensor(
|
|
462
|
-
[[inp["image"].shape[2], inp["image"].shape[1]] for inp in inputs],
|
|
471
|
+
[[inp["image"].shape[2], inp["image"].shape[1]] for inp in context.inputs],
|
|
463
472
|
dtype=torch.int32,
|
|
464
|
-
device=features
|
|
473
|
+
device=features.device,
|
|
465
474
|
)
|
|
466
475
|
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
outputs = self.predictor(feat_map, pos_embedding)
|
|
476
|
+
pos_embedding = self.pos_embedding(features)
|
|
477
|
+
outputs = self.predictor(features, pos_embedding)
|
|
470
478
|
|
|
471
479
|
if targets is not None:
|
|
472
480
|
# Convert boxes from [x0, y0, x1, y1] to [cx, cy, w, h].
|
|
@@ -490,4 +498,7 @@ class Detr(nn.Module):
|
|
|
490
498
|
|
|
491
499
|
results = self.postprocess(outputs, image_sizes)
|
|
492
500
|
|
|
493
|
-
return
|
|
501
|
+
return ModelOutput(
|
|
502
|
+
outputs=results,
|
|
503
|
+
loss_dict=losses,
|
|
504
|
+
)
|
rslearn/models/dinov3.py
CHANGED
|
@@ -13,9 +13,12 @@ import torch
|
|
|
13
13
|
import torchvision
|
|
14
14
|
from einops import rearrange
|
|
15
15
|
|
|
16
|
+
from rslearn.train.model_context import ModelContext
|
|
16
17
|
from rslearn.train.transforms.normalize import Normalize
|
|
17
18
|
from rslearn.train.transforms.transform import Transform
|
|
18
19
|
|
|
20
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
21
|
+
|
|
19
22
|
|
|
20
23
|
class DinoV3Models(StrEnum):
|
|
21
24
|
"""Names for different DinoV3 images on torch hub."""
|
|
@@ -40,7 +43,7 @@ DINOV3_PTHS: dict[str, str] = {
|
|
|
40
43
|
}
|
|
41
44
|
|
|
42
45
|
|
|
43
|
-
class DinoV3(
|
|
46
|
+
class DinoV3(FeatureExtractor):
|
|
44
47
|
"""DinoV3 Backbones.
|
|
45
48
|
|
|
46
49
|
Must have the pretrained weights downloaded in checkpoint_dir for them to be loaded.
|
|
@@ -91,16 +94,18 @@ class DinoV3(torch.nn.Module):
|
|
|
91
94
|
self.do_resizing = do_resizing
|
|
92
95
|
self.model = self._load_model(size, checkpoint_dir)
|
|
93
96
|
|
|
94
|
-
def forward(self,
|
|
97
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
95
98
|
"""Forward pass for the dinov3 model.
|
|
96
99
|
|
|
97
100
|
Args:
|
|
98
|
-
|
|
101
|
+
context: the model context. Input dicts must include "image" key.
|
|
99
102
|
|
|
100
103
|
Returns:
|
|
101
|
-
|
|
104
|
+
a FeatureMaps with one feature map.
|
|
102
105
|
"""
|
|
103
|
-
cur = torch.stack(
|
|
106
|
+
cur = torch.stack(
|
|
107
|
+
[inp["image"] for inp in context.inputs], dim=0
|
|
108
|
+
) # (B, C, H, W)
|
|
104
109
|
|
|
105
110
|
if self.do_resizing and (
|
|
106
111
|
cur.shape[2] != self.image_size or cur.shape[3] != self.image_size
|
|
@@ -118,7 +123,7 @@ class DinoV3(torch.nn.Module):
|
|
|
118
123
|
height, width = int(num_patches**0.5), int(num_patches**0.5)
|
|
119
124
|
features = rearrange(features, "b (h w) d -> b d h w", h=height, w=width)
|
|
120
125
|
|
|
121
|
-
return [features]
|
|
126
|
+
return FeatureMaps([features])
|
|
122
127
|
|
|
123
128
|
def get_backbone_channels(self) -> list:
|
|
124
129
|
"""Returns the output channels of this model when used as a backbone.
|
rslearn/models/faster_rcnn.py
CHANGED
|
@@ -6,6 +6,10 @@ from typing import Any
|
|
|
6
6
|
import torch
|
|
7
7
|
import torchvision
|
|
8
8
|
|
|
9
|
+
from rslearn.train.model_context import ModelContext, ModelOutput
|
|
10
|
+
|
|
11
|
+
from .component import FeatureMaps, Predictor
|
|
12
|
+
|
|
9
13
|
|
|
10
14
|
class NoopTransform(torch.nn.Module):
|
|
11
15
|
"""A placeholder transform used with torchvision detection model."""
|
|
@@ -55,7 +59,7 @@ class NoopTransform(torch.nn.Module):
|
|
|
55
59
|
return image_list, targets
|
|
56
60
|
|
|
57
61
|
|
|
58
|
-
class FasterRCNN(
|
|
62
|
+
class FasterRCNN(Predictor):
|
|
59
63
|
"""Faster R-CNN head for predicting bounding boxes.
|
|
60
64
|
|
|
61
65
|
It inputs multi-scale features, using each feature map to predict ROIs and then
|
|
@@ -176,20 +180,23 @@ class FasterRCNN(torch.nn.Module):
|
|
|
176
180
|
|
|
177
181
|
def forward(
|
|
178
182
|
self,
|
|
179
|
-
|
|
180
|
-
|
|
183
|
+
intermediates: Any,
|
|
184
|
+
context: ModelContext,
|
|
181
185
|
targets: list[dict[str, Any]] | None = None,
|
|
182
|
-
) ->
|
|
186
|
+
) -> ModelOutput:
|
|
183
187
|
"""Compute the detection outputs and loss from features.
|
|
184
188
|
|
|
185
189
|
Args:
|
|
186
|
-
|
|
187
|
-
|
|
190
|
+
intermediates: the output from the previous component, which must be a FeatureMaps.
|
|
191
|
+
context: the model context. Input dicts must contain image key for original image size.
|
|
188
192
|
targets: should contain class key that stores the class label.
|
|
189
193
|
|
|
190
194
|
Returns:
|
|
191
195
|
tuple of outputs and loss dict
|
|
192
196
|
"""
|
|
197
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
198
|
+
raise ValueError("input to FasterRCNN must be FeatureMaps")
|
|
199
|
+
|
|
193
200
|
# Fix target labels to be 1 size in case it's empty.
|
|
194
201
|
# For some reason this is needed.
|
|
195
202
|
if targets:
|
|
@@ -203,11 +210,11 @@ class FasterRCNN(torch.nn.Module):
|
|
|
203
210
|
),
|
|
204
211
|
)
|
|
205
212
|
|
|
206
|
-
image_list = [inp["image"] for inp in inputs]
|
|
213
|
+
image_list = [inp["image"] for inp in context.inputs]
|
|
207
214
|
images, targets = self.noop_transform(image_list, targets)
|
|
208
215
|
|
|
209
216
|
feature_dict = collections.OrderedDict()
|
|
210
|
-
for i, feat_map in enumerate(
|
|
217
|
+
for i, feat_map in enumerate(intermediates.feature_maps):
|
|
211
218
|
feature_dict[f"feat{i}"] = feat_map
|
|
212
219
|
|
|
213
220
|
proposals, proposal_losses = self.rpn(images, feature_dict, targets)
|
|
@@ -219,4 +226,7 @@ class FasterRCNN(torch.nn.Module):
|
|
|
219
226
|
losses.update(proposal_losses)
|
|
220
227
|
losses.update(detector_losses)
|
|
221
228
|
|
|
222
|
-
return
|
|
229
|
+
return ModelOutput(
|
|
230
|
+
outputs=detections,
|
|
231
|
+
loss_dict=losses,
|
|
232
|
+
)
|
|
@@ -2,10 +2,12 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
|
-
import
|
|
5
|
+
from rslearn.train.model_context import ModelContext
|
|
6
6
|
|
|
7
|
+
from .component import FeatureMaps, IntermediateComponent
|
|
7
8
|
|
|
8
|
-
|
|
9
|
+
|
|
10
|
+
class FeatureCenterCrop(IntermediateComponent):
|
|
9
11
|
"""Apply center cropping on the input feature maps."""
|
|
10
12
|
|
|
11
13
|
def __init__(
|
|
@@ -24,20 +26,21 @@ class FeatureCenterCrop(torch.nn.Module):
|
|
|
24
26
|
super().__init__()
|
|
25
27
|
self.sizes = sizes
|
|
26
28
|
|
|
27
|
-
def forward(
|
|
28
|
-
self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
|
|
29
|
-
) -> list[torch.Tensor]:
|
|
29
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
30
30
|
"""Apply center cropping on the feature maps.
|
|
31
31
|
|
|
32
32
|
Args:
|
|
33
|
-
|
|
34
|
-
|
|
33
|
+
intermediates: output from the previous model component, which must be a FeatureMaps.
|
|
34
|
+
context: the model context.
|
|
35
35
|
|
|
36
36
|
Returns:
|
|
37
37
|
center cropped feature maps.
|
|
38
38
|
"""
|
|
39
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
40
|
+
raise ValueError("input to FeatureCenterCrop must be FeatureMaps")
|
|
41
|
+
|
|
39
42
|
new_features = []
|
|
40
|
-
for i, feat in enumerate(
|
|
43
|
+
for i, feat in enumerate(intermediates.feature_maps):
|
|
41
44
|
height, width = self.sizes[i]
|
|
42
45
|
if feat.shape[2] < height or feat.shape[3] < width:
|
|
43
46
|
raise ValueError(
|
|
@@ -47,4 +50,4 @@ class FeatureCenterCrop(torch.nn.Module):
|
|
|
47
50
|
start_w = feat.shape[3] // 2 - width // 2
|
|
48
51
|
feat = feat[:, :, start_h : start_h + height, start_w : start_w + width]
|
|
49
52
|
new_features.append(feat)
|
|
50
|
-
return new_features
|
|
53
|
+
return FeatureMaps(new_features)
|
rslearn/models/fpn.py
CHANGED
|
@@ -1,12 +1,16 @@
|
|
|
1
1
|
"""Feature pyramid network."""
|
|
2
2
|
|
|
3
3
|
import collections
|
|
4
|
+
from typing import Any
|
|
4
5
|
|
|
5
|
-
import torch
|
|
6
6
|
import torchvision
|
|
7
7
|
|
|
8
|
+
from rslearn.train.model_context import ModelContext
|
|
8
9
|
|
|
9
|
-
|
|
10
|
+
from .component import FeatureMaps, IntermediateComponent
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Fpn(IntermediateComponent):
|
|
10
14
|
"""A feature pyramid network (FPN).
|
|
11
15
|
|
|
12
16
|
The FPN inputs a multi-scale feature map. At each scale, it computes new features
|
|
@@ -32,20 +36,27 @@ class Fpn(torch.nn.Module):
|
|
|
32
36
|
in_channels_list=in_channels, out_channels=out_channels
|
|
33
37
|
)
|
|
34
38
|
|
|
35
|
-
def forward(self,
|
|
39
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
36
40
|
"""Compute outputs of the FPN.
|
|
37
41
|
|
|
38
42
|
Args:
|
|
39
|
-
|
|
43
|
+
intermediates: the output from the previous component, which must be a FeatureMaps.
|
|
44
|
+
context: the model context.
|
|
40
45
|
|
|
41
46
|
Returns:
|
|
42
|
-
new multi-scale feature maps from the FPN
|
|
47
|
+
new multi-scale feature maps from the FPN.
|
|
43
48
|
"""
|
|
44
|
-
|
|
49
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
50
|
+
raise ValueError("input to Fpn must be FeatureMaps")
|
|
51
|
+
|
|
52
|
+
feature_maps = intermediates.feature_maps
|
|
53
|
+
inp = collections.OrderedDict(
|
|
54
|
+
[(f"feat{i}", el) for i, el in enumerate(feature_maps)]
|
|
55
|
+
)
|
|
45
56
|
output = self.fpn(inp)
|
|
46
57
|
output = list(output.values())
|
|
47
58
|
|
|
48
59
|
if self.prepend:
|
|
49
|
-
return output +
|
|
60
|
+
return FeatureMaps(output + feature_maps)
|
|
50
61
|
else:
|
|
51
|
-
return output
|
|
62
|
+
return FeatureMaps(output)
|
|
@@ -4,16 +4,16 @@ import math
|
|
|
4
4
|
import tempfile
|
|
5
5
|
from contextlib import nullcontext
|
|
6
6
|
from enum import StrEnum
|
|
7
|
-
from typing import
|
|
7
|
+
from typing import cast
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
import torch
|
|
11
|
-
import torch.nn as nn
|
|
12
11
|
from einops import rearrange, repeat
|
|
13
12
|
from huggingface_hub import hf_hub_download
|
|
14
13
|
from upath import UPath
|
|
15
14
|
|
|
16
15
|
from rslearn.log_utils import get_logger
|
|
16
|
+
from rslearn.models.component import FeatureExtractor, FeatureMaps
|
|
17
17
|
from rslearn.models.galileo.single_file_galileo import (
|
|
18
18
|
CONFIG_FILENAME,
|
|
19
19
|
DW_BANDS,
|
|
@@ -39,6 +39,7 @@ from rslearn.models.galileo.single_file_galileo import (
|
|
|
39
39
|
MaskedOutput,
|
|
40
40
|
Normalizer,
|
|
41
41
|
)
|
|
42
|
+
from rslearn.train.model_context import ModelContext
|
|
42
43
|
|
|
43
44
|
logger = get_logger(__name__)
|
|
44
45
|
|
|
@@ -70,7 +71,7 @@ AUTOCAST_DTYPE_MAP = {
|
|
|
70
71
|
}
|
|
71
72
|
|
|
72
73
|
|
|
73
|
-
class GalileoModel(
|
|
74
|
+
class GalileoModel(FeatureExtractor):
|
|
74
75
|
"""Galileo backbones."""
|
|
75
76
|
|
|
76
77
|
input_keys = [
|
|
@@ -410,11 +411,11 @@ class GalileoModel(nn.Module):
|
|
|
410
411
|
months=months,
|
|
411
412
|
)
|
|
412
413
|
|
|
413
|
-
def forward(self,
|
|
414
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
414
415
|
"""Compute feature maps from the Galileo backbone.
|
|
415
416
|
|
|
416
|
-
|
|
417
|
-
|
|
417
|
+
Args:
|
|
418
|
+
context: the model context. Input dicts should contain keys corresponding to Galileo.input_keys
|
|
418
419
|
(also documented below) and values are tensors of the following shapes,
|
|
419
420
|
per input key:
|
|
420
421
|
"s1": B (T * C) H W
|
|
@@ -436,10 +437,12 @@ class GalileoModel(nn.Module):
|
|
|
436
437
|
take a pool of the space_time unmasked tokens (i.e. of the s1 and s2 tokens).
|
|
437
438
|
"""
|
|
438
439
|
stacked_inputs = {}
|
|
439
|
-
for key in inputs[0].keys():
|
|
440
|
+
for key in context.inputs[0].keys():
|
|
440
441
|
# assume all the keys in an input are consistent
|
|
441
442
|
if key in self.input_keys:
|
|
442
|
-
stacked_inputs[key] = torch.stack(
|
|
443
|
+
stacked_inputs[key] = torch.stack(
|
|
444
|
+
[inp[key] for inp in context.inputs], dim=0
|
|
445
|
+
)
|
|
443
446
|
s_t_channels = []
|
|
444
447
|
for space_time_modality in ["s1", "s2"]:
|
|
445
448
|
if space_time_modality not in stacked_inputs:
|
|
@@ -502,14 +505,14 @@ class GalileoModel(nn.Module):
|
|
|
502
505
|
# Decide context based on self.autocast_dtype.
|
|
503
506
|
device = galileo_input.s_t_x.device
|
|
504
507
|
if self.autocast_dtype is None:
|
|
505
|
-
|
|
508
|
+
torch_context = nullcontext()
|
|
506
509
|
else:
|
|
507
510
|
assert device is not None
|
|
508
|
-
|
|
511
|
+
torch_context = torch.amp.autocast(
|
|
509
512
|
device_type=device.type, dtype=self.autocast_dtype
|
|
510
513
|
)
|
|
511
514
|
|
|
512
|
-
with
|
|
515
|
+
with torch_context:
|
|
513
516
|
outputs = self.model(
|
|
514
517
|
s_t_x=galileo_input.s_t_x,
|
|
515
518
|
s_t_m=galileo_input.s_t_m,
|
|
@@ -530,18 +533,20 @@ class GalileoModel(nn.Module):
|
|
|
530
533
|
averaged = self.model.average_tokens(
|
|
531
534
|
s_t_x, sp_x, t_x, st_x, s_t_m, sp_m, t_m, st_m
|
|
532
535
|
)
|
|
533
|
-
return [repeat(averaged, "b d -> b d 1 1")]
|
|
536
|
+
return FeatureMaps([repeat(averaged, "b d -> b d 1 1")])
|
|
534
537
|
else:
|
|
535
538
|
s_t_x = outputs[0]
|
|
536
539
|
# we will be assuming we only want s_t_x, and (for now) that we want s1 or s2 bands
|
|
537
540
|
# s_t_x has shape [b, h, w, t, c_g, d]
|
|
538
541
|
# and we want [b, d, h, w]
|
|
539
|
-
return
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
542
|
+
return FeatureMaps(
|
|
543
|
+
[
|
|
544
|
+
rearrange(
|
|
545
|
+
s_t_x[:, :, :, :, s_t_channels, :].mean(dim=3),
|
|
546
|
+
"b h w c_g d -> b c_g d h w",
|
|
547
|
+
).mean(dim=1)
|
|
548
|
+
]
|
|
549
|
+
)
|
|
545
550
|
|
|
546
551
|
def get_backbone_channels(self) -> list:
|
|
547
552
|
"""Returns the output channels of this model when used as a backbone.
|
rslearn/models/module_wrapper.py
CHANGED
|
@@ -1,67 +1,35 @@
|
|
|
1
|
-
"""Module
|
|
1
|
+
"""Module wrapper provided for backwards compatibility."""
|
|
2
2
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
7
8
|
|
|
8
|
-
|
|
9
|
-
|
|
9
|
+
from .component import (
|
|
10
|
+
FeatureExtractor,
|
|
11
|
+
FeatureMaps,
|
|
12
|
+
IntermediateComponent,
|
|
13
|
+
)
|
|
10
14
|
|
|
11
|
-
The module should input feature map and produce a new feature map.
|
|
12
15
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
"""
|
|
16
|
-
|
|
17
|
-
def __init__(
|
|
18
|
-
self,
|
|
19
|
-
module: torch.nn.Module,
|
|
20
|
-
):
|
|
21
|
-
"""Initialize a DecoderModuleWrapper.
|
|
16
|
+
class EncoderModuleWrapper(FeatureExtractor):
|
|
17
|
+
"""Wraps one or more IntermediateComponents to function as the feature extractor.
|
|
22
18
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
"""
|
|
26
|
-
super().__init__()
|
|
27
|
-
self.module = module
|
|
28
|
-
|
|
29
|
-
def forward(
|
|
30
|
-
self, features: list[torch.Tensor], inputs: list[torch.Tensor]
|
|
31
|
-
) -> list[torch.Tensor]:
|
|
32
|
-
"""Apply the wrapped module on each feature map.
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
features: list of feature maps at different resolutions.
|
|
36
|
-
inputs: original inputs (ignored).
|
|
37
|
-
|
|
38
|
-
Returns:
|
|
39
|
-
new features
|
|
40
|
-
"""
|
|
41
|
-
new_features = []
|
|
42
|
-
for feat_map in features:
|
|
43
|
-
feat_map = self.module(feat_map)
|
|
44
|
-
new_features.append(feat_map)
|
|
45
|
-
return new_features
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
class EncoderModuleWrapper(torch.nn.Module):
|
|
49
|
-
"""Wraps a module that is intended to be used as the decoder to work in encoder.
|
|
50
|
-
|
|
51
|
-
The module should input a feature map that corresponds to the original image, i.e.
|
|
52
|
-
the depth of the feature map would be the number of bands in the input image.
|
|
19
|
+
The first component should input a FeatureMaps, which will be computed from the
|
|
20
|
+
overall inputs by stacking the "image" key from each input dict.
|
|
53
21
|
"""
|
|
54
22
|
|
|
55
23
|
def __init__(
|
|
56
24
|
self,
|
|
57
|
-
module:
|
|
58
|
-
modules: list[
|
|
25
|
+
module: IntermediateComponent | None = None,
|
|
26
|
+
modules: list[IntermediateComponent] = [],
|
|
59
27
|
):
|
|
60
28
|
"""Initialize an EncoderModuleWrapper.
|
|
61
29
|
|
|
62
30
|
Args:
|
|
63
|
-
module: the
|
|
64
|
-
must be set.
|
|
31
|
+
module: the IntermediateComponent to wrap for use as a FeatureExtractor.
|
|
32
|
+
Exactly one of module or modules must be set.
|
|
65
33
|
modules: list of modules to wrap
|
|
66
34
|
"""
|
|
67
35
|
super().__init__()
|
|
@@ -74,18 +42,19 @@ class EncoderModuleWrapper(torch.nn.Module):
|
|
|
74
42
|
else:
|
|
75
43
|
raise ValueError("one of module or modules must be set")
|
|
76
44
|
|
|
77
|
-
def forward(
|
|
78
|
-
self,
|
|
79
|
-
inputs: list[dict[str, Any]],
|
|
80
|
-
) -> list[torch.Tensor]:
|
|
45
|
+
def forward(self, context: ModelContext) -> Any:
|
|
81
46
|
"""Compute outputs from the wrapped module.
|
|
82
47
|
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
48
|
+
Args:
|
|
49
|
+
context: the model context. Input dicts must include "image" key containing
|
|
50
|
+
the image to convert to a FeatureMaps, which will be passed to the
|
|
51
|
+
first wrapped module.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
the output from the last wrapped module.
|
|
86
55
|
"""
|
|
87
|
-
images = torch.stack([inp["image"] for inp in inputs], dim=0)
|
|
88
|
-
cur = [images]
|
|
56
|
+
images = torch.stack([inp["image"] for inp in context.inputs], dim=0)
|
|
57
|
+
cur: Any = FeatureMaps([images])
|
|
89
58
|
for m in self.encoder_modules:
|
|
90
|
-
cur = m(cur,
|
|
59
|
+
cur = m(cur, context)
|
|
91
60
|
return cur
|