rslearn 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +2 -9
- rslearn/config/__init__.py +2 -0
- rslearn/config/dataset.py +64 -20
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +34 -84
- rslearn/dataset/materialize.py +5 -5
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +26 -80
- rslearn/lightning_cli.py +22 -11
- rslearn/main.py +12 -37
- rslearn/models/anysat.py +11 -9
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +8 -9
- rslearn/models/clip.py +18 -15
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +21 -11
- rslearn/models/conv.py +15 -8
- rslearn/models/croma.py +13 -8
- rslearn/models/detr/detr.py +25 -14
- rslearn/models/dinov3.py +11 -6
- rslearn/models/faster_rcnn.py +19 -9
- rslearn/models/feature_center_crop.py +12 -9
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/galileo.py +23 -18
- rslearn/models/module_wrapper.py +26 -57
- rslearn/models/molmo.py +16 -14
- rslearn/models/multitask.py +102 -73
- rslearn/models/olmoearth_pretrain/model.py +135 -38
- rslearn/models/panopticon.py +8 -7
- rslearn/models/pick_features.py +18 -24
- rslearn/models/pooling_decoder.py +22 -14
- rslearn/models/presto/presto.py +16 -10
- rslearn/models/presto/single_file_presto.py +4 -10
- rslearn/models/prithvi.py +12 -8
- rslearn/models/resize_features.py +21 -7
- rslearn/models/sam2_enc.py +11 -9
- rslearn/models/satlaspretrain.py +15 -9
- rslearn/models/simple_time_series.py +37 -17
- rslearn/models/singletask.py +24 -17
- rslearn/models/ssl4eo_s12.py +15 -10
- rslearn/models/swin.py +22 -13
- rslearn/models/terramind.py +24 -7
- rslearn/models/trunk.py +6 -3
- rslearn/models/unet.py +18 -9
- rslearn/models/upsample.py +22 -9
- rslearn/train/all_patches_dataset.py +89 -37
- rslearn/train/dataset.py +105 -97
- rslearn/train/lightning_module.py +51 -32
- rslearn/train/model_context.py +54 -0
- rslearn/train/prediction_writer.py +111 -41
- rslearn/train/scheduler.py +15 -0
- rslearn/train/tasks/classification.py +34 -15
- rslearn/train/tasks/detection.py +24 -31
- rslearn/train/tasks/embedding.py +33 -29
- rslearn/train/tasks/multi_task.py +7 -7
- rslearn/train/tasks/per_pixel_regression.py +41 -19
- rslearn/train/tasks/regression.py +38 -21
- rslearn/train/tasks/segmentation.py +33 -15
- rslearn/train/tasks/task.py +3 -2
- rslearn/train/transforms/resize.py +74 -0
- rslearn/utils/geometry.py +73 -0
- rslearn/utils/jsonargparse.py +66 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/RECORD +71 -66
- rslearn/dataset/index.py +0 -173
- rslearn/models/registry.py +0 -22
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
rslearn/models/pick_features.py
CHANGED
|
@@ -2,45 +2,39 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
|
-
import
|
|
5
|
+
from rslearn.train.model_context import ModelContext
|
|
6
6
|
|
|
7
|
+
from .component import (
|
|
8
|
+
FeatureMaps,
|
|
9
|
+
IntermediateComponent,
|
|
10
|
+
)
|
|
7
11
|
|
|
8
|
-
|
|
12
|
+
|
|
13
|
+
class PickFeatures(IntermediateComponent):
|
|
9
14
|
"""Picks a subset of feature maps in a multi-scale feature map list."""
|
|
10
15
|
|
|
11
|
-
def __init__(self, indexes: list[int]
|
|
16
|
+
def __init__(self, indexes: list[int]):
|
|
12
17
|
"""Create a new PickFeatures.
|
|
13
18
|
|
|
14
19
|
Args:
|
|
15
20
|
indexes: the indexes of the input feature map list to select.
|
|
16
|
-
collapse: return one feature map instead of list. If enabled, indexes must
|
|
17
|
-
consist of one index. This is mainly useful for using PickFeatures as
|
|
18
|
-
the final module in the decoder, since the final prediction is expected
|
|
19
|
-
to be one feature map for most tasks like segmentation.
|
|
20
21
|
"""
|
|
21
22
|
super().__init__()
|
|
22
23
|
self.indexes = indexes
|
|
23
|
-
self.collapse = collapse
|
|
24
|
-
|
|
25
|
-
if self.collapse and len(self.indexes) != 1:
|
|
26
|
-
raise ValueError("if collapse is enabled, must get exactly one index")
|
|
27
24
|
|
|
28
25
|
def forward(
|
|
29
26
|
self,
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
) -> list[torch.Tensor]:
|
|
27
|
+
intermediates: Any,
|
|
28
|
+
context: ModelContext,
|
|
29
|
+
) -> FeatureMaps:
|
|
34
30
|
"""Pick a subset of the features.
|
|
35
31
|
|
|
36
32
|
Args:
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
targets: targets, not used
|
|
33
|
+
intermediates: the output from the previous component, which must be a FeatureMaps.
|
|
34
|
+
context: the model context.
|
|
40
35
|
"""
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
return new_features
|
|
36
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
37
|
+
raise ValueError("input to PickFeatures must be FeatureMaps")
|
|
38
|
+
|
|
39
|
+
new_features = [intermediates.feature_maps[idx] for idx in self.indexes]
|
|
40
|
+
return FeatureMaps(new_features)
|
|
@@ -4,8 +4,16 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
7
8
|
|
|
8
|
-
|
|
9
|
+
from .component import (
|
|
10
|
+
FeatureMaps,
|
|
11
|
+
FeatureVector,
|
|
12
|
+
IntermediateComponent,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PoolingDecoder(IntermediateComponent):
|
|
9
17
|
"""Decoder that computes flat vector from a 2D feature map.
|
|
10
18
|
|
|
11
19
|
It inputs multi-scale features, but only uses the last feature map. Then applies a
|
|
@@ -57,25 +65,26 @@ class PoolingDecoder(torch.nn.Module):
|
|
|
57
65
|
|
|
58
66
|
self.output_layer = torch.nn.Linear(prev_channels, out_channels)
|
|
59
67
|
|
|
60
|
-
def forward(
|
|
61
|
-
self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
|
|
62
|
-
) -> torch.Tensor:
|
|
68
|
+
def forward(self, intermediates: Any, context: ModelContext) -> Any:
|
|
63
69
|
"""Compute flat output vector from multi-scale feature map.
|
|
64
70
|
|
|
65
71
|
Args:
|
|
66
|
-
|
|
67
|
-
|
|
72
|
+
intermediates: the output from the previous component, which must be a FeatureMaps.
|
|
73
|
+
context: the model context.
|
|
68
74
|
|
|
69
75
|
Returns:
|
|
70
76
|
flat feature vector
|
|
71
77
|
"""
|
|
78
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
79
|
+
raise ValueError("input to PoolingDecoder must be a FeatureMaps")
|
|
80
|
+
|
|
72
81
|
# Only use last feature map.
|
|
73
|
-
features =
|
|
82
|
+
features = intermediates.feature_maps[-1]
|
|
74
83
|
|
|
75
84
|
features = self.conv_layers(features)
|
|
76
85
|
features = torch.amax(features, dim=(2, 3))
|
|
77
86
|
features = self.fc_layers(features)
|
|
78
|
-
return self.output_layer(features)
|
|
87
|
+
return FeatureVector(self.output_layer(features))
|
|
79
88
|
|
|
80
89
|
|
|
81
90
|
class SegmentationPoolingDecoder(PoolingDecoder):
|
|
@@ -108,14 +117,13 @@ class SegmentationPoolingDecoder(PoolingDecoder):
|
|
|
108
117
|
super().__init__(in_channels=in_channels, out_channels=out_channels, **kwargs)
|
|
109
118
|
self.image_key = image_key
|
|
110
119
|
|
|
111
|
-
def forward(
|
|
112
|
-
self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
|
|
113
|
-
) -> torch.Tensor:
|
|
120
|
+
def forward(self, intermediates: Any, context: ModelContext) -> Any:
|
|
114
121
|
"""Extend PoolingDecoder forward to upsample the output to a segmentation mask.
|
|
115
122
|
|
|
116
123
|
This only works when all of the pixels have the same segmentation target.
|
|
117
124
|
"""
|
|
118
|
-
output_probs = super().forward(
|
|
125
|
+
output_probs = super().forward(intermediates, context)
|
|
119
126
|
# BC -> BCHW
|
|
120
|
-
h, w = inputs[0][self.image_key].shape[1:3]
|
|
121
|
-
|
|
127
|
+
h, w = context.inputs[0][self.image_key].shape[1:3]
|
|
128
|
+
feat_map = output_probs.feature_vector[:, :, None, None].repeat([1, 1, h, w])
|
|
129
|
+
return FeatureMaps([feat_map])
|
rslearn/models/presto/presto.py
CHANGED
|
@@ -2,14 +2,13 @@
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import tempfile
|
|
5
|
-
from typing import Any
|
|
6
5
|
|
|
7
6
|
import torch
|
|
8
7
|
from einops import rearrange, repeat
|
|
9
8
|
from huggingface_hub import hf_hub_download
|
|
10
|
-
from torch import nn
|
|
11
9
|
from upath import UPath
|
|
12
10
|
|
|
11
|
+
from rslearn.models.component import FeatureExtractor, FeatureMaps
|
|
13
12
|
from rslearn.models.presto.single_file_presto import (
|
|
14
13
|
ERA5_BANDS,
|
|
15
14
|
NUM_DYNAMIC_WORLD_CLASSES,
|
|
@@ -21,6 +20,7 @@ from rslearn.models.presto.single_file_presto import (
|
|
|
21
20
|
SRTM_BANDS,
|
|
22
21
|
)
|
|
23
22
|
from rslearn.models.presto.single_file_presto import Presto as SFPresto
|
|
23
|
+
from rslearn.train.model_context import ModelContext
|
|
24
24
|
|
|
25
25
|
logger = logging.getLogger(__name__)
|
|
26
26
|
|
|
@@ -36,7 +36,7 @@ HF_HUB_ID = "nasaharvest/presto"
|
|
|
36
36
|
MODEL_FILENAME = "default_model.pt"
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
class Presto(
|
|
39
|
+
class Presto(FeatureExtractor):
|
|
40
40
|
"""Presto."""
|
|
41
41
|
|
|
42
42
|
input_keys = [
|
|
@@ -184,22 +184,26 @@ class Presto(nn.Module):
|
|
|
184
184
|
x = (x + PRESTO_ADD_BY.to(device=device)) / PRESTO_DIV_BY.to(device=device)
|
|
185
185
|
return x, mask, dynamic_world.long(), months.long()
|
|
186
186
|
|
|
187
|
-
def forward(self,
|
|
187
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
188
188
|
"""Compute feature maps from the Presto backbone.
|
|
189
189
|
|
|
190
|
-
|
|
191
|
-
|
|
190
|
+
Args:
|
|
191
|
+
context: the model context. Input dicts should have some subset of Presto.input_keys.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
a FeatureMaps with one feature map that is at the same resolution as the
|
|
195
|
+
input (since Presto operates per-pixel).
|
|
192
196
|
"""
|
|
193
197
|
stacked_inputs = {}
|
|
194
198
|
latlons: torch.Tensor | None = None
|
|
195
|
-
for key in inputs[0].keys():
|
|
199
|
+
for key in context.inputs[0].keys():
|
|
196
200
|
# assume all the keys in an input are consistent
|
|
197
201
|
if key in self.input_keys:
|
|
198
202
|
if key == "latlon":
|
|
199
|
-
latlons = torch.stack([inp[key] for inp in inputs], dim=0)
|
|
203
|
+
latlons = torch.stack([inp[key] for inp in context.inputs], dim=0)
|
|
200
204
|
else:
|
|
201
205
|
stacked_inputs[key] = torch.stack(
|
|
202
|
-
[inp[key] for inp in inputs], dim=0
|
|
206
|
+
[inp[key] for inp in context.inputs], dim=0
|
|
203
207
|
)
|
|
204
208
|
|
|
205
209
|
(
|
|
@@ -247,7 +251,9 @@ class Presto(nn.Module):
|
|
|
247
251
|
)
|
|
248
252
|
output_features[batch_idx : batch_idx + self.pixel_batch_size] = output_b
|
|
249
253
|
|
|
250
|
-
return
|
|
254
|
+
return FeatureMaps(
|
|
255
|
+
[rearrange(output_features, "(b h w) d -> b d h w", h=h, w=w, b=b)]
|
|
256
|
+
)
|
|
251
257
|
|
|
252
258
|
def get_backbone_channels(self) -> list:
|
|
253
259
|
"""Returns the output channels of this model when used as a backbone.
|
|
@@ -281,10 +281,7 @@ def get_sinusoid_encoding_table(
|
|
|
281
281
|
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
|
282
282
|
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
|
283
283
|
|
|
284
|
-
|
|
285
|
-
return torch.FloatTensor(sinusoid_table).cuda()
|
|
286
|
-
else:
|
|
287
|
-
return torch.FloatTensor(sinusoid_table)
|
|
284
|
+
return torch.FloatTensor(sinusoid_table)
|
|
288
285
|
|
|
289
286
|
|
|
290
287
|
def get_month_encoding_table(d_hid: int) -> torch.Tensor:
|
|
@@ -296,10 +293,7 @@ def get_month_encoding_table(d_hid: int) -> torch.Tensor:
|
|
|
296
293
|
cos_table = np.cos(np.stack([angles for _ in range(d_hid // 2)], axis=-1))
|
|
297
294
|
month_table = np.concatenate([sin_table[:-1], cos_table[:-1]], axis=-1)
|
|
298
295
|
|
|
299
|
-
|
|
300
|
-
return torch.FloatTensor(month_table).cuda()
|
|
301
|
-
else:
|
|
302
|
-
return torch.FloatTensor(month_table)
|
|
296
|
+
return torch.FloatTensor(month_table)
|
|
303
297
|
|
|
304
298
|
|
|
305
299
|
def month_to_tensor(
|
|
@@ -405,7 +399,7 @@ class Encoder(nn.Module):
|
|
|
405
399
|
"""initialize_weights."""
|
|
406
400
|
pos_embed = get_sinusoid_encoding_table(
|
|
407
401
|
self.pos_embed.shape[1], self.pos_embed.shape[-1]
|
|
408
|
-
)
|
|
402
|
+
).to(device=self.pos_embed.device)
|
|
409
403
|
self.pos_embed.data.copy_(pos_embed)
|
|
410
404
|
|
|
411
405
|
# initialize nn.Linear and nn.LayerNorm
|
|
@@ -640,7 +634,7 @@ class Decoder(nn.Module):
|
|
|
640
634
|
"""initialize_weights."""
|
|
641
635
|
pos_embed = get_sinusoid_encoding_table(
|
|
642
636
|
self.pos_embed.shape[1], self.pos_embed.shape[-1]
|
|
643
|
-
)
|
|
637
|
+
).to(device=self.pos_embed.device)
|
|
644
638
|
self.pos_embed.data.copy_(pos_embed)
|
|
645
639
|
|
|
646
640
|
# initialize nn.Linear and nn.LayerNorm
|
rslearn/models/prithvi.py
CHANGED
|
@@ -25,9 +25,12 @@ from timm.layers import to_2tuple
|
|
|
25
25
|
from timm.models.vision_transformer import Block
|
|
26
26
|
from torch.nn import functional as F
|
|
27
27
|
|
|
28
|
+
from rslearn.train.model_context import ModelContext
|
|
28
29
|
from rslearn.train.transforms.normalize import Normalize
|
|
29
30
|
from rslearn.train.transforms.transform import Transform
|
|
30
31
|
|
|
32
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
33
|
+
|
|
31
34
|
logger = logging.getLogger(__name__)
|
|
32
35
|
|
|
33
36
|
|
|
@@ -77,7 +80,7 @@ def get_config(cache_dir: Path, hf_hub_id: str, hf_hub_revision: str) -> dict[st
|
|
|
77
80
|
return json.load(f)["pretrained_cfg"]
|
|
78
81
|
|
|
79
82
|
|
|
80
|
-
class PrithviV2(
|
|
83
|
+
class PrithviV2(FeatureExtractor):
|
|
81
84
|
"""An Rslearn wrapper for Prithvi 2.0."""
|
|
82
85
|
|
|
83
86
|
INPUT_KEY = "image"
|
|
@@ -157,18 +160,18 @@ class PrithviV2(nn.Module):
|
|
|
157
160
|
)
|
|
158
161
|
return data
|
|
159
162
|
|
|
160
|
-
def forward(self,
|
|
163
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
161
164
|
"""Compute feature maps from the Prithvi V2 backbone.
|
|
162
165
|
|
|
163
166
|
Args:
|
|
164
|
-
|
|
165
|
-
(Harmonized Landsat-Sentinel) data.
|
|
167
|
+
context: the model context. Input dicts must include "image" key containing
|
|
168
|
+
HLS (Harmonized Landsat-Sentinel) data.
|
|
166
169
|
|
|
167
170
|
Returns:
|
|
168
|
-
|
|
169
|
-
|
|
171
|
+
a FeatureMaps with one map of shape [B, H/p_s, W/p_s, 11*1024] that contains stacked
|
|
172
|
+
feature maps across the 11 transformer blocks.
|
|
170
173
|
"""
|
|
171
|
-
x = torch.stack([inp[self.INPUT_KEY] for inp in inputs], dim=0)
|
|
174
|
+
x = torch.stack([inp[self.INPUT_KEY] for inp in context.inputs], dim=0)
|
|
172
175
|
x = self._resize_data(x)
|
|
173
176
|
num_timesteps = x.shape[1] // len(self.bands)
|
|
174
177
|
x = rearrange(x, "b (t c) h w -> b c t h w", t=num_timesteps)
|
|
@@ -177,9 +180,10 @@ class PrithviV2(nn.Module):
|
|
|
177
180
|
# know the number of timesteps and don't need to recompute it.
|
|
178
181
|
# in addition we average along the time dimension (instead of concatenating)
|
|
179
182
|
# to keep the embeddings reasonably sized.
|
|
180
|
-
|
|
183
|
+
result = self.model.encoder.prepare_features_for_image_model(
|
|
181
184
|
features, num_timesteps
|
|
182
185
|
)
|
|
186
|
+
return FeatureMaps([torch.cat(result, dim=1)])
|
|
183
187
|
|
|
184
188
|
def get_backbone_channels(self) -> list:
|
|
185
189
|
"""Returns the output channels of this model when used as a backbone.
|
|
@@ -1,9 +1,18 @@
|
|
|
1
1
|
"""The ResizeFeatures module."""
|
|
2
2
|
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
3
5
|
import torch
|
|
4
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
8
|
+
|
|
9
|
+
from .component import (
|
|
10
|
+
FeatureMaps,
|
|
11
|
+
IntermediateComponent,
|
|
12
|
+
)
|
|
13
|
+
|
|
5
14
|
|
|
6
|
-
class ResizeFeatures(
|
|
15
|
+
class ResizeFeatures(IntermediateComponent):
|
|
7
16
|
"""Resize input features to new sizes."""
|
|
8
17
|
|
|
9
18
|
def __init__(
|
|
@@ -30,16 +39,21 @@ class ResizeFeatures(torch.nn.Module):
|
|
|
30
39
|
)
|
|
31
40
|
self.layers = torch.nn.ModuleList(layers)
|
|
32
41
|
|
|
33
|
-
def forward(
|
|
34
|
-
self, features: list[torch.Tensor], inputs: list[torch.Tensor]
|
|
35
|
-
) -> list[torch.Tensor]:
|
|
42
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
36
43
|
"""Resize the input feature maps to new sizes.
|
|
37
44
|
|
|
38
45
|
Args:
|
|
39
|
-
|
|
40
|
-
|
|
46
|
+
intermediates: the outputs from the previous component, which must be a FeatureMaps.
|
|
47
|
+
context: the model context.
|
|
41
48
|
|
|
42
49
|
Returns:
|
|
43
50
|
resized feature maps
|
|
44
51
|
"""
|
|
45
|
-
|
|
52
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
53
|
+
raise ValueError("input to ResizeFeatures must be a FeatureMaps")
|
|
54
|
+
|
|
55
|
+
feat_maps = intermediates.feature_maps
|
|
56
|
+
resized_feat_maps = [
|
|
57
|
+
self.layers[idx](feat_map) for idx, feat_map in enumerate(feat_maps)
|
|
58
|
+
]
|
|
59
|
+
return FeatureMaps(resized_feat_maps)
|
rslearn/models/sam2_enc.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
"""SegmentAnything2 encoders."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
3
|
import torch
|
|
6
|
-
import torch.nn as nn
|
|
7
4
|
from sam2.build_sam import build_sam2
|
|
8
5
|
from upath import UPath
|
|
9
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
8
|
+
|
|
9
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
10
|
+
|
|
10
11
|
|
|
11
|
-
class SAM2Encoder(
|
|
12
|
+
class SAM2Encoder(FeatureExtractor):
|
|
12
13
|
"""SAM2's image encoder."""
|
|
13
14
|
|
|
14
15
|
def __init__(self, model_identifier: str) -> None:
|
|
@@ -84,18 +85,19 @@ class SAM2Encoder(nn.Module):
|
|
|
84
85
|
del self.model.obj_ptr_proj
|
|
85
86
|
del self.model.image_encoder.neck
|
|
86
87
|
|
|
87
|
-
def forward(self,
|
|
88
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
88
89
|
"""Extract multi-scale features from a batch of images.
|
|
89
90
|
|
|
90
91
|
Args:
|
|
91
|
-
|
|
92
|
+
context: the model context. Input dicts must have a key 'image' containing
|
|
93
|
+
the input for the SAM2 image encoder.
|
|
92
94
|
|
|
93
95
|
Returns:
|
|
94
|
-
|
|
96
|
+
feature maps from the encoder.
|
|
95
97
|
"""
|
|
96
|
-
images = torch.stack([inp["image"] for inp in inputs], dim=0)
|
|
98
|
+
images = torch.stack([inp["image"] for inp in context.inputs], dim=0)
|
|
97
99
|
features = self.encoder(images)
|
|
98
|
-
return features
|
|
100
|
+
return FeatureMaps(features)
|
|
99
101
|
|
|
100
102
|
def get_backbone_channels(self) -> list[list[int]]:
|
|
101
103
|
"""Returns the output channels of the encoder at different scales.
|
rslearn/models/satlaspretrain.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
"""SatlasPretrain models."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
3
|
import satlaspretrain_models
|
|
6
4
|
import torch
|
|
7
5
|
import torch.nn.functional as F
|
|
8
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
8
|
+
|
|
9
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
10
|
+
|
|
9
11
|
|
|
10
|
-
class SatlasPretrain(
|
|
12
|
+
class SatlasPretrain(FeatureExtractor):
|
|
11
13
|
"""SatlasPretrain backbones."""
|
|
12
14
|
|
|
13
15
|
def __init__(
|
|
@@ -64,15 +66,19 @@ class SatlasPretrain(torch.nn.Module):
|
|
|
64
66
|
else:
|
|
65
67
|
return data
|
|
66
68
|
|
|
67
|
-
def forward(self,
|
|
69
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
68
70
|
"""Compute feature maps from the SatlasPretrain backbone.
|
|
69
71
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
72
|
+
Args:
|
|
73
|
+
context: the model context. Input dicts must contain an "image" key
|
|
74
|
+
containing the image input to the model.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
multi-resolution feature maps computed by the model.
|
|
73
78
|
"""
|
|
74
|
-
images = torch.stack([inp["image"] for inp in inputs], dim=0)
|
|
75
|
-
|
|
79
|
+
images = torch.stack([inp["image"] for inp in context.inputs], dim=0)
|
|
80
|
+
feature_maps = self.model(self.maybe_resize(images))
|
|
81
|
+
return FeatureMaps(feature_maps)
|
|
76
82
|
|
|
77
83
|
def get_backbone_channels(self) -> list:
|
|
78
84
|
"""Returns the output channels of this model when used as a backbone.
|
|
@@ -4,11 +4,15 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
7
8
|
|
|
8
|
-
|
|
9
|
-
"""SimpleTimeSeries wraps another encoder and applies it on an image time series.
|
|
9
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
10
10
|
|
|
11
|
-
|
|
11
|
+
|
|
12
|
+
class SimpleTimeSeries(FeatureExtractor):
|
|
13
|
+
"""SimpleTimeSeries wraps another FeatureExtractor and applies it on an image time series.
|
|
14
|
+
|
|
15
|
+
It independently applies the other FeatureExtractor on each image in the time series to
|
|
12
16
|
extract feature maps. It then provides a few ways to combine the features into one
|
|
13
17
|
final feature map:
|
|
14
18
|
- Temporal max pooling.
|
|
@@ -19,7 +23,7 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
19
23
|
|
|
20
24
|
def __init__(
|
|
21
25
|
self,
|
|
22
|
-
encoder:
|
|
26
|
+
encoder: FeatureExtractor,
|
|
23
27
|
image_channels: int | None = None,
|
|
24
28
|
op: str = "max",
|
|
25
29
|
groups: list[list[int]] | None = None,
|
|
@@ -31,9 +35,9 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
31
35
|
"""Create a new SimpleTimeSeries.
|
|
32
36
|
|
|
33
37
|
Args:
|
|
34
|
-
encoder: the underlying
|
|
35
|
-
function that returns the output channels, or backbone_channels must be
|
|
36
|
-
|
|
38
|
+
encoder: the underlying FeatureExtractor. It must provide get_backbone_channels
|
|
39
|
+
function that returns the output channels, or backbone_channels must be set.
|
|
40
|
+
It must output a FeatureMaps.
|
|
37
41
|
image_channels: the number of channels per image of the time series. The
|
|
38
42
|
input should have multiple images concatenated on the channel axis, so
|
|
39
43
|
this parameter is used to distinguish the different images.
|
|
@@ -179,24 +183,27 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
179
183
|
|
|
180
184
|
def forward(
|
|
181
185
|
self,
|
|
182
|
-
|
|
183
|
-
) ->
|
|
186
|
+
context: ModelContext,
|
|
187
|
+
) -> FeatureMaps:
|
|
184
188
|
"""Compute outputs from the backbone.
|
|
185
189
|
|
|
186
|
-
|
|
187
|
-
|
|
190
|
+
Args:
|
|
191
|
+
context: the model context. Input dicts must include "image" key containing the image time
|
|
188
192
|
series to process (with images concatenated on the channel dimension).
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
the FeatureMaps aggregated temporally.
|
|
189
196
|
"""
|
|
190
197
|
# First get features of each image.
|
|
191
198
|
# To do so, we need to split up each grouped image into its component images (which have had their channels stacked).
|
|
192
199
|
batched_inputs: list[dict[str, Any]] | None = None
|
|
193
|
-
n_batch = len(inputs)
|
|
200
|
+
n_batch = len(context.inputs)
|
|
194
201
|
n_images: int | None = None
|
|
195
202
|
|
|
196
203
|
if self.image_keys is not None:
|
|
197
204
|
for image_key, image_channels in self.image_keys.items():
|
|
198
205
|
batched_images = self._get_batched_images(
|
|
199
|
-
inputs, image_key, image_channels
|
|
206
|
+
context.inputs, image_key, image_channels
|
|
200
207
|
)
|
|
201
208
|
|
|
202
209
|
if batched_inputs is None:
|
|
@@ -213,13 +220,26 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
213
220
|
else:
|
|
214
221
|
assert self.image_channels is not None
|
|
215
222
|
batched_images = self._get_batched_images(
|
|
216
|
-
inputs, self.image_key, self.image_channels
|
|
223
|
+
context.inputs, self.image_key, self.image_channels
|
|
217
224
|
)
|
|
218
225
|
batched_inputs = [{self.image_key: image} for image in batched_images]
|
|
219
226
|
n_images = batched_images.shape[0] // n_batch
|
|
220
227
|
|
|
221
228
|
assert n_images is not None
|
|
222
229
|
|
|
230
|
+
# Now we can apply the underlying FeatureExtractor.
|
|
231
|
+
# Its output must be a FeatureMaps.
|
|
232
|
+
assert batched_inputs is not None
|
|
233
|
+
encoder_output = self.encoder(
|
|
234
|
+
ModelContext(
|
|
235
|
+
inputs=batched_inputs,
|
|
236
|
+
metadatas=context.metadatas,
|
|
237
|
+
)
|
|
238
|
+
)
|
|
239
|
+
if not isinstance(encoder_output, FeatureMaps):
|
|
240
|
+
raise ValueError(
|
|
241
|
+
"output of underlying FeatureExtractor in SimpleTimeSeries must be a FeatureMaps"
|
|
242
|
+
)
|
|
223
243
|
all_features = [
|
|
224
244
|
feat_map.reshape(
|
|
225
245
|
n_batch,
|
|
@@ -228,7 +248,7 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
228
248
|
feat_map.shape[2],
|
|
229
249
|
feat_map.shape[3],
|
|
230
250
|
)
|
|
231
|
-
for feat_map in
|
|
251
|
+
for feat_map in encoder_output.feature_maps
|
|
232
252
|
]
|
|
233
253
|
|
|
234
254
|
# Groups defaults to flattening all the feature maps.
|
|
@@ -284,7 +304,7 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
284
304
|
.permute(0, 3, 1, 2)
|
|
285
305
|
)
|
|
286
306
|
else:
|
|
287
|
-
raise
|
|
307
|
+
raise ValueError(f"unknown aggregation op {self.op}")
|
|
288
308
|
|
|
289
309
|
aggregated_features.append(group_features)
|
|
290
310
|
|
|
@@ -293,4 +313,4 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
293
313
|
|
|
294
314
|
output_features.append(aggregated_features)
|
|
295
315
|
|
|
296
|
-
return output_features
|
|
316
|
+
return FeatureMaps(output_features)
|
rslearn/models/singletask.py
CHANGED
|
@@ -4,6 +4,10 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext, ModelOutput
|
|
8
|
+
|
|
9
|
+
from .component import FeatureExtractor, IntermediateComponent, Predictor
|
|
10
|
+
|
|
7
11
|
|
|
8
12
|
class SingleTaskModel(torch.nn.Module):
|
|
9
13
|
"""Standard model wrapper.
|
|
@@ -14,38 +18,41 @@ class SingleTaskModel(torch.nn.Module):
|
|
|
14
18
|
outputs and targets from the last module (which also receives the targets).
|
|
15
19
|
"""
|
|
16
20
|
|
|
17
|
-
def __init__(
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
encoder: list[FeatureExtractor | IntermediateComponent],
|
|
24
|
+
decoder: list[IntermediateComponent | Predictor],
|
|
25
|
+
):
|
|
18
26
|
"""Initialize a new SingleTaskModel.
|
|
19
27
|
|
|
20
28
|
Args:
|
|
21
|
-
encoder: modules to compute intermediate feature representations.
|
|
22
|
-
|
|
29
|
+
encoder: modules to compute intermediate feature representations. The first
|
|
30
|
+
module must be a FeatureExtractor, and following modules must be
|
|
31
|
+
IntermediateComponents.
|
|
32
|
+
decoder: modules to compute outputs and loss. The last module must be a
|
|
33
|
+
Predictor, while the previous modules must be IntermediateComponents.
|
|
23
34
|
"""
|
|
24
35
|
super().__init__()
|
|
25
|
-
self.encoder = torch.nn.
|
|
36
|
+
self.encoder = torch.nn.ModuleList(encoder)
|
|
26
37
|
self.decoder = torch.nn.ModuleList(decoder)
|
|
27
38
|
|
|
28
39
|
def forward(
|
|
29
40
|
self,
|
|
30
|
-
|
|
41
|
+
context: ModelContext,
|
|
31
42
|
targets: list[dict[str, Any]] | None = None,
|
|
32
|
-
) ->
|
|
43
|
+
) -> ModelOutput:
|
|
33
44
|
"""Apply the sequence of modules on the inputs.
|
|
34
45
|
|
|
35
46
|
Args:
|
|
36
|
-
|
|
47
|
+
context: the model context.
|
|
37
48
|
targets: optional list of target dicts
|
|
38
|
-
info: optional dictionary of info to pass to the last module
|
|
39
49
|
|
|
40
50
|
Returns:
|
|
41
|
-
|
|
51
|
+
the model output.
|
|
42
52
|
"""
|
|
43
|
-
|
|
44
|
-
|
|
53
|
+
cur = self.encoder[0](context)
|
|
54
|
+
for module in self.encoder[1:]:
|
|
55
|
+
cur = module(cur, context)
|
|
45
56
|
for module in self.decoder[:-1]:
|
|
46
|
-
cur = module(cur,
|
|
47
|
-
|
|
48
|
-
return {
|
|
49
|
-
"outputs": outputs,
|
|
50
|
-
"loss_dict": loss_dict,
|
|
51
|
-
}
|
|
57
|
+
cur = module(cur, context)
|
|
58
|
+
return self.decoder[-1](cur, context, targets)
|