rslearn 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. rslearn/config/dataset.py +30 -23
  2. rslearn/data_sources/__init__.py +2 -0
  3. rslearn/data_sources/aws_landsat.py +44 -161
  4. rslearn/data_sources/aws_open_data.py +2 -4
  5. rslearn/data_sources/aws_sentinel1.py +1 -3
  6. rslearn/data_sources/aws_sentinel2_element84.py +54 -165
  7. rslearn/data_sources/climate_data_store.py +1 -3
  8. rslearn/data_sources/copernicus.py +1 -2
  9. rslearn/data_sources/data_source.py +1 -1
  10. rslearn/data_sources/direct_materialize_data_source.py +336 -0
  11. rslearn/data_sources/earthdaily.py +52 -155
  12. rslearn/data_sources/earthdatahub.py +425 -0
  13. rslearn/data_sources/eurocrops.py +1 -2
  14. rslearn/data_sources/gcp_public_data.py +1 -2
  15. rslearn/data_sources/google_earth_engine.py +1 -2
  16. rslearn/data_sources/hf_srtm.py +595 -0
  17. rslearn/data_sources/local_files.py +3 -3
  18. rslearn/data_sources/openstreetmap.py +1 -1
  19. rslearn/data_sources/planet.py +1 -2
  20. rslearn/data_sources/planet_basemap.py +1 -2
  21. rslearn/data_sources/planetary_computer.py +183 -186
  22. rslearn/data_sources/soilgrids.py +3 -3
  23. rslearn/data_sources/stac.py +1 -2
  24. rslearn/data_sources/usda_cdl.py +1 -3
  25. rslearn/data_sources/usgs_landsat.py +7 -254
  26. rslearn/data_sources/utils.py +204 -64
  27. rslearn/data_sources/worldcereal.py +1 -1
  28. rslearn/data_sources/worldcover.py +1 -1
  29. rslearn/data_sources/worldpop.py +1 -1
  30. rslearn/data_sources/xyz_tiles.py +5 -9
  31. rslearn/dataset/materialize.py +5 -1
  32. rslearn/models/clay/clay.py +3 -3
  33. rslearn/models/concatenate_features.py +6 -1
  34. rslearn/models/detr/detr.py +4 -1
  35. rslearn/models/dinov3.py +0 -1
  36. rslearn/models/olmoearth_pretrain/model.py +3 -1
  37. rslearn/models/pooling_decoder.py +1 -1
  38. rslearn/models/prithvi.py +0 -1
  39. rslearn/models/simple_time_series.py +97 -35
  40. rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
  41. rslearn/train/data_module.py +32 -27
  42. rslearn/train/dataset.py +260 -117
  43. rslearn/train/dataset_index.py +156 -0
  44. rslearn/train/lightning_module.py +1 -1
  45. rslearn/train/model_context.py +19 -3
  46. rslearn/train/prediction_writer.py +69 -41
  47. rslearn/train/tasks/classification.py +1 -1
  48. rslearn/train/tasks/detection.py +5 -5
  49. rslearn/train/tasks/per_pixel_regression.py +13 -13
  50. rslearn/train/tasks/regression.py +1 -1
  51. rslearn/train/tasks/segmentation.py +26 -13
  52. rslearn/train/transforms/concatenate.py +17 -27
  53. rslearn/train/transforms/crop.py +8 -19
  54. rslearn/train/transforms/flip.py +4 -10
  55. rslearn/train/transforms/mask.py +9 -15
  56. rslearn/train/transforms/normalize.py +31 -82
  57. rslearn/train/transforms/pad.py +7 -13
  58. rslearn/train/transforms/resize.py +5 -22
  59. rslearn/train/transforms/select_bands.py +16 -36
  60. rslearn/train/transforms/sentinel1.py +4 -16
  61. rslearn/utils/__init__.py +2 -0
  62. rslearn/utils/geometry.py +21 -0
  63. rslearn/utils/m2m_api.py +251 -0
  64. rslearn/utils/retry_session.py +43 -0
  65. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
  66. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/RECORD +71 -66
  67. rslearn/data_sources/earthdata_srtm.py +0 -282
  68. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
  69. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
  70. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
  71. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
  72. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,5 @@
1
1
  """Mask transform."""
2
2
 
3
- import torch
4
-
5
3
  from rslearn.train.model_context import RasterImage
6
4
  from rslearn.train.transforms.transform import Transform, read_selector
7
5
 
@@ -32,9 +30,7 @@ class Mask(Transform):
32
30
  self.mask_selector = mask_selector
