rslearn 0.0.17__py3-none-any.whl → 0.0.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. rslearn/config/__init__.py +2 -0
  2. rslearn/config/dataset.py +49 -4
  3. rslearn/dataset/add_windows.py +1 -1
  4. rslearn/dataset/dataset.py +9 -65
  5. rslearn/dataset/materialize.py +5 -5
  6. rslearn/dataset/storage/__init__.py +1 -0
  7. rslearn/dataset/storage/file.py +202 -0
  8. rslearn/dataset/storage/storage.py +140 -0
  9. rslearn/dataset/window.py +26 -80
  10. rslearn/main.py +11 -36
  11. rslearn/models/anysat.py +11 -9
  12. rslearn/models/clay/clay.py +8 -9
  13. rslearn/models/clip.py +18 -15
  14. rslearn/models/component.py +99 -0
  15. rslearn/models/concatenate_features.py +21 -11
  16. rslearn/models/conv.py +15 -8
  17. rslearn/models/croma.py +13 -8
  18. rslearn/models/detr/detr.py +25 -14
  19. rslearn/models/dinov3.py +11 -6
  20. rslearn/models/faster_rcnn.py +19 -9
  21. rslearn/models/feature_center_crop.py +12 -9
  22. rslearn/models/fpn.py +19 -8
  23. rslearn/models/galileo/galileo.py +23 -18
  24. rslearn/models/module_wrapper.py +26 -57
  25. rslearn/models/molmo.py +16 -14
  26. rslearn/models/multitask.py +102 -73
  27. rslearn/models/olmoearth_pretrain/model.py +18 -12
  28. rslearn/models/panopticon.py +8 -7
  29. rslearn/models/pick_features.py +18 -24
  30. rslearn/models/pooling_decoder.py +22 -14
  31. rslearn/models/presto/presto.py +16 -10
  32. rslearn/models/presto/single_file_presto.py +4 -10
  33. rslearn/models/prithvi.py +12 -8
  34. rslearn/models/resize_features.py +21 -7
  35. rslearn/models/sam2_enc.py +11 -9
  36. rslearn/models/satlaspretrain.py +15 -9
  37. rslearn/models/simple_time_series.py +31 -17
  38. rslearn/models/singletask.py +24 -17
  39. rslearn/models/ssl4eo_s12.py +15 -10
  40. rslearn/models/swin.py +22 -13
  41. rslearn/models/terramind.py +24 -7
  42. rslearn/models/trunk.py +6 -3
  43. rslearn/models/unet.py +18 -9
  44. rslearn/models/upsample.py +22 -9
  45. rslearn/train/all_patches_dataset.py +22 -18
  46. rslearn/train/dataset.py +69 -54
  47. rslearn/train/lightning_module.py +51 -32
  48. rslearn/train/model_context.py +54 -0
  49. rslearn/train/prediction_writer.py +111 -41
  50. rslearn/train/tasks/classification.py +34 -15
  51. rslearn/train/tasks/detection.py +24 -31
  52. rslearn/train/tasks/embedding.py +33 -29
  53. rslearn/train/tasks/multi_task.py +7 -7
  54. rslearn/train/tasks/per_pixel_regression.py +41 -19
  55. rslearn/train/tasks/regression.py +38 -21
  56. rslearn/train/tasks/segmentation.py +33 -15
  57. rslearn/train/tasks/task.py +3 -2
  58. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/METADATA +1 -1
  59. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/RECORD +64 -61
  60. rslearn/dataset/index.py +0 -173
  61. rslearn/models/registry.py +0 -22
  62. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/WHEEL +0 -0
  63. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/entry_points.txt +0 -0
  64. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/licenses/LICENSE +0 -0
  65. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/licenses/NOTICE +0 -0
  66. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/top_level.txt +0 -0
rslearn/dataset/window.py CHANGED
@@ -1,20 +1,16 @@
1
1
  """rslearn windows."""
2
2
 
3
- import json
4
3
  from datetime import datetime
