rslearn 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. rslearn/arg_parser.py +2 -9
  2. rslearn/config/dataset.py +15 -16
  3. rslearn/dataset/dataset.py +28 -22
  4. rslearn/lightning_cli.py +22 -11
  5. rslearn/main.py +1 -1
  6. rslearn/models/anysat.py +35 -33
  7. rslearn/models/attention_pooling.py +177 -0
  8. rslearn/models/clip.py +5 -2
  9. rslearn/models/component.py +12 -0
  10. rslearn/models/croma.py +11 -3
  11. rslearn/models/dinov3.py +2 -1
  12. rslearn/models/faster_rcnn.py +2 -1
  13. rslearn/models/galileo/galileo.py +58 -31
  14. rslearn/models/module_wrapper.py +6 -1
  15. rslearn/models/molmo.py +4 -2
  16. rslearn/models/olmoearth_pretrain/model.py +206 -51
  17. rslearn/models/olmoearth_pretrain/norm.py +5 -3
  18. rslearn/models/panopticon.py +3 -1
  19. rslearn/models/presto/presto.py +45 -15
  20. rslearn/models/prithvi.py +9 -7
  21. rslearn/models/sam2_enc.py +3 -1
  22. rslearn/models/satlaspretrain.py +4 -1
  23. rslearn/models/simple_time_series.py +43 -17
  24. rslearn/models/ssl4eo_s12.py +19 -14
  25. rslearn/models/swin.py +3 -1
  26. rslearn/models/terramind.py +5 -4
  27. rslearn/train/all_patches_dataset.py +96 -28
  28. rslearn/train/dataset.py +102 -53
  29. rslearn/train/model_context.py +35 -1
  30. rslearn/train/scheduler.py +15 -0
  31. rslearn/train/tasks/classification.py +8 -2
  32. rslearn/train/tasks/detection.py +3 -2
  33. rslearn/train/tasks/multi_task.py +2 -3
  34. rslearn/train/tasks/per_pixel_regression.py +14 -5
  35. rslearn/train/tasks/regression.py +8 -2
  36. rslearn/train/tasks/segmentation.py +13 -4
  37. rslearn/train/tasks/task.py +2 -2
  38. rslearn/train/transforms/concatenate.py +45 -5
  39. rslearn/train/transforms/crop.py +22 -8
  40. rslearn/train/transforms/flip.py +13 -5
  41. rslearn/train/transforms/mask.py +11 -2
  42. rslearn/train/transforms/normalize.py +46 -15
  43. rslearn/train/transforms/pad.py +15 -3
  44. rslearn/train/transforms/resize.py +83 -0
  45. rslearn/train/transforms/select_bands.py +11 -2
  46. rslearn/train/transforms/sentinel1.py +18 -3
  47. rslearn/utils/geometry.py +73 -0
  48. rslearn/utils/jsonargparse.py +66 -0
  49. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/METADATA +1 -1
  50. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/RECORD +55 -53
  51. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/WHEEL +0 -0
  52. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/entry_points.txt +0 -0
  53. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/licenses/LICENSE +0 -0
  54. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/licenses/NOTICE +0 -0
  55. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/top_level.txt +0 -0
rslearn/train/dataset.py CHANGED
@@ -8,6 +8,7 @@ import random
8
8
  import tempfile
9
9
  import time
10
10
  import uuid
11
+ from datetime import datetime
11
12
  from typing import Any
12
13
 
13
14
  import torch
@@ -19,12 +20,18 @@ from rslearn.config import (
19
20
  DType,
20
21
  LayerConfig,
21
22
  )
23
+ from rslearn.data_sources.data_source import Item
22
24
  from rslearn.dataset.dataset import Dataset
23
25
  from rslearn.dataset.storage.file import FileWindowStorage
24
- from rslearn.dataset.window import Window, get_layer_and_group_from_dir_name
26
+ from rslearn.dataset.window import (
27
+ Window,
28
+ WindowLayerData,
29
+ get_layer_and_group_from_dir_name,
30
+ )
25
31
  from rslearn.log_utils import get_logger
32
+ from rslearn.train.model_context import RasterImage
26
33
  from rslearn.utils.feature import Feature
27
- from rslearn.utils.geometry import PixelBounds
34
+ from rslearn.utils.geometry import PixelBounds, ResolutionFactor
28
35
  from rslearn.utils.mp import star_imap_unordered
