rslearn 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. rslearn/arg_parser.py +2 -9
  2. rslearn/config/__init__.py +2 -0
  3. rslearn/config/dataset.py +64 -20
  4. rslearn/dataset/add_windows.py +1 -1
  5. rslearn/dataset/dataset.py +34 -84
  6. rslearn/dataset/materialize.py +5 -5
  7. rslearn/dataset/storage/__init__.py +1 -0
  8. rslearn/dataset/storage/file.py +202 -0
  9. rslearn/dataset/storage/storage.py +140 -0
  10. rslearn/dataset/window.py +26 -80
  11. rslearn/lightning_cli.py +22 -11
  12. rslearn/main.py +12 -37
  13. rslearn/models/anysat.py +11 -9
  14. rslearn/models/attention_pooling.py +177 -0
  15. rslearn/models/clay/clay.py +8 -9
  16. rslearn/models/clip.py +18 -15
  17. rslearn/models/component.py +111 -0
  18. rslearn/models/concatenate_features.py +21 -11
  19. rslearn/models/conv.py +15 -8
  20. rslearn/models/croma.py +13 -8
  21. rslearn/models/detr/detr.py +25 -14
  22. rslearn/models/dinov3.py +11 -6
  23. rslearn/models/faster_rcnn.py +19 -9
  24. rslearn/models/feature_center_crop.py +12 -9
  25. rslearn/models/fpn.py +19 -8
  26. rslearn/models/galileo/galileo.py +23 -18
  27. rslearn/models/module_wrapper.py +26 -57
  28. rslearn/models/molmo.py +16 -14
  29. rslearn/models/multitask.py +102 -73
  30. rslearn/models/olmoearth_pretrain/model.py +135 -38
  31. rslearn/models/panopticon.py +8 -7
  32. rslearn/models/pick_features.py +18 -24
  33. rslearn/models/pooling_decoder.py +22 -14
  34. rslearn/models/presto/presto.py +16 -10
  35. rslearn/models/presto/single_file_presto.py +4 -10
  36. rslearn/models/prithvi.py +12 -8
  37. rslearn/models/resize_features.py +21 -7
  38. rslearn/models/sam2_enc.py +11 -9
  39. rslearn/models/satlaspretrain.py +15 -9
  40. rslearn/models/simple_time_series.py +37 -17
  41. rslearn/models/singletask.py +24 -17
  42. rslearn/models/ssl4eo_s12.py +15 -10
  43. rslearn/models/swin.py +22 -13
  44. rslearn/models/terramind.py +24 -7
  45. rslearn/models/trunk.py +6 -3
  46. rslearn/models/unet.py +18 -9
  47. rslearn/models/upsample.py +22 -9
  48. rslearn/train/all_patches_dataset.py +89 -37
  49. rslearn/train/dataset.py +105 -97
  50. rslearn/train/lightning_module.py +51 -32
  51. rslearn/train/model_context.py +54 -0
  52. rslearn/train/prediction_writer.py +111 -41
  53. rslearn/train/scheduler.py +15 -0
  54. rslearn/train/tasks/classification.py +34 -15
  55. rslearn/train/tasks/detection.py +24 -31
  56. rslearn/train/tasks/embedding.py +33 -29
  57. rslearn/train/tasks/multi_task.py +7 -7
  58. rslearn/train/tasks/per_pixel_regression.py +41 -19
  59. rslearn/train/tasks/regression.py +38 -21
  60. rslearn/train/tasks/segmentation.py +33 -15
  61. rslearn/train/tasks/task.py +3 -2
  62. rslearn/train/transforms/resize.py +74 -0
  63. rslearn/utils/geometry.py +73 -0
  64. rslearn/utils/jsonargparse.py +66 -0
  65. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
  66. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/RECORD +71 -66
  67. rslearn/dataset/index.py +0 -173
  68. rslearn/models/registry.py +0 -22
  69. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
  70. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
  71. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
  72. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
  73. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