5
- from typing import TYPE_CHECKING, Any
4
+ from typing import Any
6
5
 
7
6
  import shapely
8
7
  from upath import UPath
9
8
 
9
+ from rslearn.dataset.storage.storage import WindowStorage
10
10
  from rslearn.log_utils import get_logger
11
11
  from rslearn.utils import Projection, STGeometry
12
- from rslearn.utils.fsspec import open_atomic
13
12
  from rslearn.utils.raster_format import get_bandset_dirname
14
13
 
15
- if TYPE_CHECKING:
16
- from .index import DatasetIndex
17
-
18
14
  logger = get_logger(__name__)
19
15
 
20
16
  LAYERS_DIRECTORY_NAME = "layers"
@@ -138,14 +134,13 @@ class Window:
138
134
 
139
135
  def __init__(
140
136
  self,
141
- path: UPath,
137
+ storage: WindowStorage,
142
138
  group: str,
143
139
  name: str,
144
140
  projection: Projection,
145
141
  bounds: tuple[int, int, int, int],
146
142
  time_range: tuple[datetime, datetime] | None,
147
143
  options: dict[str, Any] = {},
148
- index: "DatasetIndex | None" = None,
149
144
  ) -> None:
150
145
  """Creates a new Window instance.
151
146
 
@@ -153,23 +148,21 @@ class Window:
153
148
  stored in metadata.json.
154
149
 
155
150
  Args:
156
- path: the directory of this window
151
+ storage: the dataset storage for the underlying rslearn dataset.
157
152
  group: the group the window belongs to
158
153
  name: the unique name for this window
159
154
  projection: the projection of the window
160
155
  bounds: the bounds of the window in pixel coordinates
161
156
  time_range: optional time range of the window
162
157
  options: additional options (?)
163
- index: DatasetIndex if it is available
164
158
  """
165
- self.path = path
159
+ self.storage = storage
166
160
  self.group = group
167
161
  self.name = name
168
162
  self.projection = projection
169
163
  self.bounds = bounds
170
164
  self.time_range = time_range
171
165
  self.options = options
172
- self.index = index
173
166
 
174
167
  def get_geometry(self) -> STGeometry:
175
168
  """Computes the STGeometry corresponding to this window."""
@@ -181,29 +174,11 @@ class Window:
181
174
 
182
175
  def load_layer_datas(self) -> dict[str, WindowLayerData]:
183
176
  """Load layer datas describing items in retrieved layers from items.json."""
184
- # Load from index if it is available.
185
- if self.index is not None:
186
- layer_datas = self.index.layer_datas.get(self.name, [])
187
-
188
- else:
189
- items_fname = self.path / "items.json"
190
- if not items_fname.exists():
191
- return {}
192
- with items_fname.open("r") as f:
193
- layer_datas = [
194
- WindowLayerData.deserialize(layer_data)
195
- for layer_data in json.load(f)
196
- ]
197
-
198
- return {layer_data.layer_name: layer_data for layer_data in layer_datas}
177
+ return self.storage.get_layer_datas(self.group, self.name)
199
178
 
200
179
  def save_layer_datas(self, layer_datas: dict[str, WindowLayerData]) -> None:
201
180
  """Save layer datas to items.json."""
202
- json_data = [layer_data.serialize() for layer_data in layer_datas.values()]
203
- items_fname = self.path / "items.json"
204
- logger.info(f"Saving window items to {items_fname}")
205
- with open_atomic(items_fname, "w") as f:
206
- json.dump(json_data, f)
181
+ self.storage.save_layer_datas(self.group, self.name, layer_datas)
207
182
 
208
183
  def list_completed_layers(self) -> list[tuple[str, int]]:
209
184
  """List the layers available for this window that are completed.
@@ -211,18 +186,7 @@ class Window:
211
186
  Returns:
212
187
  a list of (layer_name, group_idx) completed layers.
213
188
  """
