rslearn 0.0.16__py3-none-any.whl → 0.0.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. rslearn/config/__init__.py +2 -0
  2. rslearn/config/dataset.py +55 -4
  3. rslearn/dataset/add_windows.py +1 -1
  4. rslearn/dataset/dataset.py +9 -65
  5. rslearn/dataset/materialize.py +5 -5
  6. rslearn/dataset/storage/__init__.py +1 -0
  7. rslearn/dataset/storage/file.py +202 -0
  8. rslearn/dataset/storage/storage.py +140 -0
  9. rslearn/dataset/window.py +26 -80
  10. rslearn/lightning_cli.py +10 -3
  11. rslearn/main.py +11 -36
  12. rslearn/models/anysat.py +11 -9
  13. rslearn/models/clay/clay.py +8 -9
  14. rslearn/models/clip.py +18 -15
  15. rslearn/models/component.py +99 -0
  16. rslearn/models/concatenate_features.py +21 -11
  17. rslearn/models/conv.py +15 -8
  18. rslearn/models/croma.py +13 -8
  19. rslearn/models/detr/detr.py +25 -14
  20. rslearn/models/dinov3.py +11 -6
  21. rslearn/models/faster_rcnn.py +19 -9
  22. rslearn/models/feature_center_crop.py +12 -9
  23. rslearn/models/fpn.py +19 -8
  24. rslearn/models/galileo/galileo.py +23 -18
  25. rslearn/models/module_wrapper.py +26 -57
  26. rslearn/models/molmo.py +16 -14
  27. rslearn/models/multitask.py +102 -73
  28. rslearn/models/olmoearth_pretrain/model.py +20 -17
  29. rslearn/models/panopticon.py +8 -7
  30. rslearn/models/pick_features.py +18 -24
  31. rslearn/models/pooling_decoder.py +22 -14
  32. rslearn/models/presto/presto.py +16 -10
  33. rslearn/models/presto/single_file_presto.py +4 -10
  34. rslearn/models/prithvi.py +12 -8
  35. rslearn/models/resize_features.py +21 -7
  36. rslearn/models/sam2_enc.py +11 -9
  37. rslearn/models/satlaspretrain.py +15 -9
  38. rslearn/models/simple_time_series.py +31 -17
  39. rslearn/models/singletask.py +24 -17
  40. rslearn/models/ssl4eo_s12.py +15 -10
  41. rslearn/models/swin.py +22 -13
  42. rslearn/models/terramind.py +24 -7
  43. rslearn/models/trunk.py +6 -3
  44. rslearn/models/unet.py +18 -9
  45. rslearn/models/upsample.py +22 -9
  46. rslearn/train/all_patches_dataset.py +22 -18
  47. rslearn/train/dataset.py +69 -54
  48. rslearn/train/lightning_module.py +51 -32
  49. rslearn/train/model_context.py +54 -0
  50. rslearn/train/prediction_writer.py +111 -41
  51. rslearn/train/tasks/classification.py +34 -15
  52. rslearn/train/tasks/detection.py +24 -31
  53. rslearn/train/tasks/embedding.py +33 -29
  54. rslearn/train/tasks/multi_task.py +7 -7
  55. rslearn/train/tasks/per_pixel_regression.py +41 -19
  56. rslearn/train/tasks/regression.py +38 -21
  57. rslearn/train/tasks/segmentation.py +33 -15
  58. rslearn/train/tasks/task.py +3 -2
  59. {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/METADATA +58 -25
  60. {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/RECORD +65 -62
  61. rslearn/dataset/index.py +0 -173
  62. rslearn/models/registry.py +0 -22
  63. {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/WHEEL +0 -0
  64. {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/entry_points.txt +0 -0
  65. {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/licenses/LICENSE +0 -0
  66. {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/licenses/NOTICE +0 -0
  67. {rslearn-0.0.16.dist-info → rslearn-0.0.18.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@ import torchmetrics.classification
12
12
  import torchvision
13
13
  from torchmetrics import Metric, MetricCollection
14
14
 
15
+ from rslearn.train.model_context import SampleMetadata
15
16
  from rslearn.utils import Feature, STGeometry
16
17
 
17
18
  from .task import BasicTask
@@ -127,7 +128,7 @@ class DetectionTask(BasicTask):
127
128
  def process_inputs(
128
129
  self,
129
130
  raw_inputs: dict[str, torch.Tensor | list[Feature]],
130
- metadata: dict[str, Any],
131
+ metadata: SampleMetadata,
131
132
  load_targets: bool = True,
132
133
  ) -> tuple[dict[str, Any], dict[str, Any]]:
133
134
  """Processes the data into targets.
@@ -144,6 +145,8 @@ class DetectionTask(BasicTask):
144
145
  if not load_targets:
145
146
  return {}, {}
146
147
 
148
+ bounds = metadata.patch_bounds
149
+
147
150
  boxes = []
148
151
  class_labels = []
149
152
  valid = 1
@@ -186,39 +189,33 @@ class DetectionTask(BasicTask):
186
189
  else:
187
190
  box = [int(val) for val in shp.bounds]
188
191
 
189
- if box[0] >= metadata["bounds"][2] or box[2] <= metadata["bounds"][0]:
192
+ if box[0] >= bounds[2] or box[2] <= bounds[0]:
190
193
  continue
191
- if box[1] >= metadata["bounds"][3] or box[3] <= metadata["bounds"][1]:
194
+ if box[1] >= bounds[3] or box[3] <= bounds[1]:
192
195
  continue
193
196
 
194
197
  if self.exclude_by_center:
195
198
  center_col = (box[0] + box[2]) // 2
196
199
  center_row = (box[1] + box[3]) // 2
197
- if (
198
- center_col <= metadata["bounds"][0]
199
- or center_col >= metadata["bounds"][2]
200
- ):
200
+ if center_col <= bounds[0] or center_col >= bounds[2]:
201
201
  continue
202
- if (
203
- center_row <= metadata["bounds"][1]
204
- or center_row >= metadata["bounds"][3]
205
- ):
202
+ if center_row <= bounds[1] or center_row >= bounds[3]:
206
203
  continue
207
204
 
208
205
  if self.clip_boxes:
209
206
  box = [
210
- np.clip(box[0], metadata["bounds"][0], metadata["bounds"][2]),
211
- np.clip(box[1], metadata["bounds"][1], metadata["bounds"][3]),
212
- np.clip(box[2], metadata["bounds"][0], metadata["bounds"][2]),
213
- np.clip(box[3], metadata["bounds"][1], metadata["bounds"][3]),
207
+ np.clip(box[0], bounds[0], bounds[2]),
208
+ np.clip(box[1], bounds[1], bounds[3]),
209
+ np.clip(box[2], bounds[0], bounds[2]),
210
+ np.clip(box[3], bounds[1], bounds[3]),
214
211
  ]
215
212
 
216
213
  # Convert to relative coordinates.
217
214
  box = [
218
- box[0] - metadata["bounds"][0],
219
- box[1] - metadata["bounds"][1],
220
- box[2] - metadata["bounds"][0],
221
- box[3] - metadata["bounds"][1],
215
+ box[0] - bounds[0],
216
+ box[1] - bounds[1],
217
+ box[2] - bounds[0],
218
+ box[3] - bounds[1],
222
219
  ]
223
220
 
224
221
  boxes.append(box)
@@ -238,16 +235,12 @@ class DetectionTask(BasicTask):
238
235
  "valid": torch.tensor(valid, dtype=torch.int32),
239
236
  "boxes": boxes,
240
237
  "labels": class_labels,
241
- "width": torch.tensor(
242
- metadata["bounds"][2] - metadata["bounds"][0], dtype=torch.float32
243
- ),
244
- "height": torch.tensor(
245
- metadata["bounds"][3] - metadata["bounds"][1], dtype=torch.float32
246
- ),
238
+ "width": torch.tensor(bounds[2] - bounds[0], dtype=torch.float32),
239
+ "height": torch.tensor(bounds[3] - bounds[1], dtype=torch.float32),
247
240
  }
248
241
 
249
242
  def process_output(
250
- self, raw_output: Any, metadata: dict[str, Any]
243
+ self, raw_output: Any, metadata: SampleMetadata
251
244
  ) -> npt.NDArray[Any] | list[Feature]:
252
245
  """Processes an output into raster or vector data.
253
246
 
@@ -267,12 +260,12 @@ class DetectionTask(BasicTask):
267
260
  features = []
268
261
  for box, class_id, score in zip(boxes, class_ids, scores):
269
262
  shp = shapely.box(
270
- metadata["bounds"][0] + float(box[0]),
271
- metadata["bounds"][1] + float(box[1]),
272
- metadata["bounds"][0] + float(box[2]),
273
- metadata["bounds"][1] + float(box[3]),
263
+ metadata.patch_bounds[0] + float(box[0]),
264
+ metadata.patch_bounds[1] + float(box[1]),
265
+ metadata.patch_bounds[0] + float(box[2]),
266
+ metadata.patch_bounds[1] + float(box[3]),
274
267
  )
275
- geom = STGeometry(metadata["projection"], shp, None)
268
+ geom = STGeometry(metadata.projection, shp, None)
276
269
  properties: dict[str, Any] = {
277
270
  "score": float(score),
278
271
  }
@@ -6,6 +6,8 @@ import numpy.typing as npt
6
6
  import torch
7
7
  from torchmetrics import MetricCollection
8
8
 
9
+ from rslearn.models.component import FeatureMaps
10
+ from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
9
11
  from rslearn.utils import Feature
10
12
 
11
13
  from .task import Task
@@ -21,7 +23,7 @@ class EmbeddingTask(Task):
21
23
  def process_inputs(
22
24
  self,
23
25
  raw_inputs: dict[str, torch.Tensor],
24
- metadata: dict[str, Any],
26
+ metadata: SampleMetadata,
25
27
  load_targets: bool = True,
26
28
  ) -> tuple[dict[str, Any], dict[str, Any]]:
27
29
  """Processes the data into targets.
@@ -38,17 +40,22 @@ class EmbeddingTask(Task):
38
40
  return {}, {}
39
41
 
40
42
  def process_output(
41
- self, raw_output: Any, metadata: dict[str, Any]
43
+ self, raw_output: Any, metadata: SampleMetadata
42
44
  ) -> npt.NDArray[Any] | list[Feature]:
43
45
  """Processes an output into raster or vector data.
44
46
 
45
47
  Args:
46
- raw_output: the output from prediction head.
48
+ raw_output: the output from prediction head, which must be a CxHxW tensor.
47
49
  metadata: metadata about the patch being read
48
50
 
49
51
  Returns:
50
52
  either raster or vector data.
51
53
  """
54
+ if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 3:
55
+ raise ValueError(
56
+ "output for EmbeddingTask must be a tensor with three dimensions"
57
+ )
58
+
52
59
  # Just convert the raw output to numpy array that can be saved to GeoTIFF.
53
60
  return raw_output.cpu().numpy()
54
61
 
@@ -76,41 +83,38 @@ class EmbeddingTask(Task):
76
83
  return MetricCollection({})
77
84
 
78
85
 
79
- class EmbeddingHead(torch.nn.Module):
86
+ class EmbeddingHead:
80
87
  """Head for embedding task.
81
88
 
82
- This picks one feature map from the input list of feature maps to output. It also
83
- returns a dummy loss.
89
+ It just adds a dummy loss to act as a Predictor.
84
90
  """
85
91
 
86
- def __init__(self, feature_map_index: int | None = 0):
87
- """Create a new EmbeddingHead.
88
-
89
- Args:
90
- feature_map_index: the index of the feature map to choose from the input
91
- list of multi-scale feature maps (default 0). If the input is already
92
- a single feature map, then set to None.
93
- """
94
- super().__init__()
95
- self.feature_map_index = feature_map_index
96
-
97
92
  def forward(
98
93
  self,
99
- features: torch.Tensor,
100
- inputs: list[dict[str, Any]],
94
+ intermediates: Any,
95
+ context: ModelContext,
101
96
  targets: list[dict[str, Any]] | None = None,
102
- ) -> tuple[torch.Tensor, dict[str, Any]]:
103
- """Select the desired feature map and return it along with a dummy loss.
97
+ ) -> ModelOutput:
98
+ """Return the feature map along with a dummy loss.
104
99
 
105
100
  Args:
106
- features: list of BCHW feature maps (or one feature map, if feature_map_index is None).
107
- inputs: original inputs (ignored).
108
- targets: should contain classes key that stores the per-pixel class labels.
101
+ intermediates: output from the previous model component, which must be a
102
+ FeatureMaps consisting of a single feature map.
103
+ context: the model context.
104
+ targets: the targets (ignored).
109
105
 
110
106
  Returns:
111
- tuple of outputs and loss dict
107
+ model output with the feature map that was input to this component along
108
+ with a dummy loss.
112
109
  """
113
- if self.feature_map_index is not None:
114
- features = features[self.feature_map_index]
115
-
116
- return features, {"loss": 0}
110
+ if not isinstance(intermediates, FeatureMaps):
111
+ raise ValueError("input to EmbeddingHead must be a FeatureMaps")
112
+ if len(intermediates.feature_maps) != 1:
113
+ raise ValueError(
114
+ f"input to EmbeddingHead must have one feature map, but got {len(intermediates.feature_maps)}"
115
+ )
116
+
117
+ return ModelOutput(
118
+ outputs=intermediates.feature_maps[0],
119
+ loss_dict={"loss": 0},
120
+ )
@@ -6,6 +6,7 @@ import numpy.typing as npt
6
6
  import torch
7
7
  from torchmetrics import Metric, MetricCollection
8
8
 
9
+ from rslearn.train.model_context import SampleMetadata
9
10
  from rslearn.utils import Feature
10
11
 
11
12
  from .task import Task
@@ -30,7 +31,7 @@ class MultiTask(Task):
30
31
  def process_inputs(
31
32
  self,
32
33
  raw_inputs: dict[str, torch.Tensor | list[Feature]],
33
- metadata: dict[str, Any],
34
+ metadata: SampleMetadata,
34
35
  load_targets: bool = True,
35
36
  ) -> tuple[dict[str, Any], dict[str, Any]]:
36
37
  """Processes the data into targets.
@@ -46,14 +47,12 @@ class MultiTask(Task):
46
47
  """
47
48
  input_dict = {}
48
49
  target_dict = {}
49
- if metadata["dataset_source"] is None:
50
+ if metadata.dataset_source is None:
50
51
  # No multi-dataset, so always compute across all tasks
51
52
  task_iter = list(self.tasks.items())
52
53
  else:
53
54
  # Multi-dataset, so only compute for the task in this dataset
54
- task_iter = [
55
- (metadata["dataset_source"], self.tasks[metadata["dataset_source"]])
56
- ]
55
+ task_iter = [(metadata.dataset_source, self.tasks[metadata.dataset_source])]
57
56
 
58
57
  for task_name, task in task_iter:
59
58
  cur_raw_inputs = {}
@@ -71,12 +70,13 @@ class MultiTask(Task):
71
70
  return input_dict, target_dict
72
71
 
73
72
  def process_output(
74
- self, raw_output: Any, metadata: dict[str, Any]
73
+ self, raw_output: Any, metadata: SampleMetadata
75
74
  ) -> dict[str, Any]:
76
75
  """Processes an output into raster or vector data.
77
76
 
78
77
  Args:
79
- raw_output: the output from prediction head.
78
+ raw_output: the output from prediction head. It must be a dict mapping from
79
+ task name to per-task output for this sample.
80
80
  metadata: metadata about the patch being read
81
81
 
82
82
  Returns:
@@ -8,6 +8,8 @@ import torch
8
8
  import torchmetrics
9
9
  from torchmetrics import Metric, MetricCollection
10
10
 
11
+ from rslearn.models.component import FeatureMaps, Predictor
12
+ from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
11
13
  from rslearn.utils.feature import Feature
12
14
 
13
15
  from .task import BasicTask
@@ -41,7 +43,7 @@ class PerPixelRegressionTask(BasicTask):
41
43
  def process_inputs(
42
44
  self,
43
45
  raw_inputs: dict[str, torch.Tensor],
44
- metadata: dict[str, Any],
46
+ metadata: SampleMetadata,
45
47
  load_targets: bool = True,
46
48
  ) -> tuple[dict[str, Any], dict[str, Any]]:
47
49
  """Processes the data into targets.
@@ -72,20 +74,23 @@ class PerPixelRegressionTask(BasicTask):
72
74
  }
73
75
 
74
76
  def process_output(
75
- self, raw_output: Any, metadata: dict[str, Any]
77
+ self, raw_output: Any, metadata: SampleMetadata
76
78
  ) -> npt.NDArray[Any] | list[Feature]:
77
79
  """Processes an output into raster or vector data.
78
80
 
79
81
  Args:
80
- raw_output: the output from prediction head.
82
+ raw_output: the output from prediction head, which must be an HW tensor.
81
83
  metadata: metadata about the patch being read
82
84
 
83
85
  Returns:
84
86
  either raster or vector data.
85
87
  """
86
- # Input could be CHW (with single channel) or just HW.
87
- if len(raw_output.shape) == 2:
88
- raw_output = raw_output[None, :, :]
88
+ if not isinstance(raw_output, torch.Tensor):
89
+ raise ValueError("output for PerPixelRegressionTask must be a tensor")
90
+ if len(raw_output.shape) != 2:
91
+ raise ValueError(
92
+ f"PerPixelRegressionTask output must be an HW tensor, but got shape {raw_output.shape}"
93
+ )
89
94
  return (raw_output / self.scale_factor).cpu().numpy()
90
95
 
91
96
  def visualize(
@@ -133,7 +138,7 @@ class PerPixelRegressionTask(BasicTask):
133
138
  return MetricCollection(metric_dict)
134
139
 
135
140
 
136
- class PerPixelRegressionHead(torch.nn.Module):
141
+ class PerPixelRegressionHead(Predictor):
137
142
  """Head for per-pixel regression task."""
138
143
 
139
144
  def __init__(
@@ -156,24 +161,38 @@ class PerPixelRegressionHead(torch.nn.Module):
156
161
 
157
162
  def forward(
158
163
  self,
159
- logits: torch.Tensor,
160
- inputs: list[dict[str, Any]],
164
+ intermediates: Any,
165
+ context: ModelContext,
161
166
  targets: list[dict[str, Any]] | None = None,
162
- ) -> tuple[torch.Tensor, dict[str, Any]]:
167
+ ) -> ModelOutput:
163
168
  """Compute the regression outputs and loss from logits and targets.
164
169
 
165
170
  Args:
166
- logits: BxHxW or BxCxHxW tensor.
167
- inputs: original inputs (ignored).
168
- targets: should contain target key that stores the regression labels.
171
+ intermediates: output from previous component, which must be a FeatureMaps
172
+ with one feature map corresponding to the logits. The channel dimension
173
+ size must be 1.
174
+ context: the model context.
175
+ targets: must contain values key that stores the regression labels, and
176
+ valid key containing mask image indicating where the labels are valid.
169
177
 
170
178
  Returns:
171
- tuple of outputs and loss dict
179
+ tuple of outputs and loss dict. The output is a BHW tensor so that the
180
+ per-sample output is an HW tensor.
172
181
  """
173
- assert len(logits.shape) in [3, 4]
174
- if len(logits.shape) == 4:
175
- assert logits.shape[1] == 1
176
- logits = logits[:, 0, :, :]
182
+ if not isinstance(intermediates, FeatureMaps):
183
+ raise ValueError(
184
+ "the input to PerPixelRegressionHead must be a FeatureMaps"
185
+ )
186
+ if len(intermediates.feature_maps) != 1:
187
+ raise ValueError(
188
+ "the input to PerPixelRegressionHead must have one feature map"
189
+ )
190
+ if intermediates.feature_maps[0].shape[1] != 1:
191
+ raise ValueError(
192
+ f"the input to PerPixelRegressionHead must have channel dimension size 1, but got {intermediates.feature_maps[0].shape}"
193
+ )
194
+
195
+ logits = intermediates.feature_maps[0][:, 0, :, :]
177
196
 
178
197
  if self.use_sigmoid:
179
198
  outputs = torch.nn.functional.sigmoid(logits)
@@ -200,7 +219,10 @@ class PerPixelRegressionHead(torch.nn.Module):
200
219
  else:
201
220
  losses["regress"] = (scores * mask).sum() / mask_total
202
221
 
203
- return outputs, losses
222
+ return ModelOutput(
223
+ outputs=outputs,
224
+ loss_dict=losses,
225
+ )
204
226
 
205
227
 
206
228
  class PerPixelRegressionMetricWrapper(Metric):
@@ -10,6 +10,8 @@ import torchmetrics
10
10
  from PIL import Image, ImageDraw
11
11
  from torchmetrics import Metric, MetricCollection
12
12
 
13
+ from rslearn.models.component import FeatureVector, Predictor
14
+ from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
13
15
  from rslearn.utils.feature import Feature
14
16
  from rslearn.utils.geometry import STGeometry
15
17
 
@@ -62,7 +64,7 @@ class RegressionTask(BasicTask):
62
64
  def process_inputs(
63
65
  self,
64
66
  raw_inputs: dict[str, torch.Tensor | list[Feature]],
65
- metadata: dict[str, Any],
67
+ metadata: SampleMetadata,
66
68
  load_targets: bool = True,
67
69
  ) -> tuple[dict[str, Any], dict[str, Any]]:
68
70
  """Processes the data into targets.
@@ -103,22 +105,26 @@ class RegressionTask(BasicTask):
103
105
  }
104
106
 
105
107
  def process_output(
106
- self, raw_output: Any, metadata: dict[str, Any]
107
- ) -> npt.NDArray[Any] | list[Feature]:
108
+ self, raw_output: Any, metadata: SampleMetadata
109
+ ) -> list[Feature]:
108
110
  """Processes an output into raster or vector data.
109
111
 
110
112
  Args:
111
- raw_output: the output from prediction head.
113
+ raw_output: the output from prediction head, which must be a scalar tensor.
112
114
  metadata: metadata about the patch being read
113
115
 
114
116
  Returns:
115
- either raster or vector data.
117
+ a list with a single Feature corresponding to the patch extent and with a
118
+ property containing the predicted value.
116
119
  """
120
+ if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 0:
121
+ raise ValueError("output for RegressionTask must be a scalar Tensor")
122
+
117
123
  output = raw_output.item() / self.scale_factor
118
124
  feature = Feature(
119
125
  STGeometry(
120
- metadata["projection"],
121
- shapely.Point(metadata["bounds"][0], metadata["bounds"][1]),
126
+ metadata.projection,
127
+ shapely.Point(metadata.patch_bounds[0], metadata.patch_bounds[1]),
122
128
  None,
123
129
  ),
124
130
  {
@@ -180,7 +186,7 @@ class RegressionTask(BasicTask):
180
186
  return MetricCollection(metric_dict)
181
187
 
182
188
 
183
- class RegressionHead(torch.nn.Module):
189
+ class RegressionHead(Predictor):
184
190
  """Head for regression task."""
185
191
 
186
192
  def __init__(
@@ -199,24 +205,32 @@ class RegressionHead(torch.nn.Module):
199
205
 
200
206
  def forward(
201
207
  self,
202
- logits: torch.Tensor,
203
- inputs: list[dict[str, Any]],
208
+ intermediates: Any,
209
+ context: ModelContext,
204
210
  targets: list[dict[str, Any]] | None = None,
205
- ) -> tuple[torch.Tensor, dict[str, Any]]:
211
+ ) -> ModelOutput:
206
212
  """Compute the regression outputs and loss from logits and targets.
207
213
 
208
214
  Args:
209
- logits: tensor that is (BatchSize, 1) or (BatchSize) in shape.
210
- inputs: original inputs (ignored).
211
- targets: should contain target key that stores the regression label.
215
+ intermediates: output from previous model component, which must be a
216
+ FeatureVector with channel dimension size 1 (Bx1).
217
+ context: the model context.
218
+ targets: target dicts, which each must contain a "value" key containing the
219
+ regression label, along with a "valid" key containing a flag indicating
220
+ whether each example is valid for this task.
212
221
 
213
222
  Returns:
214
- tuple of outputs and loss dict
223
+ the model outputs. The output is a B tensor so that it is split up into a
224
+ scalar for each example.
215
225
  """
216
- assert len(logits.shape) in [1, 2]
217
- if len(logits.shape) == 2:
218
- assert logits.shape[1] == 1
219
- logits = logits[:, 0]
226
+ if not isinstance(intermediates, FeatureVector):
227
+ raise ValueError("the input to RegressionHead must be a FeatureVector")
228
+ if intermediates.feature_vector.shape[1] != 1:
229
+ raise ValueError(
230
+ f"the input to RegressionHead must have channel dimension size 1, but got shape {intermediates.feature_vector.shape}"
231
+ )
232
+
233
+ logits = intermediates.feature_vector[:, 0]
220
234
 
221
235
  if self.use_sigmoid:
222
236
  outputs = torch.nn.functional.sigmoid(logits)
@@ -232,9 +246,12 @@ class RegressionHead(torch.nn.Module):
232
246
  elif self.loss_mode == "l1":
233
247
  losses["regress"] = torch.mean(torch.abs(outputs - labels) * mask)
234
248
  else:
235
- assert False
249
+ raise ValueError(f"unknown loss mode {self.loss_mode}")
236
250
 
237
- return outputs, losses
251
+ return ModelOutput(
252
+ outputs=outputs,
253
+ loss_dict=losses,
254
+ )
238
255
 
239
256
 
240
257
  class RegressionMetricWrapper(Metric):
@@ -8,7 +8,8 @@ import torch
8
8
  import torchmetrics.classification
9
9
  from torchmetrics import Metric, MetricCollection
10
10
 
11
- from rslearn.utils import Feature
11
+ from rslearn.models.component import FeatureMaps, Predictor
12
+ from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
12
13
 
13
14
  from .task import BasicTask
14
15
 
@@ -108,7 +109,7 @@ class SegmentationTask(BasicTask):
108
109
  def process_inputs(
109
110
  self,
110
111
  raw_inputs: dict[str, torch.Tensor],
111
- metadata: dict[str, Any],
112
+ metadata: SampleMetadata,
112
113
  load_targets: bool = True,
113
114
  ) -> tuple[dict[str, Any], dict[str, Any]]:
114
115
  """Processes the data into targets.
@@ -148,17 +149,20 @@ class SegmentationTask(BasicTask):
148
149
  }
149
150
 
150
151
  def process_output(
151
- self, raw_output: Any, metadata: dict[str, Any]
152
- ) -> npt.NDArray[Any] | list[Feature]:
152
+ self, raw_output: Any, metadata: SampleMetadata
153
+ ) -> npt.NDArray[Any]:
153
154
  """Processes an output into raster or vector data.
154
155
 
155
156
  Args:
156
- raw_output: the output from prediction head.
157
+ raw_output: the output from prediction head, which must be a CHW tensor.
157
158
  metadata: metadata about the patch being read
158
159
 
159
160
  Returns:
160
- either raster or vector data.
161
+ CHW numpy array with one channel, containing the predicted class IDs.
161
162
  """
163
+ if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 3:
164
+ raise ValueError("the output for SegmentationTask must be a CHW tensor")
165
+
162
166
  if self.prob_scales is not None:
163
167
  raw_output = (
164
168
  raw_output
@@ -166,7 +170,7 @@ class SegmentationTask(BasicTask):
166
170
  self.prob_scales, device=raw_output.device, dtype=raw_output.dtype
167
171
  )[:, None, None]
168
172
  )
169
- classes = raw_output.argmax(dim=0).cpu().numpy().astype(np.uint8)
173
+ classes = raw_output.argmax(dim=0).cpu().numpy()
170
174
  return classes[None, :, :]
171
175
 
172
176
  def visualize(
@@ -258,25 +262,36 @@ class SegmentationTask(BasicTask):
258
262
  return MetricCollection(metrics)
259
263
 
260
264
 
261
- class SegmentationHead(torch.nn.Module):
265
+ class SegmentationHead(Predictor):
262
266
  """Head for segmentation task."""
263
267
 
264
268
  def forward(
265
269
  self,
266
- logits: torch.Tensor,
267
- inputs: list[dict[str, Any]],
270
+ intermediates: Any,
271
+ context: ModelContext,
268
272
  targets: list[dict[str, Any]] | None = None,
269
- ) -> tuple[torch.Tensor, dict[str, Any]]:
273
+ ) -> ModelOutput:
270
274
  """Compute the segmentation outputs from logits and targets.
271
275
 
272
276
  Args:
273
- logits: tensor that is (BatchSize, NumClasses, Height, Width) in shape.
274
- inputs: original inputs (ignored).
275
- targets: should contain classes key that stores the per-pixel class labels.
277
+ intermediates: a FeatureMaps with a single feature map containing the
278
+ segmentation logits.
279
+ context: the model context
280
+ targets: list of target dicts, where each target dict must contain a key
281
+ "classes" containing the per-pixel class labels, along with "valid"
282
+ containing a mask indicating where the example is valid.
276
283
 
277
284
  Returns:
278
285
  tuple of outputs and loss dict
279
286
  """
287
+ if not isinstance(intermediates, FeatureMaps):
288
+ raise ValueError("input to SegmentationHead must be a FeatureMaps")
289
+ if len(intermediates.feature_maps) != 1:
290
+ raise ValueError(
291
+ f"input to SegmentationHead must have one feature map, but got {len(intermediates.feature_maps)}"
292
+ )
293
+
294
+ logits = intermediates.feature_maps[0]
280
295
  outputs = torch.nn.functional.softmax(logits, dim=1)
281
296
 
282
297
  losses = {}
@@ -295,7 +310,10 @@ class SegmentationHead(torch.nn.Module):
295
310
  # the summed mask loss be zero.
296
311
  losses["cls"] = torch.sum(per_pixel_loss * mask)
297
312
 
298
- return outputs, losses
313
+ return ModelOutput(
314
+ outputs=outputs,
315
+ loss_dict=losses,
316
+ )
299
317
 
300
318
 
301
319
  class SegmentationMetric(Metric):
@@ -7,6 +7,7 @@ import numpy.typing as npt
7
7
  import torch
8
8
  from torchmetrics import MetricCollection
9
9
 
10
+ from rslearn.train.model_context import SampleMetadata
10
11
  from rslearn.utils import Feature
11
12
 
12
13
 
@@ -21,7 +22,7 @@ class Task:
21
22
  def process_inputs(
22
23
  self,
23
24
  raw_inputs: dict[str, torch.Tensor | list[Feature]],
24
- metadata: dict[str, Any],
25
+ metadata: SampleMetadata,
25
26
  load_targets: bool = True,
26
27
  ) -> tuple[dict[str, Any], dict[str, Any]]:
27
28
  """Processes the data into targets.
@@ -38,7 +39,7 @@ class Task:
38
39
  raise NotImplementedError
39
40
 
40
41
  def process_output(
41
- self, raw_output: Any, metadata: dict[str, Any]
42
+ self, raw_output: Any, metadata: SampleMetadata
42
43
  ) -> npt.NDArray[Any] | list[Feature] | dict[str, Any]:
43
44
  """Processes an output into raster or vector data.
44
45