33
31
  self.mask_value = mask_value
34
32
 
35
- def apply_image(
36
- self, image: torch.Tensor | RasterImage, mask: torch.Tensor | RasterImage
37
- ) -> torch.Tensor | RasterImage:
33
+ def apply_image(self, image: RasterImage, mask: RasterImage) -> RasterImage:
38
34
  """Apply the mask on the image.
39
35
 
40
36
  Args:
@@ -44,21 +40,19 @@ class Mask(Transform):
44
40
  Returns:
45
41
  masked image
46
42
  """
47
- # Tile the mask to have same number of bands as the image.
48
- if isinstance(mask, RasterImage):
49
- mask = mask.image
43
+ # Extract the mask tensor (CTHW format)
44
+ mask_tensor = mask.image
50
45
 
51
- if image.shape[0] != mask.shape[0]:
52
- if mask.shape[0] != 1:
46
+ # Tile the mask to have same number of bands (C dimension) as the image.
47
+ if image.shape[0] != mask_tensor.shape[0]:
48
+ if mask_tensor.shape[0] != 1:
53
49
  raise ValueError(
54
50
  "expected mask to either have same bands as image, or one band"
55
51
  )
56
- mask = mask.repeat(image.shape[0], 1, 1)
52
+ # Repeat along C dimension, keep T, H, W the same
53
+ mask_tensor = mask_tensor.repeat(image.shape[0], 1, 1, 1)
57
54
 
58
- if isinstance(image, torch.Tensor):
59
- image[mask == 0] = self.mask_value
60
- else:
61
- image.image[mask == 0] = self.mask_value
55
+ image.image[mask_tensor == 0] = self.mask_value
62
56
  return image
63
57
 
64
58
  def forward(self, input_dict: dict, target_dict: dict) -> tuple[dict, dict]:
@@ -1,5 +1,6 @@
1
1
  """Normalization transforms."""
2
2
 
3
+ import warnings
3
4
  from typing import Any
4
5
 
5
6
  import torch
@@ -35,14 +36,17 @@ class Normalize(Transform):
35
36
  bands: optionally restrict the normalization to these band indices. If set,
36
37
  mean and std must either be one value, or have length equal to the
37
38
  number of band indices passed here.
38
- num_bands: the number of bands per image, to distinguish different images
39
- in a time series. If set, then the bands list is repeated for each
40
- image, e.g. if bands=[2] then we apply normalization on images[2],
41
- images[2+num_bands], images[2+num_bands*2], etc. Or if the bands list
42
- is not set, then we apply the mean and std on each image in the time
43
- series.
39
+ num_bands: deprecated, no longer used. Will be removed after 2026-04-01.
44
40
  """
45
41
  super().__init__()
42
+
43
+ if num_bands is not None:
44
+ warnings.warn(
45
+ "num_bands is deprecated and no longer used. "
46
+ "It will be removed after 2026-04-01.",
47
+ FutureWarning,
48
+ )
49
+
46
50
  self.mean = torch.tensor(mean)
47
51
  self.std = torch.tensor(std)
48
52
 
@@ -55,92 +59,37 @@ class Normalize(Transform):
55
59
 
56
60
  self.selectors = selectors
57
61
  self.bands = torch.tensor(bands) if bands is not None else None
58
- self.num_bands = num_bands
59
62
 
60
- def apply_image(
61
- self, image: torch.Tensor | RasterImage
62
- ) -> torch.Tensor | RasterImage:
63
+ def apply_image(self, image: RasterImage) -> RasterImage:
63
64
  """Normalize the specified image.
64
65
 
65
66
  Args:
66
67
  image: the image to transform.
