rslearn 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. rslearn/arg_parser.py +2 -9
  2. rslearn/config/__init__.py +2 -0
  3. rslearn/config/dataset.py +64 -20
  4. rslearn/dataset/add_windows.py +1 -1
  5. rslearn/dataset/dataset.py +34 -84
  6. rslearn/dataset/materialize.py +5 -5
  7. rslearn/dataset/storage/__init__.py +1 -0
  8. rslearn/dataset/storage/file.py +202 -0
  9. rslearn/dataset/storage/storage.py +140 -0
  10. rslearn/dataset/window.py +26 -80
  11. rslearn/lightning_cli.py +22 -11
  12. rslearn/main.py +12 -37
  13. rslearn/models/anysat.py +11 -9
  14. rslearn/models/attention_pooling.py +177 -0
  15. rslearn/models/clay/clay.py +8 -9
  16. rslearn/models/clip.py +18 -15
  17. rslearn/models/component.py +111 -0
  18. rslearn/models/concatenate_features.py +21 -11
  19. rslearn/models/conv.py +15 -8
  20. rslearn/models/croma.py +13 -8
  21. rslearn/models/detr/detr.py +25 -14
  22. rslearn/models/dinov3.py +11 -6
  23. rslearn/models/faster_rcnn.py +19 -9
  24. rslearn/models/feature_center_crop.py +12 -9
  25. rslearn/models/fpn.py +19 -8
  26. rslearn/models/galileo/galileo.py +23 -18
  27. rslearn/models/module_wrapper.py +26 -57
  28. rslearn/models/molmo.py +16 -14
  29. rslearn/models/multitask.py +102 -73
  30. rslearn/models/olmoearth_pretrain/model.py +135 -38
  31. rslearn/models/panopticon.py +8 -7
  32. rslearn/models/pick_features.py +18 -24
  33. rslearn/models/pooling_decoder.py +22 -14
  34. rslearn/models/presto/presto.py +16 -10
  35. rslearn/models/presto/single_file_presto.py +4 -10
  36. rslearn/models/prithvi.py +12 -8
  37. rslearn/models/resize_features.py +21 -7
  38. rslearn/models/sam2_enc.py +11 -9
  39. rslearn/models/satlaspretrain.py +15 -9
  40. rslearn/models/simple_time_series.py +37 -17
  41. rslearn/models/singletask.py +24 -17
  42. rslearn/models/ssl4eo_s12.py +15 -10
  43. rslearn/models/swin.py +22 -13
  44. rslearn/models/terramind.py +24 -7
  45. rslearn/models/trunk.py +6 -3
  46. rslearn/models/unet.py +18 -9
  47. rslearn/models/upsample.py +22 -9
  48. rslearn/train/all_patches_dataset.py +89 -37
  49. rslearn/train/dataset.py +105 -97
  50. rslearn/train/lightning_module.py +51 -32
  51. rslearn/train/model_context.py +54 -0
  52. rslearn/train/prediction_writer.py +111 -41
  53. rslearn/train/scheduler.py +15 -0
  54. rslearn/train/tasks/classification.py +34 -15
  55. rslearn/train/tasks/detection.py +24 -31
  56. rslearn/train/tasks/embedding.py +33 -29
  57. rslearn/train/tasks/multi_task.py +7 -7
  58. rslearn/train/tasks/per_pixel_regression.py +41 -19
  59. rslearn/train/tasks/regression.py +38 -21
  60. rslearn/train/tasks/segmentation.py +33 -15
  61. rslearn/train/tasks/task.py +3 -2
  62. rslearn/train/transforms/resize.py +74 -0
  63. rslearn/utils/geometry.py +73 -0
  64. rslearn/utils/jsonargparse.py +66 -0
  65. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
  66. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/RECORD +71 -66
  67. rslearn/dataset/index.py +0 -173
  68. rslearn/models/registry.py +0 -22
  69. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
  70. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
  71. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
  72. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
  73. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
@@ -2,45 +2,39 @@
2
2
 
