rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,21 @@
1
- """Normalization transforms."""
1
+ """Concatenate bands across multiple image inputs."""
2
+
3
+ from datetime import datetime
4
+ from enum import Enum
5
+ from typing import Any
2
6
 
3
7
  import torch
4
8
 
5
- from .transform import Transform
9
+ from rslearn.train.model_context import RasterImage
10
+
11
+ from .transform import Transform, read_selector, write_selector
12
+
13
+
14
+ class ConcatenateDim(Enum):
15
+ """Enum for concatenation dimensions."""
16
+
17
+ CHANNEL = 0
18
+ TIME = 1
6
19
 
7
20
 
8
21
  class Concatenate(Transform):
@@ -12,6 +25,7 @@ class Concatenate(Transform):
12
25
  self,
13
26
  selections: dict[str, list[int]],
14
27
  output_selector: str,
28
+ concatenate_dim: ConcatenateDim | int = ConcatenateDim.TIME,
15
29
  ):
16
30
  """Initialize a new Concatenate.
17
31
 
@@ -19,12 +33,20 @@ class Concatenate(Transform):
19
33
  selections: map from selector to list of band indices in that input to
20
34
  retain, or empty list to use all bands.
21
35
  output_selector: the output selector under which to save the concatenate image.
36
+ concatenate_dim: the dimension against which to concatenate the inputs
22
37
  """
23
38
  super().__init__()
24
39
  self.selections = selections
25
40
  self.output_selector = output_selector
41
+ self.concatenate_dim = (
42
+ concatenate_dim.value
43
+ if isinstance(concatenate_dim, ConcatenateDim)
44
+ else concatenate_dim
45
+ )
26
46
 
27
- def forward(self, input_dict, target_dict):
47
+ def forward(
48
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
49
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
28
50
  """Apply concatenation over the inputs and targets.
29
51
 
30
52
  Args:
@@ -32,14 +54,36 @@ class Concatenate(Transform):
32
54
  target_dict: the target
33
55
 
34
56
  Returns:
35
- normalized (input_dicts, target_dicts) tuple
57
+ concatenated (input_dicts, target_dicts) tuple. If one of the
58
+ specified inputs is a RasterImage, a RasterImage will be returned.
59
+ Otherwise it will be a torch.Tensor.
36
60
  """
37
61
  images = []
62
+ return_raster_image: bool = False
63
+ timestamps: list[tuple[datetime, datetime]] | None = None
38
64
  for selector, wanted_bands in self.selections.items():
39
- image = self.read_selector(input_dict, target_dict, selector)
40
- if wanted_bands:
41
- image = image[wanted_bands, :, :]
42
- images.append(image)
43
- result = torch.concatenate(images, dim=0)
44
- self.write_selector(input_dict, target_dict, self.output_selector, result)
65
+ image = read_selector(input_dict, target_dict, selector)
66
+ if isinstance(image, torch.Tensor):
67
+ if wanted_bands:
68
+ image = image[wanted_bands, :, :]
69
+ images.append(image)
70
+ elif isinstance(image, RasterImage):
71
+ return_raster_image = True
72
+ if wanted_bands:
73
+ images.append(image.image[wanted_bands, :, :])
74
+ else:
75
+ images.append(image.image)
76
+ if timestamps is None:
77
+ if image.timestamps is not None:
78
+ # assume all concatenated modalities have the same
79
+ # number of timestamps
80
+ timestamps = image.timestamps
81
+ if return_raster_image:
82
+ result = RasterImage(
83
+ torch.concatenate(images, dim=self.concatenate_dim),
84
+ timestamps=timestamps,
85
+ )
86
+ else:
87
+ result = torch.concatenate(images, dim=self.concatenate_dim)
88
+ write_selector(input_dict, target_dict, self.output_selector, result)
45
89
  return input_dict, target_dict
@@ -5,7 +5,9 @@ from typing import Any
5
5
  import torch
6
6
  import torchvision
7
7
 
8
- from .transform import Transform
8
+ from rslearn.train.model_context import RasterImage
9
+
10
+ from .transform import Transform, read_selector
9
11
 
10
12
 
11
13
  class Crop(Transform):
@@ -69,7 +71,9 @@ class Crop(Transform):
69
71
  "remove_from_top": remove_from_top,
70
72
  }
