rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,59 @@
1
+ """The ResizeFeatures module."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+
7
+ from rslearn.train.model_context import ModelContext
8
+
9
+ from .component import (
10
+ FeatureMaps,
11
+ IntermediateComponent,
12
+ )
13
+
14
+
15
+ class ResizeFeatures(IntermediateComponent):
16
+ """Resize input features to new sizes."""
17
+
18
+ def __init__(
19
+ self,
20
+ out_sizes: list[tuple[int, int]],
21
+ mode: str = "bilinear",
22
+ ):
23
+ """Initialize a ResizeFeatures.
24
+
25
+ Args:
26
+ out_sizes: the output sizes of the feature maps. There must be one entry
27
+ for each input feature map.
28
+ mode: mode to pass to torch.nn.Upsample, e.g. "bilinear" (default) or
29
+ "nearest".
30
+ """
31
+ super().__init__()
32
+ layers = []
33
+ for size in out_sizes:
34
+ layers.append(
35
+ torch.nn.Upsample(
36
+ size=size,
37
+ mode=mode,
38
+ )
39
+ )
40
+ self.layers = torch.nn.ModuleList(layers)
41
+
42
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
43
+ """Resize the input feature maps to new sizes.
44
+
45
+ Args:
46
+ intermediates: the outputs from the previous component, which must be a FeatureMaps.
47
+ context: the model context.
48
+
49
+ Returns:
50
+ resized feature maps
51
+ """
52
+ if not isinstance(intermediates, FeatureMaps):
53
+ raise ValueError("input to ResizeFeatures must be a FeatureMaps")
54
+
55
+ feat_maps = intermediates.feature_maps
56
+ resized_feat_maps = [
57
+ self.layers[idx](feat_map) for idx, feat_map in enumerate(feat_maps)
58
+ ]
59
+ return FeatureMaps(resized_feat_maps)
@@ -1,14 +1,15 @@
1
1
  """SegmentAnything2 encoders."""
2
2
 
3
- from typing import Any
4
-
5
3
  import torch
6
- import torch.nn as nn
7
4
  from sam2.build_sam import build_sam2
8
5
  from upath import UPath
9
6
 
7
+ from rslearn.train.model_context import ModelContext
8
+
9
+ from .component import FeatureExtractor, FeatureMaps
10
+
10
11
 
11
- class SAM2Encoder(nn.Module):
12
+ class SAM2Encoder(FeatureExtractor):
12
13
  """SAM2's image encoder."""
13
14
 
14
15
  def __init__(self, model_identifier: str) -> None:
@@ -84,18 +85,21 @@ class SAM2Encoder(nn.Module):
84
85
  del self.model.obj_ptr_proj
85
86
  del self.model.image_encoder.neck
86
87
 
87
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
88
+ def forward(self, context: ModelContext) -> FeatureMaps:
88
89
  """Extract multi-scale features from a batch of images.
89
90
 
90
91
  Args:
91
- inputs: List of dictionaries, each containing the input image under the key 'image'.
92
+ context: the model context. Input dicts must have a key 'image' containing
93
+ the input for the SAM2 image encoder.
92
94
 
93
95
  Returns:
94
- List[torch.Tensor]: Multi-scale feature tensors from the encoder.
96
+ feature maps from the encoder.
95
97
  """
