rslearn 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. rslearn/arg_parser.py +2 -9
  2. rslearn/config/__init__.py +2 -0
  3. rslearn/config/dataset.py +64 -20
  4. rslearn/dataset/add_windows.py +1 -1
  5. rslearn/dataset/dataset.py +34 -84
  6. rslearn/dataset/materialize.py +5 -5
  7. rslearn/dataset/storage/__init__.py +1 -0
  8. rslearn/dataset/storage/file.py +202 -0
  9. rslearn/dataset/storage/storage.py +140 -0
  10. rslearn/dataset/window.py +26 -80
  11. rslearn/lightning_cli.py +22 -11
  12. rslearn/main.py +12 -37
  13. rslearn/models/anysat.py +11 -9
  14. rslearn/models/attention_pooling.py +177 -0
  15. rslearn/models/clay/clay.py +8 -9
  16. rslearn/models/clip.py +18 -15
  17. rslearn/models/component.py +111 -0
  18. rslearn/models/concatenate_features.py +21 -11
  19. rslearn/models/conv.py +15 -8
  20. rslearn/models/croma.py +13 -8
  21. rslearn/models/detr/detr.py +25 -14
  22. rslearn/models/dinov3.py +11 -6
  23. rslearn/models/faster_rcnn.py +19 -9
  24. rslearn/models/feature_center_crop.py +12 -9
  25. rslearn/models/fpn.py +19 -8
  26. rslearn/models/galileo/galileo.py +23 -18
  27. rslearn/models/module_wrapper.py +26 -57
  28. rslearn/models/molmo.py +16 -14
  29. rslearn/models/multitask.py +102 -73
  30. rslearn/models/olmoearth_pretrain/model.py +135 -38
  31. rslearn/models/panopticon.py +8 -7
  32. rslearn/models/pick_features.py +18 -24
  33. rslearn/models/pooling_decoder.py +22 -14
  34. rslearn/models/presto/presto.py +16 -10
  35. rslearn/models/presto/single_file_presto.py +4 -10
  36. rslearn/models/prithvi.py +12 -8
  37. rslearn/models/resize_features.py +21 -7
  38. rslearn/models/sam2_enc.py +11 -9
  39. rslearn/models/satlaspretrain.py +15 -9
  40. rslearn/models/simple_time_series.py +37 -17
  41. rslearn/models/singletask.py +24 -17
  42. rslearn/models/ssl4eo_s12.py +15 -10
  43. rslearn/models/swin.py +22 -13
  44. rslearn/models/terramind.py +24 -7
  45. rslearn/models/trunk.py +6 -3
  46. rslearn/models/unet.py +18 -9
  47. rslearn/models/upsample.py +22 -9
  48. rslearn/train/all_patches_dataset.py +89 -37
  49. rslearn/train/dataset.py +105 -97
  50. rslearn/train/lightning_module.py +51 -32
  51. rslearn/train/model_context.py +54 -0
  52. rslearn/train/prediction_writer.py +111 -41
  53. rslearn/train/scheduler.py +15 -0
  54. rslearn/train/tasks/classification.py +34 -15
  55. rslearn/train/tasks/detection.py +24 -31
  56. rslearn/train/tasks/embedding.py +33 -29
  57. rslearn/train/tasks/multi_task.py +7 -7
  58. rslearn/train/tasks/per_pixel_regression.py +41 -19
  59. rslearn/train/tasks/regression.py +38 -21
  60. rslearn/train/tasks/segmentation.py +33 -15
  61. rslearn/train/tasks/task.py +3 -2
  62. rslearn/train/transforms/resize.py +74 -0
  63. rslearn/utils/geometry.py +73 -0
  64. rslearn/utils/jsonargparse.py +66 -0
  65. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
  66. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/RECORD +71 -66
  67. rslearn/dataset/index.py +0 -173
  68. rslearn/models/registry.py +0 -22
  69. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
  70. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
  71. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
  72. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
  73. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  """MultiTaskModel for rslearn."""
2
2
 
3
+ from collections.abc import Iterable
3
4
  from copy import deepcopy
4
5
  from typing import Any
5
6
 
@@ -7,6 +8,9 @@ import torch
7
8
 
8
9
  from rslearn.log_utils import get_logger
