rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
"""Terramind models."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from einops import rearrange
|
|
9
|
+
from terratorch.registry import BACKBONE_REGISTRY
|
|
10
|
+
|
|
11
|
+
from rslearn.train.model_context import ModelContext
|
|
12
|
+
from rslearn.train.transforms.transform import Transform
|
|
13
|
+
|
|
14
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# TerraMind v1 provides two sizes: base and large
|
|
18
|
+
class TerramindSize(str, Enum):
|
|
19
|
+
"""Size of the Terramind model."""
|
|
20
|
+
|
|
21
|
+
BASE = "base"
|
|
22
|
+
LARGE = "large"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# Pretraining image size for Terramind
|
|
26
|
+
IMAGE_SIZE = 224
|
|
27
|
+
# Default patch size for Terramind
|
|
28
|
+
PATCH_SIZE = 16
|
|
29
|
+
|
|
30
|
+
# Modalities supported by Terramind
|
|
31
|
+
# S2L1C: Sentinel-2 Level 1C (Top-of-atmosphere reflectance), range: 1000 – 11000 DN
|
|
32
|
+
# S2L2A: Sentinel-2 Level 2A (Bottom-of-atmosphere reflectance), range: 1000 – 11000 DN
|
|
33
|
+
# S1GRD: Sentinel-1 GRD (Calibrated SAR backscatter), range: -50 – +10 dB
|
|
34
|
+
# S1RTC: Sentinel-1 RTC (Radiometrically terrain corrected), range: -50 – +10 dB
|
|
35
|
+
# RGB: Processed RGB images based on S2L2A, range: 0-255
|
|
36
|
+
# DEM: Digital Elevation Model (Copernicus DEM, 30m), range: -400 – 8800 meters
|
|
37
|
+
|
|
38
|
+
# More details in the TerraMesh paper: https://arxiv.org/pdf/2504.11172v1
|
|
39
|
+
TERRAMIND_MODALITIES = ["S2L1C", "S2L2A", "S1GRD", "S1RTC", "RGB", "DEM"]
|
|
40
|
+
|
|
41
|
+
# TerraMind band orders and standardization values
|
|
42
|
+
PRETRAINED_BANDS = {
|
|
43
|
+
"S2L2A": {
|
|
44
|
+
"B01": [1390.458, 2106.761],
|
|
45
|
+
"B02": [1503.317, 2141.107],
|
|
46
|
+
"B03": [1718.197, 2038.973],
|
|
47
|
+
"B04": [1853.910, 2134.138],
|
|
48
|
+
"B05": [2199.100, 2085.321],
|
|
49
|
+
"B06": [2779.975, 1889.926],
|
|
50
|
+
"B07": [2987.011, 1820.257],
|
|
51
|
+
"B08": [3083.234, 1871.918],
|
|
52
|
+
"B8A": [3132.220, 1753.829],
|
|
53
|
+
"B09": [3162.988, 1797.379],
|
|
54
|
+
"B11": [2424.884, 1434.261],
|
|
55
|
+
"B12": [1857.648, 1334.311],
|
|
56
|
+
},
|
|
57
|
+
"S2L1C": {
|
|
58
|
+
"B01": [2357.089, 1624.683],
|
|
59
|
+
"B02": [2137.385, 1675.806],
|
|
60
|
+
"B03": [2018.788, 1557.708],
|
|
61
|
+
"B04": [2082.986, 1833.702],
|
|
62
|
+
"B05": [2295.651, 1823.738],
|
|
63
|
+
"B06": [2854.537, 1733.977],
|
|
64
|
+
"B07": [3122.849, 1732.131],
|
|
65
|
+
"B08": [3040.560, 1679.732],
|
|
66
|
+
"B8A": [3306.481, 1727.26],
|
|
67
|
+
"B09": [1473.847, 1024.687],
|
|
68
|
+
"B10": [506.070, 442.165],
|
|
69
|
+
"B11": [2472.825, 1331.411],
|
|
70
|
+
"B12": [1838.929, 1160.419],
|
|
71
|
+
},
|
|
72
|
+
"RGB": {
|
|
73
|
+
"Red": [87.271, 58.767],
|
|
74
|
+
"Green": [80.931, 47.663],
|
|
75
|
+
"Blue": [66.667, 42.631],
|
|
76
|
+
},
|
|
77
|
+
"S1GRD": {
|
|
78
|
+
"vv": [-12.599, 5.195],
|
|
79
|
+
"vh": [-20.293, 5.890],
|
|
80
|
+
},
|
|
81
|
+
"S1RTC": {
|
|
82
|
+
"vv": [-10.93, 4.391],
|
|
83
|
+
"vh": [-17.329, 4.459],
|
|
84
|
+
},
|
|
85
|
+
"DEM": {
|
|
86
|
+
"DEM": [670.665, 951.272],
|
|
87
|
+
},
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class Terramind(FeatureExtractor):
|
|
92
|
+
"""Terramind backbones."""
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
model_size: TerramindSize,
|
|
97
|
+
modalities: list[str] = ["S2L2A"],
|
|
98
|
+
do_resizing: bool = False,
|
|
99
|
+
) -> None:
|
|
100
|
+
"""Initialize the Terramind model.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
model_size: The size of the Terramind model.
|
|
104
|
+
modalities: The modalities to use.
|
|
105
|
+
do_resizing: Whether to resize the input images to the pretraining resolution.
|
|
106
|
+
"""
|
|
107
|
+
super().__init__()
|
|
108
|
+
|
|
109
|
+
# Check if all modalities are valid
|
|
110
|
+
for modality in modalities:
|
|
111
|
+
if modality not in TERRAMIND_MODALITIES:
|
|
112
|
+
raise ValueError(f"Invalid modality: {modality}")
|
|
113
|
+
|
|
114
|
+
if model_size == TerramindSize.BASE:
|
|
115
|
+
self.model = BACKBONE_REGISTRY.build(
|
|
116
|
+
"terramind_v1_base", modalities=modalities, pretrained=True
|
|
117
|
+
)
|
|
118
|
+
elif model_size == TerramindSize.LARGE:
|
|
119
|
+
self.model = BACKBONE_REGISTRY.build(
|
|
120
|
+
"terramind_v1_large", modalities=modalities, pretrained=True
|
|
121
|
+
)
|
|
122
|
+
else:
|
|
123
|
+
raise ValueError(f"Invalid model size: {model_size}")
|
|
124
|
+
|
|
125
|
+
self.model_size = model_size
|
|
126
|
+
self.modalities = modalities
|
|
127
|
+
self.do_resizing = do_resizing
|
|
128
|
+
|
|
129
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
130
|
+
"""Forward pass for the Terramind model.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
context: the model context. Input dicts must include modalities as keys
|
|
134
|
+
which are defined in the self.modalities list.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
a FeatureMaps with one feature map from the encoder, at 1/16 of the input
|
|
138
|
+
resolution.
|
|
139
|
+
"""
|
|
140
|
+
model_inputs = {}
|
|
141
|
+
for modality in self.modalities:
|
|
142
|
+
# We assume the all the inputs include the same modalities
|
|
143
|
+
if modality not in context.inputs[0]:
|
|
144
|
+
continue
|
|
145
|
+
cur = torch.stack(
|
|
146
|
+
[inp[modality].single_ts_to_chw_tensor() for inp in context.inputs],
|
|
147
|
+
dim=0,
|
|
148
|
+
) # (B, C, H, W)
|
|
149
|
+
if self.do_resizing and (
|
|
150
|
+
cur.shape[2] != IMAGE_SIZE or cur.shape[3] != IMAGE_SIZE
|
|
151
|
+
):
|
|
152
|
+
if cur.shape[2] == 1 and cur.shape[3] == 1:
|
|
153
|
+
new_height, new_width = PATCH_SIZE, PATCH_SIZE
|
|
154
|
+
else:
|
|
155
|
+
new_height, new_width = IMAGE_SIZE, IMAGE_SIZE
|
|
156
|
+
cur = F.interpolate(
|
|
157
|
+
cur,
|
|
158
|
+
size=(new_height, new_width),
|
|
159
|
+
mode="bilinear",
|
|
160
|
+
align_corners=False,
|
|
161
|
+
)
|
|
162
|
+
model_inputs[modality] = cur
|
|
163
|
+
|
|
164
|
+
# By default, the patch embeddings are averaged over all modalities to reduce output tokens
|
|
165
|
+
# The output is a list of tensors (B, N, D) from each layer of the transformer
|
|
166
|
+
# We only get the last layer's output
|
|
167
|
+
image_features = self.model(model_inputs)[-1]
|
|
168
|
+
batch_size, num_patches, _ = image_features.shape
|
|
169
|
+
height, width = int(num_patches**0.5), int(num_patches**0.5)
|
|
170
|
+
return FeatureMaps(
|
|
171
|
+
[
|
|
172
|
+
rearrange(
|
|
173
|
+
image_features,
|
|
174
|
+
"b (h w) d -> b d h w",
|
|
175
|
+
b=batch_size,
|
|
176
|
+
h=height,
|
|
177
|
+
w=width,
|
|
178
|
+
)
|
|
179
|
+
]
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
def get_backbone_channels(self) -> list:
|
|
183
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
184
|
+
|
|
185
|
+
The output channels is a list of (patch_size, depth) that corresponds
|
|
186
|
+
to the feature maps that the backbone returns.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
the output channels of the backbone as a list of (patch_size, depth) tuples.
|
|
190
|
+
"""
|
|
191
|
+
if self.model_size == TerramindSize.BASE:
|
|
192
|
+
depth = 768
|
|
193
|
+
elif self.model_size == TerramindSize.LARGE:
|
|
194
|
+
depth = 1024
|
|
195
|
+
else:
|
|
196
|
+
raise ValueError(f"Invalid model size: {self.model_size}")
|
|
197
|
+
return [(PATCH_SIZE, depth)]
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class TerramindNormalize(Transform):
|
|
201
|
+
"""Normalize inputs using Terramind normalization.
|
|
202
|
+
|
|
203
|
+
It will apply normalization to the modalities that are specified in the model configuration.
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
def __init__(self) -> None:
|
|
207
|
+
"""Initialize a new TerramindNormalize."""
|
|
208
|
+
super().__init__()
|
|
209
|
+
|
|
210
|
+
def apply_image(
|
|
211
|
+
self, image: torch.Tensor, means: list[float], stds: list[float]
|
|
212
|
+
) -> torch.Tensor:
|
|
213
|
+
"""Normalize the specified image with Terramind normalization.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
image: the image to normalize.
|
|
217
|
+
means: the means to use for the normalization.
|
|
218
|
+
stds: the standard deviations to use for the normalization.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
The normalized image.
|
|
222
|
+
"""
|
|
223
|
+
images = image.float() # (C, 1, H, W)
|
|
224
|
+
if images.shape[0] % len(means) != 0:
|
|
225
|
+
raise ValueError(
|
|
226
|
+
f"the number of image channels {images.shape[0]} is not multiple of expected number of bands {len(means)}"
|
|
227
|
+
)
|
|
228
|
+
for i in range(images.shape[0]):
|
|
229
|
+
band_idx = i % len(means)
|
|
230
|
+
images[i] = (images[i] - means[band_idx]) / stds[band_idx]
|
|
231
|
+
return images
|
|
232
|
+
|
|
233
|
+
def forward(
|
|
234
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
235
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
236
|
+
"""Normalize the specified image with Terramind normalization.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
input_dict: the input dictionary.
|
|
240
|
+
target_dict: the target dictionary.
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
normalized (input_dicts, target_dicts) tuple
|
|
244
|
+
"""
|
|
245
|
+
for modality in TERRAMIND_MODALITIES:
|
|
246
|
+
if modality not in input_dict:
|
|
247
|
+
continue
|
|
248
|
+
band_info = PRETRAINED_BANDS[modality]
|
|
249
|
+
means = [band_info[band][0] for band in band_info]
|
|
250
|
+
stds = [band_info[band][1] for band in band_info]
|
|
251
|
+
input_dict[modality].image = self.apply_image(
|
|
252
|
+
input_dict[modality].image,
|
|
253
|
+
means,
|
|
254
|
+
stds,
|
|
255
|
+
)
|
|
256
|
+
return input_dict, target_dict
|
rslearn/models/trunk.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""Trunk module for decoder."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from rslearn.log_utils import get_logger
|
|
9
|
+
from rslearn.models.task_embedding import BaseTaskEmbedding
|
|
10
|
+
from rslearn.train.model_context import ModelOutput
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DecoderTrunkLayer(torch.nn.Module, ABC):
|
|
16
|
+
"""Trunk layer for decoder."""
|
|
17
|
+
|
|
18
|
+
def __init__(self) -> None:
|
|
19
|
+
"""Initialize the DecoderTrunkLayer module."""
|
|
20
|
+
super().__init__()
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def forward(
|
|
24
|
+
self, x: torch.Tensor, task_embedding: torch.Tensor | None = None
|
|
25
|
+
) -> dict[str, torch.Tensor]:
|
|
26
|
+
"""Forward pass.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
x: input tensor of shape (batch_size, seq_len, dim)
|
|
30
|
+
task_embedding: task embedding tensor of shape (batch_size, dim), or None
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
dict with key "outputs" (output tensor of shape (batch_size, seq_len, dim))
|
|
34
|
+
and optionally other keys.
|
|
35
|
+
"""
|
|
36
|
+
raise NotImplementedError
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def apply_auxiliary_losses(
|
|
40
|
+
self, trunk_out: dict[str, Any], outs: ModelOutput
|
|
41
|
+
) -> None:
|
|
42
|
+
"""Apply auxiliary losses in-place.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
trunk_out: The output of the trunk.
|
|
46
|
+
outs: The output of the decoders, with key "loss_dict" containing the losses.
|
|
47
|
+
"""
|
|
48
|
+
raise NotImplementedError
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class DecoderTrunk(torch.nn.Module):
|
|
52
|
+
"""Trunk module for decoder, including arbitrary layers plus an optional task embedding."""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
task_embedding: BaseTaskEmbedding | None = None,
|
|
57
|
+
layers: list[DecoderTrunkLayer] | None = None,
|
|
58
|
+
) -> None:
|
|
59
|
+
"""Initialize the DecoderTrunk module.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
task_embedding: Task-specific embedding module, or None if not using task embedding.
|
|
63
|
+
layers: List of other shared layers. The first one should expect a
|
|
64
|
+
B x T x C tensor, and the last should output a B x T x C tensor.
|
|
65
|
+
All layers must output a dict with key "outputs" (output tensor of shape
|
|
66
|
+
(B, T, C)) and optionally other keys.
|
|
67
|
+
"""
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.layers = torch.nn.ModuleList(layers or [])
|
|
70
|
+
self.task_embedding = task_embedding
|
|
71
|
+
|
|
72
|
+
# If we have multiple instances of the same layer class, output keys will get overwritten
|
|
73
|
+
if layers is not None:
|
|
74
|
+
types = [type(layer) for layer in layers]
|
|
75
|
+
if len(set(types)) != len(types):
|
|
76
|
+
logger.warning(
|
|
77
|
+
"Multiple instances of the same layer class found in trunk. "
|
|
78
|
+
"Only the keys from the last instance will be used"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
def register_tasks(self, task_names: list[str]) -> None:
|
|
82
|
+
"""Register tasks.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
task_names: list of task names
|
|
86
|
+
"""
|
|
87
|
+
if self.task_embedding is not None:
|
|
88
|
+
self.task_embedding.register_tasks(task_names)
|
|
89
|
+
|
|
90
|
+
def forward(
|
|
91
|
+
self,
|
|
92
|
+
features: list[torch.tensor],
|
|
93
|
+
inputs: list[dict[str, Any]],
|
|
94
|
+
) -> dict[str, Any]:
|
|
95
|
+
"""Forward pass.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
features: The encoder features, a 1-list of B x C x H x W features.
|
|
99
|
+
inputs: The original inputs to the encoder.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
dict with key "outputs" (output tensor of shape (batch_size, seq_len, dim))
|
|
103
|
+
and optionally other keys from the other layers.
|
|
104
|
+
"""
|
|
105
|
+
embeds = None
|
|
106
|
+
if self.task_embedding is not None:
|
|
107
|
+
embeds = self.task_embedding.compute_embeds(features, inputs)
|
|
108
|
+
features = self.task_embedding(features, inputs, embeds=embeds)
|
|
109
|
+
|
|
110
|
+
if not self.layers:
|
|
111
|
+
return {"outputs": features}
|
|
112
|
+
|
|
113
|
+
assert len(features) == 1, "DecoderTrunk only supports one feature map"
|
|
114
|
+
x = torch.einsum("bchw->bhwc", features[0])
|
|
115
|
+
x = torch.flatten(x, start_dim=1, end_dim=2) # B x T x C, T = HW
|
|
116
|
+
out = {}
|
|
117
|
+
for layer in self.layers:
|
|
118
|
+
layer_out = layer(x, task_embedding=embeds)
|
|
119
|
+
x = layer_out.pop("outputs") # unspecified shape
|
|
120
|
+
out.update(layer_out)
|
|
121
|
+
x = torch.einsum("btc->bct", x) # B x C x T
|
|
122
|
+
x = x.view(*features[0].shape) # B x C x H x W
|
|
123
|
+
|
|
124
|
+
out["outputs"] = [x]
|
|
125
|
+
return out
|
|
126
|
+
|
|
127
|
+
def apply_auxiliary_losses(
|
|
128
|
+
self, trunk_out: dict[str, Any], outs: ModelOutput
|
|
129
|
+
) -> None:
|
|
130
|
+
"""Apply auxiliary losses in-place.
|
|
131
|
+
|
|
132
|
+
Each layer handles its own auxiliary losses, assuming the loss key is `loss_dict`.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
trunk_out: The output of the trunk.
|
|
136
|
+
outs: The output of the decoders.
|
|
137
|
+
"""
|
|
138
|
+
for layer in self.layers:
|
|
139
|
+
layer.apply_auxiliary_losses(trunk_out, outs)
|
rslearn/models/unet.py
CHANGED
|
@@ -3,9 +3,17 @@
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
+
import torch.nn.functional as F
|
|
6
7
|
|
|
8
|
+
from rslearn.train.model_context import ModelContext
|
|
7
9
|
|
|
8
|
-
|
|
10
|
+
from .component import (
|
|
11
|
+
FeatureMaps,
|
|
12
|
+
IntermediateComponent,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class UNetDecoder(IntermediateComponent):
|
|
9
17
|
"""UNet-style decoder.
|
|
10
18
|
|
|
11
19
|
It inputs multi-scale features. Starting from last (lowest resolution) feature map,
|
|
@@ -16,20 +24,30 @@ class UNetDecoder(torch.nn.Module):
|
|
|
16
24
|
def __init__(
|
|
17
25
|
self,
|
|
18
26
|
in_channels: list[tuple[int, int]],
|
|
19
|
-
out_channels: int,
|
|
27
|
+
out_channels: int | None,
|
|
20
28
|
conv_layers_per_resolution: int = 1,
|
|
21
29
|
kernel_size: int = 3,
|
|
22
|
-
|
|
30
|
+
num_channels: dict[int, int] = {},
|
|
31
|
+
target_resolution_factor: int = 1,
|
|
32
|
+
original_size_to_interpolate: tuple[int, int] | None = None,
|
|
33
|
+
) -> None:
|
|
23
34
|
"""Initialize a UNetDecoder.
|
|
24
35
|
|
|
25
36
|
Args:
|
|
26
37
|
in_channels: list of (downsample factor, num channels) indicating the
|
|
27
38
|
resolution (1/downsample_factor of input resolution) and number of
|
|
28
39
|
channels in each feature map of the multi-scale features.
|
|
29
|
-
out_channels: channels to output at each pixel
|
|
40
|
+
out_channels: channels to output at each pixel, or None to skip the output
|
|
41
|
+
layer.
|
|
30
42
|
conv_layers_per_resolution: number of convolutional layers to apply after
|
|
31
43
|
each up-sampling operation
|
|
32
44
|
kernel_size: kernel size to use in convolutional layers
|
|
45
|
+
num_channels: override number of output channels to use at different
|
|
46
|
+
downsample factors.
|
|
47
|
+
target_resolution_factor: output features at 1/target_resolution_factor
|
|
48
|
+
relative to the input resolution. The default is 1 which outputs pixel
|
|
49
|
+
level features.
|
|
50
|
+
original_size_to_interpolate: the original size to interpolate the output to.
|
|
33
51
|
"""
|
|
34
52
|
super().__init__()
|
|
35
53
|
|
|
@@ -52,7 +70,7 @@ class UNetDecoder(torch.nn.Module):
|
|
|
52
70
|
]
|
|
53
71
|
)
|
|
54
72
|
channels_by_factor = {factor: channels for factor, channels in in_channels}
|
|
55
|
-
while cur_factor >
|
|
73
|
+
while cur_factor > target_resolution_factor:
|
|
56
74
|
# Add upsampling layer.
|
|
57
75
|
cur_layers.append(torch.nn.Upsample(scale_factor=2))
|
|
58
76
|
cur_factor //= 2
|
|
@@ -62,28 +80,39 @@ class UNetDecoder(torch.nn.Module):
|
|
|
62
80
|
# concatenating with.
|
|
63
81
|
if cur_factor in channels_by_factor:
|
|
64
82
|
layers.append(torch.nn.Sequential(*cur_layers))
|
|
83
|
+
# Number of output channels for this layer can be configured
|
|
84
|
+
# per-resolution by the user, otherwise we default to the feature map
|
|
85
|
+
# channels at the corresponding downsample factor.
|
|
86
|
+
cur_out_channels = num_channels.get(
|
|
87
|
+
cur_factor, channels_by_factor[cur_factor]
|
|
88
|
+
)
|
|
65
89
|
cur_layers = [
|
|
66
90
|
torch.nn.Conv2d(
|
|
67
91
|
in_channels=cur_channels + channels_by_factor[cur_factor],
|
|
68
|
-
out_channels=
|
|
92
|
+
out_channels=cur_out_channels,
|
|
69
93
|
kernel_size=kernel_size,
|
|
70
94
|
padding="same",
|
|
71
95
|
),
|
|
72
96
|
torch.nn.ReLU(inplace=True),
|
|
73
97
|
]
|
|
74
|
-
cur_channels =
|
|
98
|
+
cur_channels = cur_out_channels
|
|
75
99
|
else:
|
|
100
|
+
# Since there is no feature map at the next downsample factor, the
|
|
101
|
+
# default is to keep the same number of channels (but the user can
|
|
102
|
+
# still override it with num_channels).
|
|
103
|
+
cur_out_channels = num_channels.get(cur_factor, cur_channels)
|
|
76
104
|
cur_layers.extend(
|
|
77
105
|
[
|
|
78
106
|
torch.nn.Conv2d(
|
|
79
107
|
in_channels=cur_channels,
|
|
80
|
-
out_channels=
|
|
108
|
+
out_channels=cur_out_channels,
|
|
81
109
|
kernel_size=kernel_size,
|
|
82
110
|
padding="same",
|
|
83
111
|
),
|
|
84
112
|
torch.nn.ReLU(inplace=True),
|
|
85
113
|
]
|
|
86
114
|
)
|
|
115
|
+
cur_channels = cur_out_channels
|
|
87
116
|
|
|
88
117
|
# Add remaining conv layers.
|
|
89
118
|
for _ in range(conv_layers_per_resolution - 1):
|
|
@@ -99,30 +128,47 @@ class UNetDecoder(torch.nn.Module):
|
|
|
99
128
|
]
|
|
100
129
|
)
|
|
101
130
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
131
|
+
if out_channels is not None:
|
|
132
|
+
cur_layers.append(
|
|
133
|
+
torch.nn.Conv2d(
|
|
134
|
+
in_channels=cur_channels,
|
|
135
|
+
out_channels=out_channels,
|
|
136
|
+
kernel_size=kernel_size,
|
|
137
|
+
padding="same",
|
|
138
|
+
),
|
|
139
|
+
)
|
|
110
140
|
layers.append(torch.nn.Sequential(*cur_layers))
|
|
111
141
|
self.layers = torch.nn.ModuleList(layers)
|
|
142
|
+
self.original_size_to_interpolate = original_size_to_interpolate
|
|
112
143
|
|
|
113
|
-
def
|
|
144
|
+
def _resize(self, features: torch.Tensor) -> torch.Tensor:
|
|
145
|
+
"""Interpolate the features to the original size."""
|
|
146
|
+
return F.interpolate(
|
|
147
|
+
features,
|
|
148
|
+
size=self.original_size_to_interpolate,
|
|
149
|
+
mode="bilinear",
|
|
150
|
+
align_corners=False,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
114
154
|
"""Compute output from multi-scale feature map.
|
|
115
155
|
|
|
116
156
|
Args:
|
|
117
|
-
|
|
118
|
-
|
|
157
|
+
intermediates: the output from the previous model component, which must be a FeatureMaps.
|
|
158
|
+
context: the model context.
|
|
119
159
|
|
|
120
160
|
Returns:
|
|
121
|
-
output
|
|
161
|
+
output FeatureMaps consisting of one map. The embedding size is equal to the
|
|
162
|
+
configured out_channels.
|
|
122
163
|
"""
|
|
164
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
165
|
+
raise ValueError("input to UNetDecoder must be a FeatureMaps")
|
|
166
|
+
|
|
123
167
|
# Reverse the features since we will pass them in from lowest resolution to highest.
|
|
124
|
-
in_features = list(reversed(
|
|
168
|
+
in_features = list(reversed(intermediates.feature_maps))
|
|
125
169
|
cur_features = self.layers[0](in_features[0])
|
|
126
170
|
for in_feat, layer in zip(in_features[1:], self.layers[1:]):
|
|
127
171
|
cur_features = layer(torch.cat([cur_features, in_feat], dim=1))
|
|
128
|
-
|
|
172
|
+
if self.original_size_to_interpolate is not None:
|
|
173
|
+
cur_features = self._resize(cur_features)
|
|
174
|
+
return FeatureMaps([cur_features])
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""An upsampling layer."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
8
|
+
|
|
9
|
+
from .component import (
|
|
10
|
+
FeatureMaps,
|
|
11
|
+
IntermediateComponent,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Upsample(IntermediateComponent):
|
|
16
|
+
"""Upsamples each input feature map by the same factor."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
scale_factor: int,
|
|
21
|
+
mode: str = "bilinear",
|
|
22
|
+
):
|
|
23
|
+
"""Initialize an Upsample.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
scale_factor: the upsampling factor, e.g. 2 to double the size.
|
|
27
|
+
mode: "nearest" or "bilinear".
|
|
28
|
+
"""
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.layer = torch.nn.Upsample(scale_factor=scale_factor, mode=mode)
|
|
31
|
+
|
|
32
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
33
|
+
"""Upsample each feature map by scale_factor.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
intermediates: the output from the previous component, which must be a FeatureMaps.
|
|
37
|
+
context: the model context.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
upsampled feature maps.
|
|
41
|
+
"""
|
|
42
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
43
|
+
raise ValueError("input to Upsample must be a FeatureMaps")
|
|
44
|
+
|
|
45
|
+
upsampled_feat_maps = [
|
|
46
|
+
self.layer(feat_map) for feat_map in intermediates.feature_maps
|
|
47
|
+
]
|
|
48
|
+
return FeatureMaps(upsampled_feat_maps)
|