3
3
  from typing import Any
4
4
 
5
- import torch
5
+ from rslearn.train.model_context import ModelContext
6
6
 
7
+ from .component import (
8
+ FeatureMaps,
9
+ IntermediateComponent,
10
+ )
7
11
 
8
- class PickFeatures(torch.nn.Module):
12
+
13
+ class PickFeatures(IntermediateComponent):
9
14
  """Picks a subset of feature maps in a multi-scale feature map list."""
10
15
 
11
- def __init__(self, indexes: list[int], collapse: bool = False):
16
+ def __init__(self, indexes: list[int]):
12
17
  """Create a new PickFeatures.
13
18
 
14
19
  Args:
15
20
  indexes: the indexes of the input feature map list to select.
16
- collapse: return one feature map instead of list. If enabled, indexes must
17
- consist of one index. This is mainly useful for using PickFeatures as
18
- the final module in the decoder, since the final prediction is expected
19
- to be one feature map for most tasks like segmentation.
20
21
  """
21
22
  super().__init__()
22
23
  self.indexes = indexes
23
- self.collapse = collapse
24
-
25
- if self.collapse and len(self.indexes) != 1:
26
- raise ValueError("if collapse is enabled, must get exactly one index")
27
24
 
28
25
  def forward(
29
26
  self,
30
- features: list[torch.Tensor],
31
- inputs: list[dict[str, Any]] | None = None,
32
- targets: list[dict[str, Any]] | None = None,
33
- ) -> list[torch.Tensor]:
27
+ intermediates: Any,
28
+ context: ModelContext,
29
+ ) -> FeatureMaps:
34
30
  """Pick a subset of the features.
35
31
 
36
32
  Args:
37
- features: input features
38
- inputs: raw inputs, not used
39
- targets: targets, not used
33
+ intermediates: the output from the previous component, which must be a FeatureMaps.
34
+ context: the model context.
40
35
  """
41
- new_features = [features[idx] for idx in self.indexes]
42
- if self.collapse:
43
- assert len(new_features) == 1
44
- return new_features[0]
45
- else:
46
- return new_features
36
+ if not isinstance(intermediates, FeatureMaps):
37
+ raise ValueError("input to PickFeatures must be FeatureMaps")
38
+
39
+ new_features = [intermediates.feature_maps[idx] for idx in self.indexes]
40
+ return FeatureMaps(new_features)
@@ -4,8 +4,16 @@ from typing import Any
4
4
 
5
5
  import torch
6
6
 
7
+ from rslearn.train.model_context import ModelContext
7
8
 
8
- class PoolingDecoder(torch.nn.Module):
9
+ from .component import (
10
+ FeatureMaps,
11
+ FeatureVector,
12
+ IntermediateComponent,
13
+ )
14
+
15
+
16
+ class PoolingDecoder(IntermediateComponent):
9
17
  """Decoder that computes flat vector from a 2D feature map.
10
18
 
11
19
  It inputs multi-scale features, but only uses the last feature map. Then applies a
@@ -57,25 +65,26 @@ class PoolingDecoder(torch.nn.Module):
57
65
 
58
66
  self.output_layer = torch.nn.Linear(prev_channels, out_channels)
59
67
 
60
- def forward(
61
- self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
62
- ) -> torch.Tensor:
68
+ def forward(self, intermediates: Any, context: ModelContext) -> Any:
63
69
  """Compute flat output vector from multi-scale feature map.
64
70
 
65
71
  Args:
66
- features: list of feature maps at different resolutions.
67
- inputs: original inputs (ignored).
72
+ intermediates: the output from the previous component, which must be a FeatureMaps.
73
+ context: the model context.
68
74
 
69
75
  Returns:
70
76
  flat feature vector