96
- images = torch.stack([inp["image"] for inp in inputs], dim=0)
98
+ images = torch.stack(
99
+ [inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
100
+ )
97
101
  features = self.encoder(images)
98
- return features
102
+ return FeatureMaps(features)
99
103
 
100
104
  def get_backbone_channels(self) -> list[list[int]]:
101
105
  """Returns the output channels of the encoder at different scales.
@@ -1,19 +1,20 @@
1
1
  """SatlasPretrain models."""
2
2
 
3
- from typing import Any
4
-
5
3
  import satlaspretrain_models
6
4
  import torch
5
+ import torch.nn.functional as F
6
+
7
+ from rslearn.train.model_context import ModelContext
7
8
 
9
+ from .component import FeatureExtractor, FeatureMaps
8
10
 
9
- class SatlasPretrain(torch.nn.Module):
11
+
12
+ class SatlasPretrain(FeatureExtractor):
10
13
  """SatlasPretrain backbones."""
11
14
 
12
15
  def __init__(
13
- self,
14
- model_identifier: str,
15
- fpn: bool = False,
16
- ):
16
+ self, model_identifier: str, fpn: bool = False, resize_to_pretrain: bool = False
17
+ ) -> None:
17
18
  """Instantiate a new SatlasPretrain instance.
18
19
 
19
20
  Args:
@@ -21,11 +22,13 @@ class SatlasPretrain(torch.nn.Module):
21
22
  https://github.com/allenai/satlaspretrain_models
22
23
  fpn: whether to include the feature pyramid network, otherwise only the
23
24
  Swin-v2-Transformer is used.
25
+ resize_to_pretrain: whether to resize inputs to the pretraining input
26
+ size (512 x 512)
24
27
  """
25
28
  super().__init__()
26
29
  weights_manager = satlaspretrain_models.Weights()
27
30
  self.model = weights_manager.get_pretrained_model(
28
- model_identifier=model_identifier, fpn=fpn
31
+ model_identifier=model_identifier, fpn=fpn, device="cpu"
29
32
  )
30
33
 
31
34
  if "SwinB" in model_identifier:
@@ -49,21 +52,38 @@ class SatlasPretrain(torch.nn.Module):
49
52
  [16, 1024],
50
53
  [32, 2048],
51
54
  ]
55
+ self.resize_to_pretrain = resize_to_pretrain
56
+
57
+ def maybe_resize(self, data: torch.Tensor) -> list[torch.Tensor]:
58
+ """Resize to pretraining sizes if resize_to_pretrain == True."""
59
+ if self.resize_to_pretrain:
60
+ return F.interpolate(
61
+ data,
62
+ size=(512, 512),
63
+ mode="bilinear",
64
+ align_corners=False,
65
+ )
66
+ else:
67
+ return data
52
68
 
53
- def forward(
54
- self, inputs: list[dict[str, Any]], targets: list[dict[str, Any]] = None
55
- ):
69
+ def forward(self, context: ModelContext) -> FeatureMaps:
56
70
  """Compute feature maps from the SatlasPretrain backbone.
57
71
 
58
- Inputs:
59
- inputs: input dicts that must include "image" key containing the image to
60
- process.
61
- targets: target dicts that are ignored
72
+ Args:
73
+ context: the model context. Input dicts must contain an "image" key
74
+ containing the image input to the model.
75
+
76
+ Returns:
77
+ multi-resolution feature maps computed by the model.
62
78
  """
63
- images = torch.stack([inp["image"] for inp in inputs], dim=0)
64
- return self.model(images)
79
+ # take the first (assumed to be only) timestep
80
+ images = torch.stack(
81
+ [inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
82
+ )
83
+ feature_maps = self.model(self.maybe_resize(images))
84
+ return FeatureMaps(feature_maps)
65
85
 
66
- def get_backbone_channels(self):
86
+ def get_backbone_channels(self) -> list:
67
87
  """Returns the output channels of this model when used as a backbone.
68
88
 
69
89
  The output channels is a list of (downsample_factor, depth) that corresponds
@@ -3,12 +3,17 @@
3
3
  from typing import Any
4
4
 
5
5
  import torch
6
+ from einops import rearrange
6
7
 
8
+ from rslearn.train.model_context import ModelContext, RasterImage
7
9
 
8
- class SimpleTimeSeries(torch.nn.Module):
9
- """SimpleTimeSeries wraps another encoder and applies it on an image time series.
10
+ from .component import FeatureExtractor, FeatureMaps
10
11
 
11
- It independently applies the other encoder on each image in the time series to
12
+
13
+ class SimpleTimeSeries(FeatureExtractor):
14
+ """SimpleTimeSeries wraps another FeatureExtractor and applies it on an image time series.
15
+
16
+ It independently applies the other FeatureExtractor on each image in the time series to
12
17
  extract feature maps. It then provides a few ways to combine the features into one
13
18
  final feature map:
14
19
  - Temporal max pooling.
@@ -19,17 +24,21 @@ class SimpleTimeSeries(torch.nn.Module):
19
24
 
20
25
  def __init__(
21
26
  self,
22
- encoder: torch.nn.Module,
23
- image_channels: int,
27
+ encoder: FeatureExtractor,
28
+ image_channels: int | None = None,
24
29
  op: str = "max",
25
30
  groups: list[list[int]] | None = None,
26
31
  num_layers: int | None = None,
27
- ):
32
+ image_key: str = "image",
33
+ backbone_channels: list[tuple[int, int]] | None = None,
34
+ image_keys: dict[str, int] | None = None,
35
+ ) -> None:
28
36
  """Create a new SimpleTimeSeries.
29
37
 
30
38
  Args:
31
- encoder: the underlying encoder. It must provide get_backbone_channels
32
- function that returns the output channels.
39
+ encoder: the underlying FeatureExtractor. It must provide get_backbone_channels
40
+ function that returns the output channels, or backbone_channels must be set.
41
+ It must output a FeatureMaps.
33
42
  image_channels: the number of channels per image of the time series. The
34
43
  input should have multiple images concatenated on the channel axis, so
35
44
  this parameter is used to distinguish the different images.
@@ -42,76 +51,101 @@ class SimpleTimeSeries(torch.nn.Module):
42
51
  combined before features and the combined after features. groups is a
43
52
  list of sets, and each set is a list of image indices.
44
53
  num_layers: the number of layers for convrnn, conv3d, and conv1d ops.
54
+ image_key: the key to access the images.
55
+ backbone_channels: manually specify the backbone channels. Can be set if
56
+ the encoder does not provide get_backbone_channels function.
57
+ image_keys: as an alternative to setting image_channels, map from the key
58
+ in input dict to the number of channels per timestep for that modality.
59
+ This way SimpleTimeSeries can be used with multimodal inputs. One of
60
+ image_channels or image_keys must be specified.
45
61
  """