67
68
  """
68
-
69
- def _repeat_mean_and_std(
70
- image_channels: int, num_bands: int | None, is_raster_image: bool
71
- ) -> tuple[torch.Tensor, torch.Tensor]:
72
- """Get mean and std tensor that are suitable for applying on the image."""
73
- # We only need to repeat the tensor if both of these are true:
74
- # - The mean/std are not just one scalar.
75
- # - self.num_bands is set, otherwise we treat the input as a single image.
76
- if len(self.mean.shape) == 0:
77
- return self.mean, self.std
78
- if num_bands is None:
79
- return self.mean, self.std
80
- num_images = image_channels // num_bands
81
- if is_raster_image:
82
- # add an extra T dimension, CTHW
83
- return self.mean.repeat(num_images)[
84
- :, None, None, None
85
- ], self.std.repeat(num_images)[:, None, None, None]
86
- else:
87
- # add an extra T dimension, CTHW
88
- return self.mean.repeat(num_images)[:, None, None], self.std.repeat(
89
- num_images
90
- )[:, None, None]
69
+ # Get mean/std with singleton dims for broadcasting over CTHW.
70
+ if len(self.mean.shape) == 0:
71
+ # Scalar - broadcasts naturally.
72
+ mean, std = self.mean, self.std
73
+ else:
74
+ # Vector of length C - add singleton dims for T, H, W.
75
+ mean = self.mean[:, None, None, None]
76
+ std = self.std[:, None, None, None]
91
77
 
92
78
  if self.bands is not None:
93
- # User has provided band indices to normalize.
94
- # If num_bands is set, then we repeat these for each image in the input
95
- # image time series.
96
- band_indices = self.bands
97
- if self.num_bands:
98
- num_images = image.shape[0] // self.num_bands
99
- band_indices = torch.cat(
100
- [
101
- band_indices + image_idx * self.num_bands
102
- for image_idx in range(num_images)
103
- ],
104
- dim=0,
79
+ # Normalize only specific band indices.
80
+ image.image[self.bands] = (image.image[self.bands] - mean) / std
81
+ if self.valid_min is not None:
82
+ image.image[self.bands] = torch.clamp(
83
+ image.image[self.bands],
84
+ min=self.valid_min,
85
+ max=self.valid_max,
105
86
  )
106
-
107
- # We use len(self.bands) here because that is how many bands per timestep
108
- # we are actually processing with the mean/std.
109
- mean, std = _repeat_mean_and_std(
110
- image_channels=len(band_indices),
111
- num_bands=len(self.bands),
112
- is_raster_image=isinstance(image, RasterImage),
113
- )
114
- if isinstance(image, torch.Tensor):
115
- image[band_indices] = (image[band_indices] - mean) / std
116
- if self.valid_min is not None:
117
- image[band_indices] = torch.clamp(
118
- image[band_indices], min=self.valid_min, max=self.valid_max
119
- )
120
- else:
121
- image.image[band_indices] = (image.image[band_indices] - mean) / std
122
- if self.valid_min is not None:
123
- image.image[band_indices] = torch.clamp(
124
- image.image[band_indices],
125
- min=self.valid_min,
126
- max=self.valid_max,
127
- )
128
87
  else:
129
- mean, std = _repeat_mean_and_std(
130
- image_channels=image.shape[0],
131
- num_bands=self.num_bands,
132
- is_raster_image=isinstance(image, RasterImage),
133
- )
134
- if isinstance(image, torch.Tensor):
135
- image = (image - mean) / std
136
- if self.valid_min is not None:
137
- image = torch.clamp(image, min=self.valid_min, max=self.valid_max)
138
- else:
139
- image.image = (image.image - mean) / std
140
- if self.valid_min is not None:
141
- image.image = torch.clamp(
142
- image.image, min=self.valid_min, max=self.valid_max
143
- )
88
+ image.image = (image.image - mean) / std
89
+ if self.valid_min is not None:
90
+ image.image = torch.clamp(
91
+ image.image, min=self.valid_min, max=self.valid_max
92
+ )
144
93
  return image
145
94
 
146
95
  def forward(
@@ -50,9 +50,7 @@ class Pad(Transform):
50
50
  """
51
51
  return {"size": torch.randint(low=self.size[0], high=self.size[1], size=())}
52
52
 
53
- def apply_image(
54
- self, image: RasterImage | torch.Tensor, state: dict[str, bool]
55
- ) -> RasterImage | torch.Tensor:
53
+ def apply_image(self, image: RasterImage, state: dict[str, bool]) -> RasterImage:
56
54
  """Apply the sampled state on the specified image.
57
55
 
58
56
  Args:
@@ -105,16 +103,12 @@ class Pad(Transform):
105
103
  horizontal_pad = (horizontal_half, horizontal_extra - horizontal_half)
106
104
  vertical_pad = (vertical_half, vertical_extra - vertical_half)
107
105
 
108
- if isinstance(image, RasterImage):
109
- image.image = apply_padding(
110
- image.image, True, horizontal_pad[0], horizontal_pad[1]
111
- )
112
- image.image = apply_padding(
113
- image.image, False, vertical_pad[0], vertical_pad[1]
114
- )
115
- else:
116
- image = apply_padding(image, True, horizontal_pad[0], horizontal_pad[1])
117
- image = apply_padding(image, False, vertical_pad[0], vertical_pad[1])
106
+ image.image = apply_padding(
107
+ image.image, True, horizontal_pad[0], horizontal_pad[1]
108
+ )
109
+ image.image = apply_padding(
110
+ image.image, False, vertical_pad[0], vertical_pad[1]
111
+ )
118
112
  return image
119
113
 
120
114
  def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
@@ -2,7 +2,6 @@
2
2
 
3
3
  from typing import Any
4
4
 
5
- import torch
6
5
  import torchvision
7
6
  from torchvision.transforms import InterpolationMode
8
7
 
@@ -40,32 +39,16 @@ class Resize(Transform):
40
39
  self.selectors = selectors
41
40
  self.interpolation = INTERPOLATION_MODES[interpolation]
42
41
 
43
- def apply_resize(
44
- self, image: torch.Tensor | RasterImage
45
- ) -> torch.Tensor | RasterImage:
42
+ def apply_resize(self, image: RasterImage) -> RasterImage:
46
43
  """Apply resizing on the specified image.
