rslearn 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. rslearn/arg_parser.py +2 -9
  2. rslearn/config/__init__.py +2 -0
  3. rslearn/config/dataset.py +64 -20
  4. rslearn/dataset/add_windows.py +1 -1
  5. rslearn/dataset/dataset.py +34 -84
  6. rslearn/dataset/materialize.py +5 -5
  7. rslearn/dataset/storage/__init__.py +1 -0
  8. rslearn/dataset/storage/file.py +202 -0
  9. rslearn/dataset/storage/storage.py +140 -0
  10. rslearn/dataset/window.py +26 -80
  11. rslearn/lightning_cli.py +22 -11
  12. rslearn/main.py +12 -37
  13. rslearn/models/anysat.py +11 -9
  14. rslearn/models/attention_pooling.py +177 -0
  15. rslearn/models/clay/clay.py +8 -9
  16. rslearn/models/clip.py +18 -15
  17. rslearn/models/component.py +111 -0
  18. rslearn/models/concatenate_features.py +21 -11
  19. rslearn/models/conv.py +15 -8
  20. rslearn/models/croma.py +13 -8
  21. rslearn/models/detr/detr.py +25 -14
  22. rslearn/models/dinov3.py +11 -6
  23. rslearn/models/faster_rcnn.py +19 -9
  24. rslearn/models/feature_center_crop.py +12 -9
  25. rslearn/models/fpn.py +19 -8
  26. rslearn/models/galileo/galileo.py +23 -18
  27. rslearn/models/module_wrapper.py +26 -57
  28. rslearn/models/molmo.py +16 -14
  29. rslearn/models/multitask.py +102 -73
  30. rslearn/models/olmoearth_pretrain/model.py +135 -38
  31. rslearn/models/panopticon.py +8 -7
  32. rslearn/models/pick_features.py +18 -24
  33. rslearn/models/pooling_decoder.py +22 -14
  34. rslearn/models/presto/presto.py +16 -10
  35. rslearn/models/presto/single_file_presto.py +4 -10
  36. rslearn/models/prithvi.py +12 -8
  37. rslearn/models/resize_features.py +21 -7
  38. rslearn/models/sam2_enc.py +11 -9
  39. rslearn/models/satlaspretrain.py +15 -9
  40. rslearn/models/simple_time_series.py +37 -17
  41. rslearn/models/singletask.py +24 -17
  42. rslearn/models/ssl4eo_s12.py +15 -10
  43. rslearn/models/swin.py +22 -13
  44. rslearn/models/terramind.py +24 -7
  45. rslearn/models/trunk.py +6 -3
  46. rslearn/models/unet.py +18 -9
  47. rslearn/models/upsample.py +22 -9
  48. rslearn/train/all_patches_dataset.py +89 -37
  49. rslearn/train/dataset.py +105 -97
  50. rslearn/train/lightning_module.py +51 -32
  51. rslearn/train/model_context.py +54 -0
  52. rslearn/train/prediction_writer.py +111 -41
  53. rslearn/train/scheduler.py +15 -0
  54. rslearn/train/tasks/classification.py +34 -15
  55. rslearn/train/tasks/detection.py +24 -31
  56. rslearn/train/tasks/embedding.py +33 -29
  57. rslearn/train/tasks/multi_task.py +7 -7
  58. rslearn/train/tasks/per_pixel_regression.py +41 -19
  59. rslearn/train/tasks/regression.py +38 -21
  60. rslearn/train/tasks/segmentation.py +33 -15
  61. rslearn/train/tasks/task.py +3 -2
  62. rslearn/train/transforms/resize.py +74 -0
  63. rslearn/utils/geometry.py +73 -0
  64. rslearn/utils/jsonargparse.py +66 -0
  65. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
  66. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/RECORD +71 -66
  67. rslearn/dataset/index.py +0 -173
  68. rslearn/models/registry.py +0 -22
  69. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
  70. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
  71. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
  72. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
  73. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  """rslearn PredictionWriter implementation."""
