rslearn 0.0.16__py3-none-any.whl → 0.0.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. rslearn/config/__init__.py +2 -0
  2. rslearn/config/dataset.py +55 -4
  3. rslearn/dataset/add_windows.py +1 -1
  4. rslearn/dataset/dataset.py +9 -65
  5. rslearn/dataset/materialize.py +5 -5
  6. rslearn/dataset/storage/__init__.py +1 -0
  7. rslearn/dataset/storage/file.py +202 -0
  8. rslearn/dataset/storage/storage.py +140 -0
  9. rslearn/dataset/window.py +26 -80
  10. rslearn/lightning_cli.py +10 -3
  11. rslearn/main.py +11 -36
  12. rslearn/models/anysat.py +11 -9
  13. rslearn/models/clay/clay.py +8 -9
  14. rslearn/models/clip.py +18 -15
  15. rslearn/models/component.py +99 -0
  16. rslearn/models/concatenate_features.py +21 -11
  17. rslearn/models/conv.py +15 -8
  18. rslearn/models/croma.py +13 -8
  19. rslearn/models/detr/detr.py +25 -14
  20. rslearn/models/dinov3.py +11 -6
  21. rslearn/models/faster_rcnn.py +19 -9
  22. rslearn/models/feature_center_crop.py +12 -9
  23. rslearn/models/fpn.py +19 -8
  24. rslearn/models/galileo/galileo.py +23 -18
  25. rslearn/models/module_wrapper.py +26 -57
  26. rslearn/models/molmo.py +16 -14
  27. rslearn/models/multitask.py +102 -73
  28. rslearn/models/olmoearth_pretrain/model.py +20 -17
  29. rslearn/models/panopticon.py +8 -7
  30. rslearn/models/pick_features.py +18 -24
  31. rslearn/models/pooling_decoder.py +22 -14
  32. rslearn/models/presto/presto.py +16 -10
  33. rslearn/models/presto/single_file_presto.py +4 -10
  34. rslearn/models/prithvi.py +12 -8
  35. rslearn/models/resize_features.py +21 -7
  36. rslearn/models/sam2_enc.py +11 -9
  37. rslearn/models/satlaspretrain.py +15 -9
  38. rslearn/models/simple_time_series.py +31 -17
  39. rslearn/models/singletask.py +24 -17
  40. rslearn/models/ssl4eo_s12.py +15 -10
  41. rslearn/models/swin.py +22 -13
  42. rslearn/models/terramind.py +24 -7
  43. rslearn/models/trunk.py +6 -3
  44. rslearn/models/unet.py +18 -9
  45. rslearn/models/upsample.py +22 -9
  46. rslearn/train/all_patches_dataset.py +22 -18
  47. rslearn/train/dataset.py +69 -54
  48. rslearn/train/lightning_module.py +51 -32
  49. rslearn/train/model_context.py +54 -0
  50. rslearn/train/prediction_writer.py +111 -41
  51. rslearn/train/tasks/classification.py +34 -15
  52. rslearn/train/tasks/detection.py +24 -31
  53. rslearn/train/tasks/embedding.py +33 -29
  54. rslearn/train/tasks/multi_task.py +7 -7
  55. rslearn/train/tasks/per_pixel_regression.py +41 -19
  56. rslearn/train/tasks/regression.py +38 -21
  57. rslearn/train/tasks/segmentation.py +33 -15
  58. rslearn/train/tasks/task.py +3 -2
  59. {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/METADATA +58 -25
  60. {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/RECORD +65 -62
  61. rslearn/dataset/index.py +0 -173
  62. rslearn/models/registry.py +0 -22
  63. {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/WHEEL +0 -0
  64. {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/entry_points.txt +0 -0
  65. {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/licenses/LICENSE +0 -0
  66. {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/licenses/NOTICE +0 -0
  67. {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@ from upath import UPath
12
12
 
13
13
  from rslearn.log_utils import get_logger
14
14
 
15
+ from .model_context import ModelContext, ModelOutput
15
16
  from .optimizer import AdamW, OptimizerFactory
16
17
  from .scheduler import PlateauScheduler, SchedulerFactory
17
18
  from .tasks import Task
@@ -231,12 +232,16 @@ class RslearnLightningModule(L.LightningModule):
231
232
  Returns:
232
233
  The loss tensor.
233
234
  """
234
- inputs, targets, _ = batch
235
+ inputs, targets, metadatas = batch
236
+ context = ModelContext(
237
+ inputs=inputs,
238
+ metadatas=metadatas,
239
+ )
235
240
  batch_size = len(inputs)
236
- model_outputs = self(inputs, targets)
237
- self.on_train_forward(inputs, targets, model_outputs)
241
+ model_outputs = self(context, targets)
242
+ self.on_train_forward(context, targets, model_outputs)
238
243
 
239
- loss_dict = model_outputs["loss_dict"]
244
+ loss_dict = model_outputs.loss_dict
240
245
  train_loss = sum(loss_dict.values())
241
246
  self.log_dict(
242
247
  {"train_" + k: v for k, v in loss_dict.items()},
@@ -266,13 +271,17 @@ class RslearnLightningModule(L.LightningModule):
266
271
  batch_idx: Integer displaying index of this batch.
267
272
  dataloader_idx: Index of the current dataloader.
268
273
  """
269
- inputs, targets, _ = batch
274
+ inputs, targets, metadatas = batch
275
+ context = ModelContext(
276
+ inputs=inputs,
277
+ metadatas=metadatas,
278
+ )
270
279
  batch_size = len(inputs)
271
- model_outputs = self(inputs, targets)
272
- self.on_val_forward(inputs, targets, model_outputs)
280
+ model_outputs = self(context, targets)
281
+ self.on_val_forward(context, targets, model_outputs)
273
282
 
274
- loss_dict = model_outputs["loss_dict"]
275
- outputs = model_outputs["outputs"]
283
+ loss_dict = model_outputs.loss_dict
284
+ outputs = model_outputs.outputs
276
285
  val_loss = sum(loss_dict.values())
277
286
  self.log_dict(
278
287
  {"val_" + k: v for k, v in loss_dict.items()},
@@ -304,12 +313,16 @@ class RslearnLightningModule(L.LightningModule):
304
313
  dataloader_idx: Index of the current dataloader.
305
314
  """
306
315
  inputs, targets, metadatas = batch
316
+ context = ModelContext(
317
+ inputs=inputs,
318
+ metadatas=metadatas,
319
+ )
307
320
  batch_size = len(inputs)
308
- model_outputs = self(inputs, targets)
309
- self.on_test_forward(inputs, targets, model_outputs)
321
+ model_outputs = self(context, targets)
322
+ self.on_test_forward(context, targets, model_outputs)
310
323
 
311
- loss_dict = model_outputs["loss_dict"]
312
- outputs = model_outputs["outputs"]
324
+ loss_dict = model_outputs.loss_dict
325
+ outputs = model_outputs.outputs
313
326
  test_loss = sum(loss_dict.values())
314
327
  self.log_dict(
315
328
  {"test_" + k: v for k, v in loss_dict.items()},
@@ -345,7 +358,7 @@ class RslearnLightningModule(L.LightningModule):
345
358
 
346
359
  def predict_step(
347
360
  self, batch: Any, batch_idx: int, dataloader_idx: int = 0
348
- ) -> torch.Tensor:
361
+ ) -> ModelOutput:
349
362
  """Compute the predicted class probabilities.
350
363
 
351
364
  Args:
@@ -356,63 +369,69 @@ class RslearnLightningModule(L.LightningModule):
356
369
  Returns:
357
370
  Output predicted probabilities.
358
371
  """
359
- inputs, _, _ = batch
360
- model_outputs = self(inputs)
372
+ inputs, _, metadatas = batch
373
+ context = ModelContext(
374
+ inputs=inputs,
375
+ metadatas=metadatas,
376
+ )
377
+ model_outputs = self(context)
361
378
  return model_outputs
362
379
 
363
- def forward(self, *args: Any, **kwargs: Any) -> Any:
380
+ def forward(
381
+ self, context: ModelContext, targets: list[dict[str, Any]] | None = None
382
+ ) -> ModelOutput:
364
383
  """Forward pass of the model.
365
384
 
366
385
  Args:
367
- args: Arguments to pass to model.
368
- kwargs: Keyword arguments to pass to model.
386
+ context: the model context.
387
+ targets: the target dicts.
369
388
 
370
389
  Returns:
371
390
  Output of the model.
372
391
  """
373
- return self.model(*args, **kwargs)
392
+ return self.model(context, targets)
374
393
 
375
394
  def on_train_forward(
376
395
  self,
377
- inputs: list[dict[str, Any]],
396
+ context: ModelContext,
378
397
  targets: list[dict[str, Any]],
379
- model_outputs: dict[str, Any],
398
+ model_outputs: ModelOutput,
380
399
  ) -> None:
381
400
  """Hook to run after the forward pass of the model during training.
382
401
 
383
402
  Args:
384
- inputs: The input batch.
403
+ context: The model context.
385
404
  targets: The target batch.
386
- model_outputs: The output of the model, with keys "outputs" and "loss_dict", and possibly other keys.
405
+ model_outputs: The output of the model.
387
406
  """
388
407
  pass
389
408
 
390
409
  def on_val_forward(
391
410
  self,
392
- inputs: list[dict[str, Any]],
411
+ context: ModelContext,
393
412
  targets: list[dict[str, Any]],
394
- model_outputs: dict[str, Any],
413
+ model_outputs: ModelOutput,
395
414
  ) -> None:
396
415
  """Hook to run after the forward pass of the model during validation.
397
416
 
398
417
  Args:
399
- inputs: The input batch.
418
+ context: The model context.
400
419
  targets: The target batch.
401
- model_outputs: The output of the model, with keys "outputs" and "loss_dict", and possibly other keys.
420
+ model_outputs: The output of the model.
402
421
  """
403
422
  pass
404
423
 
405
424
  def on_test_forward(
406
425
  self,
407
- inputs: list[dict[str, Any]],
426
+ context: ModelContext,
408
427
  targets: list[dict[str, Any]],
409
- model_outputs: dict[str, Any],
428
+ model_outputs: ModelOutput,
410
429
  ) -> None:
411
430
  """Hook to run after the forward pass of the model during testing.
412
431
 
413
432
  Args:
414
- inputs: The input batch.
433
+ context: The model context.
415
434
  targets: The target batch.
416
- model_outputs: The output of the model, with keys "outputs" and "loss_dict", and possibly other keys.
435
+ model_outputs: The output of the model.
417
436
  """
418
437
  pass
@@ -0,0 +1,54 @@
1
+ """Data classes to provide various context to models."""
2
+
3
+ from collections.abc import Iterable
4
+ from dataclasses import dataclass, field
5
+ from datetime import datetime
6
+ from typing import Any
7
+
8
+ import torch
9
+
10
+ from rslearn.utils.geometry import PixelBounds, Projection
11
+
12
+
13
+ @dataclass
14
+ class SampleMetadata:
15
+ """Metadata pertaining to an example."""
16
+
17
+ window_group: str
18
+ window_name: str
19
+ window_bounds: PixelBounds
20
+ patch_bounds: PixelBounds
21
+ patch_idx: int
22
+ num_patches_in_window: int
23
+ time_range: tuple[datetime, datetime] | None
24
+ projection: Projection
25
+
26
+ # Task name to differentiate different tasks.
27
+ dataset_source: str | None
28
+
29
+
30
+ @dataclass
31
+ class ModelContext:
32
+ """Context to pass to all model components."""
33
+
34
+ # One input dict per example in the batch.
35
+ inputs: list[dict[str, torch.Tensor]]
36
+ # One SampleMetadata per example in the batch.
37
+ metadatas: list[SampleMetadata]
38
+ # Arbitrary dict that components can add to.
39
+ context_dict: dict[str, Any] = field(default_factory=lambda: {})
40
+
41
+
42
+ @dataclass
43
+ class ModelOutput:
44
+ """The output from the Predictor.
45
+
46
+ Args:
47
+ outputs: output compatible with the configured Task.
48
+ loss_dict: map from loss names to scalar tensors.
49
+ metadata: arbitrary dict that can be used to store other outputs.
50
+ """
51
+
52
+ outputs: Iterable[Any]
53
+ loss_dict: dict[str, torch.Tensor]
54
+ metadata: dict[str, Any] = field(default_factory=lambda: {})
@@ -1,6 +1,7 @@
1
1
  """rslearn PredictionWriter implementation."""
2
2
 
3
- from collections.abc import Sequence
3
+ import json
4
+ from collections.abc import Iterable, Sequence
4
5
  from dataclasses import dataclass
5
6
  from pathlib import Path
6
7
  from typing import Any
@@ -12,11 +13,15 @@ from lightning.pytorch.callbacks import BasePredictionWriter
12
13
  from upath import UPath
13
14
 
14
15
  from rslearn.config import (
16
+ DatasetConfig,
15
17
  LayerConfig,
16
18
  LayerType,
19
+ StorageConfig,
17
20
  )
18
- from rslearn.dataset import Dataset, Window
21
+ from rslearn.dataset import Window
22
+ from rslearn.dataset.storage.storage import WindowStorage
19
23
  from rslearn.log_utils import get_logger
24
+ from rslearn.train.model_context import SampleMetadata
20
25
  from rslearn.utils.array import copy_spatial_array
21
26
  from rslearn.utils.feature import Feature
22
27
  from rslearn.utils.geometry import PixelBounds
@@ -27,6 +32,7 @@ from rslearn.utils.raster_format import (
27
32
  from rslearn.utils.vector_format import VectorFormat
28
33
 
29
34
  from .lightning_module import RslearnLightningModule
35
+ from .model_context import ModelOutput
30
36
  from .tasks.task import Task
31
37
 
32
38
  logger = get_logger(__name__)
@@ -43,12 +49,18 @@ class PendingPatchOutput:
43
49
  class PatchPredictionMerger:
44
50
  """Base class for merging predictions from multiple patches."""
45
51
 
46
- def merge(self, window: Window, outputs: Sequence[PendingPatchOutput]) -> Any:
52
+ def merge(
53
+ self,
54
+ window: Window,
55
+ outputs: Sequence[PendingPatchOutput],
56
+ layer_config: LayerConfig,
57
+ ) -> Any:
47
58
  """Merge the outputs.
48
59
 
49
60
  Args:
50
61
  window: the window we are merging the outputs for.
51
62
  outputs: the outputs to process.
63
+ layer_config: the output layer configuration.
52
64
 
53
65
  Returns:
54
66
  the merged outputs.
@@ -60,7 +72,10 @@ class VectorMerger(PatchPredictionMerger):
60
72
  """Merger for vector data that simply concatenates the features."""
61
73
 
62
74
  def merge(
63
- self, window: Window, outputs: Sequence[PendingPatchOutput]
75
+ self,
76
+ window: Window,
77
+ outputs: Sequence[PendingPatchOutput],
78
+ layer_config: LayerConfig,
64
79
  ) -> list[Feature]:
65
80
  """Concatenate the vector features."""
66
81
  return [feat for output in outputs for feat in output.output]
@@ -83,18 +98,20 @@ class RasterMerger(PatchPredictionMerger):
83
98
  self.downsample_factor = downsample_factor
84
99
 
85
100
  def merge(
86
- self, window: Window, outputs: Sequence[PendingPatchOutput]
101
+ self,
102
+ window: Window,
103
+ outputs: Sequence[PendingPatchOutput],
104
+ layer_config: LayerConfig,
87
105
  ) -> npt.NDArray:
88
106
  """Merge the raster outputs."""
89
107
  num_channels = outputs[0].output.shape[0]
90
- dtype = outputs[0].output.dtype
91
108
  merged_image = np.zeros(
92
109
  (
93
110
  num_channels,
94
111
  (window.bounds[3] - window.bounds[1]) // self.downsample_factor,
95
112
  (window.bounds[2] - window.bounds[0]) // self.downsample_factor,
96
113
  ),
97
- dtype=dtype,
114
+ dtype=layer_config.band_sets[0].dtype.get_numpy_dtype(),
98
115
  )
99
116
 
100
117
  # Ensure the outputs are sorted by height then width.
@@ -148,6 +165,7 @@ class RslearnWriter(BasePredictionWriter):
148
165
  merger: PatchPredictionMerger | None = None,
149
166
  output_path: str | Path | None = None,
150
167
  layer_config: LayerConfig | None = None,
168
+ storage_config: StorageConfig | None = None,
151
169
  ):
152
170
  """Create a new RslearnWriter.
153
171
 
@@ -163,28 +181,24 @@ class RslearnWriter(BasePredictionWriter):
163
181
  layer_config: optional layer configuration. If provided, this config will be
164
182
  used instead of reading from the dataset config, allowing usage without
165
183
  requiring dataset config at the output path.
184
+ storage_config: optional storage configuration, needed similar to layer_config
185
+ if there is no dataset config.
166
186
  """
167
187
  super().__init__(write_interval="batch")
168
188
  self.output_layer = output_layer
169
189
  self.selector = selector or []
170
- self.path = UPath(path, **path_options or {})
171
- self.output_path = (
190
+ ds_upath = UPath(path, **path_options or {})
191
+ output_upath = (
172
192
  UPath(output_path, **path_options or {})
173
193
  if output_path is not None
174
- else None
194
+ else ds_upath
175
195
  )
176
196
 
177
- # Handle dataset and layer config
178
- self.layer_config: LayerConfig
179
- if layer_config:
180
- self.layer_config = layer_config
181
- else:
182
- dataset = Dataset(self.path)
183
- if self.output_layer not in dataset.layers:
184
- raise KeyError(
185
- f"Output layer '{self.output_layer}' not found in dataset layers."
186
- )
187
- self.layer_config = dataset.layers[self.output_layer]
197
+ self.layer_config, self.dataset_storage = (
198
+ self._get_layer_config_and_dataset_storage(
199
+ ds_upath, output_upath, layer_config, storage_config
200
+ )
201
+ )
188
202
 
189
203
  self.format: RasterFormat | VectorFormat
190
204
  if self.layer_config.type == LayerType.RASTER:
@@ -207,11 +221,73 @@ class RslearnWriter(BasePredictionWriter):
207
221
  # patches of each window need to be reconstituted.
208
222
  self.pending_outputs: dict[str, list[PendingPatchOutput]] = {}
209
223
 
224
+ def _get_layer_config_and_dataset_storage(
225
+ self,
226
+ ds_upath: UPath,
227
+ output_upath: UPath,
228
+ layer_config: LayerConfig | None,
229
+ storage_config: StorageConfig | None,
230
+ ) -> tuple[LayerConfig, WindowStorage]:
231
+ """Get the layer config and dataset storage to use.
232
+
233
+ This is a helper function for the init method.
234
+
235
+ If layer_config is set, we use that. If storage_config is set, we use it to
236
+ instantiate a WindowStorage using the output_upath.
237
+
238
+ If one of them is not set, we load the config from the ds_upath. Otherwise, we
239
+ avoid reading the dataset config; this way, RslearnWriter can be used with
240
+ output directories that do not contain the dataset config, as long as
241
+ layer_config and storage_config are both provided.
242
+
243
+ Args:
244
+ ds_upath: the dataset path, where a dataset config can be loaded from if
245
+ layer_config or storage_config is not provided.
246
+ output_upath: the output directory, which could be different from the
247
+ dataset path.
248
+ layer_config: optional LayerConfig to provide.
249
+ storage_config: optional StorageConfig to provide.
250
+
251
+ Returns:
252
+ a tuple (layer_config, dataset_storage)
253
+ """
254
+ dataset_storage: WindowStorage | None = None
255
+
256
+ # Instantiate the WindowStorage from the storage_config if provided.
257
+ if storage_config:
258
+ dataset_storage = (
259
+ storage_config.instantiate_window_storage_factory().get_storage(
260
+ output_upath
261
+ )
262
+ )
263
+
264
+ if not layer_config or not dataset_storage:
265
+ # Need to load dataset config since one of LayerConfig/StorageConfig is missing.
266
+ # We use DatasetConfig.model_validate instead of initializing the Dataset
267
+ # because we want to get a WindowStorage that has the dataset path set to
268
+ # output_upath instead of ds_upath.
269
+ with (ds_upath / "config.json").open() as f:
270
+ dataset_config = DatasetConfig.model_validate(json.load(f))
271
+
272
+ if not layer_config:
273
+ if self.output_layer not in dataset_config.layers:
274
+ raise KeyError(
275
+ f"Output layer '{self.output_layer}' not found in dataset layers."
276
+ )
277
+ layer_config = dataset_config.layers[self.output_layer]
278
+
279
+ if not dataset_storage:
280
+ dataset_storage = dataset_config.storage.instantiate_window_storage_factory().get_storage(
281
+ output_upath
282
+ )
283
+
284
+ return (layer_config, dataset_storage)
285
+
210
286
  def write_on_batch_end(
211
287
  self,
212
288
  trainer: Trainer,
213
289
  pl_module: LightningModule,
214
- prediction: dict[str, Sequence],
290
+ prediction: ModelOutput,
215
291
  batch_indices: Sequence[int] | None,
216
292
  batch: tuple[list, list, list],
217
293
  batch_idx: int,
@@ -232,13 +308,13 @@ class RslearnWriter(BasePredictionWriter):
232
308
  assert isinstance(pl_module, RslearnLightningModule)
233
309
  task = pl_module.task
234
310
  _, _, metadatas = batch
235
- self.process_output_batch(task, prediction["outputs"], metadatas)
311
+ self.process_output_batch(task, prediction.outputs, metadatas)
236
312
 
237
313
  def process_output_batch(
238
314
  self,
239
315
  task: Task,
240
- prediction: Sequence,
241
- metadatas: Sequence,
316
+ prediction: Iterable[Any],
317
+ metadatas: Iterable[SampleMetadata],
242
318
  ) -> None:
243
319
  """Write a prediction batch with simplified API.
244
320
 
@@ -263,25 +339,19 @@ class RslearnWriter(BasePredictionWriter):
263
339
  for k in self.selector:
264
340
  output = output[k]
265
341
 
266
- # Use custom output_path if provided, otherwise use dataset path
267
- window_base_path = (
268
- self.output_path if self.output_path is not None else self.path
269
- )
270
342
  window = Window(
271
- path=Window.get_window_root(
272
- window_base_path, metadata["group"], metadata["window_name"]
273
- ),
274
- group=metadata["group"],
275
- name=metadata["window_name"],
276
- projection=metadata["projection"],
277
- bounds=metadata["window_bounds"],
278
- time_range=metadata["time_range"],
343
+ storage=self.dataset_storage,
344
+ group=metadata.window_group,
345
+ name=metadata.window_name,
346
+ projection=metadata.projection,
347
+ bounds=metadata.window_bounds,
348
+ time_range=metadata.time_range,
279
349
  )
280
350
  self.process_output(
281
351
  window,
282
- metadata["patch_idx"],
283
- metadata["num_patches"],
284
- metadata["bounds"],
352
+ metadata.patch_idx,
353
+ metadata.num_patches_in_window,
354
+ metadata.patch_bounds,
285
355
  output,
286
356
  )
287
357
 
@@ -320,7 +390,7 @@ class RslearnWriter(BasePredictionWriter):
320
390
 
321
391
  # Merge outputs from overlapped patches if merger is set.
322
392
  logger.debug(f"Merging and writing for window {window.name}")
323
- merged_output = self.merger.merge(window, pending_output)
393
+ merged_output = self.merger.merge(window, pending_output, self.layer_config)
324
394
 
325
395
  if self.layer_config.type == LayerType.RASTER:
326
396
  raster_dir = window.get_raster_dir(
@@ -15,6 +15,8 @@ from torchmetrics.classification import (
15
15
  MulticlassRecall,
16
16
  )
17
17
 
18
+ from rslearn.models.component import FeatureVector, Predictor
19
+ from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
18
20
  from rslearn.utils import Feature, STGeometry
19
21
 
20
22
  from .task import BasicTask
@@ -98,7 +100,7 @@ class ClassificationTask(BasicTask):
98
100
  def process_inputs(
99
101
  self,
100
102
  raw_inputs: dict[str, torch.Tensor | list[Feature]],
101
- metadata: dict[str, Any],
103
+ metadata: SampleMetadata,
102
104
  load_targets: bool = True,
103
105
  ) -> tuple[dict[str, Any], dict[str, Any]]:
104
106
  """Processes the data into targets.
@@ -154,17 +156,25 @@ class ClassificationTask(BasicTask):
154
156
  }
155
157
 
156
158
  def process_output(
157
- self, raw_output: Any, metadata: dict[str, Any]
158
- ) -> npt.NDArray[Any] | list[Feature]:
159
+ self, raw_output: Any, metadata: SampleMetadata
160
+ ) -> list[Feature]:
159
161
  """Processes an output into raster or vector data.
160
162
 
161
163
  Args:
162
- raw_output: the output from prediction head.
164
+ raw_output: the output from prediction head, which must be a tensor
165
+ containing output probabilities (one dimension).
163
166
  metadata: metadata about the patch being read
164
167
 
165
168
  Returns:
166
- either raster or vector data.
169
+ a list with one Feature corresponding to the input patch extent with a
170
+ property name containing the predicted class. It will have another
171
+ property containing the probabilities if prob_property was set.
167
172
  """
173
+ if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 1:
174
+ raise ValueError(
175
+ "expected output for ClassificationTask to be a Tensor with one dimension"
176
+ )
177
+
168
178
  probs = raw_output.cpu().numpy()
169
179
  if len(self.classes) == 2 and self.positive_class_threshold != 0.5:
170
180
  positive_class_prob = probs[self.positive_class_id]
@@ -184,8 +194,8 @@ class ClassificationTask(BasicTask):
184
194
 
185
195
  feature = Feature(
186
196
  STGeometry(
187
- metadata["projection"],
188
- shapely.Point(metadata["bounds"][0], metadata["bounds"][1]),
197
+ metadata.projection,
198
+ shapely.Point(metadata.patch_bounds[0], metadata.patch_bounds[1]),
189
199
  None,
190
200
  ),
191
201
  {
@@ -265,25 +275,31 @@ class ClassificationTask(BasicTask):
265
275
  return MetricCollection(metrics)
266
276
 
267
277
 
268
- class ClassificationHead(torch.nn.Module):
278
+ class ClassificationHead(Predictor):
269
279
  """Head for classification task."""
270
280
 
271
281
  def forward(
272
282
  self,
273
- logits: torch.Tensor,
274
- inputs: list[dict[str, Any]],
283
+ intermediates: Any,
284
+ context: ModelContext,
275
285
  targets: list[dict[str, Any]] | None = None,
276
- ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
286
+ ) -> ModelOutput:
277
287
  """Compute the classification outputs and loss from logits and targets.
278
288
 
279
289
  Args:
280
- logits: tensor that is (BatchSize, NumClasses) in shape.
281
- inputs: original inputs (ignored).
282
- targets: should contain class key that stores the class label.
290
+ intermediates: output from the previous model component, it should be a
291
+ FeatureVector with a tensor that is (BatchSize, NumClasses) in shape.
292
+ context: the model context.
293
+ targets: must contain "class" key that stores the class label, along with
294
+ "valid" key indicating whether the label is valid for each example.
283
295
 
284
296
  Returns:
285
297
  tuple of outputs and loss dict
286
298
  """
299
+ if not isinstance(intermediates, FeatureVector):
300
+ raise ValueError("the input to ClassificationHead must be a FeatureVector")
301
+
302
+ logits = intermediates.feature_vector
287
303
  outputs = torch.nn.functional.softmax(logits, dim=1)
288
304
 
289
305
  losses = {}
@@ -298,7 +314,10 @@ class ClassificationHead(torch.nn.Module):
298
314
  )
299
315
  losses["cls"] = torch.mean(loss)
300
316
 
301
- return outputs, losses
317
+ return ModelOutput(
318
+ outputs=outputs,
319
+ loss_dict=losses,
320
+ )
302
321
 
303
322
 
304
323
  class ClassificationMetric(Metric):