47
44
 
48
- If the image is 2D, it is unsqueezed to 3D and then squeezed
49
- back after resizing.
50
-
51
45
  Args:
52
46
  image: the image to transform.
53
47
  """
54
- if isinstance(image, torch.Tensor):
55
- if image.dim() == 2:
56
- image = image.unsqueeze(0) # (H, W) -> (1, H, W)
57
- result = torchvision.transforms.functional.resize(
58
- image, self.target_size, self.interpolation
59
- )
60
- return result.squeeze(0) # (1, H, W) -> (H, W)
61
- return torchvision.transforms.functional.resize(
62
- image, self.target_size, self.interpolation
63
- )
64
- else:
65
- image.image = torchvision.transforms.functional.resize(
66
- image.image, self.target_size, self.interpolation
67
- )
68
- return image
48
+ image.image = torchvision.transforms.functional.resize(
49
+ image.image, self.target_size, self.interpolation
50
+ )
51
+ return image
69
52
 
70
53
  def forward(
71
54
  self, input_dict: dict[str, Any], target_dict: dict[str, Any]
@@ -1,9 +1,8 @@
1
1
  """The SelectBands transform."""
2
2
 
3
+ import warnings
3
4
  from typing import Any
4
5
 
5
- from rslearn.train.model_context import RasterImage
6
-
7
6
  from .transform import Transform, read_selector, write_selector
8
7
 
9
8
 
@@ -17,60 +16,41 @@ class SelectBands(Transform):
17
16
  output_selector: str = "image",
18
17
  num_bands_per_timestep: int | None = None,
19
18
  ):
20
- """Initialize a new Concatenate.
19
+ """Initialize a new SelectBands.
21
20
 
22
21
  Args:
23
- band_indices: the bands to select.
22
+ band_indices: the bands to select from the channel dimension.
24
23
  input_selector: the selector to read the input image.
25
24
  output_selector: the output selector under which to save the output image.
26
- num_bands_per_timestep: the number of bands per image, to distinguish
27
- between stacked images in an image time series. If set, then the
28
- band_indices are selected for each image in the time series.
25
+ num_bands_per_timestep: deprecated, no longer used. Will be removed after
26
+ 2026-04-01.
29
27
  """
30
28
  super().__init__()
29
+
30
+ if num_bands_per_timestep is not None:
31
+ warnings.warn(
32
+ "num_bands_per_timestep is deprecated and no longer used. "
33
+ "It will be removed after 2026-04-01.",
34
+ FutureWarning,
35
+ )
36
+
31
37
  self.input_selector = input_selector
32
38
  self.output_selector = output_selector
33
39
  self.band_indices = band_indices
34
- self.num_bands_per_timestep = num_bands_per_timestep
35
40
 
36
41
  def forward(
37
42
  self, input_dict: dict[str, Any], target_dict: dict[str, Any]
38
43
  ) -> tuple[dict[str, Any], dict[str, Any]]:
39
- """Apply concatenation over the inputs and targets.
44
+ """Apply band selection over the inputs and targets.
40
45
 
41
46
  Args:
42
47
  input_dict: the input
43
48
  target_dict: the target
44
49
 
45
50
  Returns:
