rslearn 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  """Default TileStore implementation."""
2
2
 
3
+ import json
3
4
  import math
4
5
  import shutil
5
6
  from typing import Any
@@ -35,6 +36,9 @@ from .tile_store import TileStore
35
36
  # Special filename to indicate writing is done.
36
37
  COMPLETED_FNAME = "completed"
37
38
 
39
+ # Special filename to store the bands that are present in a raster.
40
+ BANDS_FNAME = "bands.json"
41
+
38
42
 
39
43
  class DefaultTileStore(TileStore):
40
44
  """Default TileStore implementation.
@@ -84,7 +88,7 @@ class DefaultTileStore(TileStore):
84
88
  self.path = join_upath(ds_path, self.path_suffix)
85
89
 
86
90
  def _get_raster_dir(
87
- self, layer_name: str, item_name: str, bands: list[str]
91
+ self, layer_name: str, item_name: str, bands: list[str], write: bool = False
88
92
  ) -> UPath:
89
93
  """Get the directory where the specified raster is stored.
90
94
 
@@ -92,12 +96,21 @@ class DefaultTileStore(TileStore):
92
96
  layer_name: the name of the dataset layer.
93
97
  item_name: the name of the item from the data source.
94
98
  bands: list of band names that are expected to be stored together.
99
+ write: whether to create the directory and write the bands to a file inside
100
+ the directory.
95
101
 
96
102
  Returns:
97
103
  the UPath directory where the raster should be stored.
98
104
  """
99
105
  assert self.path is not None
100
- return self.path / layer_name / item_name / get_bandset_dirname(bands)
106
+ dir_name = self.path / layer_name / item_name / get_bandset_dirname(bands)
107
+
108
+ if write:
109
+ dir_name.mkdir(parents=True, exist_ok=True)
110
+ with (dir_name / BANDS_FNAME).open("w") as f:
111
+ json.dump(bands, f)
112
+
113
+ return dir_name
101
114
 
102
115
  def _get_raster_fname(
103
116
  self, layer_name: str, item_name: str, bands: list[str]
@@ -117,10 +130,12 @@ class DefaultTileStore(TileStore):
117
130
  """
118
131
  raster_dir = self._get_raster_dir(layer_name, item_name, bands)
119
132
  for fname in raster_dir.iterdir():
120
- # Ignore completed sentinel files as well as temporary files created by
133
+ # Ignore completed sentinel files, bands files, as well as temporary files created by
121
134
  # open_atomic (in case this tile store is on local filesystem).
122
135
  if fname.name == COMPLETED_FNAME:
123
136
  continue
137
+ if fname.name == BANDS_FNAME:
138
+ continue
124
139
  if ".tmp." in fname.name:
125
140
  continue
126
141
  return fname
@@ -161,8 +176,20 @@ class DefaultTileStore(TileStore):
161
176
 
162
177
  bands: list[list[str]] = []
163
178
  for raster_dir in item_dir.iterdir():
164
- parts = raster_dir.name.split("_")
165
- bands.append(parts)
179
+ if not (raster_dir / BANDS_FNAME).exists():
180
+ # This is likely a legacy directory where the bands are only encoded in
181
+ # the directory name, so we have to rely on that.
182
+ parts = raster_dir.name.split("_")
183
+ bands.append(parts)
184
+ continue
185
+
186
+ # We use the BANDS_FNAME here -- although it is slower to read the file, it
187
+ # is more reliable since sometimes the directory name is a hash of the
188
+ # bands in case there are too many bands (filename too long) or some bands
189
+ # contain the underscore character.
190
+ with (raster_dir / BANDS_FNAME).open() as f:
191
+ bands.append(json.load(f))
192
+
166
193
  return bands
167
194
 