62
+ if (image_channels is None and image_keys is None) or (
63
+ image_channels is not None and image_keys is not None
64
+ ):
65
+ raise ValueError(
66
+ "exactly one of image_channels and image_keys must be specified"
67
+ )
68
+
46
69
  super().__init__()
47
70
  self.encoder = encoder
48
71
  self.image_channels = image_channels
49
72
  self.op = op
50
73
  self.groups = groups
74
+ self.image_key = image_key
75
+ self.image_keys = image_keys
51
76
 
52
- out_channels = self.encoder.get_backbone_channels()
77
+ if backbone_channels is not None:
78
+ out_channels = backbone_channels
79
+ else:
80
+ out_channels = self.encoder.get_backbone_channels()
53
81
  if self.groups:
54
82
  self.num_groups = len(self.groups)
55
83
  else:
56
84
  self.num_groups = 1
57
85
 
58
- if self.op == "convrnn":
59
- rnn_kernel_size = 3
60
- self.rnn_layers = []
61
- for _, count in out_channels:
62
- cur_layer = [
63
- torch.nn.Sequential(
64
- torch.nn.Conv2d(
65
- 2 * count, count, rnn_kernel_size, padding="same"
66
- ),
67
- torch.nn.ReLU(inplace=True),
68
- )
69
- ]
70
- for _ in range(num_layers - 1):
71
- cur_layer.append(
86
+ if self.op in ["convrnn", "conv3d", "conv1d"]:
87
+ if num_layers is None:
88
+ raise ValueError(f"num_layers must be specified for {self.op} op")
89
+
90
+ if self.op == "convrnn":
91
+ rnn_kernel_size = 3
92
+ rnn_layers = []
93
+ for _, count in out_channels:
94
+ cur_layer = [
72
95
  torch.nn.Sequential(
73
96
  torch.nn.Conv2d(
74
- count, count, rnn_kernel_size, padding="same"
97
+ 2 * count, count, rnn_kernel_size, padding="same"
75
98
  ),
76
99
  torch.nn.ReLU(inplace=True),
77
100
  )
78
- )
79
- cur_layer = torch.nn.Sequential(*cur_layer)
80
- self.rnn_layers.append(cur_layer)
81
- self.rnn_layers = torch.nn.ModuleList(self.rnn_layers)
82
-
83
- elif self.op == "conv3d":
84
- self.conv3d_layers = []
85
- for _, count in out_channels:
86
- cur_layer = [
87
- torch.nn.Sequential(
88
- torch.nn.Conv3d(count, count, 3, padding=1, stride=(2, 1, 1)),
89
- torch.nn.ReLU(inplace=True),
90
- )
91
- for _ in range(num_layers)
92
- ]
93
- cur_layer = torch.nn.Sequential(*cur_layer)
94
- self.conv3d_layers.append(cur_layer)
95
- self.conv3d_layers = torch.nn.ModuleList(self.conv3d_layers)
96
-
97
- elif self.op == "conv1d":
98
- self.conv1d_layers = []
99
- for _, count in out_channels:
100
- cur_layer = [
101
- torch.nn.Sequential(
102
- torch.nn.Conv1d(count, count, 3, padding=1, stride=2),
103
- torch.nn.ReLU(inplace=True),
104
- )
105
- for _ in range(num_layers)
106
- ]
107
- cur_layer = torch.nn.Sequential(*cur_layer)
108
- self.conv1d_layers.append(cur_layer)
109
- self.conv1d_layers = torch.nn.ModuleList(self.conv1d_layers)
101
+ ]
102
+ for _ in range(num_layers - 1):
103
+ cur_layer.append(
104
+ torch.nn.Sequential(
105
+ torch.nn.Conv2d(
106
+ count, count, rnn_kernel_size, padding="same"
107
+ ),
108
+ torch.nn.ReLU(inplace=True),
109
+ )
110
+ )
111
+ cur_layer = torch.nn.Sequential(*cur_layer)
112
+ rnn_layers.append(cur_layer)
113
+ self.rnn_layers = torch.nn.ModuleList(rnn_layers)
114
+
115
+ elif self.op == "conv3d":
116
+ conv3d_layers = []
117
+ for _, count in out_channels:
118
+ cur_layer = [
119
+ torch.nn.Sequential(
120
+ torch.nn.Conv3d(
121
+ count, count, 3, padding=1, stride=(2, 1, 1)
122
+ ),
123
+ torch.nn.ReLU(inplace=True),
124
+ )
125
+ for _ in range(num_layers)
126
+ ]
127
+ cur_layer = torch.nn.Sequential(*cur_layer)
128
+ conv3d_layers.append(cur_layer)
129
+ self.conv3d_layers = torch.nn.ModuleList(conv3d_layers)
130
+
131
+ elif self.op == "conv1d":
132
+ conv1d_layers = []
133
+ for _, count in out_channels:
134
+ cur_layer = [
135
+ torch.nn.Sequential(
136
+ torch.nn.Conv1d(count, count, 3, padding=1, stride=2),
137
+ torch.nn.ReLU(inplace=True),
138
+ )
139
+ for _ in range(num_layers)
140
+ ]
141
+ cur_layer = torch.nn.Sequential(*cur_layer)
142
+ conv1d_layers.append(cur_layer)
143
+ self.conv1d_layers = torch.nn.ModuleList(conv1d_layers)
110
144
 
