rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,256 @@
1
+ """Terramind models."""
2
+
3
+ from enum import Enum
4
+ from typing import Any
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from terratorch.registry import BACKBONE_REGISTRY
10
+
11
+ from rslearn.train.model_context import ModelContext
12
+ from rslearn.train.transforms.transform import Transform
13
+
14
+ from .component import FeatureExtractor, FeatureMaps
15
+
16
+
17
+ # TerraMind v1 provides two sizes: base and large
18
+ class TerramindSize(str, Enum):
19
+ """Size of the Terramind model."""
20
+
21
+ BASE = "base"
22
+ LARGE = "large"
23
+
24
+
25
+ # Pretraining image size for Terramind
26
+ IMAGE_SIZE = 224
27
+ # Default patch size for Terramind
28
+ PATCH_SIZE = 16
29
+
30
+ # Modalities supported by Terramind
31
+ # S2L1C: Sentinel-2 Level 1C (Top-of-atmosphere reflectance), range: 1000 – 11000 DN
32
+ # S2L2A: Sentinel-2 Level 2A (Bottom-of-atmosphere reflectance), range: 1000 – 11000 DN
33
+ # S1GRD: Sentinel-1 GRD (Calibrated SAR backscatter), range: -50 – +10 dB
34
+ # S1RTC: Sentinel-1 RTC (Radiometrically terrain corrected), range: -50 – +10 dB
35
+ # RGB: Processed RGB images based on S2L2A, range: 0-255
36
+ # DEM: Digital Elevation Model (Copernicus DEM, 30m), range: -400 – 8800 meters
37
+
38
+ # More details in the TerraMesh paper: https://arxiv.org/pdf/2504.11172v1
39
+ TERRAMIND_MODALITIES = ["S2L1C", "S2L2A", "S1GRD", "S1RTC", "RGB", "DEM"]
40
+
41
+ # TerraMind band orders and standardization values
42
+ PRETRAINED_BANDS = {
43
+ "S2L2A": {
44
+ "B01": [1390.458, 2106.761],
45
+ "B02": [1503.317, 2141.107],
46
+ "B03": [1718.197, 2038.973],
47
+ "B04": [1853.910, 2134.138],
48
+ "B05": [2199.100, 2085.321],
49
+ "B06": [2779.975, 1889.926],
50
+ "B07": [2987.011, 1820.257],
51
+ "B08": [3083.234, 1871.918],
52
+ "B8A": [3132.220, 1753.829],
53
+ "B09": [3162.988, 1797.379],
54
+ "B11": [2424.884, 1434.261],
55
+ "B12": [1857.648, 1334.311],
56
+ },
57
+ "S2L1C": {
58
+ "B01": [2357.089, 1624.683],
59
+ "B02": [2137.385, 1675.806],
60
+ "B03": [2018.788, 1557.708],
61
+ "B04": [2082.986, 1833.702],
62
+ "B05": [2295.651, 1823.738],
63
+ "B06": [2854.537, 1733.977],
64
+ "B07": [3122.849, 1732.131],
65
+ "B08": [3040.560, 1679.732],
66
+ "B8A": [3306.481, 1727.26],
67
+ "B09": [1473.847, 1024.687],
68
+ "B10": [506.070, 442.165],
69
+ "B11": [2472.825, 1331.411],
70
+ "B12": [1838.929, 1160.419],
71
+ },
72
+ "RGB": {
73
+ "Red": [87.271, 58.767],
74
+ "Green": [80.931, 47.663],
75
+ "Blue": [66.667, 42.631],
76
+ },
77
+ "S1GRD": {
78
+ "vv": [-12.599, 5.195],
79
+ "vh": [-20.293, 5.890],
80
+ },
81
+ "S1RTC": {
82
+ "vv": [-10.93, 4.391],
83
+ "vh": [-17.329, 4.459],
84
+ },
85
+ "DEM": {
86
+ "DEM": [670.665, 951.272],
87
+ },
88
+ }
89
+
90
+
91
+ class Terramind(FeatureExtractor):
92
+ """Terramind backbones."""
93
+
94
+ def __init__(
95
+ self,
96
+ model_size: TerramindSize,
97
+ modalities: list[str] = ["S2L2A"],
98
+ do_resizing: bool = False,
99
+ ) -> None:
100
+ """Initialize the Terramind model.
101
+
102
+ Args:
103
+ model_size: The size of the Terramind model.
104
+ modalities: The modalities to use.
105
+ do_resizing: Whether to resize the input images to the pretraining resolution.
106
+ """
107
+ super().__init__()
108
+
109
+ # Check if all modalities are valid
110
+ for modality in modalities:
111
+ if modality not in TERRAMIND_MODALITIES:
112
+ raise ValueError(f"Invalid modality: {modality}")
113
+
114
+ if model_size == TerramindSize.BASE:
115
+ self.model = BACKBONE_REGISTRY.build(
116
+ "terramind_v1_base", modalities=modalities, pretrained=True
117
+ )
118
+ elif model_size == TerramindSize.LARGE:
119
+ self.model = BACKBONE_REGISTRY.build(
120
+ "terramind_v1_large", modalities=modalities, pretrained=True
121
+ )
122
+ else:
123
+ raise ValueError(f"Invalid model size: {model_size}")
124
+
125
+ self.model_size = model_size
126
+ self.modalities = modalities
127
+ self.do_resizing = do_resizing
128
+
129
+ def forward(self, context: ModelContext) -> FeatureMaps:
130
+ """Forward pass for the Terramind model.
131
+
132
+ Args:
133
+ context: the model context. Input dicts must include modalities as keys
134
+ which are defined in the self.modalities list.
135
+
136
+ Returns:
137
+ a FeatureMaps with one feature map from the encoder, at 1/16 of the input
138
+ resolution.
139
+ """
140
+ model_inputs = {}
141
+ for modality in self.modalities:
142
+ # We assume the all the inputs include the same modalities
143
+ if modality not in context.inputs[0]:
144
+ continue
145
+ cur = torch.stack(
146
+ [inp[modality].single_ts_to_chw_tensor() for inp in context.inputs],
147
+ dim=0,
148
+ ) # (B, C, H, W)
149
+ if self.do_resizing and (
150
+ cur.shape[2] != IMAGE_SIZE or cur.shape[3] != IMAGE_SIZE
151
+ ):
152
+ if cur.shape[2] == 1 and cur.shape[3] == 1:
153
+ new_height, new_width = PATCH_SIZE, PATCH_SIZE
154
+ else:
155
+ new_height, new_width = IMAGE_SIZE, IMAGE_SIZE
156
+ cur = F.interpolate(
157
+ cur,
158
+ size=(new_height, new_width),
159
+ mode="bilinear",
160
+ align_corners=False,
161
+ )
162
+ model_inputs[modality] = cur
163
+
164
+ # By default, the patch embeddings are averaged over all modalities to reduce output tokens
165
+ # The output is a list of tensors (B, N, D) from each layer of the transformer
166
+ # We only get the last layer's output
167
+ image_features = self.model(model_inputs)[-1]
168
+ batch_size, num_patches, _ = image_features.shape
169
+ height, width = int(num_patches**0.5), int(num_patches**0.5)
170
+ return FeatureMaps(
171
+ [
172
+ rearrange(
173
+ image_features,
174
+ "b (h w) d -> b d h w",
175
+ b=batch_size,
176
+ h=height,
177
+ w=width,
178
+ )
179
+ ]
180
+ )
181
+
182
+ def get_backbone_channels(self) -> list:
183
+ """Returns the output channels of this model when used as a backbone.
184
+
185
+ The output channels is a list of (patch_size, depth) that corresponds
186
+ to the feature maps that the backbone returns.
187
+
188
+ Returns:
189
+ the output channels of the backbone as a list of (patch_size, depth) tuples.
190
+ """
191
+ if self.model_size == TerramindSize.BASE:
192
+ depth = 768
193
+ elif self.model_size == TerramindSize.LARGE:
194
+ depth = 1024
195
+ else:
196
+ raise ValueError(f"Invalid model size: {self.model_size}")
197
+ return [(PATCH_SIZE, depth)]
198
+
199
+
200
+ class TerramindNormalize(Transform):
201
+ """Normalize inputs using Terramind normalization.
202
+
203
+ It will apply normalization to the modalities that are specified in the model configuration.
204
+ """
205
+
206
+ def __init__(self) -> None:
207
+ """Initialize a new TerramindNormalize."""
208
+ super().__init__()
209
+
210
+ def apply_image(
211
+ self, image: torch.Tensor, means: list[float], stds: list[float]
212
+ ) -> torch.Tensor:
213
+ """Normalize the specified image with Terramind normalization.
214
+
215
+ Args:
216
+ image: the image to normalize.
217
+ means: the means to use for the normalization.
218
+ stds: the standard deviations to use for the normalization.
219
+
220
+ Returns:
221
+ The normalized image.
222
+ """
223
+ images = image.float() # (C, 1, H, W)
224
+ if images.shape[0] % len(means) != 0:
225
+ raise ValueError(
226
+ f"the number of image channels {images.shape[0]} is not multiple of expected number of bands {len(means)}"
227
+ )
228
+ for i in range(images.shape[0]):
229
+ band_idx = i % len(means)
230
+ images[i] = (images[i] - means[band_idx]) / stds[band_idx]
231
+ return images
232
+
233
+ def forward(
234
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
235
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
236
+ """Normalize the specified image with Terramind normalization.
237
+
238
+ Args:
239
+ input_dict: the input dictionary.
240
+ target_dict: the target dictionary.
241
+
242
+ Returns:
243
+ normalized (input_dicts, target_dicts) tuple
244
+ """
245
+ for modality in TERRAMIND_MODALITIES:
246
+ if modality not in input_dict:
247
+ continue
248
+ band_info = PRETRAINED_BANDS[modality]
249
+ means = [band_info[band][0] for band in band_info]
250
+ stds = [band_info[band][1] for band in band_info]
251
+ input_dict[modality].image = self.apply_image(
252
+ input_dict[modality].image,
253
+ means,
254
+ stds,
255
+ )
256
+ return input_dict, target_dict
@@ -0,0 +1,139 @@
1
+ """Trunk module for decoder."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any
5
+
6
+ import torch
7
+
8
+ from rslearn.log_utils import get_logger
9
+ from rslearn.models.task_embedding import BaseTaskEmbedding
10
+ from rslearn.train.model_context import ModelOutput
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ class DecoderTrunkLayer(torch.nn.Module, ABC):
16
+ """Trunk layer for decoder."""
17
+
18
+ def __init__(self) -> None:
19
+ """Initialize the DecoderTrunkLayer module."""
20
+ super().__init__()
21
+
22
+ @abstractmethod
23
+ def forward(
24
+ self, x: torch.Tensor, task_embedding: torch.Tensor | None = None
25
+ ) -> dict[str, torch.Tensor]:
26
+ """Forward pass.
27
+
28
+ Args:
29
+ x: input tensor of shape (batch_size, seq_len, dim)
30
+ task_embedding: task embedding tensor of shape (batch_size, dim), or None
31
+
32
+ Returns:
33
+ dict with key "outputs" (output tensor of shape (batch_size, seq_len, dim))
34
+ and optionally other keys.
35
+ """
36
+ raise NotImplementedError
37
+
38
+ @abstractmethod
39
+ def apply_auxiliary_losses(
40
+ self, trunk_out: dict[str, Any], outs: ModelOutput
41
+ ) -> None:
42
+ """Apply auxiliary losses in-place.
43
+
44
+ Args:
45
+ trunk_out: The output of the trunk.
46
+ outs: The output of the decoders, with key "loss_dict" containing the losses.
47
+ """
48
+ raise NotImplementedError
49
+
50
+
51
+ class DecoderTrunk(torch.nn.Module):
52
+ """Trunk module for decoder, including arbitrary layers plus an optional task embedding."""
53
+
54
+ def __init__(
55
+ self,
56
+ task_embedding: BaseTaskEmbedding | None = None,
57
+ layers: list[DecoderTrunkLayer] | None = None,
58
+ ) -> None:
59
+ """Initialize the DecoderTrunk module.
60
+
61
+ Args:
62
+ task_embedding: Task-specific embedding module, or None if not using task embedding.
63
+ layers: List of other shared layers. The first one should expect a
64
+ B x T x C tensor, and the last should output a B x T x C tensor.
65
+ All layers must output a dict with key "outputs" (output tensor of shape
66
+ (B, T, C)) and optionally other keys.
67
+ """
68
+ super().__init__()
69
+ self.layers = torch.nn.ModuleList(layers or [])
70
+ self.task_embedding = task_embedding
71
+
72
+ # If we have multiple instances of the same layer class, output keys will get overwritten
73
+ if layers is not None:
74
+ types = [type(layer) for layer in layers]
75
+ if len(set(types)) != len(types):
76
+ logger.warning(
77
+ "Multiple instances of the same layer class found in trunk. "
78
+ "Only the keys from the last instance will be used"
79
+ )
80
+
81
+ def register_tasks(self, task_names: list[str]) -> None:
82
+ """Register tasks.
83
+
84
+ Args:
85
+ task_names: list of task names
86
+ """
87
+ if self.task_embedding is not None:
88
+ self.task_embedding.register_tasks(task_names)
89
+
90
+ def forward(
91
+ self,
92
+ features: list[torch.tensor],
93
+ inputs: list[dict[str, Any]],
94
+ ) -> dict[str, Any]:
95
+ """Forward pass.
96
+
97
+ Args:
98
+ features: The encoder features, a 1-list of B x C x H x W features.
99
+ inputs: The original inputs to the encoder.
100
+
101
+ Returns:
102
+ dict with key "outputs" (output tensor of shape (batch_size, seq_len, dim))
103
+ and optionally other keys from the other layers.
104
+ """
105
+ embeds = None
106
+ if self.task_embedding is not None:
107
+ embeds = self.task_embedding.compute_embeds(features, inputs)
108
+ features = self.task_embedding(features, inputs, embeds=embeds)
109
+
110
+ if not self.layers:
111
+ return {"outputs": features}
112
+
113
+ assert len(features) == 1, "DecoderTrunk only supports one feature map"
114
+ x = torch.einsum("bchw->bhwc", features[0])
115
+ x = torch.flatten(x, start_dim=1, end_dim=2) # B x T x C, T = HW
116
+ out = {}
117
+ for layer in self.layers:
118
+ layer_out = layer(x, task_embedding=embeds)
119
+ x = layer_out.pop("outputs") # unspecified shape
120
+ out.update(layer_out)
121
+ x = torch.einsum("btc->bct", x) # B x C x T
122
+ x = x.view(*features[0].shape) # B x C x H x W
123
+
124
+ out["outputs"] = [x]
125
+ return out
126
+
127
+ def apply_auxiliary_losses(
128
+ self, trunk_out: dict[str, Any], outs: ModelOutput
129
+ ) -> None:
130
+ """Apply auxiliary losses in-place.
131
+
132
+ Each layer handles its own auxiliary losses, assuming the loss key is `loss_dict`.
133
+
134
+ Args:
135
+ trunk_out: The output of the trunk.
136
+ outs: The output of the decoders.
137
+ """
138
+ for layer in self.layers:
139
+ layer.apply_auxiliary_losses(trunk_out, outs)
rslearn/models/unet.py CHANGED
@@ -3,9 +3,17 @@
3
3
  from typing import Any
4
4
 
5
5
  import torch
6
+ import torch.nn.functional as F
6
7
 
8
+ from rslearn.train.model_context import ModelContext
7
9
 
8
- class UNetDecoder(torch.nn.Module):
10
+ from .component import (
11
+ FeatureMaps,
12
+ IntermediateComponent,
13
+ )
14
+
15
+
16
+ class UNetDecoder(IntermediateComponent):
9
17
  """UNet-style decoder.
10
18
 
11
19
  It inputs multi-scale features. Starting from last (lowest resolution) feature map,
@@ -16,20 +24,30 @@ class UNetDecoder(torch.nn.Module):
16
24
  def __init__(
17
25
  self,
18
26
  in_channels: list[tuple[int, int]],
19
- out_channels: int,
27
+ out_channels: int | None,
20
28
  conv_layers_per_resolution: int = 1,
21
29
  kernel_size: int = 3,
22
- ):
30
+ num_channels: dict[int, int] = {},
31
+ target_resolution_factor: int = 1,
32
+ original_size_to_interpolate: tuple[int, int] | None = None,
33
+ ) -> None:
23
34
  """Initialize a UNetDecoder.
24
35
 
25
36
  Args:
26
37
  in_channels: list of (downsample factor, num channels) indicating the
27
38
  resolution (1/downsample_factor of input resolution) and number of
28
39
  channels in each feature map of the multi-scale features.
29
- out_channels: channels to output at each pixel.
40
+ out_channels: channels to output at each pixel, or None to skip the output
41
+ layer.
30
42
  conv_layers_per_resolution: number of convolutional layers to apply after
31
43
  each up-sampling operation
32
44
  kernel_size: kernel size to use in convolutional layers
45
+ num_channels: override number of output channels to use at different
46
+ downsample factors.
47
+ target_resolution_factor: output features at 1/target_resolution_factor
48
+ relative to the input resolution. The default is 1 which outputs pixel
49
+ level features.
50
+ original_size_to_interpolate: the original size to interpolate the output to.
33
51
  """