168
195
  def get_raster_bounds(
@@ -248,7 +275,7 @@ class DefaultTileStore(TileStore):
248
275
  bounds: the bounds of the array.
249
276
  array: the raster data.
250
277
  """
251
- raster_dir = self._get_raster_dir(layer_name, item_name, bands)
278
+ raster_dir = self._get_raster_dir(layer_name, item_name, bands, write=True)
252
279
  raster_format = GeotiffRasterFormat(geotiff_options=self.geotiff_options)
253
280
  raster_format.encode_raster(raster_dir, projection, bounds, array)
254
281
  (raster_dir / COMPLETED_FNAME).touch()
@@ -264,7 +291,7 @@ class DefaultTileStore(TileStore):
264
291
  bands: the list of bands in the array.
265
292
  fname: the raster file, which must be readable by rasterio.
266
293
  """
267
- raster_dir = self._get_raster_dir(layer_name, item_name, bands)
294
+ raster_dir = self._get_raster_dir(layer_name, item_name, bands, write=True)
268
295
  raster_dir.mkdir(parents=True, exist_ok=True)
269
296
 
270
297
  if self.convert_rasters_to_cogs:
@@ -27,14 +27,18 @@ class Normalize(Transform):
27
27
 
28
28
  Args:
29
29
  mean: a single value or one mean per channel
30
- std: a single value or one std per channel
30
+ std: a single value or one std per channel (must match the shape of mean)
31
31
  valid_range: optionally clip to a minimum and maximum value
32
32
  selectors: image items to transform
33
- bands: optionally restrict the normalization to these bands
33
+ bands: optionally restrict the normalization to these band indices. If set,
34
+ mean and std must either be one value, or have length equal to the
35
+ number of band indices passed here.
34
36
  num_bands: the number of bands per image, to distinguish different images
35
37
  in a time series. If set, then the bands list is repeated for each
36
38
  image, e.g. if bands=[2] then we apply normalization on images[2],
37
- images[2+num_bands], images[2+num_bands*2], etc.
39
+ images[2+num_bands], images[2+num_bands*2], etc. Or if the bands list
40
+ is not set, then we apply the mean and std on each image in the time
41
+ series.
38
42
  """
39
43
  super().__init__()
40
44
  self.mean = torch.tensor(mean)
@@ -57,6 +61,23 @@ class Normalize(Transform):
57
61
  Args:
58
62
  image: the image to transform.
59
63
  """
64
+
65
+ def _repeat_mean_and_std(
66
+ image_channels: int, num_bands: int | None
67
+ ) -> tuple[torch.Tensor, torch.Tensor]:
68
+ """Get mean and std tensor that are suitable for applying on the image."""
69
+ # We only need to repeat the tensor if both of these are true:
70
+ # - The mean/std are not just one scalar.
71
+ # - self.num_bands is set, otherwise we treat the input as a single image.
72
+ if len(self.mean.shape) == 0:
73
+ return self.mean, self.std
74
+ if num_bands is None:
75
+ return self.mean, self.std
76
+ num_images = image_channels // num_bands
77
+ return self.mean.repeat(num_images)[:, None, None], self.std.repeat(
78
+ num_images
79
+ )[:, None, None]
80
+
60
81
  if self.bands is not None:
61
82
  # User has provided band indices to normalize.
62
83
  # If num_bands is set, then we repeat these for each image in the input
@@ -72,13 +93,21 @@ class Normalize(Transform):
72
93
  dim=0,
73
94
  )
74
95
 
75
- image[band_indices] = (image[band_indices] - self.mean) / self.std
96
+ # We use len(self.bands) here because that is how many bands per timestep
97
+ # we are actually processing with the mean/std.
98
+ mean, std = _repeat_mean_and_std(
99
+ image_channels=len(band_indices), num_bands=len(self.bands)
100
+ )
101
+ image[band_indices] = (image[band_indices] - mean) / std
76
102
  if self.valid_min is not None:
77
103
  image[band_indices] = torch.clamp(
78
104
  image[band_indices], min=self.valid_min, max=self.valid_max
79
105
  )
80
106
  else:
81
- image = (image - self.mean) / self.std
107
+ mean, std = _repeat_mean_and_std(
108
+ image_channels=image.shape[0], num_bands=self.num_bands
109
+ )
110
+ image = (image - mean) / std
82
111
  if self.valid_min is not None:
83
112
  image = torch.clamp(image, min=self.valid_min, max=self.valid_max)
84
113
  return image