9
10
  from rslearn.models.trunk import DecoderTrunk
11
+ from rslearn.train.model_context import ModelContext, ModelOutput
12
+
13
+ from .component import FeatureExtractor, IntermediateComponent, Predictor
10
14
 
11
15
  logger = get_logger(__name__)
12
16
 
@@ -58,8 +62,8 @@ class MultiTaskModel(torch.nn.Module):
58
62
 
59
63
  def __init__(
60
64
  self,
61
- encoder: list[torch.nn.Module],
62
- decoders: dict[str, list[torch.nn.Module]],
65
+ encoder: list[FeatureExtractor | IntermediateComponent],
66
+ decoders: dict[str, list[IntermediateComponent | Predictor]],
63
67
  lazy_decode: bool = False,
64
68
  loss_weights: dict[str, float] | None = None,
65
69
  trunk: DecoderTrunk | None = None,
@@ -67,8 +71,12 @@ class MultiTaskModel(torch.nn.Module):
67
71
  """Initialize a new MultiTaskModel.
68
72
 
69
73
  Args:
70
- encoder: modules to compute intermediate feature representations.
74
+ encoder: modules to compute intermediate feature representations. The first
75
+ module must be a FeatureExtractor, and following modules must be
76
+ IntermediateComponents.
71
77
  decoders: modules to compute outputs and loss, should match number of tasks.
78
+ The last module must be a Predictor, while the previous modules must be
79
+ IntermediateComponents.
72
80
  lazy_decode: if True, only decode the outputs specified in the batch.
73
81
  loss_weights: weights for each task's loss (default: None = equal weights).
74
82
  trunk: if provided, use this trunk module to postprocess the features
@@ -76,7 +84,7 @@ class MultiTaskModel(torch.nn.Module):
76
84
  """
77
85
  super().__init__()
78
86
  self.lazy_decode = lazy_decode
79
- self.encoder = torch.nn.Sequential(*encoder)
87
+ self.encoder = torch.nn.ModuleList(encoder)
80
88
  self.decoders = torch.nn.ModuleDict(
81
89
  sort_keys(
82
90
  {
@@ -120,32 +128,28 @@ class MultiTaskModel(torch.nn.Module):
120
128
 
121
129
  def apply_decoder(
122
130
  self,
123
- features: list[torch.Tensor],
124
- inputs: list[dict[str, Any]],
131
+ intermediates: Any,
132
+ context: ModelContext,
125
133
  targets: list[dict[str, Any]] | None,
126
- decoder: list[torch.nn.Module],
134
+ decoder: list[IntermediateComponent | Predictor],
127
135
  task_name: str,
128
- outputs: list[dict[str, Any]],
129
- losses: dict[str, torch.Tensor],
130
- ) -> tuple[list[dict[str, Any]], dict[str, torch.Tensor]]:
136
+ ) -> ModelOutput:
131
137
  """Apply a decoder to a list of inputs and targets.
132
138
 
133
139
  Args:
134
- features: list of features
135
- inputs: list of input dicts
140
+ intermediates: the intermediate output from the encoder.
141
+ context: the model context.
136
142
  targets: list of target dicts
137
143
  decoder: list of decoder modules
138
144
  task_name: the name of the task
139
- outputs: list of output dicts
140
- losses: dictionary of loss values
141
145
 
142
146
  Returns:
143
- tuple of (outputs, losses)
147
+ a ModelOutput containing outputs across all the decoders.
144
148
  """
145
149
  # First, apply all but the last module in the decoder to the features
146
- cur = features
150
+ cur = intermediates
147
151
  for module in decoder[:-1]:
148
- cur = module(cur, inputs)
152
+ cur = module(cur, context)
149
153
 
150
154
  if targets is None:
151
155
  cur_targets = None
@@ -153,14 +157,7 @@ class MultiTaskModel(torch.nn.Module):
153
157
  cur_targets = [target[task_name] for target in targets]
154
158
 
155
159
  # Then, apply the last module to the features and targets
156
- cur_output, cur_loss_dict = decoder[-1](cur, inputs, cur_targets)
157
- for idx, entry in enumerate(cur_output):
158
- outputs[idx][task_name] = entry
159
- for loss_name, loss_value in cur_loss_dict.items():
160
- losses[f"{task_name}_{loss_name}"] = (
161
- loss_value * self.loss_weights[task_name]
162
- )
163
- return outputs, losses
160
+ return decoder[-1](cur, context, cur_targets)
164
161
 
165
162
  def _get_tasks_from_decoder(self, decoder: str) -> list[str]:
166
163
  """Get the tasks corresponding to this decoder.
@@ -172,66 +169,84 @@ class MultiTaskModel(torch.nn.Module):
172
169
 
173
170
  def apply_decoders(
174
171
  self,
175
- features: list[torch.Tensor],
176
- inputs: list[dict[str, Any]],
172
+ intermediates: Any,
173
+ context: ModelContext,
177
174
  targets: list[dict[str, Any]] | None,
178
- ) -> dict[str, Any]:
175
+ ) -> ModelOutput:
179
176
  """Apply all the decoders to the features and targets.
180
177
 
181
178
  Args:
182
- features: list of features
183
- inputs: list of input dicts
179
+ intermediates: the intermediates from the encoder.
180
+ context: the model context
184
181
  targets: list of target dicts
185
182
 
186
183
  Returns:
187
- dict of outputs and losses
184
+ combined ModelOutput. The outputs is a list of output dicts, one per example,
185
+ where the dict maps from task name to the corresponding task output. The
186
+ losses is a flat dict but the task name is prepended to the loss names.
188
187
  """
189
- outputs: list[dict[str, torch.Tensor | dict]] = [{} for _ in inputs]
188
+ outputs: list[dict[str, torch.Tensor | dict]] = [{} for _ in context.inputs]
190
189
  losses: dict[str, torch.Tensor] = {}
191
190
 
192
191
  if self.lazy_decode:
193
192
  # Assume that all inputs have the same dataset_source
194
- dataset_source = inputs[0]["dataset_source"]
195
- decoder = self.decoders[
196
- self.target_to_decoder.get(dataset_source, dataset_source)
197
- ]
198
- self.apply_decoder(
199
- features, inputs, targets, decoder, dataset_source, outputs, losses
193
+ task_name = context.metadatas[0].dataset_source
194
+
195
+ if task_name is None:
196
+ raise ValueError("dataset_source must be set for lazy decoding")
197
+
198
+ decoder = self.decoders[self.target_to_decoder.get(task_name, task_name)]
199
+ model_output = self.apply_decoder(
200
+ intermediates, context, targets, decoder, task_name
200
201
  )
202
+ for idx, entry in enumerate(model_output.outputs):
203
+ outputs[idx][task_name] = entry
204
+ for loss_name, loss_value in model_output.loss_dict.items():
205
+ losses[f"{task_name}_{loss_name}"] = (
206
+ loss_value * self.loss_weights[task_name]
207
+ )
201
208
  else:
202
209
  for decoder_name, decoder in self.decoders.items():
203
210
  for task_name in self._get_tasks_from_decoder(decoder_name):
204
- self.apply_decoder(
205
- features, inputs, targets, decoder, task_name, outputs, losses
211
+ model_output = self.apply_decoder(
212
+ intermediates, context, targets, decoder, task_name
206
213
  )
207
-
208
- return {
209
- "outputs": outputs,
210
- "loss_dict": losses,
211
- }
214
+ for idx, entry in enumerate(model_output.outputs):
215
+ outputs[idx][task_name] = entry
216
+ for loss_name, loss_value in model_output.loss_dict.items():
217
+ losses[f"{task_name}_{loss_name}"] = (
218
+ loss_value * self.loss_weights[task_name]
219
+ )
220
+
221
+ return ModelOutput(
222
+ outputs=outputs,
223
+ loss_dict=losses,
224
+ )
212
225
 
213
226
  def forward(
214
227
  self,
215
- inputs: list[dict[str, Any]],
228
+ context: ModelContext,
216
229
  targets: list[dict[str, Any]] | None = None,
217
- ) -> dict[str, Any]:
230
+ ) -> ModelOutput:
218
231
  """Apply the sequence of modules on the inputs, including shared trunk.
219
232
 
220
233
  Args:
221
- inputs: list of input dicts
234
+ context: the model context.
222
235
  targets: optional list of target dicts
223
236
 
224
237
  Returns:
225
- dict with keys "outputs" and "loss_dict".
238
+ the model output from apply_decoders.
226
239
  """
227
- features = self.encoder(inputs)
240
+ cur = self.encoder[0](context)
241
+ for module in self.encoder[1:]:
242
+ cur = module(cur, context)
228
243
  if self.trunk is not None:
229
- trunk_out = self.trunk(features, inputs)
230
- outs = self.apply_decoders(trunk_out.pop("outputs"), inputs, targets)
244
+ trunk_out = self.trunk(cur, context)
245
+ outs = self.apply_decoders(trunk_out.pop("outputs"), context, targets)
231
246
  self.trunk.apply_auxiliary_losses(trunk_out, outs)
232
247
  return outs | trunk_out
233
248
  else:
234
- return self.apply_decoders(features, inputs, targets)
249
+ return self.apply_decoders(cur, context, targets)
235
250
 
236
251
 
237
252
  class MultiTaskMergedModel(MultiTaskModel):
@@ -247,8 +262,8 @@ class MultiTaskMergedModel(MultiTaskModel):
247
262
 
248
263
  def __init__(
249
264
  self,
250
- encoder: list[torch.nn.Module],
251
- decoders: dict[str, list[torch.nn.Module]],
265
+ encoder: list[FeatureExtractor | IntermediateComponent],
266
+ decoders: dict[str, list[IntermediateComponent | Predictor]],
252
267
  decoder_to_target: dict[str, list[str]],
253
268
  task_label_offsets: dict[str, dict[str, Any]],
254
269
  lazy_decode: bool = False,
@@ -273,7 +288,7 @@ class MultiTaskMergedModel(MultiTaskModel):
273
288
  torch.nn.Module.__init__(self)
274
289
 
275
290
  self.lazy_decode = lazy_decode
276
- self.encoder = torch.nn.Sequential(*encoder)
291
+ self.encoder = torch.nn.ModuleList(encoder)
277
292
  self.decoders = torch.nn.ModuleDict(
278
293
  sort_keys(
279
294
  {
@@ -329,9 +344,9 @@ class MultiTaskMergedModel(MultiTaskModel):
329
344
  return offset_targets
330
345
 
331
346
  def unmerge_output_labels(
332
- self, outputs: list[dict[str, torch.Tensor | dict]], task_name: str
333
- ) -> None:
334
- """Unmerge the task labels in place.
347
+ self, outputs: Iterable[Any], task_name: str
348
+ ) -> list[dict[str, torch.Tensor | dict]]:
349
+ """Unmerge the task outputs.
335
350
 
336
351
  For most tasks, this means chopping off the corresponding label dimensions.
337
352
  For some, we might just need to subtract an offset from the target (ex: segmentation).
@@ -340,10 +355,15 @@ class MultiTaskMergedModel(MultiTaskModel):
340
355
  Args:
341
356
  outputs: the predictions
342
357
  task_name: the name of the task
358
+
359
+ Returns:
360
+ the unmerged outputs.
343
361
  """
344
362
  offset = self.task_label_offsets[task_name]["offset"]
345
363
  num_outputs = self.task_label_offsets[task_name]["num_outputs"]
346
364
  output_key = self.task_label_offsets[task_name]["outputs_key"]
365
+
366
+ unmerged_outputs: list[dict[str, torch.Tensor | dict]] = [{} for _ in outputs]
347
367
  with torch.no_grad():
348
368
  for i, output in enumerate(outputs):
349
369
  if not output:
@@ -353,35 +373,44 @@ class MultiTaskMergedModel(MultiTaskModel):
353
373
  if isinstance(output, dict):
354
374
  # For some tasks (eg object detection), we have discrete label
355
375
  # predictions instead of a distribution over labels
356
- output[output_key] -= offset
376
+ unmerged_output = output.copy()
377
+ unmerged_output[output_key] = unmerged_output[output_key] - offset
378
+ unmerged_outputs[i][task_name] = unmerged_output
357
379
  elif isinstance(output, torch.Tensor):
358
380
  # For classification/segmentation tasks, we have a distribution
359
381
  # over labels, so we need to scale the predictions so that they
360
382
  # sum to 1 since we chop off some of the probability densities
361
- tensor: torch.Tensor = output[offset : offset + num_outputs, ...]
362
- tensor /= tensor.sum(dim=0, keepdim=True).type(torch.float32)
363
- outputs[i][task_name] = tensor
383
+ unmerged_output = output[offset : offset + num_outputs, ...]
384
+ unmerged_output /= unmerged_output.sum(dim=0, keepdim=True).type(
385
+ torch.float32
386
+ )
387
+ unmerged_outputs[i][task_name] = unmerged_output
388
+
389
+ return unmerged_outputs
364
390
 
365
391
  def forward(
366
392
  self,
367
- inputs: list[dict[str, Any]],
393
+ context: ModelContext,
368
394
  targets: list[dict[str, Any]] | None = None,
369
- ) -> dict[str, Any]:
395
+ ) -> ModelOutput:
370
396
  """Apply the sequence of modules on the inputs.
371
397
 
372
398
  Args:
373
- inputs: list of input dicts
399
+ context: the model context.
374
400
  targets: optional list of target dicts
375
401
 
376
402
  Returns:
377
- dict with keys "outputs" and "loss_dict", and possibly other keys.
403
+ the model output.
378
404
  """
379
- dataset_source = inputs[0].get("dataset_source", None)
380
- assert isinstance(dataset_source, str)
381
- targets = self.merge_task_labels(targets, dataset_source)
382
- outs = super().forward(inputs, targets)
383
- self.unmerge_output_labels(outs["outputs"], dataset_source)
384
- return outs
405
+ dataset_source = context.metadatas[0].dataset_source
406
+ assert dataset_source is not None
407
+ merged_targets = self.merge_task_labels(targets, dataset_source)
408
+ outs = super().forward(context, merged_targets)
409
+ unmerged_outputs = self.unmerge_output_labels(outs.outputs, dataset_source)
410
+ return ModelOutput(
411
+ outputs=unmerged_outputs,
412
+ loss_dict=outs.loss_dict,
413
+ )
385
414
 
386
415
  def _get_tasks_from_decoder(self, decoder: str) -> list[str]:
387
416
  """Get the tasks corresponding to this decoder.
@@ -19,6 +19,8 @@ from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
19
19
  from upath import UPath
20
20
 
21
21
  from rslearn.log_utils import get_logger
22
+ from rslearn.models.component import FeatureExtractor, FeatureMaps, TokenFeatureMaps
23
+ from rslearn.train.model_context import ModelContext
22
24
 
23
25
  logger = get_logger(__name__)
24
26
 
@@ -44,7 +46,7 @@ EMBEDDING_SIZES = {
44
46
  }
45
47
 
46
48
 
47
- class OlmoEarth(torch.nn.Module):
49
+ class OlmoEarth(FeatureExtractor):
48
50
  """A wrapper to support the OlmoEarth model."""
49
51
 
50
52
  def __init__(
@@ -58,6 +60,7 @@ class OlmoEarth(torch.nn.Module):
58
60
  random_initialization: bool = False,
59
61
  embedding_size: int | None = None,
60
62
  autocast_dtype: str | None = "bfloat16",
63
+ token_pooling: bool = True,
61
64
  ):
62
65
  """Create a new OlmoEarth model.
63
66
 
@@ -81,6 +84,9 @@ class OlmoEarth(torch.nn.Module):
81
84
  embedding_size: optional embedding size to report via
82
85
  get_backbone_channels (if model_id is not set).
83
86
  autocast_dtype: which dtype to use for autocasting, or set None to disable.
87
+ token_pooling: whether or not to pool the tokens. If True, the output will be BxCxHxW. If False,
88
+ there will be an extra dimension, N, (BxCxHxWxN) representing the temporal and channel
89
+ dimensions.
84
90
  """
85
91
  if (
86
92
  sum(
@@ -131,6 +137,7 @@ class OlmoEarth(torch.nn.Module):
131
137
  else:
132
138
  model = model[part]
133
139
  self.model = model
140
+ self.token_pooling = token_pooling
134
141
 
135
142
  def _load_model_from_checkpoint(
136
143
  self, checkpoint_upath: UPath, random_initialization: bool
@@ -158,45 +165,89 @@ class OlmoEarth(torch.nn.Module):
158
165
 
159
166
  return model
160
167
 
161
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
162
- """Compute feature maps from the OlmoEarth backbone.
168
+ def _prepare_modality_inputs(
169
+ self, context: ModelContext
170
+ ) -> tuple[MaskedOlmoEarthSample, list[str], torch.device]:
171
+ """Prepare modality tensors and masks for the OlmoEarth model.
172
+
173
+ Uses a two-pass approach to ensure all modalities have consistent timestep
174
+ dimensions for position encoding.
175
+
176
+ Args:
177
+ context: the model context with input tensors.
163
178
 
164
- Inputs:
165
- inputs: input dicts. It should include keys corresponding to the modalities
166
- that should be passed to the OlmoEarth model.
179
+ Returns:
180
+ tuple of (sample, present_modalities, device)
167
181
  """
168
182
  kwargs = {}
169
183
  present_modalities = []
170
184
  device = None
171
- # Handle the case where some modalities are multitemporal and some are not.
172
- # We assume all multitemporal modalities have the same number of timesteps.
185
+
186
+ # First pass: find global max_timesteps across all modalities and samples
187
+ # TODO: currently we assume all modalities have the same number of timesteps,
188
+ # which is not true for all cases, and time series time steps are assumed to
189
+ # be 1-month apart. It also assumes continuity between available timesteps.
190
+ # We'll have to fix all that.
173
191
  max_timesteps = 1
192
+ modality_data = {}
174
193
  for modality in MODALITY_NAMES:
175
- if modality not in inputs[0]:
194
+ if modality not in context.inputs[0]:
176
195
  continue
177
196
  present_modalities.append(modality)
178
- cur = torch.stack([inp[modality] for inp in inputs], dim=0)
179
- device = cur.device
180
- # Check if it's single or multitemporal, and reshape accordingly
197
+ tensors = [inp[modality] for inp in context.inputs]
198
+ device = tensors[0].device
181
199
  num_bands = Modality.get(modality).num_bands
182
- num_timesteps = cur.shape[1] // num_bands
183
- max_timesteps = max(max_timesteps, num_timesteps)
184
- cur = rearrange(cur, "b (t c) h w -> b h w t c", t=num_timesteps)
200
+ max_t = max(t.shape[0] for t in tensors) // num_bands
201
+ max_timesteps = max(max_timesteps, max_t)
202
+ modality_data[modality] = (
203
+ tensors,
204
+ num_bands,
205
+ len(Modality.get(modality).band_sets),
206
+ )
207
+
208
+ # Second pass: pad and process each modality with global max_timesteps
209
+ for modality in present_modalities:
210
+ tensors, num_bands, num_band_sets = modality_data[modality]
211
+ target_ch = max_timesteps * num_bands
212
+
213
+ # Pad tensors to target_ch and track original timesteps for masking
214
+ padded = []
215
+ original_timesteps = []
216
+ for t in tensors:
217
+ orig_t = t.shape[0] // num_bands
218
+ original_timesteps.append(orig_t)
219
+ if t.shape[0] < target_ch:
220
+ pad = torch.zeros(
221
+ (target_ch - t.shape[0],) + t.shape[1:],
222
+ dtype=t.dtype,
223
+ device=device,
224
+ )
225
+ t = torch.cat([t, pad], dim=0)
226
+ padded.append(t)
227
+
228
+ cur = torch.stack(padded, dim=0)
229
+ cur = rearrange(cur, "b (t c) h w -> b h w t c", t=max_timesteps)
185
230
  kwargs[modality] = cur
186
- # Create mask array which is BHWTS (without channels but with band sets).
187
- num_band_sets = len(Modality.get(modality).band_sets)
188
- mask_shape = cur.shape[0:4] + (num_band_sets,)
189
- mask = (
190
- torch.ones(mask_shape, dtype=torch.int32, device=device)
191
- * MaskValue.ONLINE_ENCODER.value
231
+
232
+ # Create mask: ONLINE_ENCODER for valid, MISSING for padded timesteps
233
+ b, h, w = cur.shape[0], cur.shape[1], cur.shape[2]
234
+ mask = torch.full(
235
+ (b, h, w, max_timesteps, num_band_sets),
236
+ fill_value=MaskValue.ONLINE_ENCODER.value,
237
+ dtype=torch.int32,
238
+ device=device,
192
239
  )
240
+ for sample_idx, orig_t in enumerate(original_timesteps):
241
+ if orig_t < max_timesteps:
242
+ mask[sample_idx, :, :, orig_t:, :] = MaskValue.MISSING.value
193
243
  kwargs[f"{modality}_mask"] = mask
194
244
 
195
245
  # Timestamps is required.
196
246
  # Note that only months (0 to 11) are used in OlmoEarth position encoding.
197
- # For now, we assign same timestamps to all inputs, but later we should handle varying timestamps per input.
247
+ # For now, we assign same timestamps to all inputs, but later we should
248
+ # handle varying timestamps per input.
198
249
  timestamps = torch.zeros(
199
- (len(inputs), max_timesteps, 3), dtype=torch.int32, device=device
250
+ (len(context.inputs), max_timesteps, 3), dtype=torch.int32, device=device
200
251
  )
201
252
  timestamps[:, :, 0] = 1 # day
202
253
  timestamps[:, :, 1] = torch.arange(max_timesteps, device=device)[
@@ -205,25 +256,46 @@ class OlmoEarth(torch.nn.Module):
205
256
  timestamps[:, :, 2] = 2024 # year
206
257
  kwargs["timestamps"] = timestamps
207
258
 
208
- sample = MaskedOlmoEarthSample(**kwargs)
259
+ return MaskedOlmoEarthSample(**kwargs), present_modalities, device
260
+
261
+ def forward(self, context: ModelContext) -> FeatureMaps | TokenFeatureMaps:
262
+ """Compute feature maps from the OlmoEarth backbone.
263
+
264
+ Args:
265
+ context: the model context. Input dicts should include keys corresponding
266
+ to the modalities that should be passed to the OlmoEarth model.
267
+
268
+ Returns:
269
+ a FeatureMaps consisting of one feature map, at 1/patch_size of the input
270
+ resolution. Embeddings will be pooled across modalities and timesteps.
271
+ """
272
+ sample, present_modalities, device = self._prepare_modality_inputs(context)
209
273
 
210
274
  # Decide context based on self.autocast_dtype.
211
275
  if self.autocast_dtype is None:
212
- context = nullcontext()
276
+ torch_context = nullcontext()
213
277
  else:
214
278
  assert device is not None
215
- context = torch.amp.autocast(
279
+ torch_context = torch.amp.autocast(
216
280
  device_type=device.type, dtype=self.autocast_dtype
217
281
  )
218
282
 
219
- with context:
283
+ # Check if we can bypass masks (fast_pass=True)
284
+ missing_tokens = False
285
+ for modality in present_modalities:
286
+ modality_mask = getattr(sample, f"{modality}_mask")
287
+ if torch.any(modality_mask == MaskValue.MISSING.value):
288
+ missing_tokens = True
289
+ break
290
+
291
+ with torch_context:
220
292
  # Currently we assume the provided model always returns a TokensAndMasks object.
221
293
  tokens_and_masks: TokensAndMasks
222
294
  if isinstance(self.model, Encoder):
223
295
  # Encoder has a fast_pass argument to indicate mask is not needed.
224
296
  tokens_and_masks = self.model(
225
297
  sample,
226
- fast_pass=True,
298
+ fast_pass=not missing_tokens,
227
299
  patch_size=self.patch_size,
228
300
  **self.forward_kwargs,
229
301
  )["tokens_and_masks"]
@@ -235,16 +307,41 @@ class OlmoEarth(torch.nn.Module):
235
307
 
236
308
  # Apply temporal/modality pooling so we just have one feature per patch.
237
309
  features = []
238
- for modality in present_modalities:
239
- modality_features = getattr(tokens_and_masks, modality)
240
- # Pool over band sets and timesteps (BHWTSC -> BHWC).
241
- pooled = modality_features.mean(dim=[3, 4])
242
- # We want BHWC -> BCHW.
243
- pooled = rearrange(pooled, "b h w c -> b c h w")
244
- features.append(pooled)
245
- # Pool over the modalities, so we get one BCHW feature map.
246
- pooled = torch.stack(features, dim=0).mean(dim=0)
247
- return [pooled]
310
+ if self.token_pooling:
311
+ for modality in present_modalities:
312
+ modality_features = getattr(tokens_and_masks, modality) # BHWTSC
313
+ # If fast_pass is False, we need to mask the missing tokens before pooling.
314
+ if missing_tokens:
315
+ modality_masks = getattr(
316
+ tokens_and_masks, f"{modality}_mask"
317
+ ) # BHWTS
318
+ modality_masks_bool = (
319
+ modality_masks != MaskValue.MISSING.value
320
+ ).unsqueeze(-1)
321
+ count = modality_masks_bool.sum(dim=[3, 4])
322
+ # Masked average over band sets and timesteps (BHWTSC -> BHWC).
323
+ pooled = (modality_features * modality_masks_bool).sum(
324
+ dim=[3, 4]
325
+ ) / count.clamp(min=1)
326
+ else:
327
+ # Pool over band sets and timesteps (BHWTSC -> BHWC).
328
+ pooled = modality_features.mean(dim=[3, 4])
329
+ # We want BHWC -> BCHW.
330
+ pooled = rearrange(pooled, "b h w c -> b c h w")
331
+ features.append(pooled)
332
+ # Pool over the modalities, so we get one BCHW feature map.
333
+ pooled = torch.stack(features, dim=0).mean(dim=0)
334
+ return FeatureMaps([pooled])
335
+ else:
336
+ for modality in present_modalities:
337
+ modality_features = getattr(tokens_and_masks, modality)
338
+ # Combine band sets and timesteps into last dim (BHWTSC -> BHWCN).
339
+ modality_features = rearrange(
340
+ modality_features, "b h w t s c -> b c h w (t s)"
341
+ )
342
+ features.append(modality_features)
343
+ pooled = torch.cat(features, dim=-1)
344
+ return TokenFeatureMaps([pooled])
248
345
 
249
346
  def get_backbone_channels(self) -> list:
250
347
  """Returns the output channels of this model when used as a backbone.
@@ -3,15 +3,16 @@
3
3
  import math
4
4
  from enum import StrEnum
5
5
  from importlib import resources
6
- from typing import Any
7
6
 
8
7
  import torch
9
8
  import torch.nn.functional as F
10
9
  import yaml
11
10
  from einops import rearrange, repeat
12
- from torch import nn
13
11
 
14
12
  from rslearn.log_utils import get_logger
13
+ from rslearn.train.model_context import ModelContext
14
+
15
+ from .component import FeatureExtractor, FeatureMaps
15
16
 
16
17
  logger = get_logger(__name__)
17
18
 
@@ -28,7 +29,7 @@ class PanopticonModalities(StrEnum):
28
29
  # Add more modalities as needed
29
30
 
30
31
 
31
- class Panopticon(nn.Module):
32
+ class Panopticon(FeatureExtractor):
32
33
  """Class containing the Panopticon model that can ingest MaskedHeliosSample objects."""
33
34
 
34
35
  patch_size: int = 14
@@ -138,11 +139,11 @@ class Panopticon(nn.Module):
138
139
  "chn_ids": chn_ids,
139
140
  }
140
141
 
141
- def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
142
+ def forward(self, context: ModelContext) -> FeatureMaps:
142
143
  """Forward pass through the panopticon model."""
143
144
  batch_inputs = {
144
- key: torch.stack([inp[key] for inp in inputs], dim=0)
145
- for key in inputs[0].keys()
145
+ key: torch.stack([inp[key] for inp in context.inputs], dim=0)
146
+ for key in context.inputs[0].keys()
146
147
  }
147
148
  panopticon_inputs = self.prepare_input(batch_inputs)
148
149
  output_features = self.model.forward_features(panopticon_inputs)[
@@ -154,7 +155,7 @@ class Panopticon(nn.Module):
154
155
  output_features = rearrange(
155
156
  output_features, "b (h w) d -> b d h w", h=height, w=height
156
157
  )
157
- return [output_features]
158
+ return FeatureMaps([output_features])
158
159
 
159
160
  def get_backbone_channels(self) -> list:
160
161
  """Returns the output channels of this model when used as a backbone.