46
- normalized (input_dicts, target_dicts) tuple
51
+ (input_dicts, target_dicts) tuple with selected bands
47
52
  """
48
53
  image = read_selector(input_dict, target_dict, self.input_selector)
49
- num_bands_per_timestep = (
50
- self.num_bands_per_timestep
51
- if self.num_bands_per_timestep is not None
52
- else image.shape[0]
53
- )
54
- if isinstance(image, RasterImage):
55
- assert num_bands_per_timestep == image.shape[0], (
56
- "Expect a seperate dimension for timesteps in RasterImages."
57
- )
58
-
59
- if image.shape[0] % num_bands_per_timestep != 0:
60
- raise ValueError(
61
- f"channel dimension {image.shape[0]} is not multiple of bands per timestep {num_bands_per_timestep}"
62
- )
63
-
64
- # Copy the band indices for each timestep in the input.
65
- wanted_bands: list[int] = []
66
- for start_channel_idx in range(0, image.shape[0], num_bands_per_timestep):
67
- wanted_bands.extend(
68
- [(start_channel_idx + band_idx) for band_idx in self.band_indices]
69
- )
70
-
71
- if isinstance(image, RasterImage):
72
- image.image = image.image[wanted_bands]
73
- else:
74
- image = image[wanted_bands]
54
+ image.image = image.image[self.band_indices]
75
55
  write_selector(input_dict, target_dict, self.output_selector, image)
76
56
  return input_dict, target_dict
@@ -33,31 +33,19 @@ class Sentinel1ToDecibels(Transform):
33
33
  self.from_decibels = from_decibels
34
34
  self.epsilon = epsilon
35
35
 
36
- def apply_image(
37
- self, image: torch.Tensor | RasterImage
38
- ) -> torch.Tensor | RasterImage:
36
+ def apply_image(self, image: RasterImage) -> RasterImage:
39
37
  """Normalize the specified image.
40
38
 
41
39
  Args:
42
40
  image: the image to transform.
43
41
  """
44
- if isinstance(image, torch.Tensor):
45
- image_to_process = image
46
- else:
47
- image_to_process = image.image
48
42
  if self.from_decibels:
49
43
  # Decibels to linear scale.
50
- image_to_process = torch.pow(10.0, image_to_process / 10.0)
44
+ image.image = torch.pow(10.0, image.image / 10.0)
51
45
  else:
52
46
  # Linear scale to decibels.
53
- image_to_process = 10 * torch.log10(
54
- torch.clamp(image_to_process, min=self.epsilon)
55
- )
56
- if isinstance(image, torch.Tensor):
57
- return image_to_process
58
- else:
59
- image.image = image_to_process
60
- return image
47
+ image.image = 10 * torch.log10(torch.clamp(image.image, min=self.epsilon))
48
+ return image
61
49
 
62
50
  def forward(
63
51
  self, input_dict: dict[str, Any], target_dict: dict[str, Any]
rslearn/utils/__init__.py CHANGED
@@ -7,6 +7,7 @@ from .geometry import (
7
7
  PixelBounds,
8
8
  Projection,
9
9
  STGeometry,
10
+ get_global_raster_bounds,
10
11
  is_same_resolution,
11
12
  shp_intersects,
12
13
  )
@@ -23,6 +24,7 @@ __all__ = (
23
24
  "Projection",
24
25
  "STGeometry",
25
26
  "daterange",
27
+ "get_global_raster_bounds",
26
28
  "get_utm_ups_crs",
27
29
  "is_same_resolution",
28
30
  "logger",
rslearn/utils/geometry.py CHANGED
@@ -116,6 +116,27 @@ class Projection:
116
116
  WGS84_PROJECTION = Projection(CRS.from_epsg(WGS84_EPSG), 1, 1)
117
117
 
118
118
 
119
+ def get_global_raster_bounds(projection: Projection) -> PixelBounds:
120
+ """Get very large pixel bounds for a global raster in the given projection.
121
+
122
+ This is useful for data sources that cover the entire world and don't want to
123
+ compute exact bounds in arbitrary projections (which can fail for projections
124
+ like UTM that only cover part of the world).
125
+
126
+ Args:
127
+ projection: the projection to get bounds in.
128
+
129
+ Returns:
130
+ Pixel bounds that will intersect with any reasonable window. We assume that the
131
+ absolute value of CRS coordinates is at most 2^32, and adjust it based on the
132
+ resolution in the Projection in case very fine-grained resolutions are used.
133
+ """
134
+ crs_bound = 2**32
135
+ pixel_bound_x = int(crs_bound / abs(projection.x_resolution))
136
+ pixel_bound_y = int(crs_bound / abs(projection.y_resolution))
137
+ return (-pixel_bound_x, -pixel_bound_y, pixel_bound_x, pixel_bound_y)
138
+
139
+
119
140
  class ResolutionFactor:
120
141
  """Multiplier for the resolution in a Projection.
121
142