29
36
 
30
37
  from .model_context import SampleMetadata
@@ -130,6 +137,10 @@ class DataInput:
130
137
  """Specification of a piece of data from a window that is needed for training.
131
138
 
132
139
  The DataInput includes which layer(s) the data can be obtained from for each window.
140
+
141
+ Note that this class is not a dataclass because jsonargparse does not play well
142
+ with dataclasses without enabling specialized options which we have not validated
143
+ will work with the rest of our code.
133
144
  """
134
145
 
135
146
  def __init__(
@@ -143,7 +154,9 @@ class DataInput:
143
154
  dtype: DType = DType.FLOAT32,
144
155
  load_all_layers: bool = False,
145
156
  load_all_item_groups: bool = False,
146
- ) -> None:
157
+ resolution_factor: ResolutionFactor = ResolutionFactor(),
158
+ resampling: Resampling = Resampling.nearest,
159
+ ):
147
160
  """Initialize a new DataInput.
148
161
 
149
162
  Args:
@@ -166,6 +179,11 @@ class DataInput:
166
179
  are reading from. By default, we assume the specified layer name is of
167
180
  the form "{layer_name}.{group_idx}" and read that item group only. With
168
181
  this option enabled, we ignore the group_idx and read all item groups.
182
+ resolution_factor: controls the resolution at which raster data is loaded for training.
183
+ By default (factor=1), data is loaded at the window resolution.
184
+ E.g. for a 64x64 window at 10 m/pixel with resolution_factor=1/2,
185
+ the resulting tensor is 32x32 (covering the same geographic area at 20 m/pixel).
186
+ resampling: resampling method (default nearest neighbor).
169
187
  """
170
188
  self.data_type = data_type
171
189
  self.layers = layers
@@ -176,6 +194,8 @@ class DataInput:
176
194
  self.dtype = dtype
177
195
  self.load_all_layers = load_all_layers
178
196
  self.load_all_item_groups = load_all_item_groups
197
+ self.resolution_factor = resolution_factor
198
+ self.resampling = resampling
179
199
 
180
200
 
181
201
  def read_raster_layer_for_data_input(
@@ -185,7 +205,8 @@ def read_raster_layer_for_data_input(
185
205
  group_idx: int,
186
206
  layer_config: LayerConfig,
187
207
  data_input: DataInput,
188
- ) -> torch.Tensor:
208
+ layer_data: WindowLayerData | None,
209
+ ) -> tuple[torch.Tensor, tuple[datetime, datetime] | None]:
189
210
  """Read a raster layer for a DataInput.
190
211
 
191
212
  This scans the available rasters for the layer at the window to determine which
@@ -198,9 +219,11 @@ def read_raster_layer_for_data_input(
198
219
  group_idx: the item group.
199
220
  layer_config: the layer configuration.
200
221
  data_input: the DataInput that specifies the bands and dtype.
222
+ layer_data: the WindowLayerData associated with this layer and window.
201
223
 
202
224
  Returns:
203
- tensor containing raster data.
225
+ RasterImage containing raster data and the timestamp associated
226
+ with that data.
204
227
  """
205
228
  # See what different sets of bands we need to read to get all the
206
229
  # configured bands.
@@ -233,15 +256,23 @@ def read_raster_layer_for_data_input(
233
256
  + f"window {window.name} layer {layer_name} group {group_idx}"
234
257
  )
235
258
 
259
+ # Get the projection and bounds to read under (multiply window resolution # by
260
+ # the specified resolution factor).
261
+ final_projection = data_input.resolution_factor.multiply_projection(
262
+ window.projection
263
+ )
264
+ final_bounds = data_input.resolution_factor.multiply_bounds(bounds)
265
+
236
266
  image = torch.zeros(
237
- (len(needed_bands), bounds[3] - bounds[1], bounds[2] - bounds[0]),
267
+ (
268
+ len(needed_bands),
269
+ final_bounds[3] - final_bounds[1],
270
+ final_bounds[2] - final_bounds[0],
271
+ ),
238
272
  dtype=get_torch_dtype(data_input.dtype),
239
273
  )
240
274
 
241
275
  for band_set, src_indexes, dst_indexes in needed_sets_and_indexes:
242
- final_projection, final_bounds = band_set.get_final_projection_and_bounds(
243
- window.projection, bounds
244
- )
245
276
  if band_set.format is None:
