rslearn 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. rslearn/config/dataset.py +30 -23
  2. rslearn/data_sources/__init__.py +2 -0
  3. rslearn/data_sources/aws_landsat.py +44 -161
  4. rslearn/data_sources/aws_open_data.py +2 -4
  5. rslearn/data_sources/aws_sentinel1.py +1 -3
  6. rslearn/data_sources/aws_sentinel2_element84.py +54 -165
  7. rslearn/data_sources/climate_data_store.py +1 -3
  8. rslearn/data_sources/copernicus.py +1 -2
  9. rslearn/data_sources/data_source.py +1 -1
  10. rslearn/data_sources/direct_materialize_data_source.py +336 -0
  11. rslearn/data_sources/earthdaily.py +52 -155
  12. rslearn/data_sources/earthdatahub.py +425 -0
  13. rslearn/data_sources/eurocrops.py +1 -2
  14. rslearn/data_sources/gcp_public_data.py +1 -2
  15. rslearn/data_sources/google_earth_engine.py +1 -2
  16. rslearn/data_sources/hf_srtm.py +595 -0
  17. rslearn/data_sources/local_files.py +3 -3
  18. rslearn/data_sources/openstreetmap.py +1 -1
  19. rslearn/data_sources/planet.py +1 -2
  20. rslearn/data_sources/planet_basemap.py +1 -2
  21. rslearn/data_sources/planetary_computer.py +183 -186
  22. rslearn/data_sources/soilgrids.py +3 -3
  23. rslearn/data_sources/stac.py +1 -2
  24. rslearn/data_sources/usda_cdl.py +1 -3
  25. rslearn/data_sources/usgs_landsat.py +7 -254
  26. rslearn/data_sources/utils.py +204 -64
  27. rslearn/data_sources/worldcereal.py +1 -1
  28. rslearn/data_sources/worldcover.py +1 -1
  29. rslearn/data_sources/worldpop.py +1 -1
  30. rslearn/data_sources/xyz_tiles.py +5 -9
  31. rslearn/dataset/materialize.py +5 -1
  32. rslearn/models/clay/clay.py +3 -3
  33. rslearn/models/concatenate_features.py +6 -1
  34. rslearn/models/detr/detr.py +4 -1
  35. rslearn/models/dinov3.py +0 -1
  36. rslearn/models/olmoearth_pretrain/model.py +3 -1
  37. rslearn/models/pooling_decoder.py +1 -1
  38. rslearn/models/prithvi.py +0 -1
  39. rslearn/models/simple_time_series.py +97 -35
  40. rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
  41. rslearn/train/data_module.py +32 -27
  42. rslearn/train/dataset.py +260 -117
  43. rslearn/train/dataset_index.py +156 -0
  44. rslearn/train/lightning_module.py +1 -1
  45. rslearn/train/model_context.py +19 -3
  46. rslearn/train/prediction_writer.py +69 -41
  47. rslearn/train/tasks/classification.py +1 -1
  48. rslearn/train/tasks/detection.py +5 -5
  49. rslearn/train/tasks/per_pixel_regression.py +13 -13
  50. rslearn/train/tasks/regression.py +1 -1
  51. rslearn/train/tasks/segmentation.py +26 -13
  52. rslearn/train/transforms/concatenate.py +17 -27
  53. rslearn/train/transforms/crop.py +8 -19
  54. rslearn/train/transforms/flip.py +4 -10
  55. rslearn/train/transforms/mask.py +9 -15
  56. rslearn/train/transforms/normalize.py +31 -82
  57. rslearn/train/transforms/pad.py +7 -13
  58. rslearn/train/transforms/resize.py +5 -22
  59. rslearn/train/transforms/select_bands.py +16 -36
  60. rslearn/train/transforms/sentinel1.py +4 -16
  61. rslearn/utils/__init__.py +2 -0
  62. rslearn/utils/geometry.py +21 -0
  63. rslearn/utils/m2m_api.py +251 -0
  64. rslearn/utils/retry_session.py +43 -0
  65. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
  66. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/RECORD +71 -66
  67. rslearn/data_sources/earthdata_srtm.py +0 -282
  68. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
  69. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
  70. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
  71. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
  72. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
@@ -80,7 +80,7 @@ class WorldPop(LocalFiles):
80
80
  worldpop_upath.mkdir(parents=True, exist_ok=True)