111
145
  else:
112
146
  assert self.op in ["max", "mean"]
113
147
 
114
- def get_backbone_channels(self):
148
+ def get_backbone_channels(self) -> list:
115
149
  """Returns the output channels of this model when used as a backbone.
116
150
 
117
151
  The output channels is a list of (downsample_factor, depth) that corresponds
@@ -128,27 +162,105 @@ class SimpleTimeSeries(torch.nn.Module):
128
162
  out_channels.append((downsample_factor, depth * self.num_groups))
129
163
  return out_channels
130
164
 
165
+ def _get_batched_images(
166
+ self, input_dicts: list[dict[str, Any]], image_key: str, image_channels: int
167
+ ) -> list[RasterImage]:
168
+ """Collect and reshape images across input dicts.
169
+
170
+ The BTCHW image time series are reshaped to (B*T)CHW so they can be passed to
171
+ the forward pass of a per-image (unitemporal) model.
172
+ """
173
+ images = torch.stack(
174
+ [input_dict[image_key].image for input_dict in input_dicts], dim=0
175
+ ) # B, C, T, H, W
176
+ timestamps = [input_dict[image_key].timestamps for input_dict in input_dicts]
177
+ # if image channels is not equal to the actual number of channels, then
178
+ # then every N images should be batched together. For example, if the
179
+ # number of input channels c == 2, and image_channels == 4, then we
180
+ # want to pass 2 timesteps to the model.
181
+ # TODO is probably to make this behaviour clearer but lets leave it like
182
+ # this for now to not break things.
183
+ num_timesteps = images.shape[1] // image_channels
184
+ batched_timesteps = images.shape[2] // num_timesteps
185
+ images = rearrange(
186
+ images,
187
+ "b c (b_t k_t) h w -> (b b_t) c k_t h w",
188
+ b_t=batched_timesteps,
189
+ k_t=num_timesteps,
190
+ )
191
+ if timestamps[0] is None:
192
+ new_timestamps = [None] * images.shape[0]
193
+ else:
194
+ # we also need to split the timestamps
195
+ new_timestamps = []
196
+ for t in timestamps:
197
+ for i in range(batched_timesteps):
198
+ new_timestamps.append(
199
+ t[i * num_timesteps : (i + 1) * num_timesteps]
200
+ )
201
+ return [
202
+ RasterImage(image=image, timestamps=timestamps)
203
+ for image, timestamps in zip(images, new_timestamps)
204
+ ] # C, T, H, W
205
+
131
206
  def forward(
132
- self, inputs: list[dict[str, Any]], targets: list[dict[str, Any]] = None
133
- ):
207
+ self,
208
+ context: ModelContext,
209
+ ) -> FeatureMaps:
134
210
  """Compute outputs from the backbone.