246
277
  raise ValueError(f"No format specified for {layer_name}")
247
278
  raster_format = band_set.instantiate_raster_format()
@@ -249,49 +280,48 @@ def read_raster_layer_for_data_input(
249
280
  layer_name, band_set.bands, group_idx=group_idx
250
281
  )
251
282
 
252
- # Previously we always read in the native projection of the data, and then
253
- # zoom in or out (the resolution must be a power of two off) to match the
254
- # window's resolution.
255
- # However, this fails if the bounds are not multiples of the resolution factor.
256
- # So we fallback to reading directly in the window projection if that is the
257
- # case (which may be a bit slower).
258
- is_bounds_zoomable = True
259
- if band_set.zoom_offset < 0:
260
- zoom_factor = 2 ** (-band_set.zoom_offset)
261
- is_bounds_zoomable = (final_bounds[2] - final_bounds[0]) * zoom_factor == (
262
- bounds[2] - bounds[0]
263
- ) and (final_bounds[3] - final_bounds[1]) * zoom_factor == (
264
- bounds[3] - bounds[1]
265
- )
266
-
267
- if is_bounds_zoomable:
268
- src = raster_format.decode_raster(
269
- raster_dir, final_projection, final_bounds
270
- )
271
-
272
- # Resize to patch size if needed.
273
- # This is for band sets that are stored at a lower resolution.
274
- # Here we assume that it is a multiple.
275
- if src.shape[1:3] != image.shape[1:3]:
276
- if src.shape[1] < image.shape[1]:
277
- factor = image.shape[1] // src.shape[1]
278
- src = src.repeat(repeats=factor, axis=1).repeat(
279
- repeats=factor, axis=2
280
- )
281
- else:
282
- factor = src.shape[1] // image.shape[1]
283
- src = src[:, ::factor, ::factor]
284
-
285
- else:
286
- src = raster_format.decode_raster(
287
- raster_dir, window.projection, bounds, resampling=Resampling.nearest
288
- )
283
+ # TODO: previously we try to read based on band_set.zoom_offset when possible,
284
+ # and handle zooming in with torch.repeat (if resampling method is nearest
285
+ # neighbor). However, we have not benchmarked whether this actually improves
286
+ # data loading speed, so for simplicity, for now we let rasterio handle the
287
+ # resampling. If it really is much faster to handle it via torch, then it may
288
+ # make sense to bring back that functionality.
289
289
 
290
+ src = raster_format.decode_raster(
291
+ raster_dir, final_projection, final_bounds, resampling=Resampling.nearest
292
+ )
290
293
  image[dst_indexes, :, :] = torch.as_tensor(
291
294
  src[src_indexes, :, :].astype(data_input.dtype.get_numpy_dtype())
292
295
  )
293
296
 
294
- return image
297
+ # add the timestamp. this is a tuple defining the start and end of the time range.
298
+ time_range = None
299
+ if layer_data is not None:
300
+ item = Item.deserialize(layer_data.serialized_item_groups[group_idx][0])
301
+ if item.geometry.time_range is not None:
302
+ # we assume if one layer data has a geometry & time range, all of them do
303
+ time_ranges = [
304
+ (
305
+ datetime.fromisoformat(
306
+ Item.deserialize(
307
+ layer_data.serialized_item_groups[group_idx][idx]
308
+ ).geometry.time_range[0] # type: ignore
309
+ ),
310
+ datetime.fromisoformat(
311
+ Item.deserialize(
312
+ layer_data.serialized_item_groups[group_idx][idx]
313
+ ).geometry.time_range[1] # type: ignore
314
+ ),
315
+ )
316
+ for idx in range(len(layer_data.serialized_item_groups[group_idx]))
317
+ ]
318
+ # take the min and max
319
+ time_range = (
320
+ min([t[0] for t in time_ranges]),
321
+ max([t[1] for t in time_ranges]),
322
+ )
323
+
324
+ return image, time_range
295
325
 
296
326
 