71
77
  """
78
+ if not isinstance(intermediates, FeatureMaps):
79
+ raise ValueError("input to PoolingDecoder must be a FeatureMaps")
80
+
72
81
  # Only use last feature map.
73
- features = features[-1]
82
+ features = intermediates.feature_maps[-1]
74
83
 
75
84
  features = self.conv_layers(features)
76
85
  features = torch.amax(features, dim=(2, 3))
77
86
  features = self.fc_layers(features)
78
- return self.output_layer(features)
87
+ return FeatureVector(self.output_layer(features))
79
88
 
80
89
 
81
90
  class SegmentationPoolingDecoder(PoolingDecoder):
@@ -108,14 +117,13 @@ class SegmentationPoolingDecoder(PoolingDecoder):
108
117
  super().__init__(in_channels=in_channels, out_channels=out_channels, **kwargs)
109
118
  self.image_key = image_key
110
119
 
111
- def forward(
112
- self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
113
- ) -> torch.Tensor:
120
+ def forward(self, intermediates: Any, context: ModelContext) -> Any:
114
121
  """Extend PoolingDecoder forward to upsample the output to a segmentation mask.
115
122
 
116
123
  This only works when all of the pixels have the same segmentation target.
117
124
  """
118
- output_probs = super().forward(features, inputs)
125
+ output_probs = super().forward(intermediates, context)
119
126
  # BC -> BCHW
120
- h, w = inputs[0][self.image_key].shape[1:3]
121
- return output_probs[:, :, None, None].repeat([1, 1, h, w])
127
+ h, w = context.inputs[0][self.image_key].shape[1:3]
128
+ feat_map = output_probs.feature_vector[:, :, None, None].repeat([1, 1, h, w])
129
+ return FeatureMaps([feat_map])
@@ -2,14 +2,13 @@
2
2
 
3
3
  import logging
4
4
  import tempfile
5
- from typing import Any
6
5
 
7
6
  import torch
8
7
  from einops import rearrange, repeat
9
8
  from huggingface_hub import hf_hub_download
10
- from torch import nn
11
9
  from upath import UPath
12
10
 
11
+ from rslearn.models.component import FeatureExtractor, FeatureMaps
13
12
  from rslearn.models.presto.single_file_presto import (
14
13
  ERA5_BANDS,
15
14
  NUM_DYNAMIC_WORLD_CLASSES,
@@ -21,6 +20,7 @@ from rslearn.models.presto.single_file_presto import (
21
20
  SRTM_BANDS,
22
21
  )
23
22
  from rslearn.models.presto.single_file_presto import Presto as SFPresto
23
+ from rslearn.train.model_context import ModelContext
24
24
 
25
25
  logger = logging.getLogger(__name__)
26
26
 
@@ -36,7 +36,7 @@ HF_HUB_ID = "nasaharvest/presto"
36
36
  MODEL_FILENAME = "default_model.pt"
37
37
 
38
38
 
39
- class Presto(nn.Module):
39
+ class Presto(FeatureExtractor):
40
40
  """Presto."""
41
41
 
42
42
  input_keys = [
@@ -184,22 +184,26 @@ class Presto(nn.Module):
184
184
  x = (x + PRESTO_ADD_BY.to(device=device)) / PRESTO_DIV_BY.to(device=device)
185
185
  return x, mask, dynamic_world.long(), months.long()
186
186
 
187
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
187
+ def forward(self, context: ModelContext) -> FeatureMaps:
188
188
  """Compute feature maps from the Presto backbone.
189
189
 