81
81
  self.download_worldpop_data(worldpop_upath, timeout)
82
82
  super().__init__(
83
- src_dir=worldpop_upath,
83
+ src_dir=worldpop_upath.absolute().as_uri(),
84
84
  layer_type=LayerType.RASTER,
85
85
  context=context,
86
86
  )
@@ -19,7 +19,7 @@ from rslearn.config import LayerConfig, QueryConfig
19
19
  from rslearn.dataset import Window
20
20
  from rslearn.dataset.materialize import RasterMaterializer
21
21
  from rslearn.tile_stores import TileStore, TileStoreWithLayer
22
- from rslearn.utils import PixelBounds, Projection, STGeometry
22
+ from rslearn.utils import PixelBounds, Projection, STGeometry, get_global_raster_bounds
23
23
  from rslearn.utils.array import copy_spatial_array
24
24
  from rslearn.utils.raster_format import get_transform_from_projection_and_bounds
25
25
 
@@ -184,7 +184,7 @@ class XyzTiles(DataSource, TileStore):
184
184
  groups.append(cur_groups)
185
185
  return groups
186
186
 
187
- def deserialize_item(self, serialized_item: Any) -> Item:
187
+ def deserialize_item(self, serialized_item: dict) -> Item:
188
188
  """Deserializes an item from JSON-decoded data."""
189
189
  return Item.deserialize(serialized_item)
190
190
 
@@ -278,13 +278,9 @@ class XyzTiles(DataSource, TileStore):
278
278
  Returns:
279
279
  the bounds of the raster in the projection.
280
280
  """
281
- geom = STGeometry(self.projection, self.shp, None).to_projection(projection)
282
- return (
283
- int(geom.shp.bounds[0]),
284
- int(geom.shp.bounds[1]),
285
- int(geom.shp.bounds[2]),
286
- int(geom.shp.bounds[3]),
287
- )
281
+ # XyzTiles is a global data source, so we return global raster bounds based on
282
+ # the projection.
283
+ return get_global_raster_bounds(projection)
288
284
 
289
285
  def read_raster(
290
286
  self,
@@ -236,7 +236,11 @@ def read_and_stack_raster_windows(
236
236
  band_dtype: npt.DTypeLike,
237
237
  resampling_method: Resampling = Resampling.bilinear,
238
238
  ) -> npt.NDArray[np.generic]:
239
- """Create a stack of extent aligned raster windows.
239
+ """Create a stack of raster images, with one per item in the group.
240
+
241
+ We read the portion of each raster item corresponding to the window extent, and
242
+ stack the resulting images. This is used for the MEAN and MEDIAN compositing
243
+ methods to it can compute aggregate statistics across the stack.
240
244
 
241
245
  Args:
242
246
  group: Iterable of items (e.g., scene metadata objects) to read data from.
@@ -105,7 +105,7 @@ class Clay(FeatureExtractor):
105
105
 
106
106
  def _resize_image(self, image: torch.Tensor, original_hw: int) -> torch.Tensor:
107
107
  """Resize the image to the input resolution."""
108
- new_hw = self.patch_size if original_hw == 1 else DEFAULT_IMAGE_RESOLUTION
108
+ new_hw = PATCH_SIZE if original_hw == 1 else DEFAULT_IMAGE_RESOLUTION
109
109
  return F.interpolate(
110
110
  image, size=(new_hw, new_hw), mode="bilinear", align_corners=False
111
111
  )
@@ -123,7 +123,8 @@ class Clay(FeatureExtractor):
123
123
  device = param.device
124
124
 
125
125
  chips = torch.stack(
126
- [inp[self.modality] for inp in context.inputs], dim=0
126
+ [inp[self.modality].single_ts_to_chw_tensor() for inp in context.inputs],
127
+ dim=0,
127
128
  ) # (B, C, H, W)
128
129
  if self.do_resizing:
129
130
  chips = self._resize_image(chips, chips.shape[2])
@@ -203,7 +204,6 @@ class ClayNormalize(Transform):
203
204
  mean=means,
204
205
  std=stds,
205
206
  selectors=[modality],
206
- num_bands=len(means),
207
207
  )
208
208
  self.normalizers = torch.nn.ModuleDict(normalizers)
209
209
 
@@ -3,6 +3,7 @@
3
3
  from typing import Any
4
4
 
5
5
  import torch
6
+ from einops import rearrange
6
7
 
7
8
  from rslearn.train.model_context import ModelContext
8
9
 
@@ -79,7 +80,11 @@ class ConcatenateFeatures(IntermediateComponent):
79
80
  )
80
81
 
81
82
  add_data = torch.stack(
82
- [input_data[self.key] for input_data in context.inputs], dim=0
83
+ [
84
+ rearrange(input_data[self.key].image, "c t h w -> (c t) h w")
85
+ for input_data in context.inputs
86
+ ],
87
+ dim=0,
83
88
  )
84
89
  add_features = self.conv_layers(add_data)
85
90
 
@@ -468,7 +468,10 @@ class Detr(Predictor):
468
468
 
469
469
  # Get image sizes.
470
470
  image_sizes = torch.tensor(
471
- [[inp["image"].shape[2], inp["image"].shape[1]] for inp in context.inputs],
471
+ [
472
+ [inp["image"].image.shape[2], inp["image"].image.shape[1]]
473
+ for inp in context.inputs
474
+ ],
472
475
  dtype=torch.int32,
473
476
  device=features.device,
474
477
  )
rslearn/models/dinov3.py CHANGED
@@ -159,7 +159,6 @@ class DinoV3Normalize(Transform):
159
159
  self.normalize = Normalize(
160
160
  [value * 255 for value in mean],
161
161
  [value * 255 for value in std],
162
- num_bands=3,
163
162
  )
164
163
 
165
164
  def forward(
@@ -95,7 +95,9 @@ class OlmoEarth(FeatureExtractor):
95
95
  """
