rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""The ResizeFeatures module."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
8
|
+
|
|
9
|
+
from .component import (
|
|
10
|
+
FeatureMaps,
|
|
11
|
+
IntermediateComponent,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ResizeFeatures(IntermediateComponent):
|
|
16
|
+
"""Resize input features to new sizes."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
out_sizes: list[tuple[int, int]],
|
|
21
|
+
mode: str = "bilinear",
|
|
22
|
+
):
|
|
23
|
+
"""Initialize a ResizeFeatures.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
out_sizes: the output sizes of the feature maps. There must be one entry
|
|
27
|
+
for each input feature map.
|
|
28
|
+
mode: mode to pass to torch.nn.Upsample, e.g. "bilinear" (default) or
|
|
29
|
+
"nearest".
|
|
30
|
+
"""
|
|
31
|
+
super().__init__()
|
|
32
|
+
layers = []
|
|
33
|
+
for size in out_sizes:
|
|
34
|
+
layers.append(
|
|
35
|
+
torch.nn.Upsample(
|
|
36
|
+
size=size,
|
|
37
|
+
mode=mode,
|
|
38
|
+
)
|
|
39
|
+
)
|
|
40
|
+
self.layers = torch.nn.ModuleList(layers)
|
|
41
|
+
|
|
42
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
43
|
+
"""Resize the input feature maps to new sizes.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
intermediates: the outputs from the previous component, which must be a FeatureMaps.
|
|
47
|
+
context: the model context.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
resized feature maps
|
|
51
|
+
"""
|
|
52
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
53
|
+
raise ValueError("input to ResizeFeatures must be a FeatureMaps")
|
|
54
|
+
|
|
55
|
+
feat_maps = intermediates.feature_maps
|
|
56
|
+
resized_feat_maps = [
|
|
57
|
+
self.layers[idx](feat_map) for idx, feat_map in enumerate(feat_maps)
|
|
58
|
+
]
|
|
59
|
+
return FeatureMaps(resized_feat_maps)
|
rslearn/models/sam2_enc.py
CHANGED
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
"""SegmentAnything2 encoders."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
3
|
import torch
|
|
6
|
-
import torch.nn as nn
|
|
7
4
|
from sam2.build_sam import build_sam2
|
|
8
5
|
from upath import UPath
|
|
9
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
8
|
+
|
|
9
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
10
|
+
|
|
10
11
|
|
|
11
|
-
class SAM2Encoder(
|
|
12
|
+
class SAM2Encoder(FeatureExtractor):
|
|
12
13
|
"""SAM2's image encoder."""
|
|
13
14
|
|
|
14
15
|
def __init__(self, model_identifier: str) -> None:
|
|
@@ -84,18 +85,21 @@ class SAM2Encoder(nn.Module):
|
|
|
84
85
|
del self.model.obj_ptr_proj
|
|
85
86
|
del self.model.image_encoder.neck
|
|
86
87
|
|
|
87
|
-
def forward(self,
|
|
88
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
88
89
|
"""Extract multi-scale features from a batch of images.
|
|
89
90
|
|
|
90
91
|
Args:
|
|
91
|
-
|
|
92
|
+
context: the model context. Input dicts must have a key 'image' containing
|
|
93
|
+
the input for the SAM2 image encoder.
|
|
92
94
|
|
|
93
95
|
Returns:
|
|
94
|
-
|
|
96
|
+
feature maps from the encoder.
|
|
95
97
|
"""
|
|
96
|
-
images = torch.stack(
|
|
98
|
+
images = torch.stack(
|
|
99
|
+
[inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
|
|
100
|
+
)
|
|
97
101
|
features = self.encoder(images)
|
|
98
|
-
return features
|
|
102
|
+
return FeatureMaps(features)
|
|
99
103
|
|
|
100
104
|
def get_backbone_channels(self) -> list[list[int]]:
|
|
101
105
|
"""Returns the output channels of the encoder at different scales.
|
rslearn/models/satlaspretrain.py
CHANGED
|
@@ -1,19 +1,20 @@
|
|
|
1
1
|
"""SatlasPretrain models."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
3
|
import satlaspretrain_models
|
|
6
4
|
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
7
8
|
|
|
9
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
8
10
|
|
|
9
|
-
|
|
11
|
+
|
|
12
|
+
class SatlasPretrain(FeatureExtractor):
|
|
10
13
|
"""SatlasPretrain backbones."""
|
|
11
14
|
|
|
12
15
|
def __init__(
|
|
13
|
-
self,
|
|
14
|
-
|
|
15
|
-
fpn: bool = False,
|
|
16
|
-
):
|
|
16
|
+
self, model_identifier: str, fpn: bool = False, resize_to_pretrain: bool = False
|
|
17
|
+
) -> None:
|
|
17
18
|
"""Instantiate a new SatlasPretrain instance.
|
|
18
19
|
|
|
19
20
|
Args:
|
|
@@ -21,11 +22,13 @@ class SatlasPretrain(torch.nn.Module):
|
|
|
21
22
|
https://github.com/allenai/satlaspretrain_models
|
|
22
23
|
fpn: whether to include the feature pyramid network, otherwise only the
|
|
23
24
|
Swin-v2-Transformer is used.
|
|
25
|
+
resize_to_pretrain: whether to resize inputs to the pretraining input
|
|
26
|
+
size (512 x 512)
|
|
24
27
|
"""
|
|
25
28
|
super().__init__()
|
|
26
29
|
weights_manager = satlaspretrain_models.Weights()
|
|
27
30
|
self.model = weights_manager.get_pretrained_model(
|
|
28
|
-
model_identifier=model_identifier, fpn=fpn
|
|
31
|
+
model_identifier=model_identifier, fpn=fpn, device="cpu"
|
|
29
32
|
)
|
|
30
33
|
|
|
31
34
|
if "SwinB" in model_identifier:
|
|
@@ -49,21 +52,38 @@ class SatlasPretrain(torch.nn.Module):
|
|
|
49
52
|
[16, 1024],
|
|
50
53
|
[32, 2048],
|
|
51
54
|
]
|
|
55
|
+
self.resize_to_pretrain = resize_to_pretrain
|
|
56
|
+
|
|
57
|
+
def maybe_resize(self, data: torch.Tensor) -> list[torch.Tensor]:
|
|
58
|
+
"""Resize to pretraining sizes if resize_to_pretrain == True."""
|
|
59
|
+
if self.resize_to_pretrain:
|
|
60
|
+
return F.interpolate(
|
|
61
|
+
data,
|
|
62
|
+
size=(512, 512),
|
|
63
|
+
mode="bilinear",
|
|
64
|
+
align_corners=False,
|
|
65
|
+
)
|
|
66
|
+
else:
|
|
67
|
+
return data
|
|
52
68
|
|
|
53
|
-
def forward(
|
|
54
|
-
self, inputs: list[dict[str, Any]], targets: list[dict[str, Any]] = None
|
|
55
|
-
):
|
|
69
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
56
70
|
"""Compute feature maps from the SatlasPretrain backbone.
|
|
57
71
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
72
|
+
Args:
|
|
73
|
+
context: the model context. Input dicts must contain an "image" key
|
|
74
|
+
containing the image input to the model.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
multi-resolution feature maps computed by the model.
|
|
62
78
|
"""
|
|
63
|
-
|
|
64
|
-
|
|
79
|
+
# take the first (assumed to be only) timestep
|
|
80
|
+
images = torch.stack(
|
|
81
|
+
[inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
|
|
82
|
+
)
|
|
83
|
+
feature_maps = self.model(self.maybe_resize(images))
|
|
84
|
+
return FeatureMaps(feature_maps)
|
|
65
85
|
|
|
66
|
-
def get_backbone_channels(self):
|
|
86
|
+
def get_backbone_channels(self) -> list:
|
|
67
87
|
"""Returns the output channels of this model when used as a backbone.
|
|
68
88
|
|
|
69
89
|
The output channels is a list of (downsample_factor, depth) that corresponds
|
|
@@ -3,12 +3,17 @@
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
+
from einops import rearrange
|
|
6
7
|
|
|
8
|
+
from rslearn.train.model_context import ModelContext, RasterImage
|
|
7
9
|
|
|
8
|
-
|
|
9
|
-
"""SimpleTimeSeries wraps another encoder and applies it on an image time series.
|
|
10
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
10
11
|
|
|
11
|
-
|
|
12
|
+
|
|
13
|
+
class SimpleTimeSeries(FeatureExtractor):
|
|
14
|
+
"""SimpleTimeSeries wraps another FeatureExtractor and applies it on an image time series.
|
|
15
|
+
|
|
16
|
+
It independently applies the other FeatureExtractor on each image in the time series to
|
|
12
17
|
extract feature maps. It then provides a few ways to combine the features into one
|
|
13
18
|
final feature map:
|
|
14
19
|
- Temporal max pooling.
|
|
@@ -19,17 +24,21 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
19
24
|
|
|
20
25
|
def __init__(
|
|
21
26
|
self,
|
|
22
|
-
encoder:
|
|
23
|
-
image_channels: int,
|
|
27
|
+
encoder: FeatureExtractor,
|
|
28
|
+
image_channels: int | None = None,
|
|
24
29
|
op: str = "max",
|
|
25
30
|
groups: list[list[int]] | None = None,
|
|
26
31
|
num_layers: int | None = None,
|
|
27
|
-
|
|
32
|
+
image_key: str = "image",
|
|
33
|
+
backbone_channels: list[tuple[int, int]] | None = None,
|
|
34
|
+
image_keys: dict[str, int] | None = None,
|
|
35
|
+
) -> None:
|
|
28
36
|
"""Create a new SimpleTimeSeries.
|
|
29
37
|
|
|
30
38
|
Args:
|
|
31
|
-
encoder: the underlying
|
|
32
|
-
function that returns the output channels.
|
|
39
|
+
encoder: the underlying FeatureExtractor. It must provide get_backbone_channels
|
|
40
|
+
function that returns the output channels, or backbone_channels must be set.
|
|
41
|
+
It must output a FeatureMaps.
|
|
33
42
|
image_channels: the number of channels per image of the time series. The
|
|
34
43
|
input should have multiple images concatenated on the channel axis, so
|
|
35
44
|
this parameter is used to distinguish the different images.
|
|
@@ -42,76 +51,101 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
42
51
|
combined before features and the combined after features. groups is a
|
|
43
52
|
list of sets, and each set is a list of image indices.
|
|
44
53
|
num_layers: the number of layers for convrnn, conv3d, and conv1d ops.
|
|
54
|
+
image_key: the key to access the images.
|
|
55
|
+
backbone_channels: manually specify the backbone channels. Can be set if
|
|
56
|
+
the encoder does not provide get_backbone_channels function.
|
|
57
|
+
image_keys: as an alternative to setting image_channels, map from the key
|
|
58
|
+
in input dict to the number of channels per timestep for that modality.
|
|
59
|
+
This way SimpleTimeSeries can be used with multimodal inputs. One of
|
|
60
|
+
image_channels or image_keys must be specified.
|
|
45
61
|
"""
|
|
62
|
+
if (image_channels is None and image_keys is None) or (
|
|
63
|
+
image_channels is not None and image_keys is not None
|
|
64
|
+
):
|
|
65
|
+
raise ValueError(
|
|
66
|
+
"exactly one of image_channels and image_keys must be specified"
|
|
67
|
+
)
|
|
68
|
+
|
|
46
69
|
super().__init__()
|
|
47
70
|
self.encoder = encoder
|
|
48
71
|
self.image_channels = image_channels
|
|
49
72
|
self.op = op
|
|
50
73
|
self.groups = groups
|
|
74
|
+
self.image_key = image_key
|
|
75
|
+
self.image_keys = image_keys
|
|
51
76
|
|
|
52
|
-
|
|
77
|
+
if backbone_channels is not None:
|
|
78
|
+
out_channels = backbone_channels
|
|
79
|
+
else:
|
|
80
|
+
out_channels = self.encoder.get_backbone_channels()
|
|
53
81
|
if self.groups:
|
|
54
82
|
self.num_groups = len(self.groups)
|
|
55
83
|
else:
|
|
56
84
|
self.num_groups = 1
|
|
57
85
|
|
|
58
|
-
if self.op
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
torch.nn.ReLU(inplace=True),
|
|
68
|
-
)
|
|
69
|
-
]
|
|
70
|
-
for _ in range(num_layers - 1):
|
|
71
|
-
cur_layer.append(
|
|
86
|
+
if self.op in ["convrnn", "conv3d", "conv1d"]:
|
|
87
|
+
if num_layers is None:
|
|
88
|
+
raise ValueError(f"num_layers must be specified for {self.op} op")
|
|
89
|
+
|
|
90
|
+
if self.op == "convrnn":
|
|
91
|
+
rnn_kernel_size = 3
|
|
92
|
+
rnn_layers = []
|
|
93
|
+
for _, count in out_channels:
|
|
94
|
+
cur_layer = [
|
|
72
95
|
torch.nn.Sequential(
|
|
73
96
|
torch.nn.Conv2d(
|
|
74
|
-
count, count, rnn_kernel_size, padding="same"
|
|
97
|
+
2 * count, count, rnn_kernel_size, padding="same"
|
|
75
98
|
),
|
|
76
99
|
torch.nn.ReLU(inplace=True),
|
|
77
100
|
)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
101
|
+
]
|
|
102
|
+
for _ in range(num_layers - 1):
|
|
103
|
+
cur_layer.append(
|
|
104
|
+
torch.nn.Sequential(
|
|
105
|
+
torch.nn.Conv2d(
|
|
106
|
+
count, count, rnn_kernel_size, padding="same"
|
|
107
|
+
),
|
|
108
|
+
torch.nn.ReLU(inplace=True),
|
|
109
|
+
)
|
|
110
|
+
)
|
|
111
|
+
cur_layer = torch.nn.Sequential(*cur_layer)
|
|
112
|
+
rnn_layers.append(cur_layer)
|
|
113
|
+
self.rnn_layers = torch.nn.ModuleList(rnn_layers)
|
|
114
|
+
|
|
115
|
+
elif self.op == "conv3d":
|
|
116
|
+
conv3d_layers = []
|
|
117
|
+
for _, count in out_channels:
|
|
118
|
+
cur_layer = [
|
|
119
|
+
torch.nn.Sequential(
|
|
120
|
+
torch.nn.Conv3d(
|
|
121
|
+
count, count, 3, padding=1, stride=(2, 1, 1)
|
|
122
|
+
),
|
|
123
|
+
torch.nn.ReLU(inplace=True),
|
|
124
|
+
)
|
|
125
|
+
for _ in range(num_layers)
|
|
126
|
+
]
|
|
127
|
+
cur_layer = torch.nn.Sequential(*cur_layer)
|
|
128
|
+
conv3d_layers.append(cur_layer)
|
|
129
|
+
self.conv3d_layers = torch.nn.ModuleList(conv3d_layers)
|
|
130
|
+
|
|
131
|
+
elif self.op == "conv1d":
|
|
132
|
+
conv1d_layers = []
|
|
133
|
+
for _, count in out_channels:
|
|
134
|
+
cur_layer = [
|
|
135
|
+
torch.nn.Sequential(
|
|
136
|
+
torch.nn.Conv1d(count, count, 3, padding=1, stride=2),
|
|
137
|
+
torch.nn.ReLU(inplace=True),
|
|
138
|
+
)
|
|
139
|
+
for _ in range(num_layers)
|
|
140
|
+
]
|
|
141
|
+
cur_layer = torch.nn.Sequential(*cur_layer)
|
|
142
|
+
conv1d_layers.append(cur_layer)
|
|
143
|
+
self.conv1d_layers = torch.nn.ModuleList(conv1d_layers)
|
|
110
144
|
|
|
111
145
|
else:
|
|
112
146
|
assert self.op in ["max", "mean"]
|
|
113
147
|
|
|
114
|
-
def get_backbone_channels(self):
|
|
148
|
+
def get_backbone_channels(self) -> list:
|
|
115
149
|
"""Returns the output channels of this model when used as a backbone.
|
|
116
150
|
|
|
117
151
|
The output channels is a list of (downsample_factor, depth) that corresponds
|
|
@@ -128,27 +162,105 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
128
162
|
out_channels.append((downsample_factor, depth * self.num_groups))
|
|
129
163
|
return out_channels
|
|
130
164
|
|
|
165
|
+
def _get_batched_images(
|
|
166
|
+
self, input_dicts: list[dict[str, Any]], image_key: str, image_channels: int
|
|
167
|
+
) -> list[RasterImage]:
|
|
168
|
+
"""Collect and reshape images across input dicts.
|
|
169
|
+
|
|
170
|
+
The BTCHW image time series are reshaped to (B*T)CHW so they can be passed to
|
|
171
|
+
the forward pass of a per-image (unitemporal) model.
|
|
172
|
+
"""
|
|
173
|
+
images = torch.stack(
|
|
174
|
+
[input_dict[image_key].image for input_dict in input_dicts], dim=0
|
|
175
|
+
) # B, C, T, H, W
|
|
176
|
+
timestamps = [input_dict[image_key].timestamps for input_dict in input_dicts]
|
|
177
|
+
# if image channels is not equal to the actual number of channels, then
|
|
178
|
+
# then every N images should be batched together. For example, if the
|
|
179
|
+
# number of input channels c == 2, and image_channels == 4, then we
|
|
180
|
+
# want to pass 2 timesteps to the model.
|
|
181
|
+
# TODO is probably to make this behaviour clearer but lets leave it like
|
|
182
|
+
# this for now to not break things.
|
|
183
|
+
num_timesteps = images.shape[1] // image_channels
|
|
184
|
+
batched_timesteps = images.shape[2] // num_timesteps
|
|
185
|
+
images = rearrange(
|
|
186
|
+
images,
|
|
187
|
+
"b c (b_t k_t) h w -> (b b_t) c k_t h w",
|
|
188
|
+
b_t=batched_timesteps,
|
|
189
|
+
k_t=num_timesteps,
|
|
190
|
+
)
|
|
191
|
+
if timestamps[0] is None:
|
|
192
|
+
new_timestamps = [None] * images.shape[0]
|
|
193
|
+
else:
|
|
194
|
+
# we also need to split the timestamps
|
|
195
|
+
new_timestamps = []
|
|
196
|
+
for t in timestamps:
|
|
197
|
+
for i in range(batched_timesteps):
|
|
198
|
+
new_timestamps.append(
|
|
199
|
+
t[i * num_timesteps : (i + 1) * num_timesteps]
|
|
200
|
+
)
|
|
201
|
+
return [
|
|
202
|
+
RasterImage(image=image, timestamps=timestamps)
|
|
203
|
+
for image, timestamps in zip(images, new_timestamps)
|
|
204
|
+
] # C, T, H, W
|
|
205
|
+
|
|
131
206
|
def forward(
|
|
132
|
-
self,
|
|
133
|
-
|
|
207
|
+
self,
|
|
208
|
+
context: ModelContext,
|
|
209
|
+
) -> FeatureMaps:
|
|
134
210
|
"""Compute outputs from the backbone.
|
|
135
211
|
|
|
136
|
-
|
|
137
|
-
|
|
212
|
+
Args:
|
|
213
|
+
context: the model context. Input dicts must include "image" key containing the image time
|
|
138
214
|
series to process (with images concatenated on the channel dimension).
|
|
139
|
-
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
the FeatureMaps aggregated temporally.
|
|
140
218
|
"""
|
|
141
219
|
# First get features of each image.
|
|
142
220
|
# To do so, we need to split up each grouped image into its component images (which have had their channels stacked).
|
|
143
|
-
|
|
144
|
-
n_batch =
|
|
145
|
-
n_images
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
221
|
+
batched_inputs: list[dict[str, Any]] | None = None
|
|
222
|
+
n_batch = len(context.inputs)
|
|
223
|
+
n_images: int | None = None
|
|
224
|
+
|
|
225
|
+
if self.image_keys is not None:
|
|
226
|
+
for image_key, image_channels in self.image_keys.items():
|
|
227
|
+
batched_images = self._get_batched_images(
|
|
228
|
+
context.inputs, image_key, image_channels
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
if batched_inputs is None:
|
|
232
|
+
batched_inputs = [{} for _ in batched_images]
|
|
233
|
+
n_images = len(batched_images) // n_batch
|
|
234
|
+
elif n_images != len(batched_images) // n_batch:
|
|
235
|
+
raise ValueError(
|
|
236
|
+
"expected all modalities to have the same number of timesteps"
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
for i, image in enumerate(batched_images):
|
|
240
|
+
batched_inputs[i][image_key] = image
|
|
241
|
+
|
|
242
|
+
else:
|
|
243
|
+
assert self.image_channels is not None
|
|
244
|
+
batched_images = self._get_batched_images(
|
|
245
|
+
context.inputs, self.image_key, self.image_channels
|
|
246
|
+
)
|
|
247
|
+
batched_inputs = [{self.image_key: image} for image in batched_images]
|
|
248
|
+
n_images = len(batched_images) // n_batch
|
|
249
|
+
|
|
250
|
+
assert n_images is not None
|
|
251
|
+
# Now we can apply the underlying FeatureExtractor.
|
|
252
|
+
# Its output must be a FeatureMaps.
|
|
253
|
+
assert batched_inputs is not None
|
|
254
|
+
encoder_output = self.encoder(
|
|
255
|
+
ModelContext(
|
|
256
|
+
inputs=batched_inputs,
|
|
257
|
+
metadatas=context.metadatas,
|
|
258
|
+
)
|
|
150
259
|
)
|
|
151
|
-
|
|
260
|
+
if not isinstance(encoder_output, FeatureMaps):
|
|
261
|
+
raise ValueError(
|
|
262
|
+
"output of underlying FeatureExtractor in SimpleTimeSeries must be a FeatureMaps"
|
|
263
|
+
)
|
|
152
264
|
all_features = [
|
|
153
265
|
feat_map.reshape(
|
|
154
266
|
n_batch,
|
|
@@ -157,9 +269,8 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
157
269
|
feat_map.shape[2],
|
|
158
270
|
feat_map.shape[3],
|
|
159
271
|
)
|
|
160
|
-
for feat_map in
|
|
272
|
+
for feat_map in encoder_output.feature_maps
|
|
161
273
|
]
|
|
162
|
-
|
|
163
274
|
# Groups defaults to flattening all the feature maps.
|
|
164
275
|
groups = self.groups
|
|
165
276
|
if not groups:
|
|
@@ -171,13 +282,13 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
171
282
|
for feature_idx in range(len(all_features)):
|
|
172
283
|
aggregated_features = []
|
|
173
284
|
for group in groups:
|
|
174
|
-
|
|
285
|
+
group_features_list = []
|
|
175
286
|
for image_idx in group:
|
|
176
|
-
|
|
287
|
+
group_features_list.append(
|
|
177
288
|
all_features[feature_idx][:, image_idx, :, :, :]
|
|
178
289
|
)
|
|
179
290
|
# Resulting group features are (depth, batch, C, height, width).
|
|
180
|
-
group_features = torch.stack(
|
|
291
|
+
group_features = torch.stack(group_features_list, dim=0)
|
|
181
292
|
|
|
182
293
|
if self.op == "max":
|
|
183
294
|
group_features = torch.amax(group_features, dim=0)
|
|
@@ -213,7 +324,7 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
213
324
|
.permute(0, 3, 1, 2)
|
|
214
325
|
)
|
|
215
326
|
else:
|
|
216
|
-
raise
|
|
327
|
+
raise ValueError(f"unknown aggregation op {self.op}")
|
|
217
328
|
|
|
218
329
|
aggregated_features.append(group_features)
|
|
219
330
|
|
|
@@ -222,4 +333,4 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
222
333
|
|
|
223
334
|
output_features.append(aggregated_features)
|
|
224
335
|
|
|
225
|
-
return output_features
|
|
336
|
+
return FeatureMaps(output_features)
|
rslearn/models/singletask.py
CHANGED
|
@@ -4,6 +4,10 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext, ModelOutput
|
|
8
|
+
|
|
9
|
+
from .component import FeatureExtractor, IntermediateComponent, Predictor
|
|
10
|
+
|
|
7
11
|
|
|
8
12
|
class SingleTaskModel(torch.nn.Module):
|
|
9
13
|
"""Standard model wrapper.
|
|
@@ -14,34 +18,41 @@ class SingleTaskModel(torch.nn.Module):
|
|
|
14
18
|
outputs and targets from the last module (which also receives the targets).
|
|
15
19
|
"""
|
|
16
20
|
|
|
17
|
-
def __init__(
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
encoder: list[FeatureExtractor | IntermediateComponent],
|
|
24
|
+
decoder: list[IntermediateComponent | Predictor],
|
|
25
|
+
):
|
|
18
26
|
"""Initialize a new SingleTaskModel.
|
|
19
27
|
|
|
20
28
|
Args:
|
|
21
|
-
encoder: modules to compute intermediate feature representations.
|
|
22
|
-
|
|
29
|
+
encoder: modules to compute intermediate feature representations. The first
|
|
30
|
+
module must be a FeatureExtractor, and following modules must be
|
|
31
|
+
IntermediateComponents.
|
|
32
|
+
decoder: modules to compute outputs and loss. The last module must be a
|
|
33
|
+
Predictor, while the previous modules must be IntermediateComponents.
|
|
23
34
|
"""
|
|
24
35
|
super().__init__()
|
|
25
|
-
self.encoder = torch.nn.
|
|
36
|
+
self.encoder = torch.nn.ModuleList(encoder)
|
|
26
37
|
self.decoder = torch.nn.ModuleList(decoder)
|
|
27
38
|
|
|
28
39
|
def forward(
|
|
29
40
|
self,
|
|
30
|
-
|
|
41
|
+
context: ModelContext,
|
|
31
42
|
targets: list[dict[str, Any]] | None = None,
|
|
32
|
-
) ->
|
|
43
|
+
) -> ModelOutput:
|
|
33
44
|
"""Apply the sequence of modules on the inputs.
|
|
34
45
|
|
|
35
46
|
Args:
|
|
36
|
-
|
|
47
|
+
context: the model context.
|
|
37
48
|
targets: optional list of target dicts
|
|
38
49
|
|
|
39
50
|
Returns:
|
|
40
|
-
|
|
51
|
+
the model output.
|
|
41
52
|
"""
|
|
42
|
-
|
|
43
|
-
|
|
53
|
+
cur = self.encoder[0](context)
|
|
54
|
+
for module in self.encoder[1:]:
|
|
55
|
+
cur = module(cur, context)
|
|
44
56
|
for module in self.decoder[:-1]:
|
|
45
|
-
cur = module(cur,
|
|
46
|
-
|
|
47
|
-
return self.decoder[-1](cur, inputs, targets)
|
|
57
|
+
cur = module(cur, context)
|
|
58
|
+
return self.decoder[-1](cur, context, targets)
|