135
211
 
136
- Inputs:
137
- inputs: input dicts that must include "image" key containing the image time
212
+ Args:
213
+ context: the model context. Input dicts must include "image" key containing the image time
138
214
  series to process (with images concatenated on the channel dimension).
139
- targets: target dicts that are ignored unless
215
+
216
+ Returns:
217
+ the FeatureMaps aggregated temporally.
140
218
  """
141
219
  # First get features of each image.
142
220
  # To do so, we need to split up each grouped image into its component images (which have had their channels stacked).
143
- images = torch.stack([inp["image"] for inp in inputs], dim=0)
144
- n_batch = images.shape[0]
145
- n_images = images.shape[1] // self.image_channels
146
- n_height = images.shape[2]
147
- n_width = images.shape[3]
148
- batched_images = images.reshape(
149
- n_batch * n_images, self.image_channels, n_height, n_width
221
+ batched_inputs: list[dict[str, Any]] | None = None
222
+ n_batch = len(context.inputs)
223
+ n_images: int | None = None
224
+
225
+ if self.image_keys is not None:
226
+ for image_key, image_channels in self.image_keys.items():
227
+ batched_images = self._get_batched_images(
228
+ context.inputs, image_key, image_channels
229
+ )
230
+
231
+ if batched_inputs is None:
232
+ batched_inputs = [{} for _ in batched_images]
233
+ n_images = len(batched_images) // n_batch
234
+ elif n_images != len(batched_images) // n_batch:
235
+ raise ValueError(
236
+ "expected all modalities to have the same number of timesteps"
237
+ )
238
+
239
+ for i, image in enumerate(batched_images):
240
+ batched_inputs[i][image_key] = image
241
+
242
+ else:
243
+ assert self.image_channels is not None
244
+ batched_images = self._get_batched_images(
245
+ context.inputs, self.image_key, self.image_channels
246
+ )
247
+ batched_inputs = [{self.image_key: image} for image in batched_images]
248
+ n_images = len(batched_images) // n_batch
249
+
250
+ assert n_images is not None
251
+ # Now we can apply the underlying FeatureExtractor.
252
+ # Its output must be a FeatureMaps.
253
+ assert batched_inputs is not None
254
+ encoder_output = self.encoder(
255
+ ModelContext(
256
+ inputs=batched_inputs,
257
+ metadatas=context.metadatas,
258
+ )
150
259
  )
151
- batched_inputs = [{"image": image} for image in batched_images]
260
+ if not isinstance(encoder_output, FeatureMaps):
261
+ raise ValueError(
262
+ "output of underlying FeatureExtractor in SimpleTimeSeries must be a FeatureMaps"
263
+ )
152
264
  all_features = [
153
265
  feat_map.reshape(
154
266
  n_batch,
@@ -157,9 +269,8 @@ class SimpleTimeSeries(torch.nn.Module):
157
269
  feat_map.shape[2],
158
270
  feat_map.shape[3],
159
271
  )
160
- for feat_map in self.encoder(batched_inputs)
272
+ for feat_map in encoder_output.feature_maps
161
273
  ]
162
-
163
274
  # Groups defaults to flattening all the feature maps.
164
275
  groups = self.groups
165
276
  if not groups:
@@ -171,13 +282,13 @@ class SimpleTimeSeries(torch.nn.Module):
171
282
  for feature_idx in range(len(all_features)):
172
283
  aggregated_features = []
173
284
  for group in groups:
174
- group_features = []
285
+ group_features_list = []
175
286
  for image_idx in group:
176
- group_features.append(
287
+ group_features_list.append(
177
288
  all_features[feature_idx][:, image_idx, :, :, :]
178
289
  )
179
290
  # Resulting group features are (depth, batch, C, height, width).
180
- group_features = torch.stack(group_features, dim=0)
291
+ group_features = torch.stack(group_features_list, dim=0)
181
292
 
182
293
  if self.op == "max":
183
294
  group_features = torch.amax(group_features, dim=0)
@@ -213,7 +324,7 @@ class SimpleTimeSeries(torch.nn.Module):
213
324
  .permute(0, 3, 1, 2)
214
325
  )
215
326
  else:
216
- raise Exception(f"unknown aggregation op {self.op}")
327
+ raise ValueError(f"unknown aggregation op {self.op}")
217
328
 
218
329
  aggregated_features.append(group_features)
219
330
 
@@ -222,4 +333,4 @@ class SimpleTimeSeries(torch.nn.Module):
222
333
 
223
334
  output_features.append(aggregated_features)
224
335
 
225
- return output_features
336
+ return FeatureMaps(output_features)
@@ -4,6 +4,10 @@ from typing import Any
4
4
 
5
5
  import torch
6
6
 
7
+ from rslearn.train.model_context import ModelContext, ModelOutput
8
+
9
+ from .component import FeatureExtractor, IntermediateComponent, Predictor
10
+
7
11
 
8
12
  class SingleTaskModel(torch.nn.Module):
9
13
  """Standard model wrapper.
