rslearn 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +30 -23
- rslearn/data_sources/__init__.py +2 -0
- rslearn/data_sources/aws_landsat.py +44 -161
- rslearn/data_sources/aws_open_data.py +2 -4
- rslearn/data_sources/aws_sentinel1.py +1 -3
- rslearn/data_sources/aws_sentinel2_element84.py +54 -165
- rslearn/data_sources/climate_data_store.py +1 -3
- rslearn/data_sources/copernicus.py +1 -2
- rslearn/data_sources/data_source.py +1 -1
- rslearn/data_sources/direct_materialize_data_source.py +336 -0
- rslearn/data_sources/earthdaily.py +52 -155
- rslearn/data_sources/earthdatahub.py +425 -0
- rslearn/data_sources/eurocrops.py +1 -2
- rslearn/data_sources/gcp_public_data.py +1 -2
- rslearn/data_sources/google_earth_engine.py +1 -2
- rslearn/data_sources/hf_srtm.py +595 -0
- rslearn/data_sources/local_files.py +3 -3
- rslearn/data_sources/openstreetmap.py +1 -1
- rslearn/data_sources/planet.py +1 -2
- rslearn/data_sources/planet_basemap.py +1 -2
- rslearn/data_sources/planetary_computer.py +183 -186
- rslearn/data_sources/soilgrids.py +3 -3
- rslearn/data_sources/stac.py +1 -2
- rslearn/data_sources/usda_cdl.py +1 -3
- rslearn/data_sources/usgs_landsat.py +7 -254
- rslearn/data_sources/utils.py +204 -64
- rslearn/data_sources/worldcereal.py +1 -1
- rslearn/data_sources/worldcover.py +1 -1
- rslearn/data_sources/worldpop.py +1 -1
- rslearn/data_sources/xyz_tiles.py +5 -9
- rslearn/dataset/materialize.py +5 -1
- rslearn/models/clay/clay.py +3 -3
- rslearn/models/concatenate_features.py +6 -1
- rslearn/models/detr/detr.py +4 -1
- rslearn/models/dinov3.py +0 -1
- rslearn/models/olmoearth_pretrain/model.py +3 -1
- rslearn/models/pooling_decoder.py +1 -1
- rslearn/models/prithvi.py +0 -1
- rslearn/models/simple_time_series.py +97 -35
- rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
- rslearn/train/data_module.py +32 -27
- rslearn/train/dataset.py +260 -117
- rslearn/train/dataset_index.py +156 -0
- rslearn/train/lightning_module.py +1 -1
- rslearn/train/model_context.py +19 -3
- rslearn/train/prediction_writer.py +69 -41
- rslearn/train/tasks/classification.py +1 -1
- rslearn/train/tasks/detection.py +5 -5
- rslearn/train/tasks/per_pixel_regression.py +13 -13
- rslearn/train/tasks/regression.py +1 -1
- rslearn/train/tasks/segmentation.py +26 -13
- rslearn/train/transforms/concatenate.py +17 -27
- rslearn/train/transforms/crop.py +8 -19
- rslearn/train/transforms/flip.py +4 -10
- rslearn/train/transforms/mask.py +9 -15
- rslearn/train/transforms/normalize.py +31 -82
- rslearn/train/transforms/pad.py +7 -13
- rslearn/train/transforms/resize.py +5 -22
- rslearn/train/transforms/select_bands.py +16 -36
- rslearn/train/transforms/sentinel1.py +4 -16
- rslearn/utils/__init__.py +2 -0
- rslearn/utils/geometry.py +21 -0
- rslearn/utils/m2m_api.py +251 -0
- rslearn/utils/retry_session.py +43 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/RECORD +71 -66
- rslearn/data_sources/earthdata_srtm.py +0 -282
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
rslearn/data_sources/worldpop.py
CHANGED
|
@@ -80,7 +80,7 @@ class WorldPop(LocalFiles):
|
|
|
80
80
|
worldpop_upath.mkdir(parents=True, exist_ok=True)
|
|
81
81
|
self.download_worldpop_data(worldpop_upath, timeout)
|
|
82
82
|
super().__init__(
|
|
83
|
-
src_dir=worldpop_upath,
|
|
83
|
+
src_dir=worldpop_upath.absolute().as_uri(),
|
|
84
84
|
layer_type=LayerType.RASTER,
|
|
85
85
|
context=context,
|
|
86
86
|
)
|
|
@@ -19,7 +19,7 @@ from rslearn.config import LayerConfig, QueryConfig
|
|
|
19
19
|
from rslearn.dataset import Window
|
|
20
20
|
from rslearn.dataset.materialize import RasterMaterializer
|
|
21
21
|
from rslearn.tile_stores import TileStore, TileStoreWithLayer
|
|
22
|
-
from rslearn.utils import PixelBounds, Projection, STGeometry
|
|
22
|
+
from rslearn.utils import PixelBounds, Projection, STGeometry, get_global_raster_bounds
|
|
23
23
|
from rslearn.utils.array import copy_spatial_array
|
|
24
24
|
from rslearn.utils.raster_format import get_transform_from_projection_and_bounds
|
|
25
25
|
|
|
@@ -184,7 +184,7 @@ class XyzTiles(DataSource, TileStore):
|
|
|
184
184
|
groups.append(cur_groups)
|
|
185
185
|
return groups
|
|
186
186
|
|
|
187
|
-
def deserialize_item(self, serialized_item:
|
|
187
|
+
def deserialize_item(self, serialized_item: dict) -> Item:
|
|
188
188
|
"""Deserializes an item from JSON-decoded data."""
|
|
189
189
|
return Item.deserialize(serialized_item)
|
|
190
190
|
|
|
@@ -278,13 +278,9 @@ class XyzTiles(DataSource, TileStore):
|
|
|
278
278
|
Returns:
|
|
279
279
|
the bounds of the raster in the projection.
|
|
280
280
|
"""
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
int(geom.shp.bounds[1]),
|
|
285
|
-
int(geom.shp.bounds[2]),
|
|
286
|
-
int(geom.shp.bounds[3]),
|
|
287
|
-
)
|
|
281
|
+
# XyzTiles is a global data source, so we return global raster bounds based on
|
|
282
|
+
# the projection.
|
|
283
|
+
return get_global_raster_bounds(projection)
|
|
288
284
|
|
|
289
285
|
def read_raster(
|
|
290
286
|
self,
|
rslearn/dataset/materialize.py
CHANGED
|
@@ -236,7 +236,11 @@ def read_and_stack_raster_windows(
|
|
|
236
236
|
band_dtype: npt.DTypeLike,
|
|
237
237
|
resampling_method: Resampling = Resampling.bilinear,
|
|
238
238
|
) -> npt.NDArray[np.generic]:
|
|
239
|
-
"""Create a stack of
|
|
239
|
+
"""Create a stack of raster images, with one per item in the group.
|
|
240
|
+
|
|
241
|
+
We read the portion of each raster item corresponding to the window extent, and
|
|
242
|
+
stack the resulting images. This is used for the MEAN and MEDIAN compositing
|
|
243
|
+
methods to it can compute aggregate statistics across the stack.
|
|
240
244
|
|
|
241
245
|
Args:
|
|
242
246
|
group: Iterable of items (e.g., scene metadata objects) to read data from.
|
rslearn/models/clay/clay.py
CHANGED
|
@@ -105,7 +105,7 @@ class Clay(FeatureExtractor):
|
|
|
105
105
|
|
|
106
106
|
def _resize_image(self, image: torch.Tensor, original_hw: int) -> torch.Tensor:
|
|
107
107
|
"""Resize the image to the input resolution."""
|
|
108
|
-
new_hw =
|
|
108
|
+
new_hw = PATCH_SIZE if original_hw == 1 else DEFAULT_IMAGE_RESOLUTION
|
|
109
109
|
return F.interpolate(
|
|
110
110
|
image, size=(new_hw, new_hw), mode="bilinear", align_corners=False
|
|
111
111
|
)
|
|
@@ -123,7 +123,8 @@ class Clay(FeatureExtractor):
|
|
|
123
123
|
device = param.device
|
|
124
124
|
|
|
125
125
|
chips = torch.stack(
|
|
126
|
-
[inp[self.modality] for inp in context.inputs],
|
|
126
|
+
[inp[self.modality].single_ts_to_chw_tensor() for inp in context.inputs],
|
|
127
|
+
dim=0,
|
|
127
128
|
) # (B, C, H, W)
|
|
128
129
|
if self.do_resizing:
|
|
129
130
|
chips = self._resize_image(chips, chips.shape[2])
|
|
@@ -203,7 +204,6 @@ class ClayNormalize(Transform):
|
|
|
203
204
|
mean=means,
|
|
204
205
|
std=stds,
|
|
205
206
|
selectors=[modality],
|
|
206
|
-
num_bands=len(means),
|
|
207
207
|
)
|
|
208
208
|
self.normalizers = torch.nn.ModuleDict(normalizers)
|
|
209
209
|
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
+
from einops import rearrange
|
|
6
7
|
|
|
7
8
|
from rslearn.train.model_context import ModelContext
|
|
8
9
|
|
|
@@ -79,7 +80,11 @@ class ConcatenateFeatures(IntermediateComponent):
|
|
|
79
80
|
)
|
|
80
81
|
|
|
81
82
|
add_data = torch.stack(
|
|
82
|
-
[
|
|
83
|
+
[
|
|
84
|
+
rearrange(input_data[self.key].image, "c t h w -> (c t) h w")
|
|
85
|
+
for input_data in context.inputs
|
|
86
|
+
],
|
|
87
|
+
dim=0,
|
|
83
88
|
)
|
|
84
89
|
add_features = self.conv_layers(add_data)
|
|
85
90
|
|
rslearn/models/detr/detr.py
CHANGED
|
@@ -468,7 +468,10 @@ class Detr(Predictor):
|
|
|
468
468
|
|
|
469
469
|
# Get image sizes.
|
|
470
470
|
image_sizes = torch.tensor(
|
|
471
|
-
[
|
|
471
|
+
[
|
|
472
|
+
[inp["image"].image.shape[2], inp["image"].image.shape[1]]
|
|
473
|
+
for inp in context.inputs
|
|
474
|
+
],
|
|
472
475
|
dtype=torch.int32,
|
|
473
476
|
device=features.device,
|
|
474
477
|
)
|
rslearn/models/dinov3.py
CHANGED
|
@@ -95,7 +95,9 @@ class OlmoEarth(FeatureExtractor):
|
|
|
95
95
|
"""
|
|
96
96
|
if use_legacy_timestamps:
|
|
97
97
|
warnings.warn(
|
|
98
|
-
"For new projects, don't use legacy timesteps."
|
|
98
|
+
"For new projects, don't use legacy timesteps. "
|
|
99
|
+
"Support will be removed after 2026-04-01.",
|
|
100
|
+
FutureWarning,
|
|
99
101
|
)
|
|
100
102
|
|
|
101
103
|
if (
|
|
@@ -124,6 +124,6 @@ class SegmentationPoolingDecoder(PoolingDecoder):
|
|
|
124
124
|
"""
|
|
125
125
|
output_probs = super().forward(intermediates, context)
|
|
126
126
|
# BC -> BCHW
|
|
127
|
-
h, w = context.inputs[0][self.image_key].shape[1:3]
|
|
127
|
+
h, w = context.inputs[0][self.image_key].image.shape[1:3]
|
|
128
128
|
feat_map = output_probs.feature_vector[:, :, None, None].repeat([1, 1, h, w])
|
|
129
129
|
return FeatureMaps([feat_map])
|
rslearn/models/prithvi.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""SimpleTimeSeries encoder."""
|
|
2
2
|
|
|
3
|
+
import warnings
|
|
3
4
|
from typing import Any
|
|
4
5
|
|
|
5
6
|
import torch
|
|
@@ -25,13 +26,14 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
25
26
|
def __init__(
|
|
26
27
|
self,
|
|
27
28
|
encoder: FeatureExtractor,
|
|
28
|
-
|
|
29
|
+
num_timesteps_per_forward_pass: int = 1,
|
|
29
30
|
op: str = "max",
|
|
30
31
|
groups: list[list[int]] | None = None,
|
|
31
32
|
num_layers: int | None = None,
|
|
32
33
|
image_key: str = "image",
|
|
33
34
|
backbone_channels: list[tuple[int, int]] | None = None,
|
|
34
|
-
image_keys: dict[str, int] | None = None,
|
|
35
|
+
image_keys: list[str] | dict[str, int] | None = None,
|
|
36
|
+
image_channels: int | None = None,
|
|
35
37
|
) -> None:
|
|
36
38
|
"""Create a new SimpleTimeSeries.
|
|
37
39
|
|
|
@@ -39,9 +41,11 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
39
41
|
encoder: the underlying FeatureExtractor. It must provide get_backbone_channels
|
|
40
42
|
function that returns the output channels, or backbone_channels must be set.
|
|
41
43
|
It must output a FeatureMaps.
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
44
|
+
num_timesteps_per_forward_pass: how many timesteps to pass to the encoder
|
|
45
|
+
in each forward pass. Defaults to 1 (one timestep per forward pass).
|
|
46
|
+
Set to a higher value to batch multiple timesteps together, e.g. for
|
|
47
|
+
pre/post change detection where you want 4 pre and 4 post images
|
|
48
|
+
processed together.
|
|
45
49
|
op: one of max, mean, convrnn, conv3d, or conv1d
|
|
46
50
|
groups: sets of images for which to combine features. Within each set,
|
|
47
51
|
features are combined using the specified operation; then, across sets,
|
|
@@ -51,28 +55,53 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
51
55
|
combined before features and the combined after features. groups is a
|
|
52
56
|
list of sets, and each set is a list of image indices.
|
|
53
57
|
num_layers: the number of layers for convrnn, conv3d, and conv1d ops.
|
|
54
|
-
image_key: the key to access the images.
|
|
58
|
+
image_key: the key to access the images (used when image_keys is not set).
|
|
55
59
|
backbone_channels: manually specify the backbone channels. Can be set if
|
|
56
60
|
the encoder does not provide get_backbone_channels function.
|
|
57
|
-
image_keys:
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
+
image_keys: list of keys in input dict to process as multimodal inputs.
|
|
62
|
+
All keys use the same num_timesteps_per_forward_pass. If not set,
|
|
63
|
+
only the single image_key is used. Passing a dict[str, int] is
|
|
64
|
+
deprecated and will be removed on 2026-04-01.
|
|
65
|
+
image_channels: Deprecated, use num_timesteps_per_forward_pass instead.
|
|
66
|
+
Will be removed on 2026-04-01.
|
|
61
67
|
"""
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
"
|
|
68
|
+
# Handle deprecated image_channels parameter
|
|
69
|
+
if image_channels is not None:
|
|
70
|
+
warnings.warn(
|
|
71
|
+
"image_channels is deprecated and will be removed on 2026-04-01. "
|
|
72
|
+
"Use num_timesteps_per_forward_pass instead. The new parameter directly "
|
|
73
|
+
"specifies the number of timesteps per forward pass rather than requiring "
|
|
74
|
+
"image_channels // actual_channels.",
|
|
75
|
+
FutureWarning,
|
|
76
|
+
stacklevel=2,
|
|
67
77
|
)
|
|
68
78
|
|
|
79
|
+
# Handle deprecated dict form of image_keys
|
|
80
|
+
deprecated_image_keys_dict: dict[str, int] | None = None
|
|
81
|
+
if isinstance(image_keys, dict):
|
|
82
|
+
warnings.warn(
|
|
83
|
+
"Passing image_keys as a dict is deprecated and will be removed on "
|
|
84
|
+
"2026-04-01. Use image_keys as a list[str] and set "
|
|
85
|
+
"num_timesteps_per_forward_pass instead.",
|
|
86
|
+
FutureWarning,
|
|
87
|
+
stacklevel=2,
|
|
88
|
+
)
|
|
89
|
+
deprecated_image_keys_dict = image_keys
|
|
90
|
+
image_keys = None # Will use deprecated path in forward
|
|
91
|
+
|
|
69
92
|
super().__init__()
|
|
70
93
|
self.encoder = encoder
|
|
71
|
-
self.
|
|
94
|
+
self.num_timesteps_per_forward_pass = num_timesteps_per_forward_pass
|
|
95
|
+
# Store deprecated parameters for runtime conversion
|
|
96
|
+
self._deprecated_image_channels = image_channels
|
|
97
|
+
self._deprecated_image_keys_dict = deprecated_image_keys_dict
|
|
72
98
|
self.op = op
|
|
73
99
|
self.groups = groups
|
|
74
|
-
|
|
75
|
-
|
|
100
|
+
# Normalize image_key to image_keys list form
|
|
101
|
+
if image_keys is not None:
|
|
102
|
+
self.image_keys = image_keys
|
|
103
|
+
else:
|
|
104
|
+
self.image_keys = [image_key]
|
|
76
105
|
|
|
77
106
|
if backbone_channels is not None:
|
|
78
107
|
out_channels = backbone_channels
|
|
@@ -163,24 +192,25 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
163
192
|
return out_channels
|
|
164
193
|
|
|
165
194
|
def _get_batched_images(
|
|
166
|
-
self, input_dicts: list[dict[str, Any]], image_key: str,
|
|
195
|
+
self, input_dicts: list[dict[str, Any]], image_key: str, num_timesteps: int
|
|
167
196
|
) -> list[RasterImage]:
|
|
168
197
|
"""Collect and reshape images across input dicts.
|
|
169
198
|
|
|
170
199
|
The BTCHW image time series are reshaped to (B*T)CHW so they can be passed to
|
|
171
200
|
the forward pass of a per-image (unitemporal) model.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
input_dicts: list of input dictionaries containing RasterImage objects.
|
|
204
|
+
image_key: the key to access the RasterImage in each input dict.
|
|
205
|
+
num_timesteps: how many timesteps to batch together per forward pass.
|
|
172
206
|
"""
|
|
173
207
|
images = torch.stack(
|
|
174
208
|
[input_dict[image_key].image for input_dict in input_dicts], dim=0
|
|
175
209
|
) # B, C, T, H, W
|
|
176
210
|
timestamps = [input_dict[image_key].timestamps for input_dict in input_dicts]
|
|
177
|
-
#
|
|
178
|
-
#
|
|
179
|
-
#
|
|
180
|
-
# want to pass 2 timesteps to the model.
|
|
181
|
-
# TODO is probably to make this behaviour clearer but lets leave it like
|
|
182
|
-
# this for now to not break things.
|
|
183
|
-
num_timesteps = image_channels // images.shape[1]
|
|
211
|
+
# num_timesteps specifies how many timesteps to batch together per forward pass.
|
|
212
|
+
# For example, if the input has 8 timesteps and num_timesteps=4, we do 2
|
|
213
|
+
# forward passes, each with 4 timesteps batched together.
|
|
184
214
|
batched_timesteps = images.shape[2] // num_timesteps
|
|
185
215
|
images = rearrange(
|
|
186
216
|
images,
|
|
@@ -222,10 +252,22 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
222
252
|
n_batch = len(context.inputs)
|
|
223
253
|
n_images: int | None = None
|
|
224
254
|
|
|
225
|
-
if self.
|
|
226
|
-
|
|
255
|
+
if self._deprecated_image_keys_dict is not None:
|
|
256
|
+
# Deprecated dict form: each key has its own channels_per_timestep.
|
|
257
|
+
# The channels_per_timestep could be used to group multiple timesteps,
|
|
258
|
+
# together, so we need to divide by the actual image channel count to get
|
|
259
|
+
# the number of timesteps to be grouped.
|
|
260
|
+
for (
|
|
261
|
+
image_key,
|
|
262
|
+
channels_per_timestep,
|
|
263
|
+
) in self._deprecated_image_keys_dict.items():
|
|
264
|
+
# For deprecated image_keys dict, the value is channels per timestep,
|
|
265
|
+
# so we need to compute num_timesteps from the actual image channels
|
|
266
|
+
sample_image = context.inputs[0][image_key].image
|
|
267
|
+
actual_channels = sample_image.shape[0] # C in CTHW
|
|
268
|
+
num_timesteps = channels_per_timestep // actual_channels
|
|
227
269
|
batched_images = self._get_batched_images(
|
|
228
|
-
context.inputs, image_key,
|
|
270
|
+
context.inputs, image_key, num_timesteps
|
|
229
271
|
)
|
|
230
272
|
|
|
231
273
|
if batched_inputs is None:
|
|
@@ -240,12 +282,32 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
240
282
|
batched_inputs[i][image_key] = image
|
|
241
283
|
|
|
242
284
|
else:
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
285
|
+
# Determine num_timesteps - either from deprecated image_channels or
|
|
286
|
+
# directly from num_timesteps_per_forward_pass
|
|
287
|
+
if self._deprecated_image_channels is not None:
|
|
288
|
+
# Backwards compatibility: compute num_timesteps from image_channels
|
|
289
|
+
# (which should be a multiple of the actual per-timestep channels).
|
|
290
|
+
sample_image = context.inputs[0][self.image_keys[0]].image
|
|
291
|
+
actual_channels = sample_image.shape[0] # C in CTHW
|
|
292
|
+
num_timesteps = self._deprecated_image_channels // actual_channels
|
|
293
|
+
else:
|
|
294
|
+
num_timesteps = self.num_timesteps_per_forward_pass
|
|
295
|
+
|
|
296
|
+
for image_key in self.image_keys:
|
|
297
|
+
batched_images = self._get_batched_images(
|
|
298
|
+
context.inputs, image_key, num_timesteps
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
if batched_inputs is None:
|
|
302
|
+
batched_inputs = [{} for _ in batched_images]
|
|
303
|
+
n_images = len(batched_images) // n_batch
|
|
304
|
+
elif n_images != len(batched_images) // n_batch:
|
|
305
|
+
raise ValueError(
|
|
306
|
+
"expected all modalities to have the same number of timesteps"
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
for i, image in enumerate(batched_images):
|
|
310
|
+
batched_inputs[i][image_key] = image
|
|
249
311
|
|
|
250
312
|
assert n_images is not None
|
|
251
313
|
# Now we can apply the underlying FeatureExtractor.
|