297
327
  def read_data_input(
@@ -300,7 +330,7 @@ def read_data_input(
300
330
  bounds: PixelBounds,
301
331
  data_input: DataInput,
302
332
  rng: random.Random,
303
- ) -> torch.Tensor | list[Feature]:
333
+ ) -> RasterImage | list[Feature]:
304
334
  """Read the data specified by the DataInput from the window.
305
335
 
306
336
  Args:
@@ -342,15 +372,34 @@ def read_data_input(
342
372
  layers_to_read = [rng.choice(layer_options)]
343
373
 
344
374
  if data_input.data_type == "raster":
375
+ # load it once here
376
+ layer_datas = window.load_layer_datas()
345
377
  images: list[torch.Tensor] = []
378
+ time_ranges: list[tuple[datetime, datetime] | None] = []
346
379
  for layer_name, group_idx in layers_to_read:
347
380
  layer_config = dataset.layers[layer_name]
348
- images.append(
349
- read_raster_layer_for_data_input(
350
- window, bounds, layer_name, group_idx, layer_config, data_input
351
- )
381
+ image, time_range = read_raster_layer_for_data_input(
382
+ window,
383
+ bounds,
384
+ layer_name,
385
+ group_idx,
386
+ layer_config,
387
+ data_input,
388
+ # some layers (e.g. "label_raster") won't have associated
389
+ # layer datas
390
+ layer_datas[layer_name] if layer_name in layer_datas else None,
352
391
  )
353
- return torch.cat(images, dim=0)
392
+ if len(time_ranges) > 0:
393
+ if type(time_ranges[-1]) is not type(time_range):
394
+ raise ValueError(
395
+ f"All time ranges should be datetime tuples or None. Got {type(time_range)} amd {type(time_ranges[-1])}"
396
+ )
397
+ images.append(image)
398
+ time_ranges.append(time_range)
399
+ return RasterImage(
400
+ torch.stack(images, dim=1),
401
+ time_ranges if time_ranges[0] is not None else None, # type: ignore
402
+ )
354
403
 
355
404
  elif data_input.data_type == "vector":
356
405
  # We don't really support time series for vector data currently, we just
@@ -10,6 +10,40 @@ import torch
10
10
  from rslearn.utils.geometry import PixelBounds, Projection
11
11
 
12
12
 
13
+ @dataclass
14
+ class RasterImage:
15
+ """A raster image is a torch.tensor containing the images and their associated timestamps."""
16
+
17
+ # image is a 4D CTHW tensor
18
+ image: torch.Tensor
19
+ # if timestamps is not None, len(timestamps) must match the T dimension of the tensor
20
+ timestamps: list[tuple[datetime, datetime]] | None = None
21
+
22
+ @property
23
+ def shape(self) -> torch.Size:
24
+ """The shape of the image."""
25
+ return self.image.shape
26
+
27
+ def dim(self) -> int:
28
+ """The dim of the image."""
29
+ return self.image.dim()
30
+
31
+ @property
32
+ def dtype(self) -> torch.dtype:
33
+ """The image dtype."""
34
+ return self.image.dtype
35
+
36
+ def single_ts_to_chw_tensor(self) -> torch.Tensor:
37
+ """Single timestep models expect single timestep inputs.
38
+
39
+ This function (1) checks this raster image only has 1 timestep and
40
+ (2) returns the tensor for that (single) timestep (going from CTHW to CHW).
41
+ """
42
+ if self.image.shape[1] != 1:
43
+ raise ValueError(f"Expected a single timestep, got {self.image.shape[1]}")
44
+ return self.image[:, 0]
45
+
46
+
13
47
  @dataclass
14
48
  class SampleMetadata:
15
49
  """Metadata pertaining to an example."""
@@ -32,7 +66,7 @@ class ModelContext:
32
66
  """Context to pass to all model components."""
33
67
 
34
68
  # One input dict per example in the batch.
35
- inputs: list[dict[str, torch.Tensor]]
69
+ inputs: list[dict[str, torch.Tensor | RasterImage]]
36
70
  # One SampleMetadata per example in the batch.
37
71
  metadatas: list[SampleMetadata]
38
72
  # Arbitrary dict that components can add to.
@@ -8,6 +8,7 @@ from torch.optim.lr_scheduler import (
8
8
  CosineAnnealingLR,
9
9
  CosineAnnealingWarmRestarts,
10
10
  LRScheduler,
11
+ MultiStepLR,
11
12
  ReduceLROnPlateau,
12
13
  )
13
14
 
@@ -50,6 +51,20 @@ class PlateauScheduler(SchedulerFactory):
50
51
  return ReduceLROnPlateau(optimizer, **self.get_kwargs())
51
52
 
52
53
 
54
+ @dataclass
55
+ class MultiStepScheduler(SchedulerFactory):
56
+ """Step learning rate scheduler."""
57
+
58
+ milestones: list[int]
59
+ gamma: float | None = None
60
+ last_epoch: int | None = None
61
+
62
+ def build(self, optimizer: Optimizer) -> LRScheduler:
63
+ """Build the ReduceLROnPlateau scheduler."""
64
+ super().build(optimizer)
65
+ return MultiStepLR(optimizer, **self.get_kwargs())
66
+
67
+
53
68
  @dataclass
54
69
  class CosineAnnealingScheduler(SchedulerFactory):
55
70
  """Cosine annealing learning rate scheduler."""
@@ -16,7 +16,12 @@ from torchmetrics.classification import (
16
16
  )
17
17
 
18
18
  from rslearn.models.component import FeatureVector, Predictor
19
- from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
19
+ from rslearn.train.model_context import (
20
+ ModelContext,
21
+ ModelOutput,
22
+ RasterImage,
23
+ SampleMetadata,
24
+ )
20
25
  from rslearn.utils import Feature, STGeometry
21
26
 
22
27
  from .task import BasicTask
@@ -99,7 +104,7 @@ class ClassificationTask(BasicTask):
99
104
 
100
105
  def process_inputs(
101
106
  self,
102
- raw_inputs: dict[str, torch.Tensor | list[Feature]],
107
+ raw_inputs: dict[str, RasterImage | list[Feature]],
103
108
  metadata: SampleMetadata,
104
109
  load_targets: bool = True,
105
110
  ) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -118,6 +123,7 @@ class ClassificationTask(BasicTask):
118
123
  return {}, {}
119
124
 
120
125
  data = raw_inputs["targets"]
126
+ assert isinstance(data, list)
121
127
  for feat in data:
122
128
  if feat.properties is None:
123
129
  continue
@@ -12,7 +12,7 @@ import torchmetrics.classification
12
12
  import torchvision
13
13
  from torchmetrics import Metric, MetricCollection
14
14
 
15
- from rslearn.train.model_context import SampleMetadata
15
+ from rslearn.train.model_context import RasterImage, SampleMetadata
16
16
  from rslearn.utils import Feature, STGeometry
17
17
 
18
18
  from .task import BasicTask
@@ -127,7 +127,7 @@ class DetectionTask(BasicTask):
127
127
 
128
128
  def process_inputs(
129
129
  self,
130
- raw_inputs: dict[str, torch.Tensor | list[Feature]],
130
+ raw_inputs: dict[str, RasterImage | list[Feature]],
131
131
  metadata: SampleMetadata,
132
132
  load_targets: bool = True,
133
133
  ) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -152,6 +152,7 @@ class DetectionTask(BasicTask):
152
152
  valid = 1
153
153
 
154
154
  data = raw_inputs["targets"]
155
+ assert isinstance(data, list)
155
156
  for feat in data:
156
157
  if feat.properties is None:
157
158
  continue
@@ -3,10 +3,9 @@
3
3
  from typing import Any
4
4
 
5
5
  import numpy.typing as npt
6
- import torch
7
6
  from torchmetrics import Metric, MetricCollection
8
7
 
9
- from rslearn.train.model_context import SampleMetadata
8
+ from rslearn.train.model_context import RasterImage, SampleMetadata
10
9
  from rslearn.utils import Feature
11
10
 
12
11
  from .task import Task
@@ -30,7 +29,7 @@ class MultiTask(Task):
30
29
 
31
30
  def process_inputs(
32
31
  self,
33
- raw_inputs: dict[str, torch.Tensor | list[Feature]],
32
+ raw_inputs: dict[str, RasterImage | list[Feature]],
34
33
  metadata: SampleMetadata,
35
34
  load_targets: bool = True,
36
35
  ) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -9,7 +9,12 @@ import torchmetrics
9
9
  from torchmetrics import Metric, MetricCollection
10
10
 
11
11
  from rslearn.models.component import FeatureMaps, Predictor
12
- from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
12
+ from rslearn.train.model_context import (
13
+ ModelContext,
14
+ ModelOutput,
15
+ RasterImage,
16
+ SampleMetadata,
17
+ )
13
18
  from rslearn.utils.feature import Feature
14
19
 
15
20
  from .task import BasicTask
@@ -42,7 +47,7 @@ class PerPixelRegressionTask(BasicTask):
42
47
 
43
48
  def process_inputs(
44
49
  self,
45
- raw_inputs: dict[str, torch.Tensor],
50
+ raw_inputs: dict[str, RasterImage | list[Feature]],
46
51
  metadata: SampleMetadata,
47
52
  load_targets: bool = True,
48
53
  ) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -60,11 +65,15 @@ class PerPixelRegressionTask(BasicTask):
60
65
  if not load_targets:
61
66
  return {}, {}
62
67
 
63
- assert raw_inputs["targets"].shape[0] == 1
64
- labels = raw_inputs["targets"][0, :, :].float() * self.scale_factor
68
+ assert isinstance(raw_inputs["targets"], RasterImage)
69
+ assert raw_inputs["targets"].image.shape[0] == 1
70
+ assert raw_inputs["targets"].image.shape[1] == 1
71
+ labels = raw_inputs["targets"].image[0, 0, :, :].float() * self.scale_factor
65
72
 
66
73
  if self.nodata_value is not None:
67
- valid = (raw_inputs["targets"][0, :, :] != self.nodata_value).float()
74
+ valid = (
75
+ raw_inputs["targets"].image[0, 0, :, :] != self.nodata_value
76
+ ).float()
68
77
  else:
69
78
  valid = torch.ones(labels.shape, dtype=torch.float32)
70
79
 
@@ -11,7 +11,12 @@ from PIL import Image, ImageDraw
11
11
  from torchmetrics import Metric, MetricCollection
12
12
 
13
13
  from rslearn.models.component import FeatureVector, Predictor
14
- from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
14
+ from rslearn.train.model_context import (
15
+ ModelContext,
16
+ ModelOutput,
17
+ RasterImage,
18
+ SampleMetadata,
19
+ )
15
20
  from rslearn.utils.feature import Feature
16
21
  from rslearn.utils.geometry import STGeometry
17
22
 
@@ -63,7 +68,7 @@ class RegressionTask(BasicTask):
63
68
 
64
69
  def process_inputs(
65
70
  self,
66
- raw_inputs: dict[str, torch.Tensor | list[Feature]],
71
+ raw_inputs: dict[str, RasterImage | list[Feature]],
67
72
  metadata: SampleMetadata,
68
73
  load_targets: bool = True,
69
74
  ) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -82,6 +87,7 @@ class RegressionTask(BasicTask):
82
87
  return {}, {}
83
88
 
84
89
  data = raw_inputs["targets"]
90
+ assert isinstance(data, list)
85
91
  for feat in data:
86
92
  if feat.properties is None or self.filters is None:
87
93
  continue
@@ -1,5 +1,6 @@
1
1
  """Segmentation task."""
2
2
 
3
+ from collections.abc import Mapping
3
4
  from typing import Any
4
5
 
5
6
  import numpy as np
@@ -9,7 +10,13 @@ import torchmetrics.classification
9
10
  from torchmetrics import Metric, MetricCollection
10
11
 
11
12
  from rslearn.models.component import FeatureMaps, Predictor
12
- from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
13
+ from rslearn.train.model_context import (
14
+ ModelContext,
15
+ ModelOutput,
16
+ RasterImage,
17
+ SampleMetadata,
18
+ )
19
+ from rslearn.utils import Feature
13
20
 
14
21
  from .task import BasicTask
15
22
 
@@ -108,7 +115,7 @@ class SegmentationTask(BasicTask):
108
115
 
109
116
  def process_inputs(
110
117
  self,
111
- raw_inputs: dict[str, torch.Tensor],
118
+ raw_inputs: Mapping[str, RasterImage | list[Feature]],
112
119
  metadata: SampleMetadata,
113
120
  load_targets: bool = True,
114
121
  ) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -126,8 +133,10 @@ class SegmentationTask(BasicTask):
126
133
  if not load_targets:
127
134
  return {}, {}
128
135
 
129
- assert raw_inputs["targets"].shape[0] == 1
130
- labels = raw_inputs["targets"][0, :, :].long()
136
+ assert isinstance(raw_inputs["targets"], RasterImage)
137
+ assert raw_inputs["targets"].image.shape[0] == 1
138
+ assert raw_inputs["targets"].image.shape[1] == 1
139
+ labels = raw_inputs["targets"].image[0, 0, :, :].long()
131
140
 
132
141
  if self.class_id_mapping is not None:
133
142
  new_labels = labels.clone()
@@ -7,7 +7,7 @@ import numpy.typing as npt
7
7
  import torch
8
8
  from torchmetrics import MetricCollection
9
9
 
10
- from rslearn.train.model_context import SampleMetadata
10
+ from rslearn.train.model_context import RasterImage, SampleMetadata
11
11
  from rslearn.utils import Feature
12
12
 
13
13
 
@@ -21,7 +21,7 @@ class Task:
21
21
 
22
22
  def process_inputs(
23
23
  self,
24
- raw_inputs: dict[str, torch.Tensor | list[Feature]],
24
+ raw_inputs: dict[str, RasterImage | list[Feature]],
25
25
  metadata: SampleMetadata,
26
26
  load_targets: bool = True,
27
27
  ) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -1,12 +1,23 @@
1
1
  """Concatenate bands across multiple image inputs."""
2
2
 
3
+ from datetime import datetime
4
+ from enum import Enum
3
5
  from typing import Any
4
6
 
5
7
  import torch
6
8
 
9
+ from rslearn.train.model_context import RasterImage
10
+
7
11
  from .transform import Transform, read_selector, write_selector
8
12
 
9
13
 
14
+ class ConcatenateDim(Enum):
15
+ """Enum for concatenation dimensions."""
16
+
17
+ CHANNEL = 0
18
+ TIME = 1
19
+
20
+
10
21
  class Concatenate(Transform):
11
22
  """Concatenate bands across multiple image inputs."""
12
23
 
@@ -14,6 +25,7 @@ class Concatenate(Transform):
14
25
  self,
15
26
  selections: dict[str, list[int]],
16
27
  output_selector: str,
28
+ concatenate_dim: ConcatenateDim | int = ConcatenateDim.TIME,
17
29
  ):
18
30
  """Initialize a new Concatenate.