34
52
  super().__init__()
35
53
 
@@ -52,7 +70,7 @@ class UNetDecoder(torch.nn.Module):
52
70
  ]
53
71
  )
54
72
  channels_by_factor = {factor: channels for factor, channels in in_channels}
55
- while cur_factor > 1:
73
+ while cur_factor > target_resolution_factor:
56
74
  # Add upsampling layer.
57
75
  cur_layers.append(torch.nn.Upsample(scale_factor=2))
58
76
  cur_factor //= 2
@@ -62,28 +80,39 @@ class UNetDecoder(torch.nn.Module):
62
80
  # concatenating with.
63
81
  if cur_factor in channels_by_factor:
64
82
  layers.append(torch.nn.Sequential(*cur_layers))
83
+ # Number of output channels for this layer can be configured
84
+ # per-resolution by the user, otherwise we default to the feature map
85
+ # channels at the corresponding downsample factor.
86
+ cur_out_channels = num_channels.get(
87
+ cur_factor, channels_by_factor[cur_factor]
88
+ )
65
89
  cur_layers = [
66
90
  torch.nn.Conv2d(
67
91
  in_channels=cur_channels + channels_by_factor[cur_factor],
68
- out_channels=channels_by_factor[cur_factor],
92
+ out_channels=cur_out_channels,
69
93
  kernel_size=kernel_size,
70
94
  padding="same",
71
95
  ),
72
96
  torch.nn.ReLU(inplace=True),
73
97
  ]