96
96
  if use_legacy_timestamps:
97
97
  warnings.warn(
98
- "For new projects, don't use legacy timesteps.", DeprecationWarning
98
+ "For new projects, don't use legacy timesteps. "
99
+ "Support will be removed after 2026-04-01.",
100
+ FutureWarning,
99
101
  )
100
102
 
101
103
  if (
@@ -124,6 +124,6 @@ class SegmentationPoolingDecoder(PoolingDecoder):
124
124
  """
125
125
  output_probs = super().forward(intermediates, context)
126
126
  # BC -> BCHW
127
- h, w = context.inputs[0][self.image_key].shape[1:3]
127
+ h, w = context.inputs[0][self.image_key].image.shape[1:3]
128
128
  feat_map = output_probs.feature_vector[:, :, None, None].repeat([1, 1, h, w])
129
129
  return FeatureMaps([feat_map])
rslearn/models/prithvi.py CHANGED
@@ -230,7 +230,6 @@ class PrithviNormalize(Transform):
230
230
  self.normalizer = Normalize(
231
231
  mean=config["mean"],
232
232
  std=config["std"],
233
- num_bands=len(config["mean"]),
234
233
  selectors=[PrithviV2.INPUT_KEY],
235
234
  )
236
235
 
@@ -1,5 +1,6 @@
1
1
  """SimpleTimeSeries encoder."""
2
2
 
3
+ import warnings
3
4
  from typing import Any
4
5
 
5
6
  import torch
@@ -25,13 +26,14 @@ class SimpleTimeSeries(FeatureExtractor):
25
26
  def __init__(
26
27
  self,
27
28
  encoder: FeatureExtractor,
28
- image_channels: int | None = None,
29
+ num_timesteps_per_forward_pass: int = 1,
29
30
  op: str = "max",
30
31
  groups: list[list[int]] | None = None,
31
32
  num_layers: int | None = None,
32
33
  image_key: str = "image",
33
34
  backbone_channels: list[tuple[int, int]] | None = None,
34
- image_keys: dict[str, int] | None = None,
35
+ image_keys: list[str] | dict[str, int] | None = None,
36
+ image_channels: int | None = None,
35
37
  ) -> None:
36
38
  """Create a new SimpleTimeSeries.
37
39
 
@@ -39,9 +41,11 @@ class SimpleTimeSeries(FeatureExtractor):
39
41
  encoder: the underlying FeatureExtractor. It must provide get_backbone_channels
40
42
  function that returns the output channels, or backbone_channels must be set.
41
43
  It must output a FeatureMaps.
42
- image_channels: the number of channels per image of the time series. The
43
- input should have multiple images concatenated on the channel axis, so
44
- this parameter is used to distinguish the different images.
44
+ num_timesteps_per_forward_pass: how many timesteps to pass to the encoder
45
+ in each forward pass. Defaults to 1 (one timestep per forward pass).
46
+ Set to a higher value to batch multiple timesteps together, e.g. for
47
+ pre/post change detection where you want 4 pre and 4 post images
48
+ processed together.
45
49
  op: one of max, mean, convrnn, conv3d, or conv1d
46
50
  groups: sets of images for which to combine features. Within each set,
47
51
  features are combined using the specified operation; then, across sets,
@@ -51,28 +55,53 @@ class SimpleTimeSeries(FeatureExtractor):
51
55
  combined before features and the combined after features. groups is a
52
56
  list of sets, and each set is a list of image indices.
53
57
  num_layers: the number of layers for convrnn, conv3d, and conv1d ops.
54
- image_key: the key to access the images.
58
+ image_key: the key to access the images (used when image_keys is not set).
55
59
  backbone_channels: manually specify the backbone channels. Can be set if
56
60
  the encoder does not provide get_backbone_channels function.
57
- image_keys: as an alternative to setting image_channels, map from the key
58
- in input dict to the number of channels per timestep for that modality.
59
- This way SimpleTimeSeries can be used with multimodal inputs. One of
60
- image_channels or image_keys must be specified.
61
+ image_keys: list of keys in input dict to process as multimodal inputs.
62
+ All keys use the same num_timesteps_per_forward_pass. If not set,
63
+ only the single image_key is used. Passing a dict[str, int] is
64
+ deprecated and will be removed on 2026-04-01.
65
+ image_channels: Deprecated, use num_timesteps_per_forward_pass instead.
66
+ Will be removed on 2026-04-01.
61
67
  """
62
- if (image_channels is None and image_keys is None) or (
63
- image_channels is not None and image_keys is not None
64
- ):
65
- raise ValueError(
66
- "exactly one of image_channels and image_keys must be specified"
68
+ # Handle deprecated image_channels parameter
69
+ if image_channels is not None:
70
+ warnings.warn(
71
+ "image_channels is deprecated and will be removed on 2026-04-01. "
72
+ "Use num_timesteps_per_forward_pass instead. The new parameter directly "
73
+ "specifies the number of timesteps per forward pass rather than requiring "
74
+ "image_channels // actual_channels.",
75
+ FutureWarning,
76
+ stacklevel=2,
67
77
  )
68
78
 
79
+ # Handle deprecated dict form of image_keys
80
+ deprecated_image_keys_dict: dict[str, int] | None = None
81
+ if isinstance(image_keys, dict):
82
+ warnings.warn(
83
+ "Passing image_keys as a dict is deprecated and will be removed on "
84
+ "2026-04-01. Use image_keys as a list[str] and set "
85
+ "num_timesteps_per_forward_pass instead.",
86
+ FutureWarning,
87
+ stacklevel=2,
88
+ )
89
+ deprecated_image_keys_dict = image_keys
90
+ image_keys = None # Will use deprecated path in forward
91
+
69
92
  super().__init__()
70
93
  self.encoder = encoder
71
- self.image_channels = image_channels
94
+ self.num_timesteps_per_forward_pass = num_timesteps_per_forward_pass
95
+ # Store deprecated parameters for runtime conversion
96
+ self._deprecated_image_channels = image_channels
97
+ self._deprecated_image_keys_dict = deprecated_image_keys_dict
72
98
  self.op = op
73
99
  self.groups = groups
74
- self.image_key = image_key
75
- self.image_keys = image_keys
100
+ # Normalize image_key to image_keys list form
101
+ if image_keys is not None:
102
+ self.image_keys = image_keys
103
+ else:
104
+ self.image_keys = [image_key]
76
105
 
77
106
  if backbone_channels is not None:
78
107
  out_channels = backbone_channels
@@ -163,24 +192,25 @@ class SimpleTimeSeries(FeatureExtractor):
163
192
  return out_channels
164
193
 
165
194
  def _get_batched_images(
166
- self, input_dicts: list[dict[str, Any]], image_key: str, image_channels: int
195
+ self, input_dicts: list[dict[str, Any]], image_key: str, num_timesteps: int
167
196
  ) -> list[RasterImage]:
168
197
  """Collect and reshape images across input dicts.
169
198
 
170
199
  The BTCHW image time series are reshaped to (B*T)CHW so they can be passed to
171
200
  the forward pass of a per-image (unitemporal) model.
201
+
202
+ Args:
203
+ input_dicts: list of input dictionaries containing RasterImage objects.
204
+ image_key: the key to access the RasterImage in each input dict.
205
+ num_timesteps: how many timesteps to batch together per forward pass.
172
206
  """
173
207
  images = torch.stack(
174
208
  [input_dict[image_key].image for input_dict in input_dicts], dim=0
175
209
  ) # B, C, T, H, W
176
210
  timestamps = [input_dict[image_key].timestamps for input_dict in input_dicts]
177
- # if image channels is not equal to the actual number of channels, then
178
- # then every N images should be batched together. For example, if the
179
- # number of input channels c == 2, and image_channels == 4, then we
180
- # want to pass 2 timesteps to the model.
181
- # TODO is probably to make this behaviour clearer but lets leave it like
182
- # this for now to not break things.
183
- num_timesteps = image_channels // images.shape[1]
211
+ # num_timesteps specifies how many timesteps to batch together per forward pass.
212
+ # For example, if the input has 8 timesteps and num_timesteps=4, we do 2
213
+ # forward passes, each with 4 timesteps batched together.
184
214
  batched_timesteps = images.shape[2] // num_timesteps
185
215
  images = rearrange(
186
216
  images,
@@ -222,10 +252,22 @@ class SimpleTimeSeries(FeatureExtractor):
222
252
  n_batch = len(context.inputs)
223
253
  n_images: int | None = None
224
254
 
225
- if self.image_keys is not None:
226
- for image_key, image_channels in self.image_keys.items():
255
+ if self._deprecated_image_keys_dict is not None:
256
+ # Deprecated dict form: each key has its own channels_per_timestep.
257
+ # The channels_per_timestep could be used to group multiple timesteps,
258
+ # together, so we need to divide by the actual image channel count to get
259
+ # the number of timesteps to be grouped.
260
+ for (
261
+ image_key,
262
+ channels_per_timestep,
263
+ ) in self._deprecated_image_keys_dict.items():
264
+ # For deprecated image_keys dict, the value is channels per timestep,
265
+ # so we need to compute num_timesteps from the actual image channels
266
+ sample_image = context.inputs[0][image_key].image
267
+ actual_channels = sample_image.shape[0] # C in CTHW
268
+ num_timesteps = channels_per_timestep // actual_channels
227
269
  batched_images = self._get_batched_images(
228
- context.inputs, image_key, image_channels
270
+ context.inputs, image_key, num_timesteps
229
271
  )
230
272
 
231
273
  if batched_inputs is None:
@@ -240,12 +282,32 @@ class SimpleTimeSeries(FeatureExtractor):
240
282
  batched_inputs[i][image_key] = image
241
283
 
242
284
  else:
243
- assert self.image_channels is not None
244
- batched_images = self._get_batched_images(
245
- context.inputs, self.image_key, self.image_channels
246
- )
247
- batched_inputs = [{self.image_key: image} for image in batched_images]
248
- n_images = len(batched_images) // n_batch
285
+ # Determine num_timesteps - either from deprecated image_channels or
286
+ # directly from num_timesteps_per_forward_pass
287
+ if self._deprecated_image_channels is not None:
288
+ # Backwards compatibility: compute num_timesteps from image_channels
289
+ # (which should be a multiple of the actual per-timestep channels).
290
+ sample_image = context.inputs[0][self.image_keys[0]].image
291
+ actual_channels = sample_image.shape[0] # C in CTHW
292
+ num_timesteps = self._deprecated_image_channels // actual_channels
293
+ else:
294
+ num_timesteps = self.num_timesteps_per_forward_pass
295
+
296
+ for image_key in self.image_keys:
297
+ batched_images = self._get_batched_images(
298
+ context.inputs, image_key, num_timesteps
299
+ )
300
+
301
+ if batched_inputs is None:
302
+ batched_inputs = [{} for _ in batched_images]
303
+ n_images = len(batched_images) // n_batch
304
+ elif n_images != len(batched_images) // n_batch:
305
+ raise ValueError(
306
+ "expected all modalities to have the same number of timesteps"
307
+ )
308
+
309
+ for i, image in enumerate(batched_images):
310
+ batched_inputs[i][image_key] = image
249
311
 
250
312
  assert n_images is not None
251
313
  # Now we can apply the underlying FeatureExtractor.