214
- layers_directory = self.path / LAYERS_DIRECTORY_NAME
215
- if not layers_directory.exists():
216
- return []
217
-
218
- completed_layers = []
219
- for layer_dir in layers_directory.iterdir():
220
- layer_name, group_idx = get_layer_and_group_from_dir_name(layer_dir.name)
221
- if not self.is_layer_completed(layer_name, group_idx):
222
- continue
223
- completed_layers.append((layer_name, group_idx))
224
-
225
- return completed_layers
189
+ return self.storage.list_completed_layers(self.group, self.name)
226
190
 
227
191
  def get_layer_dir(self, layer_name: str, group_idx: int = 0) -> UPath:
228
192
  """Get the directory containing materialized data for the specified layer.
@@ -235,7 +199,9 @@ class Window:
235
199
  Returns:
236
200
  the path where data is or should be materialized.
237
201
  """
238
- return get_window_layer_dir(self.path, layer_name, group_idx)
202
+ return get_window_layer_dir(
203
+ self.storage.get_window_root(self.group, self.name), layer_name, group_idx
204
+ )
239
205
 
240
206
  def is_layer_completed(self, layer_name: str, group_idx: int = 0) -> bool:
241
207
  """Check whether the specified layer is completed.
@@ -250,14 +216,9 @@ class Window:
250
216
  Returns:
251
217
  whether the layer is completed
252
218
  """
253
- # Use the index to speed up the completed check if it is available.
254
- if self.index is not None:
255
- return (layer_name, group_idx) in self.index.completed_layers.get(
256
- self.name, []
257
- )
258
-
259
- layer_dir = self.get_layer_dir(layer_name, group_idx)
260
- return (layer_dir / "completed").exists()
219
+ return self.storage.is_layer_completed(
220
+ self.group, self.name, layer_name, group_idx
221
+ )
261
222
 
262
223
  def mark_layer_completed(self, layer_name: str, group_idx: int = 0) -> None:
263
224
  """Mark the specified layer completed.
@@ -272,8 +233,7 @@ class Window:
272
233
  layer_name: the layer name.
273
234
  group_idx: the index of the group within the layer.
274
235
  """
275
- layer_dir = self.get_layer_dir(layer_name, group_idx)
276
- (layer_dir / "completed").touch()
236
+ self.storage.mark_layer_completed(self.group, self.name, layer_name, group_idx)
277
237
 