74
- cur_channels = channels_by_factor[cur_factor]
98
+ cur_channels = cur_out_channels
75
99
  else:
100
+ # Since there is no feature map at the next downsample factor, the
101
+ # default is to keep the same number of channels (but the user can
102
+ # still override it with num_channels).
103
+ cur_out_channels = num_channels.get(cur_factor, cur_channels)
76
104
  cur_layers.extend(
77
105
  [
78
106
  torch.nn.Conv2d(
79
107
  in_channels=cur_channels,
80
- out_channels=cur_channels,
108
+ out_channels=cur_out_channels,
81
109
  kernel_size=kernel_size,
82
110
  padding="same",
83
111
  ),
84
112
  torch.nn.ReLU(inplace=True),
85
113
  ]
86
114
  )
115
+ cur_channels = cur_out_channels
87
116
 
88
117
  # Add remaining conv layers.
89
118
  for _ in range(conv_layers_per_resolution - 1):
@@ -99,30 +128,47 @@ class UNetDecoder(torch.nn.Module):
99
128
  ]
100
129
  )
101
130
 
102
- cur_layers.append(
103
- torch.nn.Conv2d(
104
- in_channels=cur_channels,
105
- out_channels=out_channels,
106
- kernel_size=kernel_size,
107
- padding="same",
108
- ),
109
- )
131
+ if out_channels is not None:
132
+ cur_layers.append(
133
+ torch.nn.Conv2d(
134
+ in_channels=cur_channels,
135
+ out_channels=out_channels,
136
+ kernel_size=kernel_size,
137
+ padding="same",
138
+ ),
139
+ )
110
140
  layers.append(torch.nn.Sequential(*cur_layers))