71
73
 
72
- def apply_image(self, image: torch.Tensor, state: dict[str, bool]) -> torch.Tensor:
74
+ def apply_image(
75
+ self, image: RasterImage | torch.Tensor, state: dict[str, Any]
76
+ ) -> RasterImage | torch.Tensor:
73
77
  """Apply the sampled state on the specified image.
74
78
 
75
79
  Args:
@@ -80,13 +84,23 @@ class Crop(Transform):
80
84
  crop_size = state["crop_size"] * image.shape[-1] // image_shape[1]
81
85
  remove_from_left = state["remove_from_left"] * image.shape[-1] // image_shape[1]
82
86
  remove_from_top = state["remove_from_top"] * image.shape[-2] // image_shape[0]
83
- return torchvision.transforms.functional.crop(
84
- image,
85
- top=remove_from_top,
86
- left=remove_from_left,
87
- height=crop_size,
88
- width=crop_size,
89
- )
87
+ if isinstance(image, RasterImage):
88
+ image.image = torchvision.transforms.functional.crop(
89
+ image.image,
90
+ top=remove_from_top,
91
+ left=remove_from_left,
92
+ height=crop_size,
93
+ width=crop_size,
94
+ )
95
+ else:
96
+ image = torchvision.transforms.functional.crop(
97
+ image,
98
+ top=remove_from_top,
99
+ left=remove_from_left,
100
+ height=crop_size,
101
+ width=crop_size,
102
+ )
103
+ return image
90
104
 
91
105
  def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
92
106
  """Apply the sampled state on the specified image.
@@ -97,7 +111,9 @@ class Crop(Transform):
97
111
  """
98
112
  raise NotImplementedError
99
113
 
100
- def forward(self, input_dict, target_dict):
114
+ def forward(
115
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
116
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
101
117
  """Apply transform over the inputs and targets.
102
118
 
103
119
  Args:
@@ -109,13 +125,15 @@ class Crop(Transform):
109
125
  """
110
126
  smallest_image_shape = None
111
127
  for selector in self.image_selectors:
112
- image = self.read_selector(input_dict, target_dict, selector)
128
+ image = read_selector(input_dict, target_dict, selector)
113
129
  if (
114
130
  smallest_image_shape is None
115
131
  or image.shape[-1] < smallest_image_shape[1]
116
132
  ):
117
133
  smallest_image_shape = image.shape[-2:]
118
134
 
135
+ if smallest_image_shape is None:
136
+ raise ValueError("No image found to crop")
119
137
  state = self.sample_state(smallest_image_shape)
120
138
 
121
139
  self.apply_fn(
@@ -1,7 +1,11 @@
1
1
  """Flip transform."""
2
2
 
3
+ from typing import Any
4
+
3
5
  import torch
4
6
 
7
+ from rslearn.train.model_context import RasterImage
8
+
5
9
  from .transform import Transform
6
10
 
7
11
 
@@ -46,17 +50,23 @@ class Flip(Transform):
46
50
  "vertical": vertical,
47
51
  }
48
52
 
49
- def apply_image(self, image: torch.Tensor, state: dict[str, bool]) -> torch.Tensor:
53
+ def apply_image(self, image: RasterImage, state: dict[str, bool]) -> RasterImage:
50
54
  """Apply the sampled state on the specified image.
51
55
 
52
56
  Args:
53
57
  image: the image to transform.
54
58
  state: the sampled state.