190
- Inputs:
191
- inputs
190
+ Args:
191
+ context: the model context. Input dicts should have some subset of Presto.input_keys.
192
+
193
+ Returns:
194
+ a FeatureMaps with one feature map that is at the same resolution as the
195
+ input (since Presto operates per-pixel).
192
196
  """
193
197
  stacked_inputs = {}
194
198
  latlons: torch.Tensor | None = None
195
- for key in inputs[0].keys():
199
+ for key in context.inputs[0].keys():
196
200
  # assume all the keys in an input are consistent
197
201
  if key in self.input_keys:
198
202
  if key == "latlon":
199
- latlons = torch.stack([inp[key] for inp in inputs], dim=0)
203
+ latlons = torch.stack([inp[key] for inp in context.inputs], dim=0)
200
204
  else:
201
205
  stacked_inputs[key] = torch.stack(
202
- [inp[key] for inp in inputs], dim=0
206
+ [inp[key] for inp in context.inputs], dim=0
203
207
  )
204
208
 
205
209
  (
@@ -247,7 +251,9 @@ class Presto(nn.Module):
247
251
  )
248
252
  output_features[batch_idx : batch_idx + self.pixel_batch_size] = output_b
249
253
 
250
- return [rearrange(output_features, "(b h w) d -> b d h w", h=h, w=w, b=b)]
254
+ return FeatureMaps(
255
+ [rearrange(output_features, "(b h w) d -> b d h w", h=h, w=w, b=b)]
256
+ )
251
257
 
252
258
  def get_backbone_channels(self) -> list:
253
259
  """Returns the output channels of this model when used as a backbone.