111
141
  self.layers = torch.nn.ModuleList(layers)
142
+ self.original_size_to_interpolate = original_size_to_interpolate
112
143
 
113
- def forward(self, in_features: list[torch.Tensor], inputs: list[dict[str, Any]]):
144
+ def _resize(self, features: torch.Tensor) -> torch.Tensor:
145
+ """Interpolate the features to the original size."""
146
+ return F.interpolate(
147
+ features,
148
+ size=self.original_size_to_interpolate,
149
+ mode="bilinear",
150
+ align_corners=False,
151
+ )
152
+
153
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
114
154
  """Compute output from multi-scale feature map.
115
155
 
116
156
  Args:
117
- in_features: list of feature maps at different resolutions.
118
- inputs: original inputs (ignored).
157
+ intermediates: the output from the previous model component, which must be a FeatureMaps.
158
+ context: the model context.
119
159
 
120
160
  Returns:
121
- output image
161
+ output FeatureMaps consisting of one map. The embedding size is equal to the
162
+ configured out_channels.
122
163
  """
164
+ if not isinstance(intermediates, FeatureMaps):
165
+ raise ValueError("input to UNetDecoder must be a FeatureMaps")
166
+
123
167
  # Reverse the features since we will pass them in from lowest resolution to highest.
124
- in_features = list(reversed(in_features))
168
+ in_features = list(reversed(intermediates.feature_maps))
125
169
  cur_features = self.layers[0](in_features[0])