19
31
 
@@ -21,10 +33,16 @@ class Concatenate(Transform):
21
33
  selections: map from selector to list of band indices in that input to
22
34
  retain, or empty list to use all bands.
23
35
  output_selector: the output selector under which to save the concatenate image.
36
+ concatenate_dim: the dimension against which to concatenate the inputs
24
37
  """
25
38
  super().__init__()
26
39
  self.selections = selections
27
40
  self.output_selector = output_selector
41
+ self.concatenate_dim = (
42
+ concatenate_dim.value
43
+ if isinstance(concatenate_dim, ConcatenateDim)
44
+ else concatenate_dim
45
+ )
28
46
 
29
47
  def forward(
30
48
  self, input_dict: dict[str, Any], target_dict: dict[str, Any]
@@ -36,14 +54,36 @@ class Concatenate(Transform):
36
54
  target_dict: the target
37
55
 
38
56
  Returns:
39
- normalized (input_dicts, target_dicts) tuple
57
+ concatenated (input_dicts, target_dicts) tuple. If one of the
58
+ specified inputs is a RasterImage, a RasterImage will be returned.
59
+ Otherwise it will be a torch.Tensor.
40
60
  """
41
61
  images = []
62
+ return_raster_image: bool = False
63
+ timestamps: list[tuple[datetime, datetime]] | None = None
42
64
  for selector, wanted_bands in self.selections.items():
43
65
  image = read_selector(input_dict, target_dict, selector)
44
- if wanted_bands:
45
- image = image[wanted_bands, :, :]
46
- images.append(image)
47
- result = torch.concatenate(images, dim=0)
66
+ if isinstance(image, torch.Tensor):
67
+ if wanted_bands:
68
+ image = image[wanted_bands, :, :]
69
+ images.append(image)
70
+ elif isinstance(image, RasterImage):
71
+ return_raster_image = True
72
+ if wanted_bands:
73
+ images.append(image.image[wanted_bands, :, :])
74
+ else:
75
+ images.append(image.image)
76
+ if timestamps is None:
77
+ if image.timestamps is not None:
78
+ # assume all concatenated modalities have the same
79
+ # number of timestamps
80
+ timestamps = image.timestamps
81
+ if return_raster_image:
82
+ result = RasterImage(
83
+ torch.concatenate(images, dim=self.concatenate_dim),
84
+ timestamps=timestamps,
85
+ )
86
+ else:
87
+ result = torch.concatenate(images, dim=self.concatenate_dim)
48
88
  write_selector(input_dict, target_dict, self.output_selector, result)
49
89
  return input_dict, target_dict