rslearn 0.0.19__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. rslearn/models/anysat.py +35 -33
  2. rslearn/models/clip.py +5 -2
  3. rslearn/models/croma.py +11 -3
  4. rslearn/models/dinov3.py +2 -1
  5. rslearn/models/faster_rcnn.py +2 -1
  6. rslearn/models/galileo/galileo.py +58 -31
  7. rslearn/models/module_wrapper.py +6 -1
  8. rslearn/models/molmo.py +4 -2
  9. rslearn/models/olmoearth_pretrain/model.py +93 -29
  10. rslearn/models/olmoearth_pretrain/norm.py +5 -3
  11. rslearn/models/panopticon.py +3 -1
  12. rslearn/models/presto/presto.py +45 -15
  13. rslearn/models/prithvi.py +9 -7
  14. rslearn/models/sam2_enc.py +3 -1
  15. rslearn/models/satlaspretrain.py +4 -1
  16. rslearn/models/simple_time_series.py +36 -16
  17. rslearn/models/ssl4eo_s12.py +19 -14
  18. rslearn/models/swin.py +3 -1
  19. rslearn/models/terramind.py +5 -4
  20. rslearn/train/all_patches_dataset.py +34 -14
  21. rslearn/train/dataset.py +66 -10
  22. rslearn/train/model_context.py +35 -1
  23. rslearn/train/tasks/classification.py +8 -2
  24. rslearn/train/tasks/detection.py +3 -2
  25. rslearn/train/tasks/multi_task.py +2 -3
  26. rslearn/train/tasks/per_pixel_regression.py +14 -5
  27. rslearn/train/tasks/regression.py +8 -2
  28. rslearn/train/tasks/segmentation.py +13 -4
  29. rslearn/train/tasks/task.py +2 -2
  30. rslearn/train/transforms/concatenate.py +45 -5
  31. rslearn/train/transforms/crop.py +22 -8
  32. rslearn/train/transforms/flip.py +13 -5
  33. rslearn/train/transforms/mask.py +11 -2
  34. rslearn/train/transforms/normalize.py +46 -15
  35. rslearn/train/transforms/pad.py +15 -3
  36. rslearn/train/transforms/resize.py +18 -9
  37. rslearn/train/transforms/select_bands.py +11 -2
  38. rslearn/train/transforms/sentinel1.py +18 -3
  39. {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/METADATA +1 -1
  40. {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/RECORD +45 -45
  41. {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/WHEEL +0 -0
  42. {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/entry_points.txt +0 -0
  43. {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/licenses/LICENSE +0 -0
  44. {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/licenses/NOTICE +0 -0
  45. {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,12 @@ from torchmetrics.classification import (
16
16
  )
17
17
 
18
18
  from rslearn.models.component import FeatureVector, Predictor
19
- from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
19
+ from rslearn.train.model_context import (
20
+ ModelContext,
21
+ ModelOutput,
22
+ RasterImage,
23
+ SampleMetadata,
24
+ )
20
25
  from rslearn.utils import Feature, STGeometry
21
26
 
22
27
  from .task import BasicTask
@@ -99,7 +104,7 @@ class ClassificationTask(BasicTask):
99
104
 
100
105
  def process_inputs(
101
106
  self,
102
- raw_inputs: dict[str, torch.Tensor | list[Feature]],
107
+ raw_inputs: dict[str, RasterImage | list[Feature]],
103
108
  metadata: SampleMetadata,
104
109
  load_targets: bool = True,
105
110
  ) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -118,6 +123,7 @@ class ClassificationTask(BasicTask):
118
123
  return {}, {}
119
124
 
120
125
  data = raw_inputs["targets"]
126
+ assert isinstance(data, list)
121
127
  for feat in data:
122
128
  if feat.properties is None:
123
129
  continue
@@ -12,7 +12,7 @@ import torchmetrics.classification
12
12
  import torchvision
13
13
  from torchmetrics import Metric, MetricCollection
14
14
 
15
- from rslearn.train.model_context import SampleMetadata
15
+ from rslearn.train.model_context import RasterImage, SampleMetadata
16
16
  from rslearn.utils import Feature, STGeometry
17
17
 
18
18
  from .task import BasicTask
@@ -127,7 +127,7 @@ class DetectionTask(BasicTask):
127
127
 
128
128
  def process_inputs(
129
129
  self,
130
- raw_inputs: dict[str, torch.Tensor | list[Feature]],
130
+ raw_inputs: dict[str, RasterImage | list[Feature]],
131
131
  metadata: SampleMetadata,
132
132
  load_targets: bool = True,
133
133
  ) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -152,6 +152,7 @@ class DetectionTask(BasicTask):
152
152
  valid = 1
153
153
 
154
154
  data = raw_inputs["targets"]
155
+ assert isinstance(data, list)
155
156
  for feat in data:
156
157
  if feat.properties is None:
157
158
  continue
@@ -3,10 +3,9 @@
3
3
  from typing import Any
4
4
 
5
5
  import numpy.typing as npt
6
- import torch
7
6
  from torchmetrics import Metric, MetricCollection
8
7
 
9
- from rslearn.train.model_context import SampleMetadata
8
+ from rslearn.train.model_context import RasterImage, SampleMetadata
10
9
  from rslearn.utils import Feature
11
10
 
12
11
  from .task import Task
@@ -30,7 +29,7 @@ class MultiTask(Task):
30
29
 
31
30
  def process_inputs(
32
31
  self,
33
- raw_inputs: dict[str, torch.Tensor | list[Feature]],
32
+ raw_inputs: dict[str, RasterImage | list[Feature]],
34
33
  metadata: SampleMetadata,
35
34
  load_targets: bool = True,
36
35
  ) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -9,7 +9,12 @@ import torchmetrics
9
9
  from torchmetrics import Metric, MetricCollection
10
10
 
11
11
  from rslearn.models.component import FeatureMaps, Predictor
12
- from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
12
+ from rslearn.train.model_context import (
13
+ ModelContext,
14
+ ModelOutput,
15
+ RasterImage,
16
+ SampleMetadata,
17
+ )
13
18
  from rslearn.utils.feature import Feature
14
19
 
15
20
  from .task import BasicTask
@@ -42,7 +47,7 @@ class PerPixelRegressionTask(BasicTask):
42
47
 
43
48
  def process_inputs(
44
49
  self,
45
- raw_inputs: dict[str, torch.Tensor],
50
+ raw_inputs: dict[str, RasterImage | list[Feature]],
46
51
  metadata: SampleMetadata,
47
52
  load_targets: bool = True,
48
53
  ) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -60,11 +65,15 @@ class PerPixelRegressionTask(BasicTask):
60
65
  if not load_targets:
61
66
  return {}, {}
62
67
 
63
- assert raw_inputs["targets"].shape[0] == 1
64
- labels = raw_inputs["targets"][0, :, :].float() * self.scale_factor
68
+ assert isinstance(raw_inputs["targets"], RasterImage)
69
+ assert raw_inputs["targets"].image.shape[0] == 1
70
+ assert raw_inputs["targets"].image.shape[1] == 1
71
+ labels = raw_inputs["targets"].image[0, 0, :, :].float() * self.scale_factor
65
72
 
66
73
  if self.nodata_value is not None:
67
- valid = (raw_inputs["targets"][0, :, :] != self.nodata_value).float()
74
+ valid = (
75
+ raw_inputs["targets"].image[0, 0, :, :] != self.nodata_value
76
+ ).float()
68
77
  else:
69
78
  valid = torch.ones(labels.shape, dtype=torch.float32)
70
79
 
@@ -11,7 +11,12 @@ from PIL import Image, ImageDraw
11
11
  from torchmetrics import Metric, MetricCollection
12
12
 
13
13
  from rslearn.models.component import FeatureVector, Predictor
14
- from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
14
+ from rslearn.train.model_context import (
15
+ ModelContext,
16
+ ModelOutput,
17
+ RasterImage,
18
+ SampleMetadata,
19
+ )
15
20
  from rslearn.utils.feature import Feature
16
21
  from rslearn.utils.geometry import STGeometry
17
22
 
@@ -63,7 +68,7 @@ class RegressionTask(BasicTask):
63
68
 
64
69
  def process_inputs(
65
70
  self,
66
- raw_inputs: dict[str, torch.Tensor | list[Feature]],
71
+ raw_inputs: dict[str, RasterImage | list[Feature]],
67
72
  metadata: SampleMetadata,
68
73
  load_targets: bool = True,
69
74
  ) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -82,6 +87,7 @@ class RegressionTask(BasicTask):
82
87
  return {}, {}
83
88
 
84
89
  data = raw_inputs["targets"]
90
+ assert isinstance(data, list)
85
91
  for feat in data:
86
92
  if feat.properties is None or self.filters is None:
87
93
  continue
@@ -1,5 +1,6 @@
1
1
  """Segmentation task."""
2
2
 
3
+ from collections.abc import Mapping
3
4
  from typing import Any
4
5
 
5
6
  import numpy as np
@@ -9,7 +10,13 @@ import torchmetrics.classification
9
10
  from torchmetrics import Metric, MetricCollection
10
11
 
11
12
  from rslearn.models.component import FeatureMaps, Predictor
12
- from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
13
+ from rslearn.train.model_context import (
14
+ ModelContext,
15
+ ModelOutput,
16
+ RasterImage,
17
+ SampleMetadata,
18
+ )
19
+ from rslearn.utils import Feature
13
20
 
14
21
  from .task import BasicTask
15
22
 
@@ -108,7 +115,7 @@ class SegmentationTask(BasicTask):
108
115
 
109
116
  def process_inputs(
110
117
  self,
111
- raw_inputs: dict[str, torch.Tensor],
118
+ raw_inputs: Mapping[str, RasterImage | list[Feature]],
112
119
  metadata: SampleMetadata,
113
120
  load_targets: bool = True,
114
121
  ) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -126,8 +133,10 @@ class SegmentationTask(BasicTask):
126
133
  if not load_targets:
127
134
  return {}, {}
128
135
 
129
- assert raw_inputs["targets"].shape[0] == 1
130
- labels = raw_inputs["targets"][0, :, :].long()
136
+ assert isinstance(raw_inputs["targets"], RasterImage)
137
+ assert raw_inputs["targets"].image.shape[0] == 1
138
+ assert raw_inputs["targets"].image.shape[1] == 1
139
+ labels = raw_inputs["targets"].image[0, 0, :, :].long()
131
140
 
132
141
  if self.class_id_mapping is not None:
133
142
  new_labels = labels.clone()
@@ -7,7 +7,7 @@ import numpy.typing as npt
7
7
  import torch
8
8
  from torchmetrics import MetricCollection
9
9
 
10
- from rslearn.train.model_context import SampleMetadata
10
+ from rslearn.train.model_context import RasterImage, SampleMetadata
11
11
  from rslearn.utils import Feature
12
12
 
13
13
 
@@ -21,7 +21,7 @@ class Task:
21
21
 
22
22
  def process_inputs(
23
23
  self,
24
- raw_inputs: dict[str, torch.Tensor | list[Feature]],
24
+ raw_inputs: dict[str, RasterImage | list[Feature]],
25
25
  metadata: SampleMetadata,
26
26
  load_targets: bool = True,
27
27
  ) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -1,12 +1,23 @@
1
1
  """Concatenate bands across multiple image inputs."""
2
2
 
3
+ from datetime import datetime
4
+ from enum import Enum
3
5
  from typing import Any
4
6
 
5
7
  import torch
6
8
 
9
+ from rslearn.train.model_context import RasterImage
10
+
7
11
  from .transform import Transform, read_selector, write_selector
8
12
 
9
13
 
14
+ class ConcatenateDim(Enum):
15
+ """Enum for concatenation dimensions."""
16
+
17
+ CHANNEL = 0
18
+ TIME = 1
19
+
20
+
10
21
  class Concatenate(Transform):
11
22
  """Concatenate bands across multiple image inputs."""
12
23
 
@@ -14,6 +25,7 @@ class Concatenate(Transform):
14
25
  self,
15
26
  selections: dict[str, list[int]],
16
27
  output_selector: str,
28
+ concatenate_dim: ConcatenateDim | int = ConcatenateDim.TIME,
17
29
  ):
18
30
  """Initialize a new Concatenate.
19
31
 
@@ -21,10 +33,16 @@ class Concatenate(Transform):
21
33
  selections: map from selector to list of band indices in that input to
22
34
  retain, or empty list to use all bands.
23
35
  output_selector: the output selector under which to save the concatenate image.
36
+ concatenate_dim: the dimension against which to concatenate the inputs
24
37
  """
25
38
  super().__init__()
26
39
  self.selections = selections
27
40
  self.output_selector = output_selector
41
+ self.concatenate_dim = (
42
+ concatenate_dim.value
43
+ if isinstance(concatenate_dim, ConcatenateDim)
44
+ else concatenate_dim
45
+ )
28
46
 
29
47
  def forward(
30
48
  self, input_dict: dict[str, Any], target_dict: dict[str, Any]
@@ -36,14 +54,36 @@ class Concatenate(Transform):
36
54
  target_dict: the target
37
55
 
38
56
  Returns:
39
- normalized (input_dicts, target_dicts) tuple
57
+ concatenated (input_dicts, target_dicts) tuple. If one of the
58
+ specified inputs is a RasterImage, a RasterImage will be returned.
59
+ Otherwise it will be a torch.Tensor.
40
60
  """
41
61
  images = []
62
+ return_raster_image: bool = False
63
+ timestamps: list[tuple[datetime, datetime]] | None = None
42
64
  for selector, wanted_bands in self.selections.items():
43
65
  image = read_selector(input_dict, target_dict, selector)
44
- if wanted_bands:
45
- image = image[wanted_bands, :, :]
46
- images.append(image)
47
- result = torch.concatenate(images, dim=0)
66
+ if isinstance(image, torch.Tensor):
67
+ if wanted_bands:
68
+ image = image[wanted_bands, :, :]
69
+ images.append(image)
70
+ elif isinstance(image, RasterImage):
71
+ return_raster_image = True
72
+ if wanted_bands:
73
+ images.append(image.image[wanted_bands, :, :])
74
+ else:
75
+ images.append(image.image)
76
+ if timestamps is None:
77
+ if image.timestamps is not None:
78
+ # assume all concatenated modalities have the same
79
+ # number of timestamps
80
+ timestamps = image.timestamps
81
+ if return_raster_image:
82
+ result = RasterImage(
83
+ torch.concatenate(images, dim=self.concatenate_dim),
84
+ timestamps=timestamps,
85
+ )
86
+ else:
87
+ result = torch.concatenate(images, dim=self.concatenate_dim)
48
88
  write_selector(input_dict, target_dict, self.output_selector, result)
49
89
  return input_dict, target_dict
@@ -5,6 +5,8 @@ from typing import Any
5
5
  import torch
6
6
  import torchvision
7
7
 
8
+ from rslearn.train.model_context import RasterImage
9
+
8
10
  from .transform import Transform, read_selector
9
11
 
10
12
 
@@ -69,7 +71,9 @@ class Crop(Transform):
69
71
  "remove_from_top": remove_from_top,
70
72
  }
71
73
 
72
- def apply_image(self, image: torch.Tensor, state: dict[str, Any]) -> torch.Tensor:
74
+ def apply_image(
75
+ self, image: RasterImage | torch.Tensor, state: dict[str, Any]
76
+ ) -> RasterImage | torch.Tensor:
73
77
  """Apply the sampled state on the specified image.
74
78
 
75
79
  Args:
@@ -80,13 +84,23 @@ class Crop(Transform):
80
84
  crop_size = state["crop_size"] * image.shape[-1] // image_shape[1]
81
85
  remove_from_left = state["remove_from_left"] * image.shape[-1] // image_shape[1]
82
86
  remove_from_top = state["remove_from_top"] * image.shape[-2] // image_shape[0]
83
- return torchvision.transforms.functional.crop(
84
- image,
85
- top=remove_from_top,
86
- left=remove_from_left,
87
- height=crop_size,
88
- width=crop_size,
89
- )
87
+ if isinstance(image, RasterImage):
88
+ image.image = torchvision.transforms.functional.crop(
89
+ image.image,
90
+ top=remove_from_top,
91
+ left=remove_from_left,
92
+ height=crop_size,
93
+ width=crop_size,
94
+ )
95
+ else:
96
+ image = torchvision.transforms.functional.crop(
97
+ image,
98
+ top=remove_from_top,
99
+ left=remove_from_left,
100
+ height=crop_size,
101
+ width=crop_size,
102
+ )
103
+ return image
90
104
 
91
105
  def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
92
106
  """Apply the sampled state on the specified image.
@@ -4,6 +4,8 @@ from typing import Any
4
4
 
5
5
  import torch
6
6
 
7
+ from rslearn.train.model_context import RasterImage
8
+
7
9
  from .transform import Transform
8
10
 
9
11
 
@@ -48,17 +50,23 @@ class Flip(Transform):
48
50
  "vertical": vertical,
49
51
  }
50
52
 
51
- def apply_image(self, image: torch.Tensor, state: dict[str, bool]) -> torch.Tensor:
53
+ def apply_image(self, image: RasterImage, state: dict[str, bool]) -> RasterImage:
52
54
  """Apply the sampled state on the specified image.
53
55
 
54
56
  Args:
55
57
  image: the image to transform.
56
58
  state: the sampled state.
57
59
  """
58
- if state["horizontal"]:
59
- image = torch.flip(image, dims=[-1])
60
- if state["vertical"]:
61
- image = torch.flip(image, dims=[-2])
60
+ if isinstance(image, RasterImage):
61
+ if state["horizontal"]:
62
+ image.image = torch.flip(image.image, dims=[-1])
63
+ if state["vertical"]:
64
+ image.image = torch.flip(image.image, dims=[-2])
65
+ elif isinstance(image, torch.Tensor):
66
+ if state["horizontal"]:
67
+ image = torch.flip(image, dims=[-1])
68
+ if state["vertical"]:
69
+ image = torch.flip(image, dims=[-2])
62
70
  return image
63
71
 
64
72
  def apply_boxes(
@@ -2,6 +2,7 @@
2
2
 
3
3
  import torch
4
4
 
5
+ from rslearn.train.model_context import RasterImage
5
6
  from rslearn.train.transforms.transform import Transform, read_selector
6
7
 
7
8
 
@@ -31,7 +32,9 @@ class Mask(Transform):
31
32
  self.mask_selector = mask_selector
32
33
  self.mask_value = mask_value
33
34
 
34
- def apply_image(self, image: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
35
+ def apply_image(
36
+ self, image: torch.Tensor | RasterImage, mask: torch.Tensor | RasterImage
37
+ ) -> torch.Tensor | RasterImage:
35
38
  """Apply the mask on the image.
36
39
 
37
40
  Args:
@@ -42,6 +45,9 @@ class Mask(Transform):
42
45
  masked image
43
46
  """
44
47
  # Tile the mask to have same number of bands as the image.
48
+ if isinstance(mask, RasterImage):
49
+ mask = mask.image
50
+
45
51
  if image.shape[0] != mask.shape[0]:
46
52
  if mask.shape[0] != 1:
47
53
  raise ValueError(
@@ -49,7 +55,10 @@ class Mask(Transform):
49
55
  )
50
56
  mask = mask.repeat(image.shape[0], 1, 1)
51
57
 
52
- image[mask == 0] = self.mask_value
58
+ if isinstance(image, torch.Tensor):
59
+ image[mask == 0] = self.mask_value
60
+ else:
61
+ image.image[mask == 0] = self.mask_value
53
62
  return image
54
63
 
55
64
  def forward(self, input_dict: dict, target_dict: dict) -> tuple[dict, dict]:
@@ -4,6 +4,8 @@ from typing import Any
4
4
 
5
5
  import torch
6
6
 
7
+ from rslearn.train.model_context import RasterImage
8
+
7
9
  from .transform import Transform
8
10
 
9
11
 
@@ -55,7 +57,9 @@ class Normalize(Transform):
55
57
  self.bands = torch.tensor(bands) if bands is not None else None
56
58
  self.num_bands = num_bands
57
59
 
58
- def apply_image(self, image: torch.Tensor) -> torch.Tensor:
60
+ def apply_image(
61
+ self, image: torch.Tensor | RasterImage
62
+ ) -> torch.Tensor | RasterImage:
59
63
  """Normalize the specified image.
60
64
 
61
65
  Args:
@@ -63,7 +67,7 @@ class Normalize(Transform):
63
67
  """
64
68
 
65
69
  def _repeat_mean_and_std(
66
- image_channels: int, num_bands: int | None
70
+ image_channels: int, num_bands: int | None, is_raster_image: bool
67
71
  ) -> tuple[torch.Tensor, torch.Tensor]:
68
72
  """Get mean and std tensor that are suitable for applying on the image."""
69
73
  # We only need to repeat the tensor if both of these are true:
@@ -74,9 +78,16 @@ class Normalize(Transform):
74
78
  if num_bands is None:
75
79
  return self.mean, self.std
76
80
  num_images = image_channels // num_bands
77
- return self.mean.repeat(num_images)[:, None, None], self.std.repeat(
78
- num_images
79
- )[:, None, None]
81
+ if is_raster_image:
82
+ # add an extra T dimension, CTHW
83
+ return self.mean.repeat(num_images)[
84
+ :, None, None, None
85
+ ], self.std.repeat(num_images)[:, None, None, None]
86
+ else:
87
+ # add an extra T dimension, CTHW
88
+ return self.mean.repeat(num_images)[:, None, None], self.std.repeat(
89
+ num_images
90
+ )[:, None, None]
80
91
 
81
92
  if self.bands is not None:
82
93
  # User has provided band indices to normalize.
@@ -96,20 +107,40 @@ class Normalize(Transform):
96
107
  # We use len(self.bands) here because that is how many bands per timestep
97
108
  # we are actually processing with the mean/std.
98
109
  mean, std = _repeat_mean_and_std(
99
- image_channels=len(band_indices), num_bands=len(self.bands)
110
+ image_channels=len(band_indices),
111
+ num_bands=len(self.bands),
112
+ is_raster_image=isinstance(image, RasterImage),
100
113
  )
101
- image[band_indices] = (image[band_indices] - mean) / std
102
- if self.valid_min is not None:
103
- image[band_indices] = torch.clamp(
104
- image[band_indices], min=self.valid_min, max=self.valid_max
105
- )
114
+ if isinstance(image, torch.Tensor):
115
+ image[band_indices] = (image[band_indices] - mean) / std
116
+ if self.valid_min is not None:
117
+ image[band_indices] = torch.clamp(
118
+ image[band_indices], min=self.valid_min, max=self.valid_max
119
+ )
120
+ else:
121
+ image.image[band_indices] = (image.image[band_indices] - mean) / std
122
+ if self.valid_min is not None:
123
+ image.image[band_indices] = torch.clamp(
124
+ image.image[band_indices],
125
+ min=self.valid_min,
126
+ max=self.valid_max,
127
+ )
106
128
  else:
107
129
  mean, std = _repeat_mean_and_std(
108
- image_channels=image.shape[0], num_bands=self.num_bands
130
+ image_channels=image.shape[0],
131
+ num_bands=self.num_bands,
132
+ is_raster_image=isinstance(image, RasterImage),
109
133
  )
110
- image = (image - mean) / std
111
- if self.valid_min is not None:
112
- image = torch.clamp(image, min=self.valid_min, max=self.valid_max)
134
+ if isinstance(image, torch.Tensor):
135
+ image = (image - mean) / std
136
+ if self.valid_min is not None:
137
+ image = torch.clamp(image, min=self.valid_min, max=self.valid_max)
138
+ else:
139
+ image.image = (image.image - mean) / std
140
+ if self.valid_min is not None:
141
+ image.image = torch.clamp(
142
+ image.image, min=self.valid_min, max=self.valid_max
143
+ )
113
144
  return image
114
145
 
115
146
  def forward(
@@ -5,6 +5,8 @@ from typing import Any
5
5
  import torch
6
6
  import torchvision
7
7
 
8
+ from rslearn.train.model_context import RasterImage
9
+
8
10
  from .transform import Transform
9
11
 
10
12
 
@@ -48,7 +50,9 @@ class Pad(Transform):
48
50
  """
49
51
  return {"size": torch.randint(low=self.size[0], high=self.size[1], size=())}
50
52
 
51
- def apply_image(self, image: torch.Tensor, state: dict[str, bool]) -> torch.Tensor:
53
+ def apply_image(
54
+ self, image: RasterImage | torch.Tensor, state: dict[str, bool]
55
+ ) -> RasterImage | torch.Tensor:
52
56
  """Apply the sampled state on the specified image.
53
57
 
54
58
  Args:
@@ -101,8 +105,16 @@ class Pad(Transform):
101
105
  horizontal_pad = (horizontal_half, horizontal_extra - horizontal_half)
102
106
  vertical_pad = (vertical_half, vertical_extra - vertical_half)
103
107
 
104
- image = apply_padding(image, True, horizontal_pad[0], horizontal_pad[1])
105
- image = apply_padding(image, False, vertical_pad[0], vertical_pad[1])
108
+ if isinstance(image, RasterImage):
109
+ image.image = apply_padding(
110
+ image.image, True, horizontal_pad[0], horizontal_pad[1]
111
+ )
112
+ image.image = apply_padding(
113
+ image.image, False, vertical_pad[0], vertical_pad[1]
114
+ )
115
+ else:
116
+ image = apply_padding(image, True, horizontal_pad[0], horizontal_pad[1])
117
+ image = apply_padding(image, False, vertical_pad[0], vertical_pad[1])
106
118
  return image
107
119
 
108
120
  def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
@@ -6,6 +6,8 @@ import torch
6
6
  import torchvision
7
7
  from torchvision.transforms import InterpolationMode
8
8
 
9
+ from rslearn.train.model_context import RasterImage
10
+
9
11
  from .transform import Transform
10
12
 
11
13
  INTERPOLATION_MODES = {
@@ -38,7 +40,9 @@ class Resize(Transform):
38
40
  self.selectors = selectors
39
41
  self.interpolation = INTERPOLATION_MODES[interpolation]
40
42
 
41
- def apply_resize(self, image: torch.Tensor) -> torch.Tensor:
43
+ def apply_resize(
44
+ self, image: torch.Tensor | RasterImage
45
+ ) -> torch.Tensor | RasterImage:
42
46
  """Apply resizing on the specified image.
43
47
 
44
48
  If the image is 2D, it is unsqueezed to 3D and then squeezed
@@ -47,16 +51,21 @@ class Resize(Transform):
47
51
  Args:
48
52
  image: the image to transform.
49
53
  """
50
- if image.dim() == 2:
51
- image = image.unsqueeze(0) # (H, W) -> (1, H, W)
52
- result = torchvision.transforms.functional.resize(
54
+ if isinstance(image, torch.Tensor):
55
+ if image.dim() == 2:
56
+ image = image.unsqueeze(0) # (H, W) -> (1, H, W)
57
+ result = torchvision.transforms.functional.resize(
58
+ image, self.target_size, self.interpolation
59
+ )
60
+ return result.squeeze(0) # (1, H, W) -> (H, W)
61
+ return torchvision.transforms.functional.resize(
53
62
  image, self.target_size, self.interpolation
54
63
  )
55
- return result.squeeze(0) # (1, H, W) -> (H, W)
56
-
57
- return torchvision.transforms.functional.resize(
58
- image, self.target_size, self.interpolation
59
- )
64
+ else:
65
+ image.image = torchvision.transforms.functional.resize(
66
+ image.image, self.target_size, self.interpolation
67
+ )
68
+ return image
60
69
 
61
70
  def forward(
62
71
  self, input_dict: dict[str, Any], target_dict: dict[str, Any]
@@ -2,6 +2,8 @@
2
2
 
3
3
  from typing import Any
4
4
 
5
+ from rslearn.train.model_context import RasterImage
6
+
5
7
  from .transform import Transform, read_selector, write_selector
6
8
 
7
9
 
@@ -49,6 +51,10 @@ class SelectBands(Transform):
49
51
  if self.num_bands_per_timestep is not None
50
52
  else image.shape[0]
51
53
  )
54
+ if isinstance(image, RasterImage):
55
+ assert num_bands_per_timestep == image.shape[0], (
56
+ "Expect a seperate dimension for timesteps in RasterImages."
57
+ )
52
58
 
53
59
  if image.shape[0] % num_bands_per_timestep != 0:
54
60
  raise ValueError(
@@ -62,6 +68,9 @@ class SelectBands(Transform):
62
68
  [(start_channel_idx + band_idx) for band_idx in self.band_indices]
63
69
  )
64
70
 
65
- result = image[wanted_bands]
66
- write_selector(input_dict, target_dict, self.output_selector, result)
71
+ if isinstance(image, RasterImage):
72
+ image.image = image.image[wanted_bands]
73
+ else:
74
+ image = image[wanted_bands]
75
+ write_selector(input_dict, target_dict, self.output_selector, image)
67
76
  return input_dict, target_dict