rslearn/train/dataset.py CHANGED
@@ -20,13 +20,15 @@ from rslearn.config import (
20
20
  LayerConfig,
21
21
  )
22
22
  from rslearn.dataset.dataset import Dataset
23
+ from rslearn.dataset.storage.file import FileWindowStorage
23
24
  from rslearn.dataset.window import Window, get_layer_and_group_from_dir_name
24
25
  from rslearn.log_utils import get_logger
25
- from rslearn.train.tasks import Task
26
26
  from rslearn.utils.feature import Feature
27
- from rslearn.utils.geometry import PixelBounds
27
+ from rslearn.utils.geometry import PixelBounds, ResolutionFactor
28
28
  from rslearn.utils.mp import star_imap_unordered
29
29
 
30
+ from .model_context import SampleMetadata
31
+ from .tasks import Task
30
32
  from .transforms import Sequential
31
33
 
32
34
  logger = get_logger(__name__)
@@ -128,6 +130,10 @@ class DataInput:
128
130
  """Specification of a piece of data from a window that is needed for training.
129
131
 
130
132
  The DataInput includes which layer(s) the data can be obtained from for each window.
133
+
134
+ Note that this class is not a dataclass because jsonargparse does not play well
135
+ with dataclasses without enabling specialized options which we have not validated
136
+ will work with the rest of our code.
131
137
  """
132
138
 
133
139
  def __init__(
@@ -141,7 +147,9 @@ class DataInput:
141
147
  dtype: DType = DType.FLOAT32,
142
148
  load_all_layers: bool = False,
143
149
  load_all_item_groups: bool = False,
144
- ) -> None:
150
+ resolution_factor: ResolutionFactor = ResolutionFactor(),
151
+ resampling: Resampling = Resampling.nearest,
152
+ ):
145
153
  """Initialize a new DataInput.
146
154
 
147
155
  Args:
@@ -164,6 +172,11 @@ class DataInput:
164
172
  are reading from. By default, we assume the specified layer name is of
165
173
  the form "{layer_name}.{group_idx}" and read that item group only. With
166
174
  this option enabled, we ignore the group_idx and read all item groups.
175
+ resolution_factor: controls the resolution at which raster data is loaded for training.
176
+ By default (factor=1), data is loaded at the window resolution.
177
+ E.g. for a 64x64 window at 10 m/pixel with resolution_factor=1/2,
178
+ the resulting tensor is 32x32 (covering the same geographic area at 20 m/pixel).
179
+ resampling: resampling method (default nearest neighbor).
167
180
  """
168
181
  self.data_type = data_type
169
182
  self.layers = layers
@@ -174,6 +187,8 @@ class DataInput:
174
187
  self.dtype = dtype
175
188
  self.load_all_layers = load_all_layers
176
189
  self.load_all_item_groups = load_all_item_groups
190
+ self.resolution_factor = resolution_factor
191
+ self.resampling = resampling
177
192
 
178
193
 