278
238
  def get_raster_dir(
279
239
  self, layer_name: str, bands: list[str], group_idx: int = 0
@@ -289,7 +249,12 @@ class Window:
289
249
  Returns:
290
250
  the directory containing the raster.
291
251
  """
292
- return get_window_raster_dir(self.path, layer_name, bands, group_idx)
252
+ return get_window_raster_dir(
253
+ self.storage.get_window_root(self.group, self.name),
254
+ layer_name,
255
+ bands,
256
+ group_idx,
257
+ )
293
258
 
294
259
  def get_metadata(self) -> dict[str, Any]:
295
260
  """Returns the window metadata dictionary."""
@@ -308,18 +273,14 @@ class Window:
308
273
 
309
274
  def save(self) -> None:
310
275
  """Save the window metadata to its root directory."""
311
- self.path.mkdir(parents=True, exist_ok=True)
312
- metadata_path = self.path / "metadata.json"
313
- logger.debug(f"Saving window metadata to {metadata_path}")
314
- with open_atomic(metadata_path, "w") as f:
315
- json.dump(self.get_metadata(), f)
276
+ self.storage.create_or_update_window(self)
316
277
 
317
278
  @staticmethod
318
- def from_metadata(path: UPath, metadata: dict[str, Any]) -> "Window":
319
- """Create a Window from its path and metadata dictionary.
279
+ def from_metadata(storage: WindowStorage, metadata: dict[str, Any]) -> "Window":
280
+ """Create a Window from the WindowStorage and the window's metadata dictionary.
320
281
 
321
282
  Args:
322
- path: the root directory of the window.
283
+ storage: the WindowStorage for the underlying dataset.
323
284
  metadata: the window metadata.
324
285
 
325
286
  Returns:
@@ -334,7 +295,7 @@ class Window:
334
295
  )
335
296
 
336
297
  return Window(
337
- path=path,
298
+ storage=storage,
338
299
  group=metadata["group"],
339
300
  name=metadata["name"],
340
301
  projection=Projection.deserialize(metadata["projection"]),
@@ -350,21 +311,6 @@ class Window:
350
311
  options=metadata["options"],
351
312
  )
352
313
 
353
- @staticmethod
354
- def load(path: UPath) -> "Window":
355
- """Load a Window from a UPath.
356
-
357
- Args:
358
- path: the root directory of the window
359
-
360
- Returns:
361
- the Window
362
- """
363
- metadata_fname = path / "metadata.json"
364
- with metadata_fname.open("r") as f:
365
- metadata = json.load(f)
366
- return Window.from_metadata(path, metadata)
367
-
368
314
  @staticmethod
369
315
  def get_window_root(ds_path: UPath, group: str, name: str) -> UPath:
370
316
  """Gets the root directory of a window.
rslearn/main.py CHANGED
@@ -27,13 +27,13 @@ from rslearn.dataset.handler_summaries import (
27
27
  PrepareDatasetWindowsSummary,
28
28
  UnknownIngestCounts,
29
29
  )
30
- from rslearn.dataset.index import DatasetIndex
31
30
  from rslearn.dataset.manage import (
32
31
  AttemptsCounter,
33
32
  materialize_dataset_windows,
34
33
  prepare_dataset_windows,
35
34
  retry,
36
35
  )
36
+ from rslearn.dataset.storage.file import FileWindowStorage
37
37
  from rslearn.log_utils import get_logger
38
38
  from rslearn.tile_stores import get_tile_store_with_layer
39
39
  from rslearn.utils import Projection, STGeometry
@@ -315,7 +315,8 @@ def apply_on_windows(
315
315
  load_workers: optional different number of workers to use for loading the
316
316
  windows. If set, workers controls the number of workers to process the
317
317
  jobs, while load_workers controls the number of workers to use for reading
318
- windows from the rslearn dataset.
318
+ windows from the rslearn dataset. Workers is only passed if the window
319
+ storage is FileWindowStorage.
319
320
  batch_size: if workers > 0, the maximum number of windows to pass to the
320
321
  function.
321
322
  jobs_per_process: optional, terminate processes after they have handled this
@@ -336,11 +337,14 @@ def apply_on_windows(
336
337
  else:
337
338
  groups = group
338
339
 
339
- if load_workers is None:
340
- load_workers = workers
341
- windows = dataset.load_windows(
342
- groups=groups, names=names, workers=load_workers, show_progress=True
343
- )
340
+ # Load the windows. We pass workers and show_progress if it is FileWindowStorage.
341
+ kwargs: dict[str, Any] = {}
342
+ if isinstance(dataset.storage, FileWindowStorage):
343
+ if load_workers is None:
344
+ load_workers = workers
345
+ kwargs["workers"] = load_workers
346
+ kwargs["show_progress"] = True
347
+ windows = dataset.load_windows(groups=groups, names=names, **kwargs)
344
348
  logger.info(f"found {len(windows)} windows")
345
349
 
346
350
  if hasattr(f, "get_jobs"):
@@ -798,35 +802,6 @@ def dataset_materialize() -> None:
798
802
  apply_on_windows_args(fn, args)
799
803
 
800
804
 
801
- @register_handler("dataset", "build_index")
802
- def dataset_build_index() -> None:
803
- """Handler for the rslearn dataset build_index command."""
804
- parser = argparse.ArgumentParser(
805
- prog="rslearn dataset build_index",
806
- description=("rslearn dataset build_index: " + "create a dataset index file"),
807
- )
808
- parser.add_argument(
809
- "--root",
810
- type=str,
811
- required=True,
812
- help="Dataset path",
813
- )
814
- parser.add_argument(
815
- "--workers",
816
- type=int,
817
- default=16,
818
- help="Number of workers",
819
- )
820
- args = parser.parse_args(args=sys.argv[3:])
821
- ds_path = UPath(args.root)
822
- dataset = Dataset(ds_path)
823
- index = DatasetIndex.build_index(
824
- dataset=dataset,
825
- workers=args.workers,
826
- )
827
- index.save_index(ds_path)
828
-
829
-
830
805
  @register_handler("model", "fit")
831
806
  def model_fit() -> None:
832
807
  """Handler for rslearn model fit."""
rslearn/models/anysat.py CHANGED
@@ -4,11 +4,13 @@ This code loads the AnySat model from torch hub. See
4
4
  https://github.com/gastruc/AnySat for applicable license and copyright information.
5
5
  """
6
6
 
7
- from typing import Any
8
-
9
7
  import torch
10
8
  from einops import rearrange
11
9
 
10
+ from rslearn.train.model_context import ModelContext
11
+
12
+ from .component import FeatureExtractor, FeatureMaps
13
+
12
14
  # AnySat github: https://github.com/gastruc/AnySat
13
15
  # Modalities and expected resolutions (meters)
14
16
  MODALITY_RESOLUTIONS: dict[str, float] = {
@@ -44,7 +46,7 @@ MODALITY_BANDS: dict[str, list[str]] = {
44
46
  TIME_SERIES_MODALITIES = {"s2", "s1-asc", "s1", "alos", "l7", "l8", "modis"}
45
47
 
46
48
 
47
- class AnySat(torch.nn.Module):
49
+ class AnySat(FeatureExtractor):
48
50
  """AnySat backbone (outputs one feature map)."""
49
51
 
50
52
  def __init__(
@@ -117,17 +119,17 @@ class AnySat(torch.nn.Module):
117
119
  )
118
120
  self._embed_dim = 768 # base width, 'dense' returns 2x
119
121
 
120
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
122
+ def forward(self, context: ModelContext) -> FeatureMaps:
121
123
  """Forward pass for the AnySat model.
122
124
 
123
125
  Args:
124
- inputs: input dicts that must include modalities as keys which are defined in the self.modalities list
126
+ context: the model context. Input dicts must include modalities as keys
127
+ which are defined in the self.modalities list
125
128
 
126
129
  Returns:
127
- List[torch.Tensor]: Single-scale feature tensors from the encoder.
130
+ a FeatureMaps with one feature map at the configured patch size.
128
131
  """
129
- if not inputs:
130
- raise ValueError("empty inputs")
132
+ inputs = context.inputs
131
133
 
132
134
  batch: dict[str, torch.Tensor] = {}
133
135
  spatial_extent: tuple[float, float] | None = None
@@ -192,7 +194,7 @@ class AnySat(torch.nn.Module):
192
194
  kwargs["output_modality"] = self.output_modality
193
195
 
194
196
  features = self.model(batch, **kwargs)
195
- return [rearrange(features, "b h w d -> b d h w")]
197
+ return FeatureMaps([rearrange(features, "b h w d -> b d h w")])
196
198
 
197
199
  def get_backbone_channels(self) -> list:
198
200
  """Returns the output channels of this model when used as a backbone.
@@ -16,6 +16,8 @@ from huggingface_hub import hf_hub_download
16
16
  # from claymodel.module import ClayMAEModule
17
17
  from terratorch.models.backbones.clay_v15.module import ClayMAEModule
18
18
 
19
+ from rslearn.models.component import FeatureExtractor, FeatureMaps
20
+ from rslearn.train.model_context import ModelContext
19
21
  from rslearn.train.transforms.normalize import Normalize
20
22
  from rslearn.train.transforms.transform import Transform
21
23
 
@@ -42,7 +44,7 @@ def get_clay_checkpoint_path(
42
44
  return hf_hub_download(repo_id=repo_id, filename=filename) # nosec B615
43
45
 
44
46
 
45
- class Clay(torch.nn.Module):
47
+ class Clay(FeatureExtractor):
46
48
  """Clay backbones."""
47
49
 
48
50
  def __init__(
@@ -108,23 +110,20 @@ class Clay(torch.nn.Module):
108
110
  image, size=(new_hw, new_hw), mode="bilinear", align_corners=False
109
111
  )
110
112
 
111
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
113
+ def forward(self, context: ModelContext) -> FeatureMaps:
112
114
  """Forward pass for the Clay model.
113
115
 
114
116
  Args:
115
- inputs: input dicts that must include `self.modality` as a key
117
+ context: the model context. Input dicts must include `self.modality` as a key
116
118
 
117
119
  Returns:
118
- List[torch.Tensor]: Single-scale feature tensors from the encoder.
120
+ a FeatureMaps consisting of one feature map, computed by Clay.
119
121
  """
120
- if self.modality not in inputs[0]:
121
- raise ValueError(f"Missing modality {self.modality} in inputs.")
122
-
123
122
  param = next(self.model.parameters())
124
123
  device = param.device
125
124
 
126
125
  chips = torch.stack(
127
- [inp[self.modality] for inp in inputs], dim=0
126
+ [inp[self.modality] for inp in context.inputs], dim=0
128
127
  ) # (B, C, H, W)
129
128
  if self.do_resizing:
130
129
  chips = self._resize_image(chips, chips.shape[2])
@@ -163,7 +162,7 @@ class Clay(torch.nn.Module):
163
162
  )
164
163
 
165
164
  features = rearrange(spatial, "b (h w) d -> b d h w", h=side, w=side)
166
- return [features]
165
+ return FeatureMaps([features])
167
166
 
168
167
  def get_backbone_channels(self) -> list:
169
168
  """Return output channels of this model when used as a backbone."""
rslearn/models/clip.py CHANGED
@@ -1,12 +1,13 @@
1
1
  """OpenAI CLIP models."""
2
2
 
3
- from typing import Any
4
-
5
- import torch
6
3
  from transformers import AutoModelForZeroShotImageClassification, AutoProcessor
7
4
 
5
+ from rslearn.train.model_context import ModelContext
6
+
7
+ from .component import FeatureExtractor, FeatureMaps
8
+
8
9
 
9
- class CLIP(torch.nn.Module):
10
+ class CLIP(FeatureExtractor):
10
11
  """CLIP image encoder."""
11
12
 
12
13
  def __init__(
@@ -31,17 +32,17 @@ class CLIP(torch.nn.Module):
31
32
  self.height = crop_size["height"] // stride[0]
32
33
  self.width = crop_size["width"] // stride[1]
33
34
 
34
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
35
+ def forward(self, context: ModelContext) -> FeatureMaps:
35
36
  """Compute outputs from the backbone.
36
37
 
37
- Inputs:
38
- inputs: input dicts that must include "image" key containing the image to
39
- process. The images should have values 0-255.
38
+ Args:
39
+ context: the model context. Input dicts must include "image" key containing
40
+ the image to process. The images should have values 0-255.
40
41
 
41
42
  Returns:
42
- list of feature maps. The ViT produces features at one scale, so the list
43
- contains a single Bx24x24x1024 feature map.
43
+ a FeatureMaps with one feature map from the ViT, which is always Bx24x24x1024.
44
44
  """
45
+ inputs = context.inputs
45
46
  device = inputs[0]["image"].device
46
47
  clip_inputs = self.processor(
47
48
  images=[inp["image"].cpu().numpy().transpose(1, 2, 0) for inp in inputs],
@@ -55,8 +56,10 @@ class CLIP(torch.nn.Module):
55
56
  batch_size = image_features.shape[0]
56
57
 
57
58
  # 576x1024 -> HxWxC
58
- return [
59
- image_features.reshape(
60
- batch_size, self.height, self.width, self.num_features
61
- ).permute(0, 3, 1, 2)
62
- ]
59
+ return FeatureMaps(
60
+ [
61
+ image_features.reshape(
62
+ batch_size, self.height, self.width, self.num_features
63
+ ).permute(0, 3, 1, 2)
64
+ ]
65
+ )
@@ -0,0 +1,99 @@
1
+ """Model component API."""
2
+
3
+ import abc
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+ import torch
8
+
9
+ from rslearn.train.model_context import ModelContext, ModelOutput
10
+
11
+
12
+ class FeatureExtractor(torch.nn.Module, abc.ABC):
13
+ """A feature extractor that performs initial processing of the inputs.
14
+
15
+ The FeatureExtractor is the first component in the encoders list for
16
+ SingleTaskModel and MultiTaskModel.
17
+ """
18
+
19
+ @abc.abstractmethod
20
+ def forward(self, context: ModelContext) -> Any:
21
+ """Extract an initial intermediate from the model context.
22
+
23
+ Args:
24
+ context: the model context.
25
+
26
+ Returns:
27
+ any intermediate to pass to downstream components. Oftentimes this is a
28
+ FeatureMaps.
29
+ """
30
+ raise NotImplementedError
31
+
32
+
33
+ class IntermediateComponent(torch.nn.Module, abc.ABC):
34
+ """An intermediate component in the model.
35
+
36
+ In SingleTaskModel and MultiTaskModel, modules after the first module
37
+ in the encoders list are IntermediateComponents, as are modules before the last
38
+ module in the decoders list(s).
39
+ """
40
+
41
+ @abc.abstractmethod
42
+ def forward(self, intermediates: Any, context: ModelContext) -> Any:
43
+ """Process the given intermediate into another intermediate.
44
+
45
+ Args:
46
+ intermediates: the output from the previous component (either a
47
+ FeatureExtractor or another IntermediateComponent).
48
+ context: the model context.
49
+
50
+ Returns:
51
+ any intermediate to pass to downstream components.
52
+ """
53
+ raise NotImplementedError
54
+
55
+
56
+ class Predictor(torch.nn.Module, abc.ABC):
57
+ """A predictor that computes task-specific outputs and a loss dict.
58
+
59
+ In SingleTaskModel and MultiTaskModel, the last module(s) in the decoders list(s)
60
+ are Predictors.
61
+ """
62
+
63
+ @abc.abstractmethod
64
+ def forward(
65
+ self,
66
+ intermediates: Any,
67
+ context: ModelContext,
68
+ targets: list[dict[str, torch.Tensor]] | None = None,
69
+ ) -> ModelOutput:
70
+ """Compute task-specific outputs and loss dict.
71
+
72
+ Args:
73
+ intermediates: the output from the previous component.
74
+ context: the model context.
75
+ targets: the training targets, or None during prediction.
76
+
77
+ Returns:
78
+ a tuple of the task-specific outputs (which should be compatible with the
79
+ configured Task) and loss dict. The loss dict maps from a name for each
80
+ loss to a scalar tensor.
81
+ """
82
+ raise NotImplementedError
83
+
84
+
85
+ @dataclass
86
+ class FeatureMaps:
87
+ """An intermediate output type for multi-resolution feature maps."""
88
+
89
+ # List of BxCxHxW feature maps at different scales, ordered from highest resolution
90
+ # (most fine-grained) to lowest resolution (coarsest).
91
+ feature_maps: list[torch.Tensor]
92
+
93
+
94
+ @dataclass
95
+ class FeatureVector:
96
+ """An intermediate output type for a flat feature vector."""
97
+
98
+ # Flat BxC feature vector.
99
+ feature_vector: torch.Tensor
@@ -4,8 +4,12 @@ from typing import Any
4
4
 
5
5
  import torch
6
6
 
7
+ from rslearn.train.model_context import ModelContext
7
8
 
8
- class ConcatenateFeatures(torch.nn.Module):
9
+ from .component import FeatureMaps, IntermediateComponent
10
+
11
+
12
+ class ConcatenateFeatures(IntermediateComponent):
9
13
  """Concatenate feature map with additional raw data inputs."""
10
14
 
11
15
  def __init__(
@@ -55,26 +59,32 @@ class ConcatenateFeatures(torch.nn.Module):
55
59
 
56
60
  self.conv_layers = torch.nn.Sequential(*conv_layers)
57
61
 
58
- def forward(
59
- self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
60
- ) -> list[torch.Tensor]:
62
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
61
63
  """Concatenate the feature map with the raw data inputs.
62
64
 
63
65
  Args:
64
- features: list of feature maps at different resolutions.
65
- inputs: original inputs.
66
+ intermediates: the previous output, which must be a FeatureMaps.
67
+ context: the model context. The input dicts must have a key matching the
68
+ configured key.
66
69
 
67
70
  Returns:
68
71
  concatenated feature maps.
69
72
  """
70
- if not features:
71
- raise ValueError("Expected at least one feature map, got none.")
73
+ if (
74
+ not isinstance(intermediates, FeatureMaps)
75
+ or len(intermediates.feature_maps) == 0
76
+ ):
77
+ raise ValueError(
78
+ "Expected input to be FeatureMaps with at least one feature map"
79
+ )
72
80
 
73
- add_data = torch.stack([input_data[self.key] for input_data in inputs], dim=0)
81
+ add_data = torch.stack(
82
+ [input_data[self.key] for input_data in context.inputs], dim=0
83
+ )
74
84
  add_features = self.conv_layers(add_data)
75
85
 
76
86
  new_features: list[torch.Tensor] = []
77
- for feature_map in features:
87
+ for feature_map in intermediates.feature_maps:
78
88
  # Shape of feature map: BCHW
79
89
  feat_h, feat_w = feature_map.shape[2], feature_map.shape[3]
80
90
 
@@ -90,4 +100,4 @@ class ConcatenateFeatures(torch.nn.Module):
90
100
 
91
101
  new_features.append(torch.cat([feature_map, resized_add_features], dim=1))
92
102
 
93
- return new_features
103
+ return FeatureMaps(new_features)
rslearn/models/conv.py CHANGED
@@ -4,8 +4,12 @@ from typing import Any
4
4
 
5
5
  import torch
6
6
 
7
+ from rslearn.train.model_context import ModelContext
7
8
 
8
- class Conv(torch.nn.Module):
9
+ from .component import FeatureMaps, IntermediateComponent
10
+
11
+
12
+ class Conv(IntermediateComponent):
9
13
  """A single convolutional layer.
10
14
 
11
15
  It inputs a set of feature maps; the conv layer is applied to each feature map
@@ -38,19 +42,22 @@ class Conv(torch.nn.Module):
38
42
  )
39
43
  self.activation = activation
40
44
 
41
- def forward(self, features: list[torch.Tensor], inputs: Any) -> list[torch.Tensor]:
42
- """Compute flat output vector from multi-scale feature map.
45
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
46
+ """Apply conv layer on each feature map.
43
47
 
44
48
  Args:
45
- features: list of feature maps at different resolutions.
46
- inputs: original inputs (ignored).
49
+ intermediates: the previous output, which must be a FeatureMaps.
50
+ context: the model context.
47
51
 
48
52
  Returns:
49
- flat feature vector
53
+ the resulting feature maps after applying the same Conv2d on each one.
50
54
  """
55
+ if not isinstance(intermediates, FeatureMaps):
56
+ raise ValueError("input to Conv must be FeatureMaps")
57
+
51
58
  new_features = []
52
- for feat_map in features:
59
+ for feat_map in intermediates.feature_maps:
53
60
  feat_map = self.layer(feat_map)
54
61
  feat_map = self.activation(feat_map)
55
62
  new_features.append(feat_map)
56
- return new_features
63
+ return FeatureMaps(new_features)