55
59
  """
56
- if state["horizontal"]:
57
- image = torch.flip(image, dims=[-1])
58
- if state["vertical"]:
59
- image = torch.flip(image, dims=[-2])
60
+ if isinstance(image, RasterImage):
61
+ if state["horizontal"]:
62
+ image.image = torch.flip(image.image, dims=[-1])
63
+ if state["vertical"]:
64
+ image.image = torch.flip(image.image, dims=[-2])
65
+ elif isinstance(image, torch.Tensor):
66
+ if state["horizontal"]:
67
+ image = torch.flip(image, dims=[-1])
68
+ if state["vertical"]:
69
+ image = torch.flip(image, dims=[-2])
60
70
  return image
61
71
 
62
72
  def apply_boxes(
@@ -90,7 +100,9 @@ class Flip(Transform):
90
100
  )
91
101
  return boxes
92
102
 
93
- def forward(self, input_dict, target_dict):
103
+ def forward(
104
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
105
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
94
106
  """Apply transform over the inputs and targets.
95
107
 
96
108
  Args:
@@ -0,0 +1,78 @@
1
+ """Mask transform."""
2
+
3
+ import torch
4
+
5
+ from rslearn.train.model_context import RasterImage
6
+ from rslearn.train.transforms.transform import Transform, read_selector
7
+
8
+
9
+ class Mask(Transform):
10
+ """Apply a mask to one or more images.
11
+
12
+ This uses one (mask) image input to mask another (target) image input. The value of
13
+ the target image is set to the mask value everywhere where the mask image is 0.
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ selectors: list[str] = ["image"],
19
+ mask_selector: str = "mask",
20
+ mask_value: int = 0,
21
+ ):
22
+ """Initialize a new Mask.
23
+
24
+ Args:
25
+ selectors: images to mask.
26
+ mask_selector: the selector for the mask image to apply.
27
+ mask_value: set each image in selectors to this value where the image
28
+ corresponding to the mask_selector is 0.
29
+ """
30
+ super().__init__()
31
+ self.selectors = selectors
32
+ self.mask_selector = mask_selector
33
+ self.mask_value = mask_value
34
+
35
+ def apply_image(
36
+ self, image: torch.Tensor | RasterImage, mask: torch.Tensor | RasterImage
37
+ ) -> torch.Tensor | RasterImage:
38
+ """Apply the mask on the image.
39
+
40
+ Args:
41
+ image: the image
42
+ mask: the mask
43
+
44
+ Returns:
45
+ masked image
46
+ """
47
+ # Tile the mask to have same number of bands as the image.
48
+ if isinstance(mask, RasterImage):
49
+ mask = mask.image
50
+
51
+ if image.shape[0] != mask.shape[0]:
52
+ if mask.shape[0] != 1:
53
+ raise ValueError(
54
+ "expected mask to either have same bands as image, or one band"
55
+ )
56
+ mask = mask.repeat(image.shape[0], 1, 1)
57
+
58
+ if isinstance(image, torch.Tensor):
59
+ image[mask == 0] = self.mask_value
60
+ else:
61
+ image.image[mask == 0] = self.mask_value
62
+ return image
63
+
64
+ def forward(self, input_dict: dict, target_dict: dict) -> tuple[dict, dict]:
65
+ """Apply mask.
66
+
67
+ Args:
68
+ input_dict: the input
69
+ target_dict: the target
70
+
71
+ Returns:
72
+ normalized (input_dicts, target_dicts) tuple
73
+ """
74
+ mask = read_selector(input_dict, target_dict, self.mask_selector)
75
+ self.apply_fn(
76
+ self.apply_image, input_dict, target_dict, self.selectors, mask=mask
77
+ )
78
+ return input_dict, target_dict
@@ -1,7 +1,11 @@
1
1
  """Normalization transforms."""
2
2
 
3
+ from typing import Any
4
+
3
5
  import torch
4
6
 
7
+ from rslearn.train.model_context import RasterImage
8
+
5
9
  from .transform import Transform
6
10
 
7
11
 
@@ -12,22 +16,31 @@ class Normalize(Transform):
12
16
  self,
13
17
  mean: float | list[float],
14
18
  std: float | list[float],
15
- valid_range: tuple[float, float]
16
- | tuple[list[float], list[float]]
17
- | None = None,
19
+ valid_range: (
20
+ tuple[float, float] | tuple[list[float], list[float]] | None
21
+ ) = None,
18
22
  selectors: list[str] = ["image"],
19
23
  bands: list[int] | None = None,
20
- ):
24
+ num_bands: int | None = None,
25
+ ) -> None:
21
26
  """Initialize a new Normalize.