@@ -0,0 +1,67 @@
1
+ """The SelectBands transform."""
2
+
3
+ from typing import Any
4
+
5
+ from .transform import Transform, read_selector, write_selector
6
+
7
+
8
+ class SelectBands(Transform):
9
+ """Select a subset of bands from an image."""
10
+
11
+ def __init__(
12
+ self,
13
+ band_indices: list[int],
14
+ input_selector: str = "image",
15
+ output_selector: str = "image",
16
+ num_bands_per_timestep: int | None = None,
17
+ ):
18
+ """Initialize a new Concatenate.
19
+
20
+ Args:
21
+ band_indices: the bands to select.
22
+ input_selector: the selector to read the input image.
23
+ output_selector: the output selector under which to save the output image.
24
+ num_bands_per_timestep: the number of bands per image, to distinguish
25
+ between stacked images in an image time series. If set, then the
26
+ band_indices are selected for each image in the time series.
27
+ """
28
+ super().__init__()
29
+ self.input_selector = input_selector
30
+ self.output_selector = output_selector
31
+ self.band_indices = band_indices
32
+ self.num_bands_per_timestep = num_bands_per_timestep
33
+
34
+ def forward(
35
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
36
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
37
+ """Apply concatenation over the inputs and targets.
38
+
39
+ Args:
40
+ input_dict: the input
41
+ target_dict: the target
42
+
43
+ Returns:
44
+ normalized (input_dicts, target_dicts) tuple
45
+ """
46
+ image = read_selector(input_dict, target_dict, self.input_selector)
47
+ num_bands_per_timestep = (
48
+ self.num_bands_per_timestep
49
+ if self.num_bands_per_timestep is not None
50
+ else image.shape[0]
51
+ )
52
+
53
+ if image.shape[0] % num_bands_per_timestep != 0:
54
+ raise ValueError(
55
+ f"channel dimension {image.shape[0]} is not multiple of bands per timestep {num_bands_per_timestep}"
56
+ )
57
+
58
+ # Copy the band indices for each timestep in the input.
59
+ wanted_bands: list[int] = []
60
+ for start_channel_idx in range(0, image.shape[0], num_bands_per_timestep):
61
+ wanted_bands.extend(
62
+ [(start_channel_idx + band_idx) for band_idx in self.band_indices]
63
+ )
64
+
65
+ result = image[wanted_bands]
66
+ write_selector(input_dict, target_dict, self.output_selector, result)
67
+ return input_dict, target_dict
@@ -0,0 +1,60 @@
1
+ """Transforms related to Sentinel-1 data."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+
7
+ from .transform import Transform
8
+
9
+
10
+ class Sentinel1ToDecibels(Transform):
11
+ """Convert Sentinel-1 data from raw intensity to or from decibels."""
12
+
13
+ def __init__(
14
+ self,
15
+ selectors: list[str] = ["image"],
16
+ from_decibels: bool = False,
17
+ epsilon: float = 1e-6,
18
+ ):
19
+ """Initialize a new Sentinel1ToDecibels.
20
+
21
+ Args:
22
+ selectors: the input selectors to apply the transform on.
23
+ from_decibels: convert from decibels to intensities instead of intensity to
24
+ decibels.
25
+ epsilon: when converting to decibels, clip the intensities to this minimum
26
+ value to avoid log issues. This is mostly to avoid pixels that have no
27
+ data with no data value being 0.
28
+ """
29
+ super().__init__()
30
+ self.selectors = selectors
31
+ self.from_decibels = from_decibels
32
+ self.epsilon = epsilon
33
+
34
+ def apply_image(self, image: torch.Tensor) -> torch.Tensor:
35
+ """Normalize the specified image.
36
+
37
+ Args:
38
+ image: the image to transform.
39
+ """
40
+ if self.from_decibels:
41
+ # Decibels to linear scale.
42
+ return torch.pow(10.0, image / 10.0)
43
+ else:
44
+ # Linear scale to decibels.
45
+ return 10 * torch.log10(torch.clamp(image, min=self.epsilon))
46
+
47
+ def forward(
48
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
49
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
50
+ """Apply normalization over the inputs and targets.
51
+
52
+ Args:
53
+ input_dict: the input
54
+ target_dict: the target
55
+
56
+ Returns:
57
+ normalized (input_dicts, target_dicts) tuple
58
+ """
59
+ self.apply_fn(self.apply_image, input_dict, target_dict, self.selectors)
60
+ return input_dict, target_dict
@@ -54,7 +54,7 @@ def read_selector(
54
54
  the item specified by the selector
55
55
  """
56
56
  d, selector = get_dict_and_subselector(input_dict, target_dict, selector)
57
- parts = selector.split("/")
57
+ parts = selector.split("/") if selector else []
58
58
  cur = d
59
59
  for part in parts:
60
60
  cur = cur[part]
@@ -76,11 +76,28 @@ def write_selector(
76
76
  v: the value to write
77
77
  """
78
78
  d, selector = get_dict_and_subselector(input_dict, target_dict, selector)
79
- parts = selector.split("/")
80
- cur = d
81
- for part in parts[:-1]:
82
- cur = cur[part]
83
- cur[parts[-1]] = v
79
+ if selector:
80
+ parts = selector.split("/")
81
+ cur = d
82
+ for part in parts[:-1]:
83
+ cur = cur[part]
84
+ cur[parts[-1]] = v
85
+ else:
86
+ # If the selector references the input or target dictionary directly, then we
87
+ # have a special case where instead of overwriting with v, we replace the keys
88
+ # with those in v. v must be a dictionary here, not a tensor, since otherwise
89
+ # it wouldn't match the type of the input or target dictionary.
90
+ if not isinstance(v, dict):
91
+ raise ValueError(
92
+ "when directly specifying the input or target dict, expected the value to be a dict"
93
+ )
94
+ if d == v:
95
+ # This may happen if the writer did not make a copy of the dictionary. In
96
+ # this case the code below would not update d correctly since it would also
97
+ # clear v.
98
+ return
99
+ d.clear()
100
+ d.update(v)
84
101
 
85
102
 
86
103
  class Transform(torch.nn.Module):
@@ -2,13 +2,13 @@
2
2
 
3
3
  import hashlib
4
4
  import json
5
- from typing import Any, BinaryIO
5
+ from collections.abc import Callable
6
+ from typing import Any, BinaryIO, TypeVar
6
7
 
7
8
  import affine
8
9
  import numpy as np
9
10
  import numpy.typing as npt
10
11
  import rasterio
11
- from class_registry import ClassRegistry
12
12
  from PIL import Image
13
13
  from rasterio.crs import CRS
14
14
  from rasterio.enums import Resampling
@@ -21,18 +21,44 @@ from rslearn.utils.fsspec import open_rasterio_upath_reader, open_rasterio_upath
21
21
 
22
22
  from .geometry import PixelBounds, Projection
23
23
 
24
- RasterFormats = ClassRegistry()
24
+ _RasterFormatT = TypeVar("_RasterFormatT", bound="RasterFormat")
25
+
26
+
27
+ class _RasterFormatRegistry(dict[str, type["RasterFormat"]]):
28
+ """Registry for RasterFormat classes."""
29
+
30
+ def register(
31
+ self, name: str
32
+ ) -> Callable[[type[_RasterFormatT]], type[_RasterFormatT]]:
33
+ """Decorator to register a raster format class."""
34
+
35
+ def decorator(cls: type[_RasterFormatT]) -> type[_RasterFormatT]:
36
+ self[name] = cls
37
+ return cls
38
+
39
+ return decorator
40
+
41
+
42
+ RasterFormats = _RasterFormatRegistry()
43
+
44
+
25
45
  logger = get_logger(__name__)
26
46
 
27
47
 
28
48
  def get_bandset_dirname(bands: list[str]) -> str:
29
49
  """Get the directory name that should be used to store the given group of bands."""
50
+ # We try to use a human-readable name with underscore as the delimiter, but if that
51
+ # isn't straightforward then we use hash instead.
30
52
  if any(["_" in band for band in bands]):
31
- raise ValueError("band names must not contain '_'")
53
+ # In this case we hash the JSON representation of the bands.
54
+ return hashlib.sha256(json.dumps(bands).encode()).hexdigest()
32
55
  dirname = "_".join(bands)
33
56
  if len(dirname) > 64:
34
57
  # Previously we simply joined the bands, but this can result in directory name
35
58
  # that is too long. In this case, now we use hash instead.
59
+ # We use a different code path here where we hash the initial directory name
60
+ # instead of the JSON, for historical reasons (to maintain backwards
61
+ # compatibility).
36
62
  dirname = hashlib.sha256(dirname.encode()).hexdigest()
37
63
  return dirname
38
64
 
@@ -141,6 +167,19 @@ class RasterFormat:
141
167
  """
142
168
  raise NotImplementedError
143
169
 
170
+ @staticmethod
171
+ def from_config(name: str, config: dict[str, Any]) -> "RasterFormat":
172
+ """Create a RasterFormat from a config dict.
173
+
174
+ Args:
175
+ name: the name of this format
176
+ config: the config dict
177
+
178
+ Returns:
179
+ the RasterFormat instance
180
+ """
181
+ raise NotImplementedError
182
+
144
183
 
145
184
  @RasterFormats.register("image_tile")
146
185
  class ImageTileRasterFormat(RasterFormat):
@@ -710,5 +749,5 @@ def load_raster_format(config: RasterFormatConfig) -> RasterFormat:
710
749
  Returns:
711
750
  the loaded RasterFormat implementation
712
751
  """
713
- cls = RasterFormats.get_class(config.name)
752
+ cls = RasterFormats[config.name]
714
753
  return cls.from_config(config.name, config.config_dict)
@@ -1,11 +1,11 @@
1
1
  """Classes for writing vector data to a UPath."""
2
2
 
3
3
  import json
4
+ from collections.abc import Callable
4
5
  from enum import Enum
5
- from typing import Any
6
+ from typing import Any, TypeVar
6
7
 
7
8
  import shapely
8
- from class_registry import ClassRegistry
9
9
  from rasterio.crs import CRS
10
10
  from upath import UPath
11
11
 
@@ -18,7 +18,25 @@ from .feature import Feature
18
18
  from .geometry import PixelBounds, Projection, STGeometry, safely_reproject_and_clip
19
19
 
20
20
  logger = get_logger(__name__)
21
- VectorFormats = ClassRegistry()
21
+ _VectorFormatT = TypeVar("_VectorFormatT", bound="VectorFormat")
22
+
23
+
24
+ class _VectorFormatRegistry(dict[str, type["VectorFormat"]]):
25
+ """Registry for VectorFormat classes."""
26
+
27
+ def register(
28
+ self, name: str
29
+ ) -> Callable[[type[_VectorFormatT]], type[_VectorFormatT]]:
30
+ """Decorator to register a vector format class."""
31
+
32
+ def decorator(cls: type[_VectorFormatT]) -> type[_VectorFormatT]:
33
+ self[name] = cls
34
+ return cls
35
+
36
+ return decorator
37
+
38
+
39
+ VectorFormats = _VectorFormatRegistry()
22
40
 
23
41
 
24
42
  class VectorFormat:
@@ -53,6 +71,19 @@ class VectorFormat:
53
71
  """
54
72
  raise NotImplementedError
55
73
 
74
+ @staticmethod
75
+ def from_config(name: str, config: dict[str, Any]) -> "VectorFormat":
76
+ """Create a VectorFormat from a config dict.
77
+
78
+ Args:
79
+ name: the name of this format
80
+ config: the config dict
81
+
82
+ Returns:
83
+ the VectorFormat instance
84
+ """
85
+ raise NotImplementedError
86
+
56
87
 
57
88
  @VectorFormats.register("tile")
58
89
  class TileVectorFormat(VectorFormat):
@@ -410,5 +441,5 @@ def load_vector_format(config: VectorFormatConfig) -> VectorFormat:
410
441
  Returns:
411
442
  the loaded VectorFormat implementation
412
443
  """
413
- cls = VectorFormats.get_class(config.name)
444
+ cls = VectorFormats[config.name]
414
445
  return cls.from_config(config.name, config.config_dict)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.7
3
+ Version: 0.0.9
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -212,7 +212,6 @@ Requires-Python: >=3.11
212
212
  Description-Content-Type: text/markdown
213
213
  License-File: LICENSE
214
214
  Requires-Dist: boto3>=1.39
215
- Requires-Dist: class_registry>=2.1
216
215
  Requires-Dist: fiona>=1.10
217
216
  Requires-Dist: fsspec>=2025.9.0
218
217
  Requires-Dist: jsonargparse>=4.35.0
@@ -233,7 +232,7 @@ Requires-Dist: cdsapi>=0.7.6; extra == "extra"
233
232
  Requires-Dist: earthdaily[platform]>=1.0.7; extra == "extra"
234
233
  Requires-Dist: earthengine-api>=1.6.3; extra == "extra"
235
234
  Requires-Dist: einops>=0.8; extra == "extra"
236
- Requires-Dist: gcsfs>=2025.9.0; extra == "extra"
235
+ Requires-Dist: fsspec[gcs,s3]; extra == "extra"
237
236
  Requires-Dist: google-cloud-bigquery>=3.35; extra == "extra"
238
237
  Requires-Dist: google-cloud-storage>=2.18; extra == "extra"
239
238
  Requires-Dist: huggingface_hub>=0.34.4; extra == "extra"
@@ -244,7 +243,6 @@ Requires-Dist: planetary_computer>=1.0; extra == "extra"
244
243
  Requires-Dist: pycocotools>=2.0; extra == "extra"
245
244
  Requires-Dist: pystac_client>=0.9; extra == "extra"
246
245
  Requires-Dist: rtree>=1.4; extra == "extra"
247
- Requires-Dist: s3fs>=2025.9.0; extra == "extra"
248
246
  Requires-Dist: satlaspretrain_models>=0.3; extra == "extra"
249
247
  Requires-Dist: scipy>=1.16; extra == "extra"
250
248
  Requires-Dist: terratorch>=1.0.2; extra == "extra"
@@ -285,6 +283,7 @@ Quick links:
285
283
  - [Examples](docs/Examples.md) contains more examples, including customizing different
286
284
  stages of rslearn with additional code.
287
285
  - [DatasetConfig](docs/DatasetConfig.md) documents the dataset configuration file.
286
+ - [ModelConfig](docs/ModelConfig.md) documents the model configuration file.
288
287
 
289
288
 
290
289
  Setup