rslearn 0.0.16__py3-none-any.whl → 0.0.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/__init__.py +2 -0
- rslearn/config/dataset.py +55 -4
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +9 -65
- rslearn/dataset/materialize.py +5 -5
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +26 -80
- rslearn/lightning_cli.py +10 -3
- rslearn/main.py +11 -36
- rslearn/models/anysat.py +11 -9
- rslearn/models/clay/clay.py +8 -9
- rslearn/models/clip.py +18 -15
- rslearn/models/component.py +99 -0
- rslearn/models/concatenate_features.py +21 -11
- rslearn/models/conv.py +15 -8
- rslearn/models/croma.py +13 -8
- rslearn/models/detr/detr.py +25 -14
- rslearn/models/dinov3.py +11 -6
- rslearn/models/faster_rcnn.py +19 -9
- rslearn/models/feature_center_crop.py +12 -9
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/galileo.py +23 -18
- rslearn/models/module_wrapper.py +26 -57
- rslearn/models/molmo.py +16 -14
- rslearn/models/multitask.py +102 -73
- rslearn/models/olmoearth_pretrain/model.py +20 -17
- rslearn/models/panopticon.py +8 -7
- rslearn/models/pick_features.py +18 -24
- rslearn/models/pooling_decoder.py +22 -14
- rslearn/models/presto/presto.py +16 -10
- rslearn/models/presto/single_file_presto.py +4 -10
- rslearn/models/prithvi.py +12 -8
- rslearn/models/resize_features.py +21 -7
- rslearn/models/sam2_enc.py +11 -9
- rslearn/models/satlaspretrain.py +15 -9
- rslearn/models/simple_time_series.py +31 -17
- rslearn/models/singletask.py +24 -17
- rslearn/models/ssl4eo_s12.py +15 -10
- rslearn/models/swin.py +22 -13
- rslearn/models/terramind.py +24 -7
- rslearn/models/trunk.py +6 -3
- rslearn/models/unet.py +18 -9
- rslearn/models/upsample.py +22 -9
- rslearn/train/all_patches_dataset.py +22 -18
- rslearn/train/dataset.py +69 -54
- rslearn/train/lightning_module.py +51 -32
- rslearn/train/model_context.py +54 -0
- rslearn/train/prediction_writer.py +111 -41
- rslearn/train/tasks/classification.py +34 -15
- rslearn/train/tasks/detection.py +24 -31
- rslearn/train/tasks/embedding.py +33 -29
- rslearn/train/tasks/multi_task.py +7 -7
- rslearn/train/tasks/per_pixel_regression.py +41 -19
- rslearn/train/tasks/regression.py +38 -21
- rslearn/train/tasks/segmentation.py +33 -15
- rslearn/train/tasks/task.py +3 -2
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/METADATA +58 -25
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/RECORD +65 -62
- rslearn/dataset/index.py +0 -173
- rslearn/models/registry.py +0 -22
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/WHEEL +0 -0
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/top_level.txt +0 -0
rslearn/models/presto/presto.py
CHANGED
|
@@ -2,14 +2,13 @@
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import tempfile
|
|
5
|
-
from typing import Any
|
|
6
5
|
|
|
7
6
|
import torch
|
|
8
7
|
from einops import rearrange, repeat
|
|
9
8
|
from huggingface_hub import hf_hub_download
|
|
10
|
-
from torch import nn
|
|
11
9
|
from upath import UPath
|
|
12
10
|
|
|
11
|
+
from rslearn.models.component import FeatureExtractor, FeatureMaps
|
|
13
12
|
from rslearn.models.presto.single_file_presto import (
|
|
14
13
|
ERA5_BANDS,
|
|
15
14
|
NUM_DYNAMIC_WORLD_CLASSES,
|
|
@@ -21,6 +20,7 @@ from rslearn.models.presto.single_file_presto import (
|
|
|
21
20
|
SRTM_BANDS,
|
|
22
21
|
)
|
|
23
22
|
from rslearn.models.presto.single_file_presto import Presto as SFPresto
|
|
23
|
+
from rslearn.train.model_context import ModelContext
|
|
24
24
|
|
|
25
25
|
logger = logging.getLogger(__name__)
|
|
26
26
|
|
|
@@ -36,7 +36,7 @@ HF_HUB_ID = "nasaharvest/presto"
|
|
|
36
36
|
MODEL_FILENAME = "default_model.pt"
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
class Presto(
|
|
39
|
+
class Presto(FeatureExtractor):
|
|
40
40
|
"""Presto."""
|
|
41
41
|
|
|
42
42
|
input_keys = [
|
|
@@ -184,22 +184,26 @@ class Presto(nn.Module):
|
|
|
184
184
|
x = (x + PRESTO_ADD_BY.to(device=device)) / PRESTO_DIV_BY.to(device=device)
|
|
185
185
|
return x, mask, dynamic_world.long(), months.long()
|
|
186
186
|
|
|
187
|
-
def forward(self,
|
|
187
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
188
188
|
"""Compute feature maps from the Presto backbone.
|
|
189
189
|
|
|
190
|
-
|
|
191
|
-
|
|
190
|
+
Args:
|
|
191
|
+
context: the model context. Input dicts should have some subset of Presto.input_keys.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
a FeatureMaps with one feature map that is at the same resolution as the
|
|
195
|
+
input (since Presto operates per-pixel).
|
|
192
196
|
"""
|
|
193
197
|
stacked_inputs = {}
|
|
194
198
|
latlons: torch.Tensor | None = None
|
|
195
|
-
for key in inputs[0].keys():
|
|
199
|
+
for key in context.inputs[0].keys():
|
|
196
200
|
# assume all the keys in an input are consistent
|
|
197
201
|
if key in self.input_keys:
|
|
198
202
|
if key == "latlon":
|
|
199
|
-
latlons = torch.stack([inp[key] for inp in inputs], dim=0)
|
|
203
|
+
latlons = torch.stack([inp[key] for inp in context.inputs], dim=0)
|
|
200
204
|
else:
|
|
201
205
|
stacked_inputs[key] = torch.stack(
|
|
202
|
-
[inp[key] for inp in inputs], dim=0
|
|
206
|
+
[inp[key] for inp in context.inputs], dim=0
|
|
203
207
|
)
|
|
204
208
|
|
|
205
209
|
(
|
|
@@ -247,7 +251,9 @@ class Presto(nn.Module):
|
|
|
247
251
|
)
|
|
248
252
|
output_features[batch_idx : batch_idx + self.pixel_batch_size] = output_b
|
|
249
253
|
|
|
250
|
-
return
|
|
254
|
+
return FeatureMaps(
|
|
255
|
+
[rearrange(output_features, "(b h w) d -> b d h w", h=h, w=w, b=b)]
|
|
256
|
+
)
|
|
251
257
|
|
|
252
258
|
def get_backbone_channels(self) -> list:
|
|
253
259
|
"""Returns the output channels of this model when used as a backbone.
|
|
@@ -281,10 +281,7 @@ def get_sinusoid_encoding_table(
|
|
|
281
281
|
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
|
282
282
|
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
|
283
283
|
|
|
284
|
-
|
|
285
|
-
return torch.FloatTensor(sinusoid_table).cuda()
|
|
286
|
-
else:
|
|
287
|
-
return torch.FloatTensor(sinusoid_table)
|
|
284
|
+
return torch.FloatTensor(sinusoid_table)
|
|
288
285
|
|
|
289
286
|
|
|
290
287
|
def get_month_encoding_table(d_hid: int) -> torch.Tensor:
|
|
@@ -296,10 +293,7 @@ def get_month_encoding_table(d_hid: int) -> torch.Tensor:
|
|
|
296
293
|
cos_table = np.cos(np.stack([angles for _ in range(d_hid // 2)], axis=-1))
|
|
297
294
|
month_table = np.concatenate([sin_table[:-1], cos_table[:-1]], axis=-1)
|
|
298
295
|
|
|
299
|
-
|
|
300
|
-
return torch.FloatTensor(month_table).cuda()
|
|
301
|
-
else:
|
|
302
|
-
return torch.FloatTensor(month_table)
|
|
296
|
+
return torch.FloatTensor(month_table)
|
|
303
297
|
|
|
304
298
|
|
|
305
299
|
def month_to_tensor(
|
|
@@ -405,7 +399,7 @@ class Encoder(nn.Module):
|
|
|
405
399
|
"""initialize_weights."""
|
|
406
400
|
pos_embed = get_sinusoid_encoding_table(
|
|
407
401
|
self.pos_embed.shape[1], self.pos_embed.shape[-1]
|
|
408
|
-
)
|
|
402
|
+
).to(device=self.pos_embed.device)
|
|
409
403
|
self.pos_embed.data.copy_(pos_embed)
|
|
410
404
|
|
|
411
405
|
# initialize nn.Linear and nn.LayerNorm
|
|
@@ -640,7 +634,7 @@ class Decoder(nn.Module):
|
|
|
640
634
|
"""initialize_weights."""
|
|
641
635
|
pos_embed = get_sinusoid_encoding_table(
|
|
642
636
|
self.pos_embed.shape[1], self.pos_embed.shape[-1]
|
|
643
|
-
)
|
|
637
|
+
).to(device=self.pos_embed.device)
|
|
644
638
|
self.pos_embed.data.copy_(pos_embed)
|
|
645
639
|
|
|
646
640
|
# initialize nn.Linear and nn.LayerNorm
|
rslearn/models/prithvi.py
CHANGED
|
@@ -25,9 +25,12 @@ from timm.layers import to_2tuple
|
|
|
25
25
|
from timm.models.vision_transformer import Block
|
|
26
26
|
from torch.nn import functional as F
|
|
27
27
|
|
|
28
|
+
from rslearn.train.model_context import ModelContext
|
|
28
29
|
from rslearn.train.transforms.normalize import Normalize
|
|
29
30
|
from rslearn.train.transforms.transform import Transform
|
|
30
31
|
|
|
32
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
33
|
+
|
|
31
34
|
logger = logging.getLogger(__name__)
|
|
32
35
|
|
|
33
36
|
|
|
@@ -77,7 +80,7 @@ def get_config(cache_dir: Path, hf_hub_id: str, hf_hub_revision: str) -> dict[st
|
|
|
77
80
|
return json.load(f)["pretrained_cfg"]
|
|
78
81
|
|
|
79
82
|
|
|
80
|
-
class PrithviV2(
|
|
83
|
+
class PrithviV2(FeatureExtractor):
|
|
81
84
|
"""An Rslearn wrapper for Prithvi 2.0."""
|
|
82
85
|
|
|
83
86
|
INPUT_KEY = "image"
|
|
@@ -157,18 +160,18 @@ class PrithviV2(nn.Module):
|
|
|
157
160
|
)
|
|
158
161
|
return data
|
|
159
162
|
|
|
160
|
-
def forward(self,
|
|
163
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
161
164
|
"""Compute feature maps from the Prithvi V2 backbone.
|
|
162
165
|
|
|
163
166
|
Args:
|
|
164
|
-
|
|
165
|
-
(Harmonized Landsat-Sentinel) data.
|
|
167
|
+
context: the model context. Input dicts must include "image" key containing
|
|
168
|
+
HLS (Harmonized Landsat-Sentinel) data.
|
|
166
169
|
|
|
167
170
|
Returns:
|
|
168
|
-
|
|
169
|
-
|
|
171
|
+
a FeatureMaps with one map of shape [B, H/p_s, W/p_s, 11*1024] that contains stacked
|
|
172
|
+
feature maps across the 11 transformer blocks.
|
|
170
173
|
"""
|
|
171
|
-
x = torch.stack([inp[self.INPUT_KEY] for inp in inputs], dim=0)
|
|
174
|
+
x = torch.stack([inp[self.INPUT_KEY] for inp in context.inputs], dim=0)
|
|
172
175
|
x = self._resize_data(x)
|
|
173
176
|
num_timesteps = x.shape[1] // len(self.bands)
|
|
174
177
|
x = rearrange(x, "b (t c) h w -> b c t h w", t=num_timesteps)
|
|
@@ -177,9 +180,10 @@ class PrithviV2(nn.Module):
|
|
|
177
180
|
# know the number of timesteps and don't need to recompute it.
|
|
178
181
|
# in addition we average along the time dimension (instead of concatenating)
|
|
179
182
|
# to keep the embeddings reasonably sized.
|
|
180
|
-
|
|
183
|
+
result = self.model.encoder.prepare_features_for_image_model(
|
|
181
184
|
features, num_timesteps
|
|
182
185
|
)
|
|
186
|
+
return FeatureMaps([torch.cat(result, dim=1)])
|
|
183
187
|
|
|
184
188
|
def get_backbone_channels(self) -> list:
|
|
185
189
|
"""Returns the output channels of this model when used as a backbone.
|
|
@@ -1,9 +1,18 @@
|
|
|
1
1
|
"""The ResizeFeatures module."""
|
|
2
2
|
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
3
5
|
import torch
|
|
4
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
8
|
+
|
|
9
|
+
from .component import (
|
|
10
|
+
FeatureMaps,
|
|
11
|
+
IntermediateComponent,
|
|
12
|
+
)
|
|
13
|
+
|
|
5
14
|
|
|
6
|
-
class ResizeFeatures(
|
|
15
|
+
class ResizeFeatures(IntermediateComponent):
|
|
7
16
|
"""Resize input features to new sizes."""
|
|
8
17
|
|
|
9
18
|
def __init__(
|
|
@@ -30,16 +39,21 @@ class ResizeFeatures(torch.nn.Module):
|
|
|
30
39
|
)
|
|
31
40
|
self.layers = torch.nn.ModuleList(layers)
|
|
32
41
|
|
|
33
|
-
def forward(
|
|
34
|
-
self, features: list[torch.Tensor], inputs: list[torch.Tensor]
|
|
35
|
-
) -> list[torch.Tensor]:
|
|
42
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
36
43
|
"""Resize the input feature maps to new sizes.
|
|
37
44
|
|
|
38
45
|
Args:
|
|
39
|
-
|
|
40
|
-
|
|
46
|
+
intermediates: the outputs from the previous component, which must be a FeatureMaps.
|
|
47
|
+
context: the model context.
|
|
41
48
|
|
|
42
49
|
Returns:
|
|
43
50
|
resized feature maps
|
|
44
51
|
"""
|
|
45
|
-
|
|
52
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
53
|
+
raise ValueError("input to ResizeFeatures must be a FeatureMaps")
|
|
54
|
+
|
|
55
|
+
feat_maps = intermediates.feature_maps
|
|
56
|
+
resized_feat_maps = [
|
|
57
|
+
self.layers[idx](feat_map) for idx, feat_map in enumerate(feat_maps)
|
|
58
|
+
]
|
|
59
|
+
return FeatureMaps(resized_feat_maps)
|
rslearn/models/sam2_enc.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
"""SegmentAnything2 encoders."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
3
|
import torch
|
|
6
|
-
import torch.nn as nn
|
|
7
4
|
from sam2.build_sam import build_sam2
|
|
8
5
|
from upath import UPath
|
|
9
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
8
|
+
|
|
9
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
10
|
+
|
|
10
11
|
|
|
11
|
-
class SAM2Encoder(
|
|
12
|
+
class SAM2Encoder(FeatureExtractor):
|
|
12
13
|
"""SAM2's image encoder."""
|
|
13
14
|
|
|
14
15
|
def __init__(self, model_identifier: str) -> None:
|
|
@@ -84,18 +85,19 @@ class SAM2Encoder(nn.Module):
|
|
|
84
85
|
del self.model.obj_ptr_proj
|
|
85
86
|
del self.model.image_encoder.neck
|
|
86
87
|
|
|
87
|
-
def forward(self,
|
|
88
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
88
89
|
"""Extract multi-scale features from a batch of images.
|
|
89
90
|
|
|
90
91
|
Args:
|
|
91
|
-
|
|
92
|
+
context: the model context. Input dicts must have a key 'image' containing
|
|
93
|
+
the input for the SAM2 image encoder.
|
|
92
94
|
|
|
93
95
|
Returns:
|
|
94
|
-
|
|
96
|
+
feature maps from the encoder.
|
|
95
97
|
"""
|
|
96
|
-
images = torch.stack([inp["image"] for inp in inputs], dim=0)
|
|
98
|
+
images = torch.stack([inp["image"] for inp in context.inputs], dim=0)
|
|
97
99
|
features = self.encoder(images)
|
|
98
|
-
return features
|
|
100
|
+
return FeatureMaps(features)
|
|
99
101
|
|
|
100
102
|
def get_backbone_channels(self) -> list[list[int]]:
|
|
101
103
|
"""Returns the output channels of the encoder at different scales.
|
rslearn/models/satlaspretrain.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
"""SatlasPretrain models."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
3
|
import satlaspretrain_models
|
|
6
4
|
import torch
|
|
7
5
|
import torch.nn.functional as F
|
|
8
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
8
|
+
|
|
9
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
10
|
+
|
|
9
11
|
|
|
10
|
-
class SatlasPretrain(
|
|
12
|
+
class SatlasPretrain(FeatureExtractor):
|
|
11
13
|
"""SatlasPretrain backbones."""
|
|
12
14
|
|
|
13
15
|
def __init__(
|
|
@@ -64,15 +66,19 @@ class SatlasPretrain(torch.nn.Module):
|
|
|
64
66
|
else:
|
|
65
67
|
return data
|
|
66
68
|
|
|
67
|
-
def forward(self,
|
|
69
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
68
70
|
"""Compute feature maps from the SatlasPretrain backbone.
|
|
69
71
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
72
|
+
Args:
|
|
73
|
+
context: the model context. Input dicts must contain an "image" key
|
|
74
|
+
containing the image input to the model.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
multi-resolution feature maps computed by the model.
|
|
73
78
|
"""
|
|
74
|
-
images = torch.stack([inp["image"] for inp in inputs], dim=0)
|
|
75
|
-
|
|
79
|
+
images = torch.stack([inp["image"] for inp in context.inputs], dim=0)
|
|
80
|
+
feature_maps = self.model(self.maybe_resize(images))
|
|
81
|
+
return FeatureMaps(feature_maps)
|
|
76
82
|
|
|
77
83
|
def get_backbone_channels(self) -> list:
|
|
78
84
|
"""Returns the output channels of this model when used as a backbone.
|
|
@@ -4,11 +4,15 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
7
8
|
|
|
8
|
-
|
|
9
|
-
"""SimpleTimeSeries wraps another encoder and applies it on an image time series.
|
|
9
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
10
10
|
|
|
11
|
-
|
|
11
|
+
|
|
12
|
+
class SimpleTimeSeries(FeatureExtractor):
|
|
13
|
+
"""SimpleTimeSeries wraps another FeatureExtractor and applies it on an image time series.
|
|
14
|
+
|
|
15
|
+
It independently applies the other FeatureExtractor on each image in the time series to
|
|
12
16
|
extract feature maps. It then provides a few ways to combine the features into one
|
|
13
17
|
final feature map:
|
|
14
18
|
- Temporal max pooling.
|
|
@@ -19,7 +23,7 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
19
23
|
|
|
20
24
|
def __init__(
|
|
21
25
|
self,
|
|
22
|
-
encoder:
|
|
26
|
+
encoder: FeatureExtractor,
|
|
23
27
|
image_channels: int | None = None,
|
|
24
28
|
op: str = "max",
|
|
25
29
|
groups: list[list[int]] | None = None,
|
|
@@ -31,9 +35,9 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
31
35
|
"""Create a new SimpleTimeSeries.
|
|
32
36
|
|
|
33
37
|
Args:
|
|
34
|
-
encoder: the underlying
|
|
35
|
-
function that returns the output channels, or backbone_channels must be
|
|
36
|
-
|
|
38
|
+
encoder: the underlying FeatureExtractor. It must provide get_backbone_channels
|
|
39
|
+
function that returns the output channels, or backbone_channels must be set.
|
|
40
|
+
It must output a FeatureMaps.
|
|
37
41
|
image_channels: the number of channels per image of the time series. The
|
|
38
42
|
input should have multiple images concatenated on the channel axis, so
|
|
39
43
|
this parameter is used to distinguish the different images.
|
|
@@ -179,24 +183,27 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
179
183
|
|
|
180
184
|
def forward(
|
|
181
185
|
self,
|
|
182
|
-
|
|
183
|
-
) ->
|
|
186
|
+
context: ModelContext,
|
|
187
|
+
) -> FeatureMaps:
|
|
184
188
|
"""Compute outputs from the backbone.
|
|
185
189
|
|
|
186
|
-
|
|
187
|
-
|
|
190
|
+
Args:
|
|
191
|
+
context: the model context. Input dicts must include "image" key containing the image time
|
|
188
192
|
series to process (with images concatenated on the channel dimension).
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
the FeatureMaps aggregated temporally.
|
|
189
196
|
"""
|
|
190
197
|
# First get features of each image.
|
|
191
198
|
# To do so, we need to split up each grouped image into its component images (which have had their channels stacked).
|
|
192
199
|
batched_inputs: list[dict[str, Any]] | None = None
|
|
193
|
-
n_batch = len(inputs)
|
|
200
|
+
n_batch = len(context.inputs)
|
|
194
201
|
n_images: int | None = None
|
|
195
202
|
|
|
196
203
|
if self.image_keys is not None:
|
|
197
204
|
for image_key, image_channels in self.image_keys.items():
|
|
198
205
|
batched_images = self._get_batched_images(
|
|
199
|
-
inputs, image_key, image_channels
|
|
206
|
+
context.inputs, image_key, image_channels
|
|
200
207
|
)
|
|
201
208
|
|
|
202
209
|
if batched_inputs is None:
|
|
@@ -213,13 +220,20 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
213
220
|
else:
|
|
214
221
|
assert self.image_channels is not None
|
|
215
222
|
batched_images = self._get_batched_images(
|
|
216
|
-
inputs, self.image_key, self.image_channels
|
|
223
|
+
context.inputs, self.image_key, self.image_channels
|
|
217
224
|
)
|
|
218
225
|
batched_inputs = [{self.image_key: image} for image in batched_images]
|
|
219
226
|
n_images = batched_images.shape[0] // n_batch
|
|
220
227
|
|
|
221
228
|
assert n_images is not None
|
|
222
229
|
|
|
230
|
+
# Now we can apply the underlying FeatureExtractor.
|
|
231
|
+
# Its output must be a FeatureMaps.
|
|
232
|
+
encoder_output = self.encoder(batched_inputs)
|
|
233
|
+
if not isinstance(encoder_output, FeatureMaps):
|
|
234
|
+
raise ValueError(
|
|
235
|
+
"output of underlying FeatureExtractor in SimpleTimeSeries must be a FeatureMaps"
|
|
236
|
+
)
|
|
223
237
|
all_features = [
|
|
224
238
|
feat_map.reshape(
|
|
225
239
|
n_batch,
|
|
@@ -228,7 +242,7 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
228
242
|
feat_map.shape[2],
|
|
229
243
|
feat_map.shape[3],
|
|
230
244
|
)
|
|
231
|
-
for feat_map in
|
|
245
|
+
for feat_map in encoder_output.feature_maps
|
|
232
246
|
]
|
|
233
247
|
|
|
234
248
|
# Groups defaults to flattening all the feature maps.
|
|
@@ -284,7 +298,7 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
284
298
|
.permute(0, 3, 1, 2)
|
|
285
299
|
)
|
|
286
300
|
else:
|
|
287
|
-
raise
|
|
301
|
+
raise ValueError(f"unknown aggregation op {self.op}")
|
|
288
302
|
|
|
289
303
|
aggregated_features.append(group_features)
|
|
290
304
|
|
|
@@ -293,4 +307,4 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
293
307
|
|
|
294
308
|
output_features.append(aggregated_features)
|
|
295
309
|
|
|
296
|
-
return output_features
|
|
310
|
+
return FeatureMaps(output_features)
|
rslearn/models/singletask.py
CHANGED
|
@@ -4,6 +4,10 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext, ModelOutput
|
|
8
|
+
|
|
9
|
+
from .component import FeatureExtractor, IntermediateComponent, Predictor
|
|
10
|
+
|
|
7
11
|
|
|
8
12
|
class SingleTaskModel(torch.nn.Module):
|
|
9
13
|
"""Standard model wrapper.
|
|
@@ -14,38 +18,41 @@ class SingleTaskModel(torch.nn.Module):
|
|
|
14
18
|
outputs and targets from the last module (which also receives the targets).
|
|
15
19
|
"""
|
|
16
20
|
|
|
17
|
-
def __init__(
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
encoder: list[FeatureExtractor | IntermediateComponent],
|
|
24
|
+
decoder: list[IntermediateComponent | Predictor],
|
|
25
|
+
):
|
|
18
26
|
"""Initialize a new SingleTaskModel.
|
|
19
27
|
|
|
20
28
|
Args:
|
|
21
|
-
encoder: modules to compute intermediate feature representations.
|
|
22
|
-
|
|
29
|
+
encoder: modules to compute intermediate feature representations. The first
|
|
30
|
+
module must be a FeatureExtractor, and following modules must be
|
|
31
|
+
IntermediateComponents.
|
|
32
|
+
decoder: modules to compute outputs and loss. The last module must be a
|
|
33
|
+
Predictor, while the previous modules must be IntermediateComponents.
|
|
23
34
|
"""
|
|
24
35
|
super().__init__()
|
|
25
|
-
self.encoder = torch.nn.
|
|
36
|
+
self.encoder = torch.nn.ModuleList(encoder)
|
|
26
37
|
self.decoder = torch.nn.ModuleList(decoder)
|
|
27
38
|
|
|
28
39
|
def forward(
|
|
29
40
|
self,
|
|
30
|
-
|
|
41
|
+
context: ModelContext,
|
|
31
42
|
targets: list[dict[str, Any]] | None = None,
|
|
32
|
-
) ->
|
|
43
|
+
) -> ModelOutput:
|
|
33
44
|
"""Apply the sequence of modules on the inputs.
|
|
34
45
|
|
|
35
46
|
Args:
|
|
36
|
-
|
|
47
|
+
context: the model context.
|
|
37
48
|
targets: optional list of target dicts
|
|
38
|
-
info: optional dictionary of info to pass to the last module
|
|
39
49
|
|
|
40
50
|
Returns:
|
|
41
|
-
|
|
51
|
+
the model output.
|
|
42
52
|
"""
|
|
43
|
-
|
|
44
|
-
|
|
53
|
+
cur = self.encoder[0](context)
|
|
54
|
+
for module in self.encoder[1:]:
|
|
55
|
+
cur = module(cur, context)
|
|
45
56
|
for module in self.decoder[:-1]:
|
|
46
|
-
cur = module(cur,
|
|
47
|
-
|
|
48
|
-
return {
|
|
49
|
-
"outputs": outputs,
|
|
50
|
-
"loss_dict": loss_dict,
|
|
51
|
-
}
|
|
57
|
+
cur = module(cur, context)
|
|
58
|
+
return self.decoder[-1](cur, context, targets)
|
rslearn/models/ssl4eo_s12.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
"""SSL4EO-S12 models."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
3
|
import torch
|
|
6
4
|
import torchvision
|
|
7
5
|
|
|
6
|
+
from rslearn.train.model_context import ModelContext
|
|
7
|
+
|
|
8
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
9
|
+
|
|
8
10
|
|
|
9
|
-
class Ssl4eoS12(
|
|
11
|
+
class Ssl4eoS12(FeatureExtractor):
|
|
10
12
|
"""The SSL4EO-S12 family of pretrained models."""
|
|
11
13
|
|
|
12
14
|
def __init__(
|
|
@@ -74,19 +76,22 @@ class Ssl4eoS12(torch.nn.Module):
|
|
|
74
76
|
|
|
75
77
|
def forward(
|
|
76
78
|
self,
|
|
77
|
-
|
|
78
|
-
) ->
|
|
79
|
+
context: ModelContext,
|
|
80
|
+
) -> FeatureMaps:
|
|
79
81
|
"""Compute outputs from the backbone.
|
|
80
82
|
|
|
81
83
|
If output_layers is set, then the outputs are multi-scale feature maps;
|
|
82
84
|
otherwise, the model is being used for classification so the outputs are class
|
|
83
85
|
probabilities and the loss.
|
|
84
86
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
process.
|
|
87
|
+
Args:
|
|
88
|
+
context: the model context. Input dicts must include "image" key containing
|
|
89
|
+
the images to process.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
feature maps computed by the pre-trained model.
|
|
88
93
|
"""
|
|
89
|
-
x = torch.stack([inp["image"] for inp in inputs], dim=0)
|
|
94
|
+
x = torch.stack([inp["image"] for inp in context.inputs], dim=0)
|
|
90
95
|
x = self.model.conv1(x)
|
|
91
96
|
x = self.model.bn1(x)
|
|
92
97
|
x = self.model.relu(x)
|
|
@@ -97,4 +102,4 @@ class Ssl4eoS12(torch.nn.Module):
|
|
|
97
102
|
layer3 = self.model.layer3(layer2)
|
|
98
103
|
layer4 = self.model.layer4(layer3)
|
|
99
104
|
all_features = [layer1, layer2, layer3, layer4]
|
|
100
|
-
return [all_features[idx] for idx in self.output_layers]
|
|
105
|
+
return FeatureMaps([all_features[idx] for idx in self.output_layers])
|
rslearn/models/swin.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""Swin Transformer."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
3
|
import torch
|
|
6
4
|
import torchvision
|
|
7
5
|
from torchvision.models.swin_transformer import (
|
|
@@ -13,8 +11,12 @@ from torchvision.models.swin_transformer import (
|
|
|
13
11
|
Swin_V2_T_Weights,
|
|
14
12
|
)
|
|
15
13
|
|
|
14
|
+
from rslearn.train.model_context import ModelContext
|
|
15
|
+
|
|
16
|
+
from .component import FeatureExtractor, FeatureMaps, FeatureVector
|
|
17
|
+
|
|
16
18
|
|
|
17
|
-
class Swin(
|
|
19
|
+
class Swin(FeatureExtractor):
|
|
18
20
|
"""A Swin Transformer model.
|
|
19
21
|
|
|
20
22
|
It can either be used stand-alone for classification, or as a feature extractor in
|
|
@@ -34,9 +36,12 @@ class Swin(torch.nn.Module):
|
|
|
34
36
|
Args:
|
|
35
37
|
arch: the architecture, e.g. "swin_v2_b" (default) or "swin_t"
|
|
36
38
|
pretrained: set True to use ImageNet pre-trained weights
|
|
37
|
-
input_channels: number of input channels (default 3)
|
|
39
|
+
input_channels: number of input channels (default 3). If not 3, the first
|
|
40
|
+
layer is updated and will be randomly initialized even if pretrained is
|
|
41
|
+
set.
|
|
38
42
|
output_layers: list of layers to output, default use as classification
|
|
39
|
-
model. For feature extraction, [1, 3, 5, 7] is
|
|
43
|
+
model (output FeatureVector). For feature extraction, [1, 3, 5, 7] is
|
|
44
|
+
recommended.
|
|
40
45
|
num_outputs: number of output logits, defaults to 1000 which matches the
|
|
41
46
|
pretrained models.
|
|
42
47
|
"""
|
|
@@ -130,19 +135,23 @@ class Swin(torch.nn.Module):
|
|
|
130
135
|
|
|
131
136
|
def forward(
|
|
132
137
|
self,
|
|
133
|
-
|
|
134
|
-
) ->
|
|
138
|
+
context: ModelContext,
|
|
139
|
+
) -> FeatureVector | FeatureMaps:
|
|
135
140
|
"""Compute outputs from the backbone.
|
|
136
141
|
|
|
137
142
|
If output_layers is set, then the outputs are multi-scale feature maps;
|
|
138
143
|
otherwise, the model is being used for classification so the outputs are class
|
|
139
144
|
probabilities and the loss.
|
|
140
145
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
process.
|
|
146
|
+
Args:
|
|
147
|
+
context: the model context. Input dicts must include "image" key containing
|
|
148
|
+
the image to process.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
a FeatureVector if the configured output_layers is None, or a FeatureMaps
|
|
152
|
+
otherwise containing one feature map per configured output layer.
|
|
144
153
|
"""
|
|
145
|
-
images = torch.stack([inp["image"] for inp in inputs], dim=0)
|
|
154
|
+
images = torch.stack([inp["image"] for inp in context.inputs], dim=0)
|
|
146
155
|
|
|
147
156
|
if self.output_layers:
|
|
148
157
|
layer_features = []
|
|
@@ -150,7 +159,7 @@ class Swin(torch.nn.Module):
|
|
|
150
159
|
for layer in self.model.features:
|
|
151
160
|
x = layer(x)
|
|
152
161
|
layer_features.append(x.permute(0, 3, 1, 2))
|
|
153
|
-
return [layer_features[idx] for idx in self.output_layers]
|
|
162
|
+
return FeatureMaps([layer_features[idx] for idx in self.output_layers])
|
|
154
163
|
|
|
155
164
|
else:
|
|
156
|
-
return self.model(images)
|
|
165
|
+
return FeatureVector(self.model(images))
|