22
27
 
23
28
  Result will be (input - mean) / std.
24
29
 
25
30
  Args:
26
31
  mean: a single value or one mean per channel
27
- std: a single value or one std per channel
32
+ std: a single value or one std per channel (must match the shape of mean)
28
33
  valid_range: optionally clip to a minimum and maximum value
29
34
  selectors: image items to transform
30
- bands: optionally restrict the normalization to these bands
35
+ bands: optionally restrict the normalization to these band indices. If set,
36
+ mean and std must either be one value, or have length equal to the
37
+ number of band indices passed here.
38
+ num_bands: the number of bands per image, to distinguish different images
39
+ in a time series. If set, then the bands list is repeated for each
40
+ image, e.g. if bands=[2] then we apply normalization on images[2],
41
+ images[2+num_bands], images[2+num_bands*2], etc. Or if the bands list
42
+ is not set, then we apply the mean and std on each image in the time
43
+ series.
31
44
  """
32
45
  super().__init__()
33
46
  self.mean = torch.tensor(mean)
@@ -41,27 +54,98 @@ class Normalize(Transform):
41
54
  self.valid_max = None
42
55
 
43
56
  self.selectors = selectors
44
- self.bands = bands
57
+ self.bands = torch.tensor(bands) if bands is not None else None
58
+ self.num_bands = num_bands
45
59
 
46
- def apply_image(self, image: torch.Tensor) -> torch.Tensor:
60
+ def apply_image(
61
+ self, image: torch.Tensor | RasterImage
62
+ ) -> torch.Tensor | RasterImage:
47
63
  """Normalize the specified image.
48
64
 
49
65
  Args:
50
66
  image: the image to transform.
