rslearn 0.0.17__py3-none-any.whl → 0.0.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. rslearn/config/__init__.py +2 -0
  2. rslearn/config/dataset.py +49 -4
  3. rslearn/dataset/add_windows.py +1 -1
  4. rslearn/dataset/dataset.py +9 -65
  5. rslearn/dataset/materialize.py +5 -5
  6. rslearn/dataset/storage/__init__.py +1 -0
  7. rslearn/dataset/storage/file.py +202 -0
  8. rslearn/dataset/storage/storage.py +140 -0
  9. rslearn/dataset/window.py +26 -80
  10. rslearn/main.py +11 -36
  11. rslearn/models/anysat.py +11 -9
  12. rslearn/models/clay/clay.py +8 -9
  13. rslearn/models/clip.py +18 -15
  14. rslearn/models/component.py +99 -0
  15. rslearn/models/concatenate_features.py +21 -11
  16. rslearn/models/conv.py +15 -8
  17. rslearn/models/croma.py +13 -8
  18. rslearn/models/detr/detr.py +25 -14
  19. rslearn/models/dinov3.py +11 -6
  20. rslearn/models/faster_rcnn.py +19 -9
  21. rslearn/models/feature_center_crop.py +12 -9
  22. rslearn/models/fpn.py +19 -8
  23. rslearn/models/galileo/galileo.py +23 -18
  24. rslearn/models/module_wrapper.py +26 -57
  25. rslearn/models/molmo.py +16 -14
  26. rslearn/models/multitask.py +102 -73
  27. rslearn/models/olmoearth_pretrain/model.py +18 -12
  28. rslearn/models/panopticon.py +8 -7
  29. rslearn/models/pick_features.py +18 -24
  30. rslearn/models/pooling_decoder.py +22 -14
  31. rslearn/models/presto/presto.py +16 -10
  32. rslearn/models/presto/single_file_presto.py +4 -10
  33. rslearn/models/prithvi.py +12 -8
  34. rslearn/models/resize_features.py +21 -7
  35. rslearn/models/sam2_enc.py +11 -9
  36. rslearn/models/satlaspretrain.py +15 -9
  37. rslearn/models/simple_time_series.py +31 -17
  38. rslearn/models/singletask.py +24 -17
  39. rslearn/models/ssl4eo_s12.py +15 -10
  40. rslearn/models/swin.py +22 -13
  41. rslearn/models/terramind.py +24 -7
  42. rslearn/models/trunk.py +6 -3
  43. rslearn/models/unet.py +18 -9
  44. rslearn/models/upsample.py +22 -9
  45. rslearn/train/all_patches_dataset.py +22 -18
  46. rslearn/train/dataset.py +69 -54
  47. rslearn/train/lightning_module.py +51 -32
  48. rslearn/train/model_context.py +54 -0
  49. rslearn/train/prediction_writer.py +111 -41
  50. rslearn/train/tasks/classification.py +34 -15
  51. rslearn/train/tasks/detection.py +24 -31
  52. rslearn/train/tasks/embedding.py +33 -29
  53. rslearn/train/tasks/multi_task.py +7 -7
  54. rslearn/train/tasks/per_pixel_regression.py +41 -19
  55. rslearn/train/tasks/regression.py +38 -21
  56. rslearn/train/tasks/segmentation.py +33 -15
  57. rslearn/train/tasks/task.py +3 -2
  58. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/METADATA +1 -1
  59. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/RECORD +64 -61
  60. rslearn/dataset/index.py +0 -173
  61. rslearn/models/registry.py +0 -22
  62. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/WHEEL +0 -0
  63. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/entry_points.txt +0 -0
  64. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/licenses/LICENSE +0 -0
  65. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/licenses/NOTICE +0 -0
  66. {rslearn-0.0.17.dist-info → rslearn-0.0.18.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  """MultiTaskModel for rslearn."""
2
2
 
3
+ from collections.abc import Iterable
3
4
  from copy import deepcopy
4
5
  from typing import Any
5
6
 
@@ -7,6 +8,9 @@ import torch
7
8
 
8
9
  from rslearn.log_utils import get_logger
9
10
  from rslearn.models.trunk import DecoderTrunk
11
+ from rslearn.train.model_context import ModelContext, ModelOutput
12
+
13
+ from .component import FeatureExtractor, IntermediateComponent, Predictor
10
14
 
11
15
  logger = get_logger(__name__)
12
16
 
@@ -58,8 +62,8 @@ class MultiTaskModel(torch.nn.Module):
58
62
 
59
63
  def __init__(
60
64
  self,
61
- encoder: list[torch.nn.Module],
62
- decoders: dict[str, list[torch.nn.Module]],
65
+ encoder: list[FeatureExtractor | IntermediateComponent],
66
+ decoders: dict[str, list[IntermediateComponent | Predictor]],
63
67
  lazy_decode: bool = False,
64
68
  loss_weights: dict[str, float] | None = None,
65
69
  trunk: DecoderTrunk | None = None,
@@ -67,8 +71,12 @@ class MultiTaskModel(torch.nn.Module):
67
71
  """Initialize a new MultiTaskModel.
68
72
 
69
73
  Args:
70
- encoder: modules to compute intermediate feature representations.
74
+ encoder: modules to compute intermediate feature representations. The first
75
+ module must be a FeatureExtractor, and following modules must be
76
+ IntermediateComponents.
71
77
  decoders: modules to compute outputs and loss, should match number of tasks.
78
+ The last module must be a Predictor, while the previous modules must be
79
+ IntermediateComponents.
72
80
  lazy_decode: if True, only decode the outputs specified in the batch.
73
81
  loss_weights: weights for each task's loss (default: None = equal weights).
74
82
  trunk: if provided, use this trunk module to postprocess the features
@@ -76,7 +84,7 @@ class MultiTaskModel(torch.nn.Module):
76
84
  """
77
85
  super().__init__()
78
86
  self.lazy_decode = lazy_decode
79
- self.encoder = torch.nn.Sequential(*encoder)
87
+ self.encoder = torch.nn.ModuleList(encoder)
80
88
  self.decoders = torch.nn.ModuleDict(
81
89
  sort_keys(
82
90
  {
@@ -120,32 +128,28 @@ class MultiTaskModel(torch.nn.Module):
120
128
 
121
129
  def apply_decoder(
122
130
  self,
123
- features: list[torch.Tensor],
124
- inputs: list[dict[str, Any]],
131
+ intermediates: Any,
132
+ context: ModelContext,
125
133
  targets: list[dict[str, Any]] | None,
126
- decoder: list[torch.nn.Module],
134
+ decoder: list[IntermediateComponent | Predictor],
127
135
  task_name: str,
128
- outputs: list[dict[str, Any]],
129
- losses: dict[str, torch.Tensor],
130
- ) -> tuple[list[dict[str, Any]], dict[str, torch.Tensor]]:
136
+ ) -> ModelOutput:
131
137
  """Apply a decoder to a list of inputs and targets.
132
138
 
133
139
  Args:
134
- features: list of features
135
- inputs: list of input dicts
140
+ intermediates: the intermediate output from the encoder.
141
+ context: the model context.
136
142
  targets: list of target dicts
137
143
  decoder: list of decoder modules
138
144
  task_name: the name of the task
139
- outputs: list of output dicts
140
- losses: dictionary of loss values
141
145
 
142
146
  Returns:
143
- tuple of (outputs, losses)
147
+ a ModelOutput containing outputs across all the decoders.
144
148
  """
145
149
  # First, apply all but the last module in the decoder to the features
146
- cur = features
150
+ cur = intermediates
147
151
  for module in decoder[:-1]:
148
- cur = module(cur, inputs)
152
+ cur = module(cur, context)
149
153
 
150
154
  if targets is None:
151
155
  cur_targets = None
@@ -153,14 +157,7 @@ class MultiTaskModel(torch.nn.Module):
153
157
  cur_targets = [target[task_name] for target in targets]
154
158
 
155
159
  # Then, apply the last module to the features and targets
156
- cur_output, cur_loss_dict = decoder[-1](cur, inputs, cur_targets)
157
- for idx, entry in enumerate(cur_output):
158
- outputs[idx][task_name] = entry
159
- for loss_name, loss_value in cur_loss_dict.items():
160
- losses[f"{task_name}_{loss_name}"] = (
161
- loss_value * self.loss_weights[task_name]
162
- )
163
- return outputs, losses
160
+ return decoder[-1](cur, context, cur_targets)
164
161
 
165
162
  def _get_tasks_from_decoder(self, decoder: str) -> list[str]:
166
163
  """Get the tasks corresponding to this decoder.
@@ -172,66 +169,84 @@ class MultiTaskModel(torch.nn.Module):
172
169
 
173
170
  def apply_decoders(
174
171
  self,
175
- features: list[torch.Tensor],
176
- inputs: list[dict[str, Any]],
172
+ intermediates: Any,
173
+ context: ModelContext,
177
174
  targets: list[dict[str, Any]] | None,
178
- ) -> dict[str, Any]:
175
+ ) -> ModelOutput:
179
176
  """Apply all the decoders to the features and targets.
180
177
 
181
178
  Args:
182
- features: list of features
183
- inputs: list of input dicts
179
+ intermediates: the intermediates from the encoder.
180
+ context: the model context
184
181
  targets: list of target dicts
185
182
 
186
183
  Returns:
187
- dict of outputs and losses
184
+ combined ModelOutput. The outputs is a list of output dicts, one per example,
185
+ where the dict maps from task name to the corresponding task output. The
186
+ losses is a flat dict but the task name is prepended to the loss names.
188
187
  """
189
- outputs: list[dict[str, torch.Tensor | dict]] = [{} for _ in inputs]
188
+ outputs: list[dict[str, torch.Tensor | dict]] = [{} for _ in context.inputs]
190
189
  losses: dict[str, torch.Tensor] = {}
191
190
 
192
191
  if self.lazy_decode:
193
192
  # Assume that all inputs have the same dataset_source
194
- dataset_source = inputs[0]["dataset_source"]
195
- decoder = self.decoders[
196
- self.target_to_decoder.get(dataset_source, dataset_source)
197
- ]
198
- self.apply_decoder(
199
- features, inputs, targets, decoder, dataset_source, outputs, losses
193
+ task_name = context.metadatas[0].dataset_source
194
+
195
+ if task_name is None:
196
+ raise ValueError("dataset_source must be set for lazy decoding")
197
+
198
+ decoder = self.decoders[self.target_to_decoder.get(task_name, task_name)]
199
+ model_output = self.apply_decoder(
200
+ intermediates, context, targets, decoder, task_name
200
201
  )
202
+ for idx, entry in enumerate(model_output.outputs):
203
+ outputs[idx][task_name] = entry
204
+ for loss_name, loss_value in model_output.loss_dict.items():
205
+ losses[f"{task_name}_{loss_name}"] = (
206
+ loss_value * self.loss_weights[task_name]
207
+ )
201
208
  else:
202
209
  for decoder_name, decoder in self.decoders.items():
203
210
  for task_name in self._get_tasks_from_decoder(decoder_name):
204
- self.apply_decoder(
205
- features, inputs, targets, decoder, task_name, outputs, losses
211
+ model_output = self.apply_decoder(
212
+ intermediates, context, targets, decoder, task_name
206
213
  )
207
-
208
- return {
209
- "outputs": outputs,
210
- "loss_dict": losses,
211
- }
214
+ for idx, entry in enumerate(model_output.outputs):
215
+ outputs[idx][task_name] = entry
216
+ for loss_name, loss_value in model_output.loss_dict.items():
217
+ losses[f"{task_name}_{loss_name}"] = (
218
+ loss_value * self.loss_weights[task_name]
219
+ )
220
+
221
+ return ModelOutput(
222
+ outputs=outputs,
223
+ loss_dict=losses,
224
+ )
212
225
 
213
226
  def forward(
214
227
  self,
215
- inputs: list[dict[str, Any]],
228
+ context: ModelContext,
216
229
  targets: list[dict[str, Any]] | None = None,
217
- ) -> dict[str, Any]:
230
+ ) -> ModelOutput:
218
231
  """Apply the sequence of modules on the inputs, including shared trunk.
219
232
 
220
233
  Args:
221
- inputs: list of input dicts
234
+ context: the model context.
222
235
  targets: optional list of target dicts
223
236
 
224
237
  Returns:
225
- dict with keys "outputs" and "loss_dict".
238
+ the model output from apply_decoders.
226
239
  """
227
- features = self.encoder(inputs)
240
+ cur = self.encoder[0](context)
241
+ for module in self.encoder[1:]:
242
+ cur = module(cur, context)
228
243
  if self.trunk is not None:
229
- trunk_out = self.trunk(features, inputs)
230
- outs = self.apply_decoders(trunk_out.pop("outputs"), inputs, targets)
244
+ trunk_out = self.trunk(cur, context)
245
+ outs = self.apply_decoders(trunk_out.pop("outputs"), context, targets)
231
246
  self.trunk.apply_auxiliary_losses(trunk_out, outs)
232
247
  return outs | trunk_out
233
248
  else:
234
- return self.apply_decoders(features, inputs, targets)
249
+ return self.apply_decoders(cur, context, targets)
235
250
 
236
251
 
237
252
  class MultiTaskMergedModel(MultiTaskModel):
@@ -247,8 +262,8 @@ class MultiTaskMergedModel(MultiTaskModel):
247
262
 
248
263
  def __init__(
249
264
  self,
250
- encoder: list[torch.nn.Module],
251
- decoders: dict[str, list[torch.nn.Module]],
265
+ encoder: list[FeatureExtractor | IntermediateComponent],
266
+ decoders: dict[str, list[IntermediateComponent | Predictor]],
252
267
  decoder_to_target: dict[str, list[str]],
253
268
  task_label_offsets: dict[str, dict[str, Any]],
254
269
  lazy_decode: bool = False,
@@ -273,7 +288,7 @@ class MultiTaskMergedModel(MultiTaskModel):
273
288
  torch.nn.Module.__init__(self)
274
289
 
275
290
  self.lazy_decode = lazy_decode
276
- self.encoder = torch.nn.Sequential(*encoder)
291
+ self.encoder = torch.nn.ModuleList(encoder)
277
292
  self.decoders = torch.nn.ModuleDict(
278
293
  sort_keys(
279
294
  {
@@ -329,9 +344,9 @@ class MultiTaskMergedModel(MultiTaskModel):
329
344
  return offset_targets
330
345
 
331
346
  def unmerge_output_labels(
332
- self, outputs: list[dict[str, torch.Tensor | dict]], task_name: str
333
- ) -> None:
334
- """Unmerge the task labels in place.
347
+ self, outputs: Iterable[Any], task_name: str
348
+ ) -> list[dict[str, torch.Tensor | dict]]:
349
+ """Unmerge the task outputs.
335
350
 
336
351
  For most tasks, this means chopping off the corresponding label dimensions.
337
352
  For some, we might just need to subtract an offset from the target (ex: segmentation).
@@ -340,10 +355,15 @@ class MultiTaskMergedModel(MultiTaskModel):
340
355
  Args:
341
356
  outputs: the predictions
342
357
  task_name: the name of the task
358
+
359
+ Returns:
360
+ the unmerged outputs.
343
361
  """
344
362
  offset = self.task_label_offsets[task_name]["offset"]
345
363
  num_outputs = self.task_label_offsets[task_name]["num_outputs"]
346
364
  output_key = self.task_label_offsets[task_name]["outputs_key"]
365
+
366
+ unmerged_outputs: list[dict[str, torch.Tensor | dict]] = [{} for _ in outputs]
347
367
  with torch.no_grad():
348
368
  for i, output in enumerate(outputs):
349
369
  if not output:
@@ -353,35 +373,44 @@ class MultiTaskMergedModel(MultiTaskModel):
353
373
  if isinstance(output, dict):
354
374
  # For some tasks (eg object detection), we have discrete label
355
375
  # predictions instead of a distribution over labels
356
- output[output_key] -= offset
376
+ unmerged_output = output.copy()
377
+ unmerged_output[output_key] = unmerged_output[output_key] - offset
378
+ unmerged_outputs[i][task_name] = unmerged_output
357
379
  elif isinstance(output, torch.Tensor):
358
380
  # For classification/segmentation tasks, we have a distribution
359
381
  # over labels, so we need to scale the predictions so that they
360
382
  # sum to 1 since we chop off some of the probability densities
361
- tensor: torch.Tensor = output[offset : offset + num_outputs, ...]
362
- tensor /= tensor.sum(dim=0, keepdim=True).type(torch.float32)
363
- outputs[i][task_name] = tensor
383
+ unmerged_output = output[offset : offset + num_outputs, ...]
384
+ unmerged_output /= unmerged_output.sum(dim=0, keepdim=True).type(
385
+ torch.float32
386
+ )
387
+ unmerged_outputs[i][task_name] = unmerged_output
388
+
389
+ return unmerged_outputs
364
390
 
365
391
  def forward(
366
392
  self,
367
- inputs: list[dict[str, Any]],
393
+ context: ModelContext,
368
394
  targets: list[dict[str, Any]] | None = None,
369
- ) -> dict[str, Any]:
395
+ ) -> ModelOutput:
370
396
  """Apply the sequence of modules on the inputs.
371
397
 
372
398
  Args:
373
- inputs: list of input dicts
399
+ context: the model context.
374
400
  targets: optional list of target dicts
375
401
 
376
402
  Returns:
377
- dict with keys "outputs" and "loss_dict", and possibly other keys.
403
+ the model output.
378
404
  """
379
- dataset_source = inputs[0].get("dataset_source", None)
380
- assert isinstance(dataset_source, str)
381
- targets = self.merge_task_labels(targets, dataset_source)
382
- outs = super().forward(inputs, targets)
383
- self.unmerge_output_labels(outs["outputs"], dataset_source)
384
- return outs
405
+ dataset_source = context.metadatas[0].dataset_source
406
+ assert dataset_source is not None
407
+ merged_targets = self.merge_task_labels(targets, dataset_source)
408
+ outs = super().forward(context, merged_targets)
409
+ unmerged_outputs = self.unmerge_output_labels(outs.outputs, dataset_source)
410
+ return ModelOutput(
411
+ outputs=unmerged_outputs,
412
+ loss_dict=outs.loss_dict,
413
+ )
385
414
 
386
415
  def _get_tasks_from_decoder(self, decoder: str) -> list[str]:
387
416
  """Get the tasks corresponding to this decoder.
@@ -19,6 +19,8 @@ from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
19
19
  from upath import UPath
20
20
 
21
21
  from rslearn.log_utils import get_logger
22
+ from rslearn.models.component import FeatureExtractor, FeatureMaps
23
+ from rslearn.train.model_context import ModelContext
22
24
 
23
25
  logger = get_logger(__name__)
24
26
 
@@ -44,7 +46,7 @@ EMBEDDING_SIZES = {
44
46
  }
45
47
 
46
48
 
47
- class OlmoEarth(torch.nn.Module):
49
+ class OlmoEarth(FeatureExtractor):
48
50
  """A wrapper to support the OlmoEarth model."""
49
51
 
50
52
  def __init__(
@@ -158,12 +160,16 @@ class OlmoEarth(torch.nn.Module):
158
160
 
159
161
  return model
160
162
 
161
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
163
+ def forward(self, context: ModelContext) -> FeatureMaps:
162
164
  """Compute feature maps from the OlmoEarth backbone.
163
165
 
164
- Inputs:
165
- inputs: input dicts. It should include keys corresponding to the modalities
166
- that should be passed to the OlmoEarth model.
166
+ Args:
167
+ context: the model context. Input dicts should include keys corresponding
168
+ to the modalities that should be passed to the OlmoEarth model.
169
+
170
+ Returns:
171
+ a FeatureMaps consisting of one feature map, at 1/patch_size of the input
172
+ resolution. Embeddings will be pooled across modalities and timesteps.
167
173
  """
168
174
  kwargs = {}
169
175
  present_modalities = []
@@ -172,10 +178,10 @@ class OlmoEarth(torch.nn.Module):
172
178
  # We assume all multitemporal modalities have the same number of timesteps.
173
179
  max_timesteps = 1
174
180
  for modality in MODALITY_NAMES:
175
- if modality not in inputs[0]:
181
+ if modality not in context.inputs[0]:
176
182
  continue
177
183
  present_modalities.append(modality)
178
- cur = torch.stack([inp[modality] for inp in inputs], dim=0)
184
+ cur = torch.stack([inp[modality] for inp in context.inputs], dim=0)
179
185
  device = cur.device
180
186
  # Check if it's single or multitemporal, and reshape accordingly
181
187
  num_bands = Modality.get(modality).num_bands
@@ -196,7 +202,7 @@ class OlmoEarth(torch.nn.Module):
196
202
  # Note that only months (0 to 11) are used in OlmoEarth position encoding.
197
203
  # For now, we assign same timestamps to all inputs, but later we should handle varying timestamps per input.
198
204
  timestamps = torch.zeros(
199
- (len(inputs), max_timesteps, 3), dtype=torch.int32, device=device
205
+ (len(context.inputs), max_timesteps, 3), dtype=torch.int32, device=device
200
206
  )
201
207
  timestamps[:, :, 0] = 1 # day
202
208
  timestamps[:, :, 1] = torch.arange(max_timesteps, device=device)[
@@ -209,14 +215,14 @@ class OlmoEarth(torch.nn.Module):
209
215
 
210
216
  # Decide context based on self.autocast_dtype.
211
217
  if self.autocast_dtype is None:
212
- context = nullcontext()
218
+ torch_context = nullcontext()
213
219
  else:
214
220
  assert device is not None
215
- context = torch.amp.autocast(
221
+ torch_context = torch.amp.autocast(
216
222
  device_type=device.type, dtype=self.autocast_dtype
217
223
  )
218
224
 
219
- with context:
225
+ with torch_context:
220
226
  # Currently we assume the provided model always returns a TokensAndMasks object.
221
227
  tokens_and_masks: TokensAndMasks
222
228
  if isinstance(self.model, Encoder):
@@ -244,7 +250,7 @@ class OlmoEarth(torch.nn.Module):
244
250
  features.append(pooled)
245
251
  # Pool over the modalities, so we get one BCHW feature map.
246
252
  pooled = torch.stack(features, dim=0).mean(dim=0)
247
- return [pooled]
253
+ return FeatureMaps([pooled])
248
254
 
249
255
  def get_backbone_channels(self) -> list:
250
256
  """Returns the output channels of this model when used as a backbone.
@@ -3,15 +3,16 @@
3
3
  import math
4
4
  from enum import StrEnum
5
5
  from importlib import resources
6
- from typing import Any
7
6
 
8
7
  import torch
9
8
  import torch.nn.functional as F
10
9
  import yaml
11
10
  from einops import rearrange, repeat
12
- from torch import nn
13
11
 
14
12
  from rslearn.log_utils import get_logger
13
+ from rslearn.train.model_context import ModelContext
14
+
15
+ from .component import FeatureExtractor, FeatureMaps
15
16
 
16
17
  logger = get_logger(__name__)
17
18
 
@@ -28,7 +29,7 @@ class PanopticonModalities(StrEnum):
28
29
  # Add more modalities as needed
29
30
 
30
31
 
31
- class Panopticon(nn.Module):
32
+ class Panopticon(FeatureExtractor):
32
33
  """Class containing the Panopticon model that can ingest MaskedHeliosSample objects."""
33
34
 
34
35
  patch_size: int = 14
@@ -138,11 +139,11 @@ class Panopticon(nn.Module):
138
139
  "chn_ids": chn_ids,
139
140
  }
140
141
 
141
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
142
+ def forward(self, context: ModelContext) -> FeatureMaps:
142
143
  """Forward pass through the panopticon model."""
143
144
  batch_inputs = {
144
- key: torch.stack([inp[key] for inp in inputs], dim=0)
145
- for key in inputs[0].keys()
145
+ key: torch.stack([inp[key] for inp in context.inputs], dim=0)
146
+ for key in context.inputs[0].keys()
146
147
  }
147
148
  panopticon_inputs = self.prepare_input(batch_inputs)
148
149
  output_features = self.model.forward_features(panopticon_inputs)[
@@ -154,7 +155,7 @@ class Panopticon(nn.Module):
154
155
  output_features = rearrange(
155
156
  output_features, "b (h w) d -> b d h w", h=height, w=height
156
157
  )
157
- return [output_features]
158
+ return FeatureMaps([output_features])
158
159
 
159
160
  def get_backbone_channels(self) -> list:
160
161
  """Returns the output channels of this model when used as a backbone.
@@ -2,45 +2,39 @@
2
2
 
3
3
  from typing import Any
4
4
 
5
- import torch
5
+ from rslearn.train.model_context import ModelContext
6
6
 
7
+ from .component import (
8
+ FeatureMaps,
9
+ IntermediateComponent,
10
+ )
7
11
 
8
- class PickFeatures(torch.nn.Module):
12
+
13
+ class PickFeatures(IntermediateComponent):
9
14
  """Picks a subset of feature maps in a multi-scale feature map list."""
10
15
 
11
- def __init__(self, indexes: list[int], collapse: bool = False):
16
+ def __init__(self, indexes: list[int]):
12
17
  """Create a new PickFeatures.
13
18
 
14
19
  Args:
15
20
  indexes: the indexes of the input feature map list to select.
16
- collapse: return one feature map instead of list. If enabled, indexes must
17
- consist of one index. This is mainly useful for using PickFeatures as
18
- the final module in the decoder, since the final prediction is expected
19
- to be one feature map for most tasks like segmentation.
20
21
  """
21
22
  super().__init__()
22
23
  self.indexes = indexes
23
- self.collapse = collapse
24
-
25
- if self.collapse and len(self.indexes) != 1:
26
- raise ValueError("if collapse is enabled, must get exactly one index")
27
24
 
28
25
  def forward(
29
26
  self,
30
- features: list[torch.Tensor],
31
- inputs: list[dict[str, Any]] | None = None,
32
- targets: list[dict[str, Any]] | None = None,
33
- ) -> list[torch.Tensor]:
27
+ intermediates: Any,
28
+ context: ModelContext,
29
+ ) -> FeatureMaps:
34
30
  """Pick a subset of the features.
35
31
 
36
32
  Args:
37
- features: input features
38
- inputs: raw inputs, not used
39
- targets: targets, not used
33
+ intermediates: the output from the previous component, which must be a FeatureMaps.
34
+ context: the model context.
40
35
  """
41
- new_features = [features[idx] for idx in self.indexes]
42
- if self.collapse:
43
- assert len(new_features) == 1
44
- return new_features[0]
45
- else:
46
- return new_features
36
+ if not isinstance(intermediates, FeatureMaps):
37
+ raise ValueError("input to PickFeatures must be FeatureMaps")
38
+
39
+ new_features = [intermediates.feature_maps[idx] for idx in self.indexes]
40
+ return FeatureMaps(new_features)
@@ -4,8 +4,16 @@ from typing import Any
4
4
 
5
5
  import torch
6
6
 
7
+ from rslearn.train.model_context import ModelContext
7
8
 
8
- class PoolingDecoder(torch.nn.Module):
9
+ from .component import (
10
+ FeatureMaps,
11
+ FeatureVector,
12
+ IntermediateComponent,
13
+ )
14
+
15
+
16
+ class PoolingDecoder(IntermediateComponent):
9
17
  """Decoder that computes flat vector from a 2D feature map.
10
18
 
11
19
  It inputs multi-scale features, but only uses the last feature map. Then applies a
@@ -57,25 +65,26 @@ class PoolingDecoder(torch.nn.Module):
57
65
 
58
66
  self.output_layer = torch.nn.Linear(prev_channels, out_channels)
59
67
 
60
- def forward(
61
- self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
62
- ) -> torch.Tensor:
68
+ def forward(self, intermediates: Any, context: ModelContext) -> Any:
63
69
  """Compute flat output vector from multi-scale feature map.
64
70
 
65
71
  Args:
66
- features: list of feature maps at different resolutions.
67
- inputs: original inputs (ignored).
72
+ intermediates: the output from the previous component, which must be a FeatureMaps.
73
+ context: the model context.
68
74
 
69
75
  Returns:
70
76
  flat feature vector
71
77
  """
78
+ if not isinstance(intermediates, FeatureMaps):
79
+ raise ValueError("input to PoolingDecoder must be a FeatureMaps")
80
+
72
81
  # Only use last feature map.
73
- features = features[-1]
82
+ features = intermediates.feature_maps[-1]
74
83
 
75
84
  features = self.conv_layers(features)
76
85
  features = torch.amax(features, dim=(2, 3))
77
86
  features = self.fc_layers(features)
78
- return self.output_layer(features)
87
+ return FeatureVector(self.output_layer(features))
79
88
 
80
89
 
81
90
  class SegmentationPoolingDecoder(PoolingDecoder):
@@ -108,14 +117,13 @@ class SegmentationPoolingDecoder(PoolingDecoder):
108
117
  super().__init__(in_channels=in_channels, out_channels=out_channels, **kwargs)
109
118
  self.image_key = image_key
110
119
 
111
- def forward(
112
- self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
113
- ) -> torch.Tensor:
120
+ def forward(self, intermediates: Any, context: ModelContext) -> Any:
114
121
  """Extend PoolingDecoder forward to upsample the output to a segmentation mask.
115
122
 
116
123
  This only works when all of the pixels have the same segmentation target.
117
124
  """
118
- output_probs = super().forward(features, inputs)
125
+ output_probs = super().forward(intermediates, context)
119
126
  # BC -> BCHW
120
- h, w = inputs[0][self.image_key].shape[1:3]
121
- return output_probs[:, :, None, None].repeat([1, 1, h, w])
127
+ h, w = context.inputs[0][self.image_key].shape[1:3]
128
+ feat_map = output_probs.feature_vector[:, :, None, None].repeat([1, 1, h, w])
129
+ return FeatureMaps([feat_map])