126
170
  for in_feat, layer in zip(in_features[1:], self.layers[1:]):
127
171
  cur_features = layer(torch.cat([cur_features, in_feat], dim=1))
128
- return cur_features
172
+ if self.original_size_to_interpolate is not None:
173
+ cur_features = self._resize(cur_features)
174
+ return FeatureMaps([cur_features])
@@ -0,0 +1,48 @@
1
+ """An upsampling layer."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+
7
+ from rslearn.train.model_context import ModelContext
8
+
9
+ from .component import (
10
+ FeatureMaps,
11
+ IntermediateComponent,
12
+ )
13
+
14
+
15
+ class Upsample(IntermediateComponent):
16
+ """Upsamples each input feature map by the same factor."""
17
+
18
+ def __init__(
19
+ self,
20
+ scale_factor: int,
21
+ mode: str = "bilinear",
22
+ ):
23
+ """Initialize an Upsample.
24
+
25
+ Args:
26
+ scale_factor: the upsampling factor, e.g. 2 to double the size.
27
+ mode: "nearest" or "bilinear".
28
+ """
29
+ super().__init__()
30
+ self.layer = torch.nn.Upsample(scale_factor=scale_factor, mode=mode)
31
+
32
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
33
+ """Upsample each feature map by scale_factor.
34
+
35
+ Args:
36
+ intermediates: the output from the previous component, which must be a FeatureMaps.
37
+ context: the model context.
38
+
39
+ Returns:
40
+ upsampled feature maps.
41
+ """
42
+ if not isinstance(intermediates, FeatureMaps):
43
+ raise ValueError("input to Upsample must be a FeatureMaps")
44
+
45
+ upsampled_feat_maps = [
46
+ self.layer(feat_map) for feat_map in intermediates.feature_maps
47
+ ]
48
+ return FeatureMaps(upsampled_feat_maps)