51
67
  """
52
- if self.bands:
53
- image[self.bands] = (image[self.bands] - self.mean) / self.std
54
- if self.valid_min is not None:
55
- image[self.bands] = torch.clamp(
56
- image[self.bands], min=self.valid_min, max=self.valid_max
68
+
69
+ def _repeat_mean_and_std(
70
+ image_channels: int, num_bands: int | None, is_raster_image: bool
71
+ ) -> tuple[torch.Tensor, torch.Tensor]:
72
+ """Get mean and std tensor that are suitable for applying on the image."""
73
+ # We only need to repeat the tensor if both of these are true:
74
+ # - The mean/std are not just one scalar.
75
+ # - self.num_bands is set, otherwise we treat the input as a single image.
76
+ if len(self.mean.shape) == 0:
77
+ return self.mean, self.std
78
+ if num_bands is None:
79
+ return self.mean, self.std
80
+ num_images = image_channels // num_bands
81
+ if is_raster_image:
82
+ # add an extra T dimension, CTHW
83
+ return self.mean.repeat(num_images)[
84
+ :, None, None, None
85
+ ], self.std.repeat(num_images)[:, None, None, None]
86
+ else:
87
+ # add an extra T dimension, CTHW
88
+ return self.mean.repeat(num_images)[:, None, None], self.std.repeat(
89
+ num_images
90
+ )[:, None, None]
91
+
92
+ if self.bands is not None:
93
+ # User has provided band indices to normalize.
94
+ # If num_bands is set, then we repeat these for each image in the input
95
+ # image time series.
96
+ band_indices = self.bands
97
+ if self.num_bands:
98
+ num_images = image.shape[0] // self.num_bands
99
+ band_indices = torch.cat(
100
+ [
101
+ band_indices + image_idx * self.num_bands
102
+ for image_idx in range(num_images)
103
+ ],
104
+ dim=0,
57
105
  )
106
+
107
+ # We use len(self.bands) here because that is how many bands per timestep
108
+ # we are actually processing with the mean/std.
109
+ mean, std = _repeat_mean_and_std(
110
+ image_channels=len(band_indices),
111
+ num_bands=len(self.bands),
112
+ is_raster_image=isinstance(image, RasterImage),
113
+ )
114
+ if isinstance(image, torch.Tensor):
115
+ image[band_indices] = (image[band_indices] - mean) / std
116
+ if self.valid_min is not None:
117
+ image[band_indices] = torch.clamp(
118
+ image[band_indices], min=self.valid_min, max=self.valid_max
119
+ )
120
+ else:
121
+ image.image[band_indices] = (image.image[band_indices] - mean) / std
122
+ if self.valid_min is not None:
123
+ image.image[band_indices] = torch.clamp(
124
+ image.image[band_indices],
125
+ min=self.valid_min,
126
+ max=self.valid_max,
127
+ )
58
128
  else:
59
- image = (image - self.mean) / self.std
60
- if self.valid_min is not None:
61
- image = torch.clamp(image, min=self.valid_min, max=self.valid_max)
129
+ mean, std = _repeat_mean_and_std(
130
+ image_channels=image.shape[0],
131
+ num_bands=self.num_bands,
132
+ is_raster_image=isinstance(image, RasterImage),
133
+ )
134
+ if isinstance(image, torch.Tensor):
135
+ image = (image - mean) / std
136
+ if self.valid_min is not None:
137
+ image = torch.clamp(image, min=self.valid_min, max=self.valid_max)
138
+ else:
139
+ image.image = (image.image - mean) / std
140
+ if self.valid_min is not None:
141
+ image.image = torch.clamp(
142
+ image.image, min=self.valid_min, max=self.valid_max
143
+ )
62
144
  return image
63
145
 
64
- def forward(self, input_dict, target_dict):
146
+ def forward(
147
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
148
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
65
149
  """Apply normalization over the inputs and targets.
66
150
 
67
151
  Args:
@@ -5,6 +5,8 @@ from typing import Any
5
5
  import torch
6
6
  import torchvision
7
7
 
8
+ from rslearn.train.model_context import RasterImage
9
+
8
10
  from .transform import Transform
9
11
 
10
12
 
@@ -25,8 +27,8 @@ class Pad(Transform):
25
27
  Args:
26
28
  size: the size to pad to, or a min/max range of pad sizes. If the image is
27
29
  larger than this size, then it is cropped instead.
28
- mode: "center" (default) to apply padding equally on all sides, or
29
- "topleft" to only apply it on the bottom and right.
30
+ mode: "topleft" (default) to only apply padding on the bottom and right
31
+ sides, or "center" to apply padding equally on all sides.
30
32
  image_selectors: image items to transform.
31
33
  box_selectors: boxes items to transform.
32
34
  """
@@ -48,7 +50,9 @@ class Pad(Transform):
48
50
  """
49
51
  return {"size": torch.randint(low=self.size[0], high=self.size[1], size=())}
50
52
 