@@ -14,34 +18,41 @@ class SingleTaskModel(torch.nn.Module):
14
18
  outputs and targets from the last module (which also receives the targets).
15
19
  """
16
20
 
17
- def __init__(self, encoder: list[torch.nn.Module], decoder: list[torch.nn.Module]):
21
+ def __init__(
22
+ self,
23
+ encoder: list[FeatureExtractor | IntermediateComponent],
24
+ decoder: list[IntermediateComponent | Predictor],
25
+ ):
18
26
  """Initialize a new SingleTaskModel.
19
27
 
20
28
  Args:
21
- encoder: modules to compute intermediate feature representations.
22
- decoder: modules to compute outputs and loss.
29
+ encoder: modules to compute intermediate feature representations. The first
30
+ module must be a FeatureExtractor, and following modules must be
31
+ IntermediateComponents.
32
+ decoder: modules to compute outputs and loss. The last module must be a
33
+ Predictor, while the previous modules must be IntermediateComponents.
23
34
  """
24
35
  super().__init__()
25
- self.encoder = torch.nn.Sequential(*encoder)
36
+ self.encoder = torch.nn.ModuleList(encoder)
26
37
  self.decoder = torch.nn.ModuleList(decoder)
27
38
 
28
39
  def forward(
29
40
  self,
30
- inputs: list[dict[str, Any]],
41
+ context: ModelContext,
31
42
  targets: list[dict[str, Any]] | None = None,
32
- ) -> tuple[list[Any], dict[str, torch.Tensor]]:
43
+ ) -> ModelOutput:
33
44
  """Apply the sequence of modules on the inputs.
34
45
 
35
46
  Args:
36
- inputs: list of input dicts
47
+ context: the model context.
37
48
  targets: optional list of target dicts
38
49
 
39
50
  Returns:
40
- tuple (outputs, loss_dict) from the last module.
51
+ the model output.
41
52
  """
42
- features = self.encoder(inputs)
43
- cur = features
53
+ cur = self.encoder[0](context)
54
+ for module in self.encoder[1:]:
55
+ cur = module(cur, context)
44
56
  for module in self.decoder[:-1]:
45
- cur = module(cur, inputs)
46
-
47
- return self.decoder[-1](cur, inputs, targets)
57
+ cur = module(cur, context)
58
+ return self.decoder[-1](cur, context, targets)