@@ -281,10 +281,7 @@ def get_sinusoid_encoding_table(
281
281
  sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
282
282
  sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
283
283
 
284
- if torch.cuda.is_available():
285
- return torch.FloatTensor(sinusoid_table).cuda()
286
- else:
287
- return torch.FloatTensor(sinusoid_table)
284
+ return torch.FloatTensor(sinusoid_table)
288
285
 
289
286
 
290
287
  def get_month_encoding_table(d_hid: int) -> torch.Tensor:
@@ -296,10 +293,7 @@ def get_month_encoding_table(d_hid: int) -> torch.Tensor:
296
293
  cos_table = np.cos(np.stack([angles for _ in range(d_hid // 2)], axis=-1))
297
294
  month_table = np.concatenate([sin_table[:-1], cos_table[:-1]], axis=-1)
298
295
 
299
- if torch.cuda.is_available():
300
- return torch.FloatTensor(month_table).cuda()
301
- else:
302
- return torch.FloatTensor(month_table)
296
+ return torch.FloatTensor(month_table)
303
297
 
304
298
 
305
299
  def month_to_tensor(
@@ -405,7 +399,7 @@ class Encoder(nn.Module):
405
399
  """initialize_weights."""
406
400
  pos_embed = get_sinusoid_encoding_table(
407
401
  self.pos_embed.shape[1], self.pos_embed.shape[-1]
408
- )
402
+ ).to(device=self.pos_embed.device)
409
403
  self.pos_embed.data.copy_(pos_embed)
410
404
 
411
405
  # initialize nn.Linear and nn.LayerNorm
@@ -640,7 +634,7 @@ class Decoder(nn.Module):
640
634
  """initialize_weights."""
641
635
  pos_embed = get_sinusoid_encoding_table(
642
636
  self.pos_embed.shape[1], self.pos_embed.shape[-1]
643
- )
637
+ ).to(device=self.pos_embed.device)
644
638
  self.pos_embed.data.copy_(pos_embed)
645
639
 
646
640
  # initialize nn.Linear and nn.LayerNorm
rslearn/models/prithvi.py CHANGED
@@ -25,9 +25,12 @@ from timm.layers import to_2tuple
25
25
  from timm.models.vision_transformer import Block
26
26
  from torch.nn import functional as F
27
27
 
28
+ from rslearn.train.model_context import ModelContext
28
29
  from rslearn.train.transforms.normalize import Normalize
29
30
  from rslearn.train.transforms.transform import Transform
30
31
 
32
+ from .component import FeatureExtractor, FeatureMaps
33
+
31
34
  logger = logging.getLogger(__name__)
32
35
 
33
36
 
@@ -77,7 +80,7 @@ def get_config(cache_dir: Path, hf_hub_id: str, hf_hub_revision: str) -> dict[st
77
80
  return json.load(f)["pretrained_cfg"]
78
81
 
79
82
 
80
- class PrithviV2(nn.Module):
83
+ class PrithviV2(FeatureExtractor):
81
84
  """An Rslearn wrapper for Prithvi 2.0."""
82
85
 
83
86
  INPUT_KEY = "image"
@@ -157,18 +160,18 @@ class PrithviV2(nn.Module):
157
160
  )
158
161
  return data
159
162
 
160
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
163
+ def forward(self, context: ModelContext) -> FeatureMaps:
161
164
  """Compute feature maps from the Prithvi V2 backbone.
162
165
 
163
166
  Args:
164
- inputs: input dicts that must include "image" key containing HLS
165
- (Harmonized Landsat-Sentinel) data.
167
+ context: the model context. Input dicts must include "image" key containing
168
+ HLS (Harmonized Landsat-Sentinel) data.
166
169
 
167
170
  Returns:
168
- 11 feature maps (one per transformer block in the Prithvi model),
169
- of shape [B, H/p_s, W/p_s, D=1024] where p_s=16 is the patch size.
171
+ a FeatureMaps with one map of shape [B, H/p_s, W/p_s, 11*1024] that contains stacked
172
+ feature maps across the 11 transformer blocks.
170
173
  """
171
- x = torch.stack([inp[self.INPUT_KEY] for inp in inputs], dim=0)
174
+ x = torch.stack([inp[self.INPUT_KEY] for inp in context.inputs], dim=0)
172
175
  x = self._resize_data(x)
173
176
  num_timesteps = x.shape[1] // len(self.bands)
174
177
  x = rearrange(x, "b (t c) h w -> b c t h w", t=num_timesteps)
@@ -177,9 +180,10 @@ class PrithviV2(nn.Module):
177
180
  # know the number of timesteps and don't need to recompute it.
178
181
  # in addition we average along the time dimension (instead of concatenating)
179
182
  # to keep the embeddings reasonably sized.
180
- return self.model.encoder.prepare_features_for_image_model(
183
+ result = self.model.encoder.prepare_features_for_image_model(
181
184
  features, num_timesteps
182
185
  )
186
+ return FeatureMaps([torch.cat(result, dim=1)])
183
187
 
184
188
  def get_backbone_channels(self) -> list:
185
189
  """Returns the output channels of this model when used as a backbone.
@@ -1,9 +1,18 @@
1
1
  """The ResizeFeatures module."""
2
2
 
3
+ from typing import Any
4
+
3
5
  import torch
4
6
 
7
+ from rslearn.train.model_context import ModelContext
8
+
9
+ from .component import (
10
+ FeatureMaps,
11
+ IntermediateComponent,
12
+ )
13
+
5
14
 
6
- class ResizeFeatures(torch.nn.Module):
15
+ class ResizeFeatures(IntermediateComponent):
7
16
  """Resize input features to new sizes."""
8
17
 
9
18
  def __init__(
@@ -30,16 +39,21 @@ class ResizeFeatures(torch.nn.Module):
30
39
  )
31
40
  self.layers = torch.nn.ModuleList(layers)
32
41
 
33
- def forward(
34
- self, features: list[torch.Tensor], inputs: list[torch.Tensor]
35
- ) -> list[torch.Tensor]:
42
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
36
43
  """Resize the input feature maps to new sizes.
37
44
 
38
45
  Args:
39
- features: list of feature maps at different resolutions.
40
- inputs: original inputs (ignored).
46
+ intermediates: the outputs from the previous component, which must be a FeatureMaps.
47
+ context: the model context.
41
48
 
42
49
  Returns:
43
50
  resized feature maps
44
51
  """
45
- return [self.layers[idx](feat_map) for idx, feat_map in enumerate(features)]
52
+ if not isinstance(intermediates, FeatureMaps):
53
+ raise ValueError("input to ResizeFeatures must be a FeatureMaps")
54
+
55
+ feat_maps = intermediates.feature_maps
56
+ resized_feat_maps = [
57
+ self.layers[idx](feat_map) for idx, feat_map in enumerate(feat_maps)
58
+ ]
59
+ return FeatureMaps(resized_feat_maps)
@@ -1,14 +1,15 @@
1
1
  """SegmentAnything2 encoders."""
2
2
 
3
- from typing import Any
4
-
5
3
  import torch
6
- import torch.nn as nn
7
4
  from sam2.build_sam import build_sam2
8
5
  from upath import UPath
9
6
 
7
+ from rslearn.train.model_context import ModelContext
8
+
9
+ from .component import FeatureExtractor, FeatureMaps
10
+
10
11
 
11
- class SAM2Encoder(nn.Module):
12
+ class SAM2Encoder(FeatureExtractor):
12
13
  """SAM2's image encoder."""
13
14
 
14
15
  def __init__(self, model_identifier: str) -> None:
@@ -84,18 +85,19 @@ class SAM2Encoder(nn.Module):
84
85
  del self.model.obj_ptr_proj
85
86
  del self.model.image_encoder.neck
86
87
 
87
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
88
+ def forward(self, context: ModelContext) -> FeatureMaps:
88
89
  """Extract multi-scale features from a batch of images.
89
90
 
90
91
  Args:
91
- inputs: List of dictionaries, each containing the input image under the key 'image'.
92
+ context: the model context. Input dicts must have a key 'image' containing
93
+ the input for the SAM2 image encoder.
92
94
 
93
95
  Returns:
94
- List[torch.Tensor]: Multi-scale feature tensors from the encoder.
96
+ feature maps from the encoder.
95
97
  """
96
- images = torch.stack([inp["image"] for inp in inputs], dim=0)
98
+ images = torch.stack([inp["image"] for inp in context.inputs], dim=0)
97
99
  features = self.encoder(images)
98
- return features
100
+ return FeatureMaps(features)
99
101
 
100
102
  def get_backbone_channels(self) -> list[list[int]]:
101
103
  """Returns the output channels of the encoder at different scales.
@@ -1,13 +1,15 @@
1
1
  """SatlasPretrain models."""
2
2
 
3
- from typing import Any
4
-
5
3
  import satlaspretrain_models
6
4
  import torch
7
5
  import torch.nn.functional as F
8
6
 
7
+ from rslearn.train.model_context import ModelContext
8
+
9
+ from .component import FeatureExtractor, FeatureMaps
10
+
9
11
 
10
- class SatlasPretrain(torch.nn.Module):
12
+ class SatlasPretrain(FeatureExtractor):
11
13
  """SatlasPretrain backbones."""
12
14
 
13
15
  def __init__(
@@ -64,15 +66,19 @@ class SatlasPretrain(torch.nn.Module):
64
66
  else:
65
67
  return data
66
68
 
67
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
69
+ def forward(self, context: ModelContext) -> FeatureMaps:
68
70
  """Compute feature maps from the SatlasPretrain backbone.
69
71
 
70
- Inputs:
71
- inputs: input dicts that must include "image" key containing the image to
72
- process.
72
+ Args:
73
+ context: the model context. Input dicts must contain an "image" key
74
+ containing the image input to the model.
75
+
76
+ Returns:
77
+ multi-resolution feature maps computed by the model.
73
78
  """
74
- images = torch.stack([inp["image"] for inp in inputs], dim=0)
75
- return self.model(self.maybe_resize(images))
79
+ images = torch.stack([inp["image"] for inp in context.inputs], dim=0)
80
+ feature_maps = self.model(self.maybe_resize(images))
81
+ return FeatureMaps(feature_maps)
76
82
 
77
83
  def get_backbone_channels(self) -> list:
78
84
  """Returns the output channels of this model when used as a backbone.
@@ -4,11 +4,15 @@ from typing import Any
4
4
 
5
5
  import torch
6
6
 
7
+ from rslearn.train.model_context import ModelContext
7
8
 
8
- class SimpleTimeSeries(torch.nn.Module):
9
- """SimpleTimeSeries wraps another encoder and applies it on an image time series.
9
+ from .component import FeatureExtractor, FeatureMaps
10
10
 
11
- It independently applies the other encoder on each image in the time series to
11
+
12
+ class SimpleTimeSeries(FeatureExtractor):
13
+ """SimpleTimeSeries wraps another FeatureExtractor and applies it on an image time series.
14
+
15
+ It independently applies the other FeatureExtractor on each image in the time series to
12
16
  extract feature maps. It then provides a few ways to combine the features into one
13
17
  final feature map:
14
18
  - Temporal max pooling.
@@ -19,7 +23,7 @@ class SimpleTimeSeries(torch.nn.Module):
19
23
 
20
24
  def __init__(
21
25
  self,
22
- encoder: torch.nn.Module,
26
+ encoder: FeatureExtractor,
23
27
  image_channels: int | None = None,
24
28
  op: str = "max",
25
29
  groups: list[list[int]] | None = None,
@@ -31,9 +35,9 @@ class SimpleTimeSeries(torch.nn.Module):
31
35
  """Create a new SimpleTimeSeries.
32
36
 
33
37
  Args:
34
- encoder: the underlying encoder. It must provide get_backbone_channels
35
- function that returns the output channels, or backbone_channels must be
36
- set.
38
+ encoder: the underlying FeatureExtractor. It must provide get_backbone_channels
39
+ function that returns the output channels, or backbone_channels must be set.
40
+ It must output a FeatureMaps.
37
41
  image_channels: the number of channels per image of the time series. The
38
42
  input should have multiple images concatenated on the channel axis, so
39
43
  this parameter is used to distinguish the different images.
@@ -179,24 +183,27 @@ class SimpleTimeSeries(torch.nn.Module):
179
183
 
180
184
  def forward(
181
185
  self,
182
- inputs: list[dict[str, Any]],
183
- ) -> list[torch.Tensor]:
186
+ context: ModelContext,
187
+ ) -> FeatureMaps:
184
188
  """Compute outputs from the backbone.
185
189
 
186
- Inputs:
187
- inputs: input dicts that must include "image" key containing the image time
190
+ Args:
191
+ context: the model context. Input dicts must include "image" key containing the image time
188
192
  series to process (with images concatenated on the channel dimension).
193
+
194
+ Returns:
195
+ the FeatureMaps aggregated temporally.
189
196
  """
190
197
  # First get features of each image.
191
198
  # To do so, we need to split up each grouped image into its component images (which have had their channels stacked).
192
199
  batched_inputs: list[dict[str, Any]] | None = None
193
- n_batch = len(inputs)
200
+ n_batch = len(context.inputs)
194
201
  n_images: int | None = None
195
202
 
196
203
  if self.image_keys is not None:
197
204
  for image_key, image_channels in self.image_keys.items():
198
205
  batched_images = self._get_batched_images(
199
- inputs, image_key, image_channels
206
+ context.inputs, image_key, image_channels
200
207
  )
201
208
 
202
209
  if batched_inputs is None:
@@ -213,13 +220,26 @@ class SimpleTimeSeries(torch.nn.Module):
213
220
  else:
214
221
  assert self.image_channels is not None
215
222
  batched_images = self._get_batched_images(
216
- inputs, self.image_key, self.image_channels
223
+ context.inputs, self.image_key, self.image_channels
217
224
  )
218
225
  batched_inputs = [{self.image_key: image} for image in batched_images]
219
226
  n_images = batched_images.shape[0] // n_batch
220
227
 
221
228
  assert n_images is not None
222
229
 
230
+ # Now we can apply the underlying FeatureExtractor.
231
+ # Its output must be a FeatureMaps.
232
+ assert batched_inputs is not None
233
+ encoder_output = self.encoder(
234
+ ModelContext(
235
+ inputs=batched_inputs,
236
+ metadatas=context.metadatas,
237
+ )
238
+ )
239
+ if not isinstance(encoder_output, FeatureMaps):
240
+ raise ValueError(
241
+ "output of underlying FeatureExtractor in SimpleTimeSeries must be a FeatureMaps"
242
+ )
223
243
  all_features = [
224
244
  feat_map.reshape(
225
245
  n_batch,
@@ -228,7 +248,7 @@ class SimpleTimeSeries(torch.nn.Module):
228
248
  feat_map.shape[2],
229
249
  feat_map.shape[3],
230
250
  )
231
- for feat_map in self.encoder(batched_inputs)
251
+ for feat_map in encoder_output.feature_maps
232
252
  ]
233
253
 
234
254
  # Groups defaults to flattening all the feature maps.
@@ -284,7 +304,7 @@ class SimpleTimeSeries(torch.nn.Module):
284
304
  .permute(0, 3, 1, 2)
285
305
  )
286
306
  else:
287
- raise Exception(f"unknown aggregation op {self.op}")
307
+ raise ValueError(f"unknown aggregation op {self.op}")
288
308
 
289
309
  aggregated_features.append(group_features)
290
310
 
@@ -293,4 +313,4 @@ class SimpleTimeSeries(torch.nn.Module):
293
313
 
294
314
  output_features.append(aggregated_features)
295
315
 
296
- return output_features
316
+ return FeatureMaps(output_features)
@@ -4,6 +4,10 @@ from typing import Any
4
4
 
5
5
  import torch
6
6
 
7
+ from rslearn.train.model_context import ModelContext, ModelOutput
8
+
9
+ from .component import FeatureExtractor, IntermediateComponent, Predictor
10
+
7
11
 
8
12
  class SingleTaskModel(torch.nn.Module):
9
13
  """Standard model wrapper.
@@ -14,38 +18,41 @@ class SingleTaskModel(torch.nn.Module):
14
18
  outputs and targets from the last module (which also receives the targets).
15
19
  """
16
20
 
17
- def __init__(self, encoder: list[torch.nn.Module], decoder: list[torch.nn.Module]):
21
+ def __init__(
22
+ self,
23
+ encoder: list[FeatureExtractor | IntermediateComponent],
24
+ decoder: list[IntermediateComponent | Predictor],
25
+ ):
18
26
  """Initialize a new SingleTaskModel.
19
27
 
20
28
  Args:
21
- encoder: modules to compute intermediate feature representations.
22
- decoder: modules to compute outputs and loss.
29
+ encoder: modules to compute intermediate feature representations. The first
30
+ module must be a FeatureExtractor, and following modules must be
31
+ IntermediateComponents.
32
+ decoder: modules to compute outputs and loss. The last module must be a
33
+ Predictor, while the previous modules must be IntermediateComponents.
23
34
  """
24
35
  super().__init__()
25
- self.encoder = torch.nn.Sequential(*encoder)
36
+ self.encoder = torch.nn.ModuleList(encoder)
26
37
  self.decoder = torch.nn.ModuleList(decoder)
27
38
 
28
39
  def forward(
29
40
  self,
30
- inputs: list[dict[str, Any]],
41
+ context: ModelContext,
31
42
  targets: list[dict[str, Any]] | None = None,
32
- ) -> dict[str, Any]:
43
+ ) -> ModelOutput:
33
44
  """Apply the sequence of modules on the inputs.
34
45
 
35
46
  Args:
36
- inputs: list of input dicts
47
+ context: the model context.
37
48
  targets: optional list of target dicts
38
- info: optional dictionary of info to pass to the last module
39
49
 
40
50
  Returns:
41
- dict with keys "outputs" and "loss_dict".
51
+ the model output.
42
52
  """
43
- features = self.encoder(inputs)
44
- cur = features
53
+ cur = self.encoder[0](context)
54
+ for module in self.encoder[1:]:
55
+ cur = module(cur, context)
45
56
  for module in self.decoder[:-1]:
46
- cur = module(cur, inputs)
47
- outputs, loss_dict = self.decoder[-1](cur, inputs, targets)
48
- return {
49
- "outputs": outputs,
50
- "loss_dict": loss_dict,
51
- }
57
+ cur = module(cur, context)
58
+ return self.decoder[-1](cur, context, targets)