51
- def apply_image(self, image: torch.Tensor, state: dict[str, bool]) -> torch.Tensor:
53
+ def apply_image(
54
+ self, image: RasterImage | torch.Tensor, state: dict[str, bool]
55
+ ) -> RasterImage | torch.Tensor:
52
56
  """Apply the sampled state on the specified image.
53
57
 
54
58
  Args:
@@ -64,11 +68,11 @@ class Pad(Transform):
64
68
  ) -> torch.Tensor:
65
69
  # Before/after must either be both non-negative or both negative.
66
70
  # >=0 indicates padding while <0 indicates cropping.
67
- assert (before < 0 and after < 0) or (before >= 0 and after >= 0)
71
+ assert (before < 0 and after <= 0) or (before >= 0 and after >= 0)
68
72
  if before > 0:
69
73
  # Padding.
70
74
  if horizontal:
71
- padding_tuple = (before, after)
75
+ padding_tuple: tuple = (before, after)
72
76
  else:
73
77
  padding_tuple = (before, after, 0, 0)
74
78
  return torch.nn.functional.pad(im, padding_tuple)
@@ -101,8 +105,16 @@ class Pad(Transform):
101
105
  horizontal_pad = (horizontal_half, horizontal_extra - horizontal_half)
102
106
  vertical_pad = (vertical_half, vertical_extra - vertical_half)
103
107
 
104
- image = apply_padding(image, True, horizontal_pad[0], horizontal_pad[1])
105
- image = apply_padding(image, False, vertical_pad[0], vertical_pad[1])
108
+ if isinstance(image, RasterImage):
109
+ image.image = apply_padding(
110
+ image.image, True, horizontal_pad[0], horizontal_pad[1]
111
+ )
112
+ image.image = apply_padding(
113
+ image.image, False, vertical_pad[0], vertical_pad[1]
114
+ )
115
+ else:
116
+ image = apply_padding(image, True, horizontal_pad[0], horizontal_pad[1])
117
+ image = apply_padding(image, False, vertical_pad[0], vertical_pad[1])
106
118
  return image
107
119
 
108
120
  def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
@@ -0,0 +1,83 @@
1
+ """Resize transform."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+ import torchvision
7
+ from torchvision.transforms import InterpolationMode
8
+
9
+ from rslearn.train.model_context import RasterImage
10
+
11
+ from .transform import Transform
12
+
13
+ INTERPOLATION_MODES = {
14
+ "nearest": InterpolationMode.NEAREST,
15
+ "nearest_exact": InterpolationMode.NEAREST_EXACT,
16
+ "bilinear": InterpolationMode.BILINEAR,
17
+ "bicubic": InterpolationMode.BICUBIC,
18
+ }
19
+
20
+
21
+ class Resize(Transform):
22
+ """Resizes inputs to a target size."""
23
+
24
+ def __init__(
25
+ self,
26
+ target_size: tuple[int, int],
27
+ selectors: list[str] = [],
28
+ interpolation: str = "nearest",
29
+ ):
30
+ """Initialize a resize transform.
31
+
32
+ Args:
33
+ target_size: the (height, width) to resize to.
34
+ selectors: items to transform.
35
+ interpolation: the interpolation mode to use for resizing.
36
+ Must be one of "nearest", "nearest_exact", "bilinear", or "bicubic".
37
+ """
38
+ super().__init__()
39
+ self.target_size = target_size
40
+ self.selectors = selectors
41
+ self.interpolation = INTERPOLATION_MODES[interpolation]
42
+
43
+ def apply_resize(
44
+ self, image: torch.Tensor | RasterImage
45
+ ) -> torch.Tensor | RasterImage:
46
+ """Apply resizing on the specified image.
47
+
48
+ If the image is 2D, it is unsqueezed to 3D and then squeezed
49
+ back after resizing.
50
+
51
+ Args:
52
+ image: the image to transform.
53
+ """
54
+ if isinstance(image, torch.Tensor):
55
+ if image.dim() == 2:
56
+ image = image.unsqueeze(0) # (H, W) -> (1, H, W)
57
+ result = torchvision.transforms.functional.resize(
58
+ image, self.target_size, self.interpolation
59
+ )
60
+ return result.squeeze(0) # (1, H, W) -> (H, W)
61
+ return torchvision.transforms.functional.resize(
62
+ image, self.target_size, self.interpolation
63
+ )
64
+ else:
65
+ image.image = torchvision.transforms.functional.resize(
66
+ image.image, self.target_size, self.interpolation
67
+ )
68
+ return image
69
+
70
+ def forward(
71
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
72
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
73
+ """Apply transform over the inputs and targets.
74
+
75
+ Args:
76
+ input_dict: the input
77
+ target_dict: the target
78
+
79
+ Returns:
80
+ transformed (input_dicts, target_dicts) tuple
81
+ """
82
+ self.apply_fn(self.apply_resize, input_dict, target_dict, self.selectors)
83
+ return input_dict, target_dict