179
194
  def read_raster_layer_for_data_input(
@@ -231,15 +246,23 @@ def read_raster_layer_for_data_input(
231
246
  + f"window {window.name} layer {layer_name} group {group_idx}"
232
247
  )
233
248
 
249
+ # Get the projection and bounds to read under (multiply window resolution # by
250
+ # the specified resolution factor).
251
+ final_projection = data_input.resolution_factor.multiply_projection(
252
+ window.projection
253
+ )
254
+ final_bounds = data_input.resolution_factor.multiply_bounds(bounds)
255
+
234
256
  image = torch.zeros(
235
- (len(needed_bands), bounds[3] - bounds[1], bounds[2] - bounds[0]),
257
+ (
258
+ len(needed_bands),
259
+ final_bounds[3] - final_bounds[1],
260
+ final_bounds[2] - final_bounds[0],
261
+ ),
236
262
  dtype=get_torch_dtype(data_input.dtype),
237
263
  )
238
264
 
239
265
  for band_set, src_indexes, dst_indexes in needed_sets_and_indexes:
240
- final_projection, final_bounds = band_set.get_final_projection_and_bounds(
241
- window.projection, bounds
242
- )
243
266
  if band_set.format is None:
244
267
  raise ValueError(f"No format specified for {layer_name}")
245
268
  raster_format = band_set.instantiate_raster_format()
@@ -247,44 +270,16 @@ def read_raster_layer_for_data_input(
247
270
  layer_name, band_set.bands, group_idx=group_idx
248
271
  )
249
272
 
250
- # Previously we always read in the native projection of the data, and then
251
- # zoom in or out (the resolution must be a power of two off) to match the
252
- # window's resolution.
253
- # However, this fails if the bounds are not multiples of the resolution factor.
254
- # So we fallback to reading directly in the window projection if that is the
255
- # case (which may be a bit slower).
256
- is_bounds_zoomable = True
257
- if band_set.zoom_offset < 0:
258
- zoom_factor = 2 ** (-band_set.zoom_offset)
259
- is_bounds_zoomable = (final_bounds[2] - final_bounds[0]) * zoom_factor == (
260
- bounds[2] - bounds[0]
261
- ) and (final_bounds[3] - final_bounds[1]) * zoom_factor == (
262
- bounds[3] - bounds[1]
263
- )
264
-
265
- if is_bounds_zoomable:
266
- src = raster_format.decode_raster(
267
- raster_dir, final_projection, final_bounds
268
- )
269
-
270
- # Resize to patch size if needed.
271
- # This is for band sets that are stored at a lower resolution.
272
- # Here we assume that it is a multiple.
273
- if src.shape[1:3] != image.shape[1:3]:
274
- if src.shape[1] < image.shape[1]:
275
- factor = image.shape[1] // src.shape[1]
276
- src = src.repeat(repeats=factor, axis=1).repeat(
277
- repeats=factor, axis=2
278
- )
279
- else:
280
- factor = src.shape[1] // image.shape[1]
281
- src = src[:, ::factor, ::factor]
282
-
283
- else:
284
- src = raster_format.decode_raster(
285
- raster_dir, window.projection, bounds, resampling=Resampling.nearest
286
- )
273
+ # TODO: previously we try to read based on band_set.zoom_offset when possible,
274
+ # and handle zooming in with torch.repeat (if resampling method is nearest
275
+ # neighbor). However, we have not benchmarked whether this actually improves
276
+ # data loading speed, so for simplicity, for now we let rasterio handle the
277
+ # resampling. If it really is much faster to handle it via torch, then it may
278
+ # make sense to bring back that functionality.
287
279
 
280
+ src = raster_format.decode_raster(
281
+ raster_dir, final_projection, final_bounds, resampling=Resampling.nearest
282
+ )
288
283
  image[dst_indexes, :, :] = torch.as_tensor(
289
284
  src[src_indexes, :, :].astype(data_input.dtype.get_numpy_dtype())
290
285
  )
@@ -575,37 +570,7 @@ class ModelDataset(torch.utils.data.Dataset):
575
570
  else:
576
571
  self.patch_size = split_config.get_patch_size()
577
572
 
578
- if split_config.names:
579
- windows = self.dataset.load_windows(
580
- groups=split_config.groups,
581
- names=split_config.names,
582
- show_progress=True,
583
- workers=workers,
584
- )
585
- elif split_config.groups:
586
- windows = self.dataset.load_windows(
587
- groups=split_config.groups, show_progress=True, workers=workers
588
- )
589
- else:
590
- windows = self.dataset.load_windows(show_progress=True, workers=workers)
591
-
592
- if split_config.tags:
593
- # Filter the window.options.
594
- new_windows = []
595
- num_removed: dict[str, int] = {}
596
- for window in windows:
597
- for k, v in split_config.tags.items():
598
- if k not in window.options or (v and window.options[k] != v):
599
- num_removed[k] = num_removed.get(k, 0) + 1
600
- break
601
- else:
602
- new_windows.append(window)
603
- logger.info(
604
- f"Started with {len(windows)} windows, ended with {len(new_windows)} windows for {self.dataset.path}"
605
- )
606
- for k, v in num_removed.items():
607
- logger.info(f"Removed {v} windows due to tag {k}")
608
- windows = new_windows
573
+ windows = self._get_initial_windows(split_config, workers)
609
574
 
610
575
  # If targets are not needed, remove them from the inputs.
611
576
  if split_config.get_skip_targets():
@@ -615,17 +580,11 @@ class ModelDataset(torch.utils.data.Dataset):
615
580
 
616
581
  # Eliminate windows that are missing either a requisite input layer, or missing
617
582
  # all target layers.
618
- # We use only main thread if the index is set, since that can take a long time
619
- # to send to the worker threads, it may get serialized for each window.
620
583
  new_windows = []
621
- if workers == 0 or (len(windows) >= 1 and windows[0].index is not None):
584
+ if workers == 0:
622
585
  for window in windows:
623
586
  if check_window(self.inputs, window) is None:
624
587
  continue
625
- # The index may be set, but now that this check is done, from here on
626
- # we no longer need it. We set it None so that we don't end up passing
627
- # it later to the dataloader workers.
628
- window.index = None
629
588
  new_windows.append(window)
630
589
  else:
631
590
  p = multiprocessing.Pool(workers)
@@ -681,12 +640,62 @@ class ModelDataset(torch.utils.data.Dataset):
681
640
  with open(self.dataset_examples_fname, "w") as f:
682
641
  json.dump([self._serialize_item(example) for example in windows], f)
683
642
 
643
+ def _get_initial_windows(
644
+ self, split_config: SplitConfig, workers: int
645
+ ) -> list[Window]:
646
+ """Get the initial windows before input layer filtering.
647
+
648
+ The windows are filtered based on configured window names, groups, and tags.
649
+
650
+ This is a helper for the init function.
651
+
652
+ Args:
653
+ split_config: the split configuration.
654
+ workers: number of worker processes.
655
+
656
+ Returns:
657
+ list of windows from the dataset after applying the aforementioned filters.
658
+ """
659
+ # Load windows from dataset.
660
+ # If the window storage is FileWindowStorage, we pass the workers/show_progress arguments.
661
+ kwargs: dict[str, Any] = {}
662
+ if isinstance(self.dataset.storage, FileWindowStorage):
663
+ kwargs["workers"] = workers
664
+ kwargs["show_progress"] = True
665
+ # We also add the name/group filters to the kwargs.
666
+ if split_config.names:
667
+ kwargs["names"] = split_config.names
668
+ if split_config.groups:
669
+ kwargs["groups"] = split_config.groups
670
+
671
+ windows = self.dataset.load_windows(**kwargs)
672
+
673
+ # Filter by tags (if provided) using the window.options.
674
+ if split_config.tags:
675
+ new_windows = []
676
+ num_removed: dict[str, int] = {}
677
+ for window in windows:
678
+ for k, v in split_config.tags.items():
679
+ if k not in window.options or (v and window.options[k] != v):
680
+ num_removed[k] = num_removed.get(k, 0) + 1
681
+ break
682
+ else:
683
+ new_windows.append(window)
684
+ logger.info(
685
+ f"Started with {len(windows)} windows, ended with {len(new_windows)} windows for {self.dataset.path}"
686
+ )
687
+ for k, v in num_removed.items():
688
+ logger.info(f"Removed {v} windows due to tag {k}")
689
+ windows = new_windows
690
+
691
+ return windows
692
+
684
693
  def _serialize_item(self, example: Window) -> dict[str, Any]:
685
694
  return example.get_metadata()
686
695
 
687
696
  def _deserialize_item(self, d: dict[str, Any]) -> Window:
688
697
  return Window.from_metadata(
689
- Window.get_window_root(self.dataset.path, d["group"], d["name"]),
698
+ self.dataset.storage,
690
699
  d,
691
700
  )
692
701
 
@@ -713,7 +722,7 @@ class ModelDataset(torch.utils.data.Dataset):
713
722
 
714
723
  def get_raw_inputs(
715
724
  self, idx: int
716
- ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
725
+ ) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
717
726
  """Get the raw inputs and base metadata for this example.
718
727
 
719
728
  This is the raster or vector data before being processed by the Task. So it
@@ -775,21 +784,23 @@ class ModelDataset(torch.utils.data.Dataset):
775
784
  if data_input.passthrough:
776
785
  passthrough_inputs[name] = raw_inputs[name]
777
786
 
778
- metadata = {
779
- "group": window.group,
780
- "window_name": window.name,
781
- "window_bounds": window.bounds,
782
- "bounds": bounds,
783
- "time_range": window.time_range,
784
- "projection": window.projection,
785
- "dataset_source": self.name,
786
- }
787
+ metadata = SampleMetadata(
788
+ window_group=window.group,
789
+ window_name=window.name,
790
+ window_bounds=window.bounds,
791
+ patch_bounds=bounds,
792
+ patch_idx=0,
793
+ num_patches_in_window=1,
794
+ time_range=window.time_range,
795
+ projection=window.projection,
796
+ dataset_source=self.name,
797
+ )
787
798
 
788
799
  return raw_inputs, passthrough_inputs, metadata
789
800
 
790
801
  def __getitem__(
791
802
  self, idx: int
792
- ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
803
+ ) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
793
804
  """Read one training example.
794
805
 
795
806
  Args:
@@ -801,8 +812,6 @@ class ModelDataset(torch.utils.data.Dataset):
801
812
  logger.debug("__getitem__ start pid=%d item_idx=%d", os.getpid(), idx)
802
813
 
803
814
  raw_inputs, passthrough_inputs, metadata = self.get_raw_inputs(idx)
804
- metadata["patch_idx"] = 0
805
- metadata["num_patches"] = 1
806
815
 
807
816
  input_dict, target_dict = self.task.process_inputs(
808
817
  raw_inputs,
@@ -811,7 +820,6 @@ class ModelDataset(torch.utils.data.Dataset):
811
820
  )
812
821
  input_dict.update(passthrough_inputs)
813
822
  input_dict, target_dict = self.transforms(input_dict, target_dict)
814
- input_dict["dataset_source"] = self.name
815
823
 
816
824
  logger.debug("__getitem__ finish pid=%d item_idx=%d", os.getpid(), idx)
817
825
 
@@ -12,6 +12,7 @@ from upath import UPath
12
12
 
13
13
  from rslearn.log_utils import get_logger
14
14
 
15
+ from .model_context import ModelContext, ModelOutput
15
16
  from .optimizer import AdamW, OptimizerFactory
16
17
  from .scheduler import PlateauScheduler, SchedulerFactory
17
18
  from .tasks import Task
@@ -231,12 +232,16 @@ class RslearnLightningModule(L.LightningModule):
231
232
  Returns:
232
233
  The loss tensor.
233
234
  """
234
- inputs, targets, _ = batch
235
+ inputs, targets, metadatas = batch
236
+ context = ModelContext(
237
+ inputs=inputs,
238
+ metadatas=metadatas,
239
+ )
235
240
  batch_size = len(inputs)
236
- model_outputs = self(inputs, targets)
237
- self.on_train_forward(inputs, targets, model_outputs)
241
+ model_outputs = self(context, targets)
242
+ self.on_train_forward(context, targets, model_outputs)
238
243
 
239
- loss_dict = model_outputs["loss_dict"]
244
+ loss_dict = model_outputs.loss_dict
240
245
  train_loss = sum(loss_dict.values())
241
246
  self.log_dict(
242
247
  {"train_" + k: v for k, v in loss_dict.items()},
@@ -266,13 +271,17 @@ class RslearnLightningModule(L.LightningModule):
266
271
  batch_idx: Integer displaying index of this batch.
267
272
  dataloader_idx: Index of the current dataloader.
268
273
  """
269
- inputs, targets, _ = batch
274
+ inputs, targets, metadatas = batch
275
+ context = ModelContext(
276
+ inputs=inputs,
277
+ metadatas=metadatas,
278
+ )
270
279
  batch_size = len(inputs)
271
- model_outputs = self(inputs, targets)
272
- self.on_val_forward(inputs, targets, model_outputs)
280
+ model_outputs = self(context, targets)
281
+ self.on_val_forward(context, targets, model_outputs)
273
282
 
274
- loss_dict = model_outputs["loss_dict"]
275
- outputs = model_outputs["outputs"]
283
+ loss_dict = model_outputs.loss_dict
284
+ outputs = model_outputs.outputs
276
285
  val_loss = sum(loss_dict.values())
277
286
  self.log_dict(
278
287
  {"val_" + k: v for k, v in loss_dict.items()},
@@ -304,12 +313,16 @@ class RslearnLightningModule(L.LightningModule):
304
313
  dataloader_idx: Index of the current dataloader.
305
314
  """
306
315
  inputs, targets, metadatas = batch
316
+ context = ModelContext(
317
+ inputs=inputs,
318
+ metadatas=metadatas,
319
+ )
307
320
  batch_size = len(inputs)
308
- model_outputs = self(inputs, targets)
309
- self.on_test_forward(inputs, targets, model_outputs)
321
+ model_outputs = self(context, targets)
322
+ self.on_test_forward(context, targets, model_outputs)
310
323
 
311
- loss_dict = model_outputs["loss_dict"]
312
- outputs = model_outputs["outputs"]
324
+ loss_dict = model_outputs.loss_dict
325
+ outputs = model_outputs.outputs
313
326
  test_loss = sum(loss_dict.values())
314
327
  self.log_dict(
315
328
  {"test_" + k: v for k, v in loss_dict.items()},
@@ -345,7 +358,7 @@ class RslearnLightningModule(L.LightningModule):
345
358
 
346
359
  def predict_step(
347
360
  self, batch: Any, batch_idx: int, dataloader_idx: int = 0
348
- ) -> torch.Tensor:
361
+ ) -> ModelOutput:
349
362
  """Compute the predicted class probabilities.
350
363
 
351
364
  Args:
@@ -356,63 +369,69 @@ class RslearnLightningModule(L.LightningModule):
356
369
  Returns:
357
370
  Output predicted probabilities.
358
371
  """
359
- inputs, _, _ = batch
360
- model_outputs = self(inputs)
372
+ inputs, _, metadatas = batch
373
+ context = ModelContext(
374
+ inputs=inputs,
375
+ metadatas=metadatas,
376
+ )
377
+ model_outputs = self(context)
361
378
  return model_outputs
362
379
 
363
- def forward(self, *args: Any, **kwargs: Any) -> Any:
380
+ def forward(
381
+ self, context: ModelContext, targets: list[dict[str, Any]] | None = None
382
+ ) -> ModelOutput:
364
383
  """Forward pass of the model.
365
384
 
366
385
  Args:
367
- args: Arguments to pass to model.
368
- kwargs: Keyword arguments to pass to model.
386
+ context: the model context.
387
+ targets: the target dicts.
369
388
 
370
389
  Returns:
371
390
  Output of the model.
372
391
  """
373
- return self.model(*args, **kwargs)
392
+ return self.model(context, targets)
374
393
 
375
394
  def on_train_forward(
376
395
  self,
377
- inputs: list[dict[str, Any]],
396
+ context: ModelContext,
378
397
  targets: list[dict[str, Any]],
379
- model_outputs: dict[str, Any],
398
+ model_outputs: ModelOutput,
380
399
  ) -> None:
381
400
  """Hook to run after the forward pass of the model during training.
382
401
 
383
402
  Args:
384
- inputs: The input batch.
403
+ context: The model context.
385
404
  targets: The target batch.
386
- model_outputs: The output of the model, with keys "outputs" and "loss_dict", and possibly other keys.
405
+ model_outputs: The output of the model.
387
406
  """
388
407
  pass
389
408
 
390
409
  def on_val_forward(
391
410
  self,
392
- inputs: list[dict[str, Any]],
411
+ context: ModelContext,
393
412
  targets: list[dict[str, Any]],
394
- model_outputs: dict[str, Any],
413
+ model_outputs: ModelOutput,
395
414
  ) -> None:
396
415
  """Hook to run after the forward pass of the model during validation.
397
416
 
398
417
  Args:
399
- inputs: The input batch.
418
+ context: The model context.
400
419
  targets: The target batch.
401
- model_outputs: The output of the model, with keys "outputs" and "loss_dict", and possibly other keys.
420
+ model_outputs: The output of the model.
402
421
  """
403
422
  pass
404
423
 
405
424
  def on_test_forward(
406
425
  self,
407
- inputs: list[dict[str, Any]],
426
+ context: ModelContext,
408
427
  targets: list[dict[str, Any]],
409
- model_outputs: dict[str, Any],
428
+ model_outputs: ModelOutput,
410
429
  ) -> None:
411
430
  """Hook to run after the forward pass of the model during testing.
412
431
 
413
432
  Args:
414
- inputs: The input batch.
433
+ context: The model context.
415
434
  targets: The target batch.
416
- model_outputs: The output of the model, with keys "outputs" and "loss_dict", and possibly other keys.
435
+ model_outputs: The output of the model.
417
436
  """
418
437
  pass
@@ -0,0 +1,54 @@
1
+ """Data classes to provide various context to models."""
2
+
3
+ from collections.abc import Iterable
4
+ from dataclasses import dataclass, field
5
+ from datetime import datetime
6
+ from typing import Any
7
+
8
+ import torch
9
+
10
+ from rslearn.utils.geometry import PixelBounds, Projection
11
+
12
+
13
+ @dataclass
14
+ class SampleMetadata:
15
+ """Metadata pertaining to an example."""
16
+
17
+ window_group: str
18
+ window_name: str
19
+ window_bounds: PixelBounds
20
+ patch_bounds: PixelBounds
21
+ patch_idx: int
22
+ num_patches_in_window: int
23
+ time_range: tuple[datetime, datetime] | None
24
+ projection: Projection
25
+
26
+ # Task name to differentiate different tasks.
27
+ dataset_source: str | None
28
+
29
+
30
+ @dataclass
31
+ class ModelContext:
32
+ """Context to pass to all model components."""
33
+
34
+ # One input dict per example in the batch.
35
+ inputs: list[dict[str, torch.Tensor]]
36
+ # One SampleMetadata per example in the batch.
37
+ metadatas: list[SampleMetadata]
38
+ # Arbitrary dict that components can add to.
39
+ context_dict: dict[str, Any] = field(default_factory=lambda: {})
40
+
41
+
42
+ @dataclass
43
+ class ModelOutput:
44
+ """The output from the Predictor.
45
+
46
+ Args:
47
+ outputs: output compatible with the configured Task.
48
+ loss_dict: map from loss names to scalar tensors.
49
+ metadata: arbitrary dict that can be used to store other outputs.
50
+ """
51
+
52
+ outputs: Iterable[Any]
53
+ loss_dict: dict[str, torch.Tensor]
54
+ metadata: dict[str, Any] = field(default_factory=lambda: {})