2
2
 
3
- from collections.abc import Sequence
3
+ import json
4
+ from collections.abc import Iterable, Sequence
4
5
  from dataclasses import dataclass
5
6
  from pathlib import Path
6
7
  from typing import Any
@@ -12,11 +13,15 @@ from lightning.pytorch.callbacks import BasePredictionWriter
12
13
  from upath import UPath
13
14
 
14
15
  from rslearn.config import (
16
+ DatasetConfig,
15
17
  LayerConfig,
16
18
  LayerType,
19
+ StorageConfig,
17
20
  )
18
- from rslearn.dataset import Dataset, Window
21
+ from rslearn.dataset import Window
22
+ from rslearn.dataset.storage.storage import WindowStorage
19
23
  from rslearn.log_utils import get_logger
24
+ from rslearn.train.model_context import SampleMetadata
20
25
  from rslearn.utils.array import copy_spatial_array
21
26
  from rslearn.utils.feature import Feature
22
27
  from rslearn.utils.geometry import PixelBounds
@@ -27,6 +32,7 @@ from rslearn.utils.raster_format import (
27
32
  from rslearn.utils.vector_format import VectorFormat
28
33
 
29
34
  from .lightning_module import RslearnLightningModule
35
+ from .model_context import ModelOutput
30
36
  from .tasks.task import Task
31
37
 
32
38
  logger = get_logger(__name__)
@@ -43,12 +49,18 @@ class PendingPatchOutput:
43
49
  class PatchPredictionMerger:
44
50
  """Base class for merging predictions from multiple patches."""
45
51
 
46
- def merge(self, window: Window, outputs: Sequence[PendingPatchOutput]) -> Any:
52
+ def merge(
53
+ self,
54
+ window: Window,
55
+ outputs: Sequence[PendingPatchOutput],
56
+ layer_config: LayerConfig,
57
+ ) -> Any:
47
58
  """Merge the outputs.
48
59
 
49
60
  Args:
50
61
  window: the window we are merging the outputs for.
51
62
  outputs: the outputs to process.
63
+ layer_config: the output layer configuration.
52
64
 
53
65
  Returns:
54
66
  the merged outputs.
@@ -60,7 +72,10 @@ class VectorMerger(PatchPredictionMerger):
60
72
  """Merger for vector data that simply concatenates the features."""
61
73
 
62
74
  def merge(
63
- self, window: Window, outputs: Sequence[PendingPatchOutput]
75
+ self,
76
+ window: Window,
77
+ outputs: Sequence[PendingPatchOutput],
78
+ layer_config: LayerConfig,
64
79
  ) -> list[Feature]:
65
80
  """Concatenate the vector features."""
66
81
  return [feat for output in outputs for feat in output.output]
@@ -83,18 +98,20 @@ class RasterMerger(PatchPredictionMerger):
83
98
  self.downsample_factor = downsample_factor
84
99
 
85
100
  def merge(
86
- self, window: Window, outputs: Sequence[PendingPatchOutput]
101
+ self,
102
+ window: Window,
103
+ outputs: Sequence[PendingPatchOutput],
104
+ layer_config: LayerConfig,
87
105
  ) -> npt.NDArray:
88
106
  """Merge the raster outputs."""
89
107
  num_channels = outputs[0].output.shape[0]
90
- dtype = outputs[0].output.dtype
91
108
  merged_image = np.zeros(
92
109
  (
93
110
  num_channels,
94
111
  (window.bounds[3] - window.bounds[1]) // self.downsample_factor,
95
112
  (window.bounds[2] - window.bounds[0]) // self.downsample_factor,
96
113
  ),
97
- dtype=dtype,
114
+ dtype=layer_config.band_sets[0].dtype.get_numpy_dtype(),
98
115
  )
99
116
 
100
117
  # Ensure the outputs are sorted by height then width.
@@ -148,6 +165,7 @@ class RslearnWriter(BasePredictionWriter):
148
165
  merger: PatchPredictionMerger | None = None,
149
166
  output_path: str | Path | None = None,
150
167
  layer_config: LayerConfig | None = None,
168
+ storage_config: StorageConfig | None = None,
151
169
  ):
152
170
  """Create a new RslearnWriter.
153
171
 
@@ -163,28 +181,24 @@ class RslearnWriter(BasePredictionWriter):
163
181
  layer_config: optional layer configuration. If provided, this config will be
164
182
  used instead of reading from the dataset config, allowing usage without
165
183
  requiring dataset config at the output path.
184
+ storage_config: optional storage configuration, needed similar to layer_config
185
+ if there is no dataset config.
166
186
  """
167
187
  super().__init__(write_interval="batch")
168
188
  self.output_layer = output_layer
169
189
  self.selector = selector or []
170
- self.path = UPath(path, **path_options or {})
171
- self.output_path = (
190
+ ds_upath = UPath(path, **path_options or {})
191
+ output_upath = (
172
192
  UPath(output_path, **path_options or {})
173
193
  if output_path is not None
174
- else None
194
+ else ds_upath
175
195
  )
176
196
 
177
- # Handle dataset and layer config
178
- self.layer_config: LayerConfig
179
- if layer_config:
180
- self.layer_config = layer_config
181
- else:
182
- dataset = Dataset(self.path)
183
- if self.output_layer not in dataset.layers:
184
- raise KeyError(
185
- f"Output layer '{self.output_layer}' not found in dataset layers."
186
- )
187
- self.layer_config = dataset.layers[self.output_layer]
197
+ self.layer_config, self.dataset_storage = (
198
+ self._get_layer_config_and_dataset_storage(
199
+ ds_upath, output_upath, layer_config, storage_config
200
+ )
201
+ )
188
202
 
189
203
  self.format: RasterFormat | VectorFormat
190
204
  if self.layer_config.type == LayerType.RASTER:
@@ -207,11 +221,73 @@ class RslearnWriter(BasePredictionWriter):
207
221
  # patches of each window need to be reconstituted.
208
222
  self.pending_outputs: dict[str, list[PendingPatchOutput]] = {}
209
223
 
224
+ def _get_layer_config_and_dataset_storage(
225
+ self,
226
+ ds_upath: UPath,
227
+ output_upath: UPath,
228
+ layer_config: LayerConfig | None,
229
+ storage_config: StorageConfig | None,
230
+ ) -> tuple[LayerConfig, WindowStorage]:
231
+ """Get the layer config and dataset storage to use.
232
+
233
+ This is a helper function for the init method.
234
+
235
+ If layer_config is set, we use that. If storage_config is set, we use it to
236
+ instantiate a WindowStorage using the output_upath.
237
+
238
+ If one of them is not set, we load the config from the ds_upath. Otherwise, we
239
+ avoid reading the dataset config; this way, RslearnWriter can be used with
240
+ output directories that do not contain the dataset config, as long as
241
+ layer_config and storage_config are both provided.
242
+
243
+ Args:
244
+ ds_upath: the dataset path, where a dataset config can be loaded from if
245
+ layer_config or storage_config is not provided.
246
+ output_upath: the output directory, which could be different from the
247
+ dataset path.
248
+ layer_config: optional LayerConfig to provide.
249
+ storage_config: optional StorageConfig to provide.
250
+
251
+ Returns:
252
+ a tuple (layer_config, dataset_storage)
253
+ """
254
+ dataset_storage: WindowStorage | None = None
255
+
256
+ # Instantiate the WindowStorage from the storage_config if provided.
257
+ if storage_config:
258
+ dataset_storage = (
259
+ storage_config.instantiate_window_storage_factory().get_storage(
260
+ output_upath
261
+ )
262
+ )
263
+
264
+ if not layer_config or not dataset_storage:
265
+ # Need to load dataset config since one of LayerConfig/StorageConfig is missing.
266
+ # We use DatasetConfig.model_validate instead of initializing the Dataset
267
+ # because we want to get a WindowStorage that has the dataset path set to
268
+ # output_upath instead of ds_upath.
269
+ with (ds_upath / "config.json").open() as f:
270
+ dataset_config = DatasetConfig.model_validate(json.load(f))
271
+
272
+ if not layer_config:
273
+ if self.output_layer not in dataset_config.layers:
274
+ raise KeyError(
275
+ f"Output layer '{self.output_layer}' not found in dataset layers."
276
+ )
277
+ layer_config = dataset_config.layers[self.output_layer]
278
+
279
+ if not dataset_storage:
280
+ dataset_storage = dataset_config.storage.instantiate_window_storage_factory().get_storage(
281
+ output_upath
282
+ )
283
+
284
+ return (layer_config, dataset_storage)
285
+
210
286
  def write_on_batch_end(
211
287
  self,
212
288
  trainer: Trainer,
213
289
  pl_module: LightningModule,
214
- prediction: dict[str, Sequence],
290
+ prediction: ModelOutput,
215
291
  batch_indices: Sequence[int] | None,
216
292
  batch: tuple[list, list, list],
217
293
  batch_idx: int,
@@ -232,13 +308,13 @@ class RslearnWriter(BasePredictionWriter):
232
308
  assert isinstance(pl_module, RslearnLightningModule)
233
309
  task = pl_module.task
234
310
  _, _, metadatas = batch
235
- self.process_output_batch(task, prediction["outputs"], metadatas)
311
+ self.process_output_batch(task, prediction.outputs, metadatas)
236
312
 
237
313
  def process_output_batch(
238
314
  self,
239
315
  task: Task,
240
- prediction: Sequence,
241
- metadatas: Sequence,
316
+ prediction: Iterable[Any],
317
+ metadatas: Iterable[SampleMetadata],
242
318
  ) -> None:
243
319
  """Write a prediction batch with simplified API.
244
320
 
@@ -263,25 +339,19 @@ class RslearnWriter(BasePredictionWriter):
263
339
  for k in self.selector:
264
340
  output = output[k]
265
341
 
266
- # Use custom output_path if provided, otherwise use dataset path
267
- window_base_path = (
268
- self.output_path if self.output_path is not None else self.path
269
- )
270
342
  window = Window(
271
- path=Window.get_window_root(
272
- window_base_path, metadata["group"], metadata["window_name"]
273
- ),
274
- group=metadata["group"],
275
- name=metadata["window_name"],
276
- projection=metadata["projection"],
277
- bounds=metadata["window_bounds"],
278
- time_range=metadata["time_range"],
343
+ storage=self.dataset_storage,
344
+ group=metadata.window_group,
345
+ name=metadata.window_name,
346
+ projection=metadata.projection,
347
+ bounds=metadata.window_bounds,
348
+ time_range=metadata.time_range,
279
349
  )
280
350
  self.process_output(
281
351
  window,
282
- metadata["patch_idx"],
283
- metadata["num_patches"],
284
- metadata["bounds"],
352
+ metadata.patch_idx,
353
+ metadata.num_patches_in_window,
354
+ metadata.patch_bounds,
285
355
  output,
286
356
  )
287
357
 
@@ -320,7 +390,7 @@ class RslearnWriter(BasePredictionWriter):
320
390
 
321
391
  # Merge outputs from overlapped patches if merger is set.
322
392
  logger.debug(f"Merging and writing for window {window.name}")
323
- merged_output = self.merger.merge(window, pending_output)
393
+ merged_output = self.merger.merge(window, pending_output, self.layer_config)
324
394
 
325
395
  if self.layer_config.type == LayerType.RASTER:
326
396
  raster_dir = window.get_raster_dir(
@@ -8,6 +8,7 @@ from torch.optim.lr_scheduler import (
8
8
  CosineAnnealingLR,
9
9
  CosineAnnealingWarmRestarts,
10
10
  LRScheduler,
11
+ MultiStepLR,
11
12
  ReduceLROnPlateau,
12
13
  )
13
14
 
@@ -50,6 +51,20 @@ class PlateauScheduler(SchedulerFactory):
50
51
  return ReduceLROnPlateau(optimizer, **self.get_kwargs())
51
52
 
52
53
 
54
+ @dataclass
55
+ class MultiStepScheduler(SchedulerFactory):
56
+ """Step learning rate scheduler."""
57
+
58
+ milestones: list[int]
59
+ gamma: float | None = None
60
+ last_epoch: int | None = None
61
+
62
+ def build(self, optimizer: Optimizer) -> LRScheduler:
63
+ """Build the ReduceLROnPlateau scheduler."""
64
+ super().build(optimizer)
65
+ return MultiStepLR(optimizer, **self.get_kwargs())
66
+
67
+
53
68
  @dataclass
54
69
  class CosineAnnealingScheduler(SchedulerFactory):
55
70
  """Cosine annealing learning rate scheduler."""
@@ -15,6 +15,8 @@ from torchmetrics.classification import (
15
15
  MulticlassRecall,
16
16
  )
17
17
 
18
+ from rslearn.models.component import FeatureVector, Predictor
19
+ from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
18
20
  from rslearn.utils import Feature, STGeometry
19
21
 
20
22
  from .task import BasicTask
@@ -98,7 +100,7 @@ class ClassificationTask(BasicTask):
98
100
  def process_inputs(
99
101
  self,
100
102
  raw_inputs: dict[str, torch.Tensor | list[Feature]],
101
- metadata: dict[str, Any],
103
+ metadata: SampleMetadata,
102
104
  load_targets: bool = True,
103
105
  ) -> tuple[dict[str, Any], dict[str, Any]]:
104
106
  """Processes the data into targets.
@@ -154,17 +156,25 @@ class ClassificationTask(BasicTask):
154
156
  }
155
157
 
156
158
  def process_output(
157
- self, raw_output: Any, metadata: dict[str, Any]
158
- ) -> npt.NDArray[Any] | list[Feature]:
159
+ self, raw_output: Any, metadata: SampleMetadata
160
+ ) -> list[Feature]:
159
161
  """Processes an output into raster or vector data.
160
162
 
161
163
  Args:
162
- raw_output: the output from prediction head.
164
+ raw_output: the output from prediction head, which must be a tensor
165
+ containing output probabilities (one dimension).
163
166
  metadata: metadata about the patch being read
164
167
 
165
168
  Returns:
166
- either raster or vector data.
169
+ a list with one Feature corresponding to the input patch extent with a
170
+ property name containing the predicted class. It will have another
171
+ property containing the probabilities if prob_property was set.
167
172
  """
173
+ if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 1:
174
+ raise ValueError(
175
+ "expected output for ClassificationTask to be a Tensor with one dimension"
176
+ )
177
+
168
178
  probs = raw_output.cpu().numpy()
169
179
  if len(self.classes) == 2 and self.positive_class_threshold != 0.5:
170
180
  positive_class_prob = probs[self.positive_class_id]
@@ -184,8 +194,8 @@ class ClassificationTask(BasicTask):
184
194
 
185
195
  feature = Feature(
186
196
  STGeometry(
187
- metadata["projection"],
188
- shapely.Point(metadata["bounds"][0], metadata["bounds"][1]),
197
+ metadata.projection,
198
+ shapely.Point(metadata.patch_bounds[0], metadata.patch_bounds[1]),
189
199
  None,
190
200
  ),
191
201
  {
@@ -265,25 +275,31 @@ class ClassificationTask(BasicTask):
265
275
  return MetricCollection(metrics)
266
276
 
267
277
 
268
- class ClassificationHead(torch.nn.Module):
278
+ class ClassificationHead(Predictor):
269
279
  """Head for classification task."""
270
280
 
271
281
  def forward(
272
282
  self,
273
- logits: torch.Tensor,
274
- inputs: list[dict[str, Any]],
283
+ intermediates: Any,
284
+ context: ModelContext,
275
285
  targets: list[dict[str, Any]] | None = None,
276
- ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
286
+ ) -> ModelOutput:
277
287
  """Compute the classification outputs and loss from logits and targets.
278
288
 
279
289
  Args:
280
- logits: tensor that is (BatchSize, NumClasses) in shape.
281
- inputs: original inputs (ignored).
282
- targets: should contain class key that stores the class label.
290
+ intermediates: output from the previous model component, it should be a
291
+ FeatureVector with a tensor that is (BatchSize, NumClasses) in shape.
292
+ context: the model context.
293
+ targets: must contain "class" key that stores the class label, along with
294
+ "valid" key indicating whether the label is valid for each example.
283
295
 
284
296
  Returns:
285
297
  tuple of outputs and loss dict
286
298
  """
299
+ if not isinstance(intermediates, FeatureVector):
300
+ raise ValueError("the input to ClassificationHead must be a FeatureVector")
301
+
302
+ logits = intermediates.feature_vector
287
303
  outputs = torch.nn.functional.softmax(logits, dim=1)
288
304
 
289
305
  losses = {}
@@ -298,7 +314,10 @@ class ClassificationHead(torch.nn.Module):
298
314
  )
299
315
  losses["cls"] = torch.mean(loss)
300
316
 
301
- return outputs, losses
317
+ return ModelOutput(
318
+ outputs=outputs,
319
+ loss_dict=losses,
320
+ )
302
321
 
303
322
 
304
323
  class ClassificationMetric(Metric):
@@ -12,6 +12,7 @@ import torchmetrics.classification
12
12
  import torchvision
13
13
  from torchmetrics import Metric, MetricCollection
14
14
 
15
+ from rslearn.train.model_context import SampleMetadata
15
16
  from rslearn.utils import Feature, STGeometry
16
17
 
17
18
  from .task import BasicTask
@@ -127,7 +128,7 @@ class DetectionTask(BasicTask):
127
128
  def process_inputs(
128
129
  self,
129
130
  raw_inputs: dict[str, torch.Tensor | list[Feature]],
130
- metadata: dict[str, Any],
131
+ metadata: SampleMetadata,
131
132
  load_targets: bool = True,
132
133
  ) -> tuple[dict[str, Any], dict[str, Any]]:
133
134
  """Processes the data into targets.
@@ -144,6 +145,8 @@ class DetectionTask(BasicTask):
144
145
  if not load_targets:
145
146
  return {}, {}
146
147
 
148
+ bounds = metadata.patch_bounds
149
+
147
150
  boxes = []
148
151
  class_labels = []
149
152
  valid = 1
@@ -186,39 +189,33 @@ class DetectionTask(BasicTask):
186
189
  else:
187
190
  box = [int(val) for val in shp.bounds]
188
191
 
189
- if box[0] >= metadata["bounds"][2] or box[2] <= metadata["bounds"][0]:
192
+ if box[0] >= bounds[2] or box[2] <= bounds[0]:
190
193
  continue
191
- if box[1] >= metadata["bounds"][3] or box[3] <= metadata["bounds"][1]:
194
+ if box[1] >= bounds[3] or box[3] <= bounds[1]:
192
195
  continue
193
196
 
194
197
  if self.exclude_by_center:
195
198
  center_col = (box[0] + box[2]) // 2
196
199
  center_row = (box[1] + box[3]) // 2
197
- if (
198
- center_col <= metadata["bounds"][0]
199
- or center_col >= metadata["bounds"][2]
200
- ):
200
+ if center_col <= bounds[0] or center_col >= bounds[2]:
201
201
  continue
202
- if (
203
- center_row <= metadata["bounds"][1]
204
- or center_row >= metadata["bounds"][3]
205
- ):
202
+ if center_row <= bounds[1] or center_row >= bounds[3]:
206
203
  continue
207
204
 
208
205
  if self.clip_boxes:
209
206
  box = [
210
- np.clip(box[0], metadata["bounds"][0], metadata["bounds"][2]),
211
- np.clip(box[1], metadata["bounds"][1], metadata["bounds"][3]),
212
- np.clip(box[2], metadata["bounds"][0], metadata["bounds"][2]),
213
- np.clip(box[3], metadata["bounds"][1], metadata["bounds"][3]),
207
+ np.clip(box[0], bounds[0], bounds[2]),
208
+ np.clip(box[1], bounds[1], bounds[3]),
209
+ np.clip(box[2], bounds[0], bounds[2]),
210
+ np.clip(box[3], bounds[1], bounds[3]),
214
211
  ]
215
212
 
216
213
  # Convert to relative coordinates.
217
214
  box = [
218
- box[0] - metadata["bounds"][0],
219
- box[1] - metadata["bounds"][1],
220
- box[2] - metadata["bounds"][0],
221
- box[3] - metadata["bounds"][1],
215
+ box[0] - bounds[0],
216
+ box[1] - bounds[1],
217
+ box[2] - bounds[0],
218
+ box[3] - bounds[1],
222
219
  ]
223
220
 
224
221
  boxes.append(box)
@@ -238,16 +235,12 @@ class DetectionTask(BasicTask):
238
235
  "valid": torch.tensor(valid, dtype=torch.int32),
239
236
  "boxes": boxes,
240
237
  "labels": class_labels,
241
- "width": torch.tensor(
242
- metadata["bounds"][2] - metadata["bounds"][0], dtype=torch.float32
243
- ),
244
- "height": torch.tensor(
245
- metadata["bounds"][3] - metadata["bounds"][1], dtype=torch.float32
246
- ),
238
+ "width": torch.tensor(bounds[2] - bounds[0], dtype=torch.float32),
239
+ "height": torch.tensor(bounds[3] - bounds[1], dtype=torch.float32),
247
240
  }
248
241
 
249
242
  def process_output(
250
- self, raw_output: Any, metadata: dict[str, Any]
243
+ self, raw_output: Any, metadata: SampleMetadata
251
244
  ) -> npt.NDArray[Any] | list[Feature]:
252
245
  """Processes an output into raster or vector data.
253
246
 
@@ -267,12 +260,12 @@ class DetectionTask(BasicTask):
267
260
  features = []
268
261
  for box, class_id, score in zip(boxes, class_ids, scores):
269
262
  shp = shapely.box(
270
- metadata["bounds"][0] + float(box[0]),
271
- metadata["bounds"][1] + float(box[1]),
272
- metadata["bounds"][0] + float(box[2]),
273
- metadata["bounds"][1] + float(box[3]),
263
+ metadata.patch_bounds[0] + float(box[0]),
264
+ metadata.patch_bounds[1] + float(box[1]),
265
+ metadata.patch_bounds[0] + float(box[2]),
266
+ metadata.patch_bounds[1] + float(box[3]),
274
267
  )
275
- geom = STGeometry(metadata["projection"], shp, None)
268
+ geom = STGeometry(metadata.projection, shp, None)
276
269
  properties: dict[str, Any] = {
277
270
  "score": float(score),
278
271
  }
@@ -6,6 +6,8 @@ import numpy.typing as npt
6
6
  import torch
7
7
  from torchmetrics import MetricCollection
8
8
 
9
+ from rslearn.models.component import FeatureMaps
10
+ from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
9
11
  from rslearn.utils import Feature
10
12
 
11
13
  from .task import Task
@@ -21,7 +23,7 @@ class EmbeddingTask(Task):
21
23
  def process_inputs(
22
24
  self,
23
25
  raw_inputs: dict[str, torch.Tensor],
24
- metadata: dict[str, Any],
26
+ metadata: SampleMetadata,
25
27
  load_targets: bool = True,
26
28
  ) -> tuple[dict[str, Any], dict[str, Any]]:
27
29
  """Processes the data into targets.
@@ -38,17 +40,22 @@ class EmbeddingTask(Task):
38
40
  return {}, {}
39
41
 
40
42
  def process_output(
41
- self, raw_output: Any, metadata: dict[str, Any]
43
+ self, raw_output: Any, metadata: SampleMetadata
42
44
  ) -> npt.NDArray[Any] | list[Feature]:
43
45
  """Processes an output into raster or vector data.
44
46
 
45
47
  Args:
46
- raw_output: the output from prediction head.
48
+ raw_output: the output from prediction head, which must be a CxHxW tensor.
47
49
  metadata: metadata about the patch being read
48
50
 
49
51
  Returns:
50
52
  either raster or vector data.
51
53
  """
54
+ if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 3:
55
+ raise ValueError(
56
+ "output for EmbeddingTask must be a tensor with three dimensions"
57
+ )
58
+
52
59
  # Just convert the raw output to numpy array that can be saved to GeoTIFF.
53
60
  return raw_output.cpu().numpy()
54
61
 
@@ -76,41 +83,38 @@ class EmbeddingTask(Task):
76
83
  return MetricCollection({})
77
84
 
78
85
 
79
- class EmbeddingHead(torch.nn.Module):
86
+ class EmbeddingHead:
80
87
  """Head for embedding task.
81
88
 
82
- This picks one feature map from the input list of feature maps to output. It also
83
- returns a dummy loss.
89
+ It just adds a dummy loss to act as a Predictor.
84
90
  """
85
91
 
86
- def __init__(self, feature_map_index: int | None = 0):
87
- """Create a new EmbeddingHead.
88
-
89
- Args:
90
- feature_map_index: the index of the feature map to choose from the input
91
- list of multi-scale feature maps (default 0). If the input is already
92
- a single feature map, then set to None.
93
- """
94
- super().__init__()
95
- self.feature_map_index = feature_map_index
96
-
97
92
  def forward(
98
93
  self,
99
- features: torch.Tensor,
100
- inputs: list[dict[str, Any]],
94
+ intermediates: Any,
95
+ context: ModelContext,
101
96
  targets: list[dict[str, Any]] | None = None,
102
- ) -> tuple[torch.Tensor, dict[str, Any]]:
103
- """Select the desired feature map and return it along with a dummy loss.
97
+ ) -> ModelOutput:
98
+ """Return the feature map along with a dummy loss.
104
99
 
105
100
  Args:
106
- features: list of BCHW feature maps (or one feature map, if feature_map_index is None).
107
- inputs: original inputs (ignored).
108
- targets: should contain classes key that stores the per-pixel class labels.
101
+ intermediates: output from the previous model component, which must be a
102
+ FeatureMaps consisting of a single feature map.
103
+ context: the model context.
104
+ targets: the targets (ignored).
109
105
 
110
106
  Returns:
111
- tuple of outputs and loss dict
107
+ model output with the feature map that was input to this component along
108
+ with a dummy loss.
112
109
  """
113
- if self.feature_map_index is not None:
114
- features = features[self.feature_map_index]
115
-
116
- return features, {"loss": 0}
110
+ if not isinstance(intermediates, FeatureMaps):
111
+ raise ValueError("input to EmbeddingHead must be a FeatureMaps")
112
+ if len(intermediates.feature_maps) != 1:
113
+ raise ValueError(
114
+ f"input to EmbeddingHead must have one feature map, but got {len(intermediates.feature_maps)}"
115
+ )
116
+
117
+ return ModelOutput(
118
+ outputs=intermediates.feature_maps[0],
119
+ loss_dict={"loss": 0},
120
+ )