rslearn 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/data_sources/local_files.py +20 -3
- rslearn/data_sources/planetary_computer.py +79 -14
- rslearn/dataset/handler_summaries.py +130 -0
- rslearn/dataset/manage.py +159 -24
- rslearn/dataset/materialize.py +21 -2
- rslearn/dataset/remap.py +29 -4
- rslearn/main.py +60 -8
- rslearn/models/clay/clay.py +29 -14
- rslearn/models/copernicusfm.py +37 -25
- rslearn/models/dinov3.py +166 -0
- rslearn/models/galileo/galileo.py +58 -12
- rslearn/models/galileo/single_file_galileo.py +7 -1
- rslearn/models/presto/presto.py +11 -0
- rslearn/models/prithvi.py +139 -52
- rslearn/models/registry.py +19 -2
- rslearn/models/resize_features.py +45 -0
- rslearn/models/simple_time_series.py +65 -10
- rslearn/models/upsample.py +2 -2
- rslearn/tile_stores/default.py +34 -7
- rslearn/train/transforms/normalize.py +34 -5
- rslearn/train/transforms/select_bands.py +67 -0
- rslearn/train/transforms/sentinel1.py +60 -0
- rslearn/train/transforms/transform.py +23 -6
- rslearn/utils/raster_format.py +44 -5
- rslearn/utils/vector_format.py +35 -4
- {rslearn-0.0.7.dist-info → rslearn-0.0.9.dist-info}/METADATA +3 -4
- {rslearn-0.0.7.dist-info → rslearn-0.0.9.dist-info}/RECORD +31 -26
- {rslearn-0.0.7.dist-info → rslearn-0.0.9.dist-info}/WHEEL +0 -0
- {rslearn-0.0.7.dist-info → rslearn-0.0.9.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.7.dist-info → rslearn-0.0.9.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.7.dist-info → rslearn-0.0.9.dist-info}/top_level.txt +0 -0
rslearn/tile_stores/default.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Default TileStore implementation."""
|
|
2
2
|
|
|
3
|
+
import json
|
|
3
4
|
import math
|
|
4
5
|
import shutil
|
|
5
6
|
from typing import Any
|
|
@@ -35,6 +36,9 @@ from .tile_store import TileStore
|
|
|
35
36
|
# Special filename to indicate writing is done.
|
|
36
37
|
COMPLETED_FNAME = "completed"
|
|
37
38
|
|
|
39
|
+
# Special filename to store the bands that are present in a raster.
|
|
40
|
+
BANDS_FNAME = "bands.json"
|
|
41
|
+
|
|
38
42
|
|
|
39
43
|
class DefaultTileStore(TileStore):
|
|
40
44
|
"""Default TileStore implementation.
|
|
@@ -84,7 +88,7 @@ class DefaultTileStore(TileStore):
|
|
|
84
88
|
self.path = join_upath(ds_path, self.path_suffix)
|
|
85
89
|
|
|
86
90
|
def _get_raster_dir(
|
|
87
|
-
self, layer_name: str, item_name: str, bands: list[str]
|
|
91
|
+
self, layer_name: str, item_name: str, bands: list[str], write: bool = False
|
|
88
92
|
) -> UPath:
|
|
89
93
|
"""Get the directory where the specified raster is stored.
|
|
90
94
|
|
|
@@ -92,12 +96,21 @@ class DefaultTileStore(TileStore):
|
|
|
92
96
|
layer_name: the name of the dataset layer.
|
|
93
97
|
item_name: the name of the item from the data source.
|
|
94
98
|
bands: list of band names that are expected to be stored together.
|
|
99
|
+
write: whether to create the directory and write the bands to a file inside
|
|
100
|
+
the directory.
|
|
95
101
|
|
|
96
102
|
Returns:
|
|
97
103
|
the UPath directory where the raster should be stored.
|
|
98
104
|
"""
|
|
99
105
|
assert self.path is not None
|
|
100
|
-
|
|
106
|
+
dir_name = self.path / layer_name / item_name / get_bandset_dirname(bands)
|
|
107
|
+
|
|
108
|
+
if write:
|
|
109
|
+
dir_name.mkdir(parents=True, exist_ok=True)
|
|
110
|
+
with (dir_name / BANDS_FNAME).open("w") as f:
|
|
111
|
+
json.dump(bands, f)
|
|
112
|
+
|
|
113
|
+
return dir_name
|
|
101
114
|
|
|
102
115
|
def _get_raster_fname(
|
|
103
116
|
self, layer_name: str, item_name: str, bands: list[str]
|
|
@@ -117,10 +130,12 @@ class DefaultTileStore(TileStore):
|
|
|
117
130
|
"""
|
|
118
131
|
raster_dir = self._get_raster_dir(layer_name, item_name, bands)
|
|
119
132
|
for fname in raster_dir.iterdir():
|
|
120
|
-
# Ignore completed sentinel files as well as temporary files created by
|
|
133
|
+
# Ignore completed sentinel files, bands files, as well as temporary files created by
|
|
121
134
|
# open_atomic (in case this tile store is on local filesystem).
|
|
122
135
|
if fname.name == COMPLETED_FNAME:
|
|
123
136
|
continue
|
|
137
|
+
if fname.name == BANDS_FNAME:
|
|
138
|
+
continue
|
|
124
139
|
if ".tmp." in fname.name:
|
|
125
140
|
continue
|
|
126
141
|
return fname
|
|
@@ -161,8 +176,20 @@ class DefaultTileStore(TileStore):
|
|
|
161
176
|
|
|
162
177
|
bands: list[list[str]] = []
|
|
163
178
|
for raster_dir in item_dir.iterdir():
|
|
164
|
-
|
|
165
|
-
|
|
179
|
+
if not (raster_dir / BANDS_FNAME).exists():
|
|
180
|
+
# This is likely a legacy directory where the bands are only encoded in
|
|
181
|
+
# the directory name, so we have to rely on that.
|
|
182
|
+
parts = raster_dir.name.split("_")
|
|
183
|
+
bands.append(parts)
|
|
184
|
+
continue
|
|
185
|
+
|
|
186
|
+
# We use the BANDS_FNAME here -- although it is slower to read the file, it
|
|
187
|
+
# is more reliable since sometimes the directory name is a hash of the
|
|
188
|
+
# bands in case there are too many bands (filename too long) or some bands
|
|
189
|
+
# contain the underscore character.
|
|
190
|
+
with (raster_dir / BANDS_FNAME).open() as f:
|
|
191
|
+
bands.append(json.load(f))
|
|
192
|
+
|
|
166
193
|
return bands
|
|
167
194
|
|
|
168
195
|
def get_raster_bounds(
|
|
@@ -248,7 +275,7 @@ class DefaultTileStore(TileStore):
|
|
|
248
275
|
bounds: the bounds of the array.
|
|
249
276
|
array: the raster data.
|
|
250
277
|
"""
|
|
251
|
-
raster_dir = self._get_raster_dir(layer_name, item_name, bands)
|
|
278
|
+
raster_dir = self._get_raster_dir(layer_name, item_name, bands, write=True)
|
|
252
279
|
raster_format = GeotiffRasterFormat(geotiff_options=self.geotiff_options)
|
|
253
280
|
raster_format.encode_raster(raster_dir, projection, bounds, array)
|
|
254
281
|
(raster_dir / COMPLETED_FNAME).touch()
|
|
@@ -264,7 +291,7 @@ class DefaultTileStore(TileStore):
|
|
|
264
291
|
bands: the list of bands in the array.
|
|
265
292
|
fname: the raster file, which must be readable by rasterio.
|
|
266
293
|
"""
|
|
267
|
-
raster_dir = self._get_raster_dir(layer_name, item_name, bands)
|
|
294
|
+
raster_dir = self._get_raster_dir(layer_name, item_name, bands, write=True)
|
|
268
295
|
raster_dir.mkdir(parents=True, exist_ok=True)
|
|
269
296
|
|
|
270
297
|
if self.convert_rasters_to_cogs:
|
|
@@ -27,14 +27,18 @@ class Normalize(Transform):
|
|
|
27
27
|
|
|
28
28
|
Args:
|
|
29
29
|
mean: a single value or one mean per channel
|
|
30
|
-
std: a single value or one std per channel
|
|
30
|
+
std: a single value or one std per channel (must match the shape of mean)
|
|
31
31
|
valid_range: optionally clip to a minimum and maximum value
|
|
32
32
|
selectors: image items to transform
|
|
33
|
-
bands: optionally restrict the normalization to these
|
|
33
|
+
bands: optionally restrict the normalization to these band indices. If set,
|
|
34
|
+
mean and std must either be one value, or have length equal to the
|
|
35
|
+
number of band indices passed here.
|
|
34
36
|
num_bands: the number of bands per image, to distinguish different images
|
|
35
37
|
in a time series. If set, then the bands list is repeated for each
|
|
36
38
|
image, e.g. if bands=[2] then we apply normalization on images[2],
|
|
37
|
-
images[2+num_bands], images[2+num_bands*2], etc.
|
|
39
|
+
images[2+num_bands], images[2+num_bands*2], etc. Or if the bands list
|
|
40
|
+
is not set, then we apply the mean and std on each image in the time
|
|
41
|
+
series.
|
|
38
42
|
"""
|
|
39
43
|
super().__init__()
|
|
40
44
|
self.mean = torch.tensor(mean)
|
|
@@ -57,6 +61,23 @@ class Normalize(Transform):
|
|
|
57
61
|
Args:
|
|
58
62
|
image: the image to transform.
|
|
59
63
|
"""
|
|
64
|
+
|
|
65
|
+
def _repeat_mean_and_std(
|
|
66
|
+
image_channels: int, num_bands: int | None
|
|
67
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
68
|
+
"""Get mean and std tensor that are suitable for applying on the image."""
|
|
69
|
+
# We only need to repeat the tensor if both of these are true:
|
|
70
|
+
# - The mean/std are not just one scalar.
|
|
71
|
+
# - self.num_bands is set, otherwise we treat the input as a single image.
|
|
72
|
+
if len(self.mean.shape) == 0:
|
|
73
|
+
return self.mean, self.std
|
|
74
|
+
if num_bands is None:
|
|
75
|
+
return self.mean, self.std
|
|
76
|
+
num_images = image_channels // num_bands
|
|
77
|
+
return self.mean.repeat(num_images)[:, None, None], self.std.repeat(
|
|
78
|
+
num_images
|
|
79
|
+
)[:, None, None]
|
|
80
|
+
|
|
60
81
|
if self.bands is not None:
|
|
61
82
|
# User has provided band indices to normalize.
|
|
62
83
|
# If num_bands is set, then we repeat these for each image in the input
|
|
@@ -72,13 +93,21 @@ class Normalize(Transform):
|
|
|
72
93
|
dim=0,
|
|
73
94
|
)
|
|
74
95
|
|
|
75
|
-
|
|
96
|
+
# We use len(self.bands) here because that is how many bands per timestep
|
|
97
|
+
# we are actually processing with the mean/std.
|
|
98
|
+
mean, std = _repeat_mean_and_std(
|
|
99
|
+
image_channels=len(band_indices), num_bands=len(self.bands)
|
|
100
|
+
)
|
|
101
|
+
image[band_indices] = (image[band_indices] - mean) / std
|
|
76
102
|
if self.valid_min is not None:
|
|
77
103
|
image[band_indices] = torch.clamp(
|
|
78
104
|
image[band_indices], min=self.valid_min, max=self.valid_max
|
|
79
105
|
)
|
|
80
106
|
else:
|
|
81
|
-
|
|
107
|
+
mean, std = _repeat_mean_and_std(
|
|
108
|
+
image_channels=image.shape[0], num_bands=self.num_bands
|
|
109
|
+
)
|
|
110
|
+
image = (image - mean) / std
|
|
82
111
|
if self.valid_min is not None:
|
|
83
112
|
image = torch.clamp(image, min=self.valid_min, max=self.valid_max)
|
|
84
113
|
return image
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""The SelectBands transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from .transform import Transform, read_selector, write_selector
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SelectBands(Transform):
|
|
9
|
+
"""Select a subset of bands from an image."""
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
band_indices: list[int],
|
|
14
|
+
input_selector: str = "image",
|
|
15
|
+
output_selector: str = "image",
|
|
16
|
+
num_bands_per_timestep: int | None = None,
|
|
17
|
+
):
|
|
18
|
+
"""Initialize a new Concatenate.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
band_indices: the bands to select.
|
|
22
|
+
input_selector: the selector to read the input image.
|
|
23
|
+
output_selector: the output selector under which to save the output image.
|
|
24
|
+
num_bands_per_timestep: the number of bands per image, to distinguish
|
|
25
|
+
between stacked images in an image time series. If set, then the
|
|
26
|
+
band_indices are selected for each image in the time series.
|
|
27
|
+
"""
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.input_selector = input_selector
|
|
30
|
+
self.output_selector = output_selector
|
|
31
|
+
self.band_indices = band_indices
|
|
32
|
+
self.num_bands_per_timestep = num_bands_per_timestep
|
|
33
|
+
|
|
34
|
+
def forward(
|
|
35
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
36
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
37
|
+
"""Apply concatenation over the inputs and targets.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
input_dict: the input
|
|
41
|
+
target_dict: the target
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
normalized (input_dicts, target_dicts) tuple
|
|
45
|
+
"""
|
|
46
|
+
image = read_selector(input_dict, target_dict, self.input_selector)
|
|
47
|
+
num_bands_per_timestep = (
|
|
48
|
+
self.num_bands_per_timestep
|
|
49
|
+
if self.num_bands_per_timestep is not None
|
|
50
|
+
else image.shape[0]
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
if image.shape[0] % num_bands_per_timestep != 0:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
f"channel dimension {image.shape[0]} is not multiple of bands per timestep {num_bands_per_timestep}"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Copy the band indices for each timestep in the input.
|
|
59
|
+
wanted_bands: list[int] = []
|
|
60
|
+
for start_channel_idx in range(0, image.shape[0], num_bands_per_timestep):
|
|
61
|
+
wanted_bands.extend(
|
|
62
|
+
[(start_channel_idx + band_idx) for band_idx in self.band_indices]
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
result = image[wanted_bands]
|
|
66
|
+
write_selector(input_dict, target_dict, self.output_selector, result)
|
|
67
|
+
return input_dict, target_dict
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Transforms related to Sentinel-1 data."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from .transform import Transform
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Sentinel1ToDecibels(Transform):
|
|
11
|
+
"""Convert Sentinel-1 data from raw intensity to or from decibels."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
selectors: list[str] = ["image"],
|
|
16
|
+
from_decibels: bool = False,
|
|
17
|
+
epsilon: float = 1e-6,
|
|
18
|
+
):
|
|
19
|
+
"""Initialize a new Sentinel1ToDecibels.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
selectors: the input selectors to apply the transform on.
|
|
23
|
+
from_decibels: convert from decibels to intensities instead of intensity to
|
|
24
|
+
decibels.
|
|
25
|
+
epsilon: when converting to decibels, clip the intensities to this minimum
|
|
26
|
+
value to avoid log issues. This is mostly to avoid pixels that have no
|
|
27
|
+
data with no data value being 0.
|
|
28
|
+
"""
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.selectors = selectors
|
|
31
|
+
self.from_decibels = from_decibels
|
|
32
|
+
self.epsilon = epsilon
|
|
33
|
+
|
|
34
|
+
def apply_image(self, image: torch.Tensor) -> torch.Tensor:
|
|
35
|
+
"""Normalize the specified image.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
image: the image to transform.
|
|
39
|
+
"""
|
|
40
|
+
if self.from_decibels:
|
|
41
|
+
# Decibels to linear scale.
|
|
42
|
+
return torch.pow(10.0, image / 10.0)
|
|
43
|
+
else:
|
|
44
|
+
# Linear scale to decibels.
|
|
45
|
+
return 10 * torch.log10(torch.clamp(image, min=self.epsilon))
|
|
46
|
+
|
|
47
|
+
def forward(
|
|
48
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
49
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
50
|
+
"""Apply normalization over the inputs and targets.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
input_dict: the input
|
|
54
|
+
target_dict: the target
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
normalized (input_dicts, target_dicts) tuple
|
|
58
|
+
"""
|
|
59
|
+
self.apply_fn(self.apply_image, input_dict, target_dict, self.selectors)
|
|
60
|
+
return input_dict, target_dict
|
|
@@ -54,7 +54,7 @@ def read_selector(
|
|
|
54
54
|
the item specified by the selector
|
|
55
55
|
"""
|
|
56
56
|
d, selector = get_dict_and_subselector(input_dict, target_dict, selector)
|
|
57
|
-
parts = selector.split("/")
|
|
57
|
+
parts = selector.split("/") if selector else []
|
|
58
58
|
cur = d
|
|
59
59
|
for part in parts:
|
|
60
60
|
cur = cur[part]
|
|
@@ -76,11 +76,28 @@ def write_selector(
|
|
|
76
76
|
v: the value to write
|
|
77
77
|
"""
|
|
78
78
|
d, selector = get_dict_and_subselector(input_dict, target_dict, selector)
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
79
|
+
if selector:
|
|
80
|
+
parts = selector.split("/")
|
|
81
|
+
cur = d
|
|
82
|
+
for part in parts[:-1]:
|
|
83
|
+
cur = cur[part]
|
|
84
|
+
cur[parts[-1]] = v
|
|
85
|
+
else:
|
|
86
|
+
# If the selector references the input or target dictionary directly, then we
|
|
87
|
+
# have a special case where instead of overwriting with v, we replace the keys
|
|
88
|
+
# with those in v. v must be a dictionary here, not a tensor, since otherwise
|
|
89
|
+
# it wouldn't match the type of the input or target dictionary.
|
|
90
|
+
if not isinstance(v, dict):
|
|
91
|
+
raise ValueError(
|
|
92
|
+
"when directly specifying the input or target dict, expected the value to be a dict"
|
|
93
|
+
)
|
|
94
|
+
if d == v:
|
|
95
|
+
# This may happen if the writer did not make a copy of the dictionary. In
|
|
96
|
+
# this case the code below would not update d correctly since it would also
|
|
97
|
+
# clear v.
|
|
98
|
+
return
|
|
99
|
+
d.clear()
|
|
100
|
+
d.update(v)
|
|
84
101
|
|
|
85
102
|
|
|
86
103
|
class Transform(torch.nn.Module):
|
rslearn/utils/raster_format.py
CHANGED
|
@@ -2,13 +2,13 @@
|
|
|
2
2
|
|
|
3
3
|
import hashlib
|
|
4
4
|
import json
|
|
5
|
-
from
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import Any, BinaryIO, TypeVar
|
|
6
7
|
|
|
7
8
|
import affine
|
|
8
9
|
import numpy as np
|
|
9
10
|
import numpy.typing as npt
|
|
10
11
|
import rasterio
|
|
11
|
-
from class_registry import ClassRegistry
|
|
12
12
|
from PIL import Image
|
|
13
13
|
from rasterio.crs import CRS
|
|
14
14
|
from rasterio.enums import Resampling
|
|
@@ -21,18 +21,44 @@ from rslearn.utils.fsspec import open_rasterio_upath_reader, open_rasterio_upath
|
|
|
21
21
|
|
|
22
22
|
from .geometry import PixelBounds, Projection
|
|
23
23
|
|
|
24
|
-
|
|
24
|
+
_RasterFormatT = TypeVar("_RasterFormatT", bound="RasterFormat")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class _RasterFormatRegistry(dict[str, type["RasterFormat"]]):
|
|
28
|
+
"""Registry for RasterFormat classes."""
|
|
29
|
+
|
|
30
|
+
def register(
|
|
31
|
+
self, name: str
|
|
32
|
+
) -> Callable[[type[_RasterFormatT]], type[_RasterFormatT]]:
|
|
33
|
+
"""Decorator to register a raster format class."""
|
|
34
|
+
|
|
35
|
+
def decorator(cls: type[_RasterFormatT]) -> type[_RasterFormatT]:
|
|
36
|
+
self[name] = cls
|
|
37
|
+
return cls
|
|
38
|
+
|
|
39
|
+
return decorator
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
RasterFormats = _RasterFormatRegistry()
|
|
43
|
+
|
|
44
|
+
|
|
25
45
|
logger = get_logger(__name__)
|
|
26
46
|
|
|
27
47
|
|
|
28
48
|
def get_bandset_dirname(bands: list[str]) -> str:
|
|
29
49
|
"""Get the directory name that should be used to store the given group of bands."""
|
|
50
|
+
# We try to use a human-readable name with underscore as the delimiter, but if that
|
|
51
|
+
# isn't straightforward then we use hash instead.
|
|
30
52
|
if any(["_" in band for band in bands]):
|
|
31
|
-
|
|
53
|
+
# In this case we hash the JSON representation of the bands.
|
|
54
|
+
return hashlib.sha256(json.dumps(bands).encode()).hexdigest()
|
|
32
55
|
dirname = "_".join(bands)
|
|
33
56
|
if len(dirname) > 64:
|
|
34
57
|
# Previously we simply joined the bands, but this can result in directory name
|
|
35
58
|
# that is too long. In this case, now we use hash instead.
|
|
59
|
+
# We use a different code path here where we hash the initial directory name
|
|
60
|
+
# instead of the JSON, for historical reasons (to maintain backwards
|
|
61
|
+
# compatibility).
|
|
36
62
|
dirname = hashlib.sha256(dirname.encode()).hexdigest()
|
|
37
63
|
return dirname
|
|
38
64
|
|
|
@@ -141,6 +167,19 @@ class RasterFormat:
|
|
|
141
167
|
"""
|
|
142
168
|
raise NotImplementedError
|
|
143
169
|
|
|
170
|
+
@staticmethod
|
|
171
|
+
def from_config(name: str, config: dict[str, Any]) -> "RasterFormat":
|
|
172
|
+
"""Create a RasterFormat from a config dict.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
name: the name of this format
|
|
176
|
+
config: the config dict
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
the RasterFormat instance
|
|
180
|
+
"""
|
|
181
|
+
raise NotImplementedError
|
|
182
|
+
|
|
144
183
|
|
|
145
184
|
@RasterFormats.register("image_tile")
|
|
146
185
|
class ImageTileRasterFormat(RasterFormat):
|
|
@@ -710,5 +749,5 @@ def load_raster_format(config: RasterFormatConfig) -> RasterFormat:
|
|
|
710
749
|
Returns:
|
|
711
750
|
the loaded RasterFormat implementation
|
|
712
751
|
"""
|
|
713
|
-
cls = RasterFormats
|
|
752
|
+
cls = RasterFormats[config.name]
|
|
714
753
|
return cls.from_config(config.name, config.config_dict)
|
rslearn/utils/vector_format.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
"""Classes for writing vector data to a UPath."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
from collections.abc import Callable
|
|
4
5
|
from enum import Enum
|
|
5
|
-
from typing import Any
|
|
6
|
+
from typing import Any, TypeVar
|
|
6
7
|
|
|
7
8
|
import shapely
|
|
8
|
-
from class_registry import ClassRegistry
|
|
9
9
|
from rasterio.crs import CRS
|
|
10
10
|
from upath import UPath
|
|
11
11
|
|
|
@@ -18,7 +18,25 @@ from .feature import Feature
|
|
|
18
18
|
from .geometry import PixelBounds, Projection, STGeometry, safely_reproject_and_clip
|
|
19
19
|
|
|
20
20
|
logger = get_logger(__name__)
|
|
21
|
-
|
|
21
|
+
_VectorFormatT = TypeVar("_VectorFormatT", bound="VectorFormat")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class _VectorFormatRegistry(dict[str, type["VectorFormat"]]):
|
|
25
|
+
"""Registry for VectorFormat classes."""
|
|
26
|
+
|
|
27
|
+
def register(
|
|
28
|
+
self, name: str
|
|
29
|
+
) -> Callable[[type[_VectorFormatT]], type[_VectorFormatT]]:
|
|
30
|
+
"""Decorator to register a vector format class."""
|
|
31
|
+
|
|
32
|
+
def decorator(cls: type[_VectorFormatT]) -> type[_VectorFormatT]:
|
|
33
|
+
self[name] = cls
|
|
34
|
+
return cls
|
|
35
|
+
|
|
36
|
+
return decorator
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
VectorFormats = _VectorFormatRegistry()
|
|
22
40
|
|
|
23
41
|
|
|
24
42
|
class VectorFormat:
|
|
@@ -53,6 +71,19 @@ class VectorFormat:
|
|
|
53
71
|
"""
|
|
54
72
|
raise NotImplementedError
|
|
55
73
|
|
|
74
|
+
@staticmethod
|
|
75
|
+
def from_config(name: str, config: dict[str, Any]) -> "VectorFormat":
|
|
76
|
+
"""Create a VectorFormat from a config dict.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
name: the name of this format
|
|
80
|
+
config: the config dict
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
the VectorFormat instance
|
|
84
|
+
"""
|
|
85
|
+
raise NotImplementedError
|
|
86
|
+
|
|
56
87
|
|
|
57
88
|
@VectorFormats.register("tile")
|
|
58
89
|
class TileVectorFormat(VectorFormat):
|
|
@@ -410,5 +441,5 @@ def load_vector_format(config: VectorFormatConfig) -> VectorFormat:
|
|
|
410
441
|
Returns:
|
|
411
442
|
the loaded VectorFormat implementation
|
|
412
443
|
"""
|
|
413
|
-
cls = VectorFormats
|
|
444
|
+
cls = VectorFormats[config.name]
|
|
414
445
|
return cls.from_config(config.name, config.config_dict)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rslearn
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.9
|
|
4
4
|
Summary: A library for developing remote sensing datasets and models
|
|
5
5
|
Author: OlmoEarth Team
|
|
6
6
|
License: Apache License
|
|
@@ -212,7 +212,6 @@ Requires-Python: >=3.11
|
|
|
212
212
|
Description-Content-Type: text/markdown
|
|
213
213
|
License-File: LICENSE
|
|
214
214
|
Requires-Dist: boto3>=1.39
|
|
215
|
-
Requires-Dist: class_registry>=2.1
|
|
216
215
|
Requires-Dist: fiona>=1.10
|
|
217
216
|
Requires-Dist: fsspec>=2025.9.0
|
|
218
217
|
Requires-Dist: jsonargparse>=4.35.0
|
|
@@ -233,7 +232,7 @@ Requires-Dist: cdsapi>=0.7.6; extra == "extra"
|
|
|
233
232
|
Requires-Dist: earthdaily[platform]>=1.0.7; extra == "extra"
|
|
234
233
|
Requires-Dist: earthengine-api>=1.6.3; extra == "extra"
|
|
235
234
|
Requires-Dist: einops>=0.8; extra == "extra"
|
|
236
|
-
Requires-Dist:
|
|
235
|
+
Requires-Dist: fsspec[gcs,s3]; extra == "extra"
|
|
237
236
|
Requires-Dist: google-cloud-bigquery>=3.35; extra == "extra"
|
|
238
237
|
Requires-Dist: google-cloud-storage>=2.18; extra == "extra"
|
|
239
238
|
Requires-Dist: huggingface_hub>=0.34.4; extra == "extra"
|
|
@@ -244,7 +243,6 @@ Requires-Dist: planetary_computer>=1.0; extra == "extra"
|
|
|
244
243
|
Requires-Dist: pycocotools>=2.0; extra == "extra"
|
|
245
244
|
Requires-Dist: pystac_client>=0.9; extra == "extra"
|
|
246
245
|
Requires-Dist: rtree>=1.4; extra == "extra"
|
|
247
|
-
Requires-Dist: s3fs>=2025.9.0; extra == "extra"
|
|
248
246
|
Requires-Dist: satlaspretrain_models>=0.3; extra == "extra"
|
|
249
247
|
Requires-Dist: scipy>=1.16; extra == "extra"
|
|
250
248
|
Requires-Dist: terratorch>=1.0.2; extra == "extra"
|
|
@@ -285,6 +283,7 @@ Quick links:
|
|
|
285
283
|
- [Examples](docs/Examples.md) contains more examples, including customizing different
|
|
286
284
|
stages of rslearn with additional code.
|
|
287
285
|
- [DatasetConfig](docs/DatasetConfig.md) documents the dataset configuration file.
|
|
286
|
+
- [ModelConfig](docs/ModelConfig.md) documents the model configuration file.
|
|
288
287
|
|
|
289
288
|
|
|
290
289
|
Setup
|