rslearn 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/__init__.py +2 -10
- rslearn/config/dataset.py +414 -420
- rslearn/data_sources/__init__.py +8 -31
- rslearn/data_sources/aws_landsat.py +13 -24
- rslearn/data_sources/aws_open_data.py +21 -46
- rslearn/data_sources/aws_sentinel1.py +3 -14
- rslearn/data_sources/climate_data_store.py +21 -40
- rslearn/data_sources/copernicus.py +30 -91
- rslearn/data_sources/data_source.py +26 -0
- rslearn/data_sources/earthdaily.py +13 -38
- rslearn/data_sources/earthdata_srtm.py +14 -32
- rslearn/data_sources/eurocrops.py +5 -9
- rslearn/data_sources/gcp_public_data.py +46 -43
- rslearn/data_sources/google_earth_engine.py +31 -44
- rslearn/data_sources/local_files.py +91 -100
- rslearn/data_sources/openstreetmap.py +21 -51
- rslearn/data_sources/planet.py +12 -30
- rslearn/data_sources/planet_basemap.py +4 -25
- rslearn/data_sources/planetary_computer.py +58 -141
- rslearn/data_sources/usda_cdl.py +15 -26
- rslearn/data_sources/usgs_landsat.py +4 -29
- rslearn/data_sources/utils.py +9 -0
- rslearn/data_sources/worldcereal.py +47 -54
- rslearn/data_sources/worldcover.py +16 -14
- rslearn/data_sources/worldpop.py +15 -18
- rslearn/data_sources/xyz_tiles.py +11 -30
- rslearn/dataset/dataset.py +6 -6
- rslearn/dataset/manage.py +28 -26
- rslearn/dataset/materialize.py +9 -45
- rslearn/lightning_cli.py +370 -1
- rslearn/main.py +3 -3
- rslearn/models/clay/clay.py +14 -1
- rslearn/models/concatenate_features.py +93 -0
- rslearn/models/croma.py +26 -3
- rslearn/models/satlaspretrain.py +18 -4
- rslearn/models/terramind.py +19 -0
- rslearn/tile_stores/__init__.py +0 -11
- rslearn/train/dataset.py +4 -12
- rslearn/train/prediction_writer.py +16 -32
- rslearn/train/tasks/classification.py +2 -1
- rslearn/utils/fsspec.py +20 -0
- rslearn/utils/jsonargparse.py +79 -0
- rslearn/utils/raster_format.py +1 -41
- rslearn/utils/vector_format.py +1 -38
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/METADATA +1 -1
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/RECORD +51 -52
- rslearn/data_sources/geotiff.py +0 -1
- rslearn/data_sources/raster_source.py +0 -23
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/WHEEL +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/top_level.txt +0 -0
rslearn/models/satlaspretrain.py
CHANGED
|
@@ -4,15 +4,14 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import satlaspretrain_models
|
|
6
6
|
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
class SatlasPretrain(torch.nn.Module):
|
|
10
11
|
"""SatlasPretrain backbones."""
|
|
11
12
|
|
|
12
13
|
def __init__(
|
|
13
|
-
self,
|
|
14
|
-
model_identifier: str,
|
|
15
|
-
fpn: bool = False,
|
|
14
|
+
self, model_identifier: str, fpn: bool = False, resize_to_pretrain: bool = False
|
|
16
15
|
) -> None:
|
|
17
16
|
"""Instantiate a new SatlasPretrain instance.
|
|
18
17
|
|
|
@@ -21,6 +20,8 @@ class SatlasPretrain(torch.nn.Module):
|
|
|
21
20
|
https://github.com/allenai/satlaspretrain_models
|
|
22
21
|
fpn: whether to include the feature pyramid network, otherwise only the
|
|
23
22
|
Swin-v2-Transformer is used.
|
|
23
|
+
resize_to_pretrain: whether to resize inputs to the pretraining input
|
|
24
|
+
size (512 x 512)
|
|
24
25
|
"""
|
|
25
26
|
super().__init__()
|
|
26
27
|
weights_manager = satlaspretrain_models.Weights()
|
|
@@ -49,6 +50,19 @@ class SatlasPretrain(torch.nn.Module):
|
|
|
49
50
|
[16, 1024],
|
|
50
51
|
[32, 2048],
|
|
51
52
|
]
|
|
53
|
+
self.resize_to_pretrain = resize_to_pretrain
|
|
54
|
+
|
|
55
|
+
def maybe_resize(self, data: torch.Tensor) -> list[torch.Tensor]:
|
|
56
|
+
"""Resize to pretraining sizes if resize_to_pretrain == True."""
|
|
57
|
+
if self.resize_to_pretrain:
|
|
58
|
+
return F.interpolate(
|
|
59
|
+
data,
|
|
60
|
+
size=(512, 512),
|
|
61
|
+
mode="bilinear",
|
|
62
|
+
align_corners=False,
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
return data
|
|
52
66
|
|
|
53
67
|
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
54
68
|
"""Compute feature maps from the SatlasPretrain backbone.
|
|
@@ -58,7 +72,7 @@ class SatlasPretrain(torch.nn.Module):
|
|
|
58
72
|
process.
|
|
59
73
|
"""
|
|
60
74
|
images = torch.stack([inp["image"] for inp in inputs], dim=0)
|
|
61
|
-
return self.model(images)
|
|
75
|
+
return self.model(self.maybe_resize(images))
|
|
62
76
|
|
|
63
77
|
def get_backbone_channels(self) -> list:
|
|
64
78
|
"""Returns the output channels of this model when used as a backbone.
|
rslearn/models/terramind.py
CHANGED
|
@@ -4,6 +4,7 @@ from enum import Enum
|
|
|
4
4
|
from typing import Any
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
7
8
|
from einops import rearrange
|
|
8
9
|
from terratorch.registry import BACKBONE_REGISTRY
|
|
9
10
|
|
|
@@ -18,6 +19,8 @@ class TerramindSize(str, Enum):
|
|
|
18
19
|
LARGE = "large"
|
|
19
20
|
|
|
20
21
|
|
|
22
|
+
# Pretraining image size for Terramind
|
|
23
|
+
IMAGE_SIZE = 224
|
|
21
24
|
# Default patch size for Terramind
|
|
22
25
|
PATCH_SIZE = 16
|
|
23
26
|
|
|
@@ -89,12 +92,14 @@ class Terramind(torch.nn.Module):
|
|
|
89
92
|
self,
|
|
90
93
|
model_size: TerramindSize,
|
|
91
94
|
modalities: list[str] = ["S2L2A"],
|
|
95
|
+
do_resizing: bool = False,
|
|
92
96
|
) -> None:
|
|
93
97
|
"""Initialize the Terramind model.
|
|
94
98
|
|
|
95
99
|
Args:
|
|
96
100
|
model_size: The size of the Terramind model.
|
|
97
101
|
modalities: The modalities to use.
|
|
102
|
+
do_resizing: Whether to resize the input images to the pretraining resolution.
|
|
98
103
|
"""
|
|
99
104
|
super().__init__()
|
|
100
105
|
|
|
@@ -116,6 +121,7 @@ class Terramind(torch.nn.Module):
|
|
|
116
121
|
|
|
117
122
|
self.model_size = model_size
|
|
118
123
|
self.modalities = modalities
|
|
124
|
+
self.do_resizing = do_resizing
|
|
119
125
|
|
|
120
126
|
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
121
127
|
"""Forward pass for the Terramind model.
|
|
@@ -132,6 +138,19 @@ class Terramind(torch.nn.Module):
|
|
|
132
138
|
if modality not in inputs[0]:
|
|
133
139
|
continue
|
|
134
140
|
cur = torch.stack([inp[modality] for inp in inputs], dim=0) # (B, C, H, W)
|
|
141
|
+
if self.do_resizing and (
|
|
142
|
+
cur.shape[2] != IMAGE_SIZE or cur.shape[3] != IMAGE_SIZE
|
|
143
|
+
):
|
|
144
|
+
if cur.shape[2] == 1 and cur.shape[3] == 1:
|
|
145
|
+
new_height, new_width = PATCH_SIZE, PATCH_SIZE
|
|
146
|
+
else:
|
|
147
|
+
new_height, new_width = IMAGE_SIZE, IMAGE_SIZE
|
|
148
|
+
cur = F.interpolate(
|
|
149
|
+
cur,
|
|
150
|
+
size=(new_height, new_width),
|
|
151
|
+
mode="bilinear",
|
|
152
|
+
align_corners=False,
|
|
153
|
+
)
|
|
135
154
|
model_inputs[modality] = cur
|
|
136
155
|
|
|
137
156
|
# By default, the patch embeddings are averaged over all modalities to reduce output tokens
|
rslearn/tile_stores/__init__.py
CHANGED
|
@@ -22,17 +22,6 @@ def load_tile_store(config: dict[str, Any], ds_path: UPath) -> TileStore:
|
|
|
22
22
|
Returns:
|
|
23
23
|
the TileStore
|
|
24
24
|
"""
|
|
25
|
-
if config is None:
|
|
26
|
-
tile_store = DefaultTileStore()
|
|
27
|
-
tile_store.set_dataset_path(ds_path)
|
|
28
|
-
return tile_store
|
|
29
|
-
|
|
30
|
-
# Backwards compatability.
|
|
31
|
-
if "name" in config and "root_dir" in config and config["name"] == "file":
|
|
32
|
-
tile_store = DefaultTileStore(config["root_dir"])
|
|
33
|
-
tile_store.set_dataset_path(ds_path)
|
|
34
|
-
return tile_store
|
|
35
|
-
|
|
36
25
|
init_jsonargparse()
|
|
37
26
|
parser = jsonargparse.ArgumentParser()
|
|
38
27
|
parser.add_argument("--tile_store", type=TileStore)
|
rslearn/train/dataset.py
CHANGED
|
@@ -17,9 +17,7 @@ from rasterio.warp import Resampling
|
|
|
17
17
|
import rslearn.train.transforms.transform
|
|
18
18
|
from rslearn.config import (
|
|
19
19
|
DType,
|
|
20
|
-
|
|
21
|
-
RasterLayerConfig,
|
|
22
|
-
VectorLayerConfig,
|
|
20
|
+
LayerConfig,
|
|
23
21
|
)
|
|
24
22
|
from rslearn.dataset.dataset import Dataset
|
|
25
23
|
from rslearn.dataset.window import Window, get_layer_and_group_from_dir_name
|
|
@@ -28,8 +26,6 @@ from rslearn.train.tasks import Task
|
|
|
28
26
|
from rslearn.utils.feature import Feature
|
|
29
27
|
from rslearn.utils.geometry import PixelBounds
|
|
30
28
|
from rslearn.utils.mp import star_imap_unordered
|
|
31
|
-
from rslearn.utils.raster_format import load_raster_format
|
|
32
|
-
from rslearn.utils.vector_format import load_vector_format
|
|
33
29
|
|
|
34
30
|
from .transforms import Sequential
|
|
35
31
|
|
|
@@ -185,7 +181,7 @@ def read_raster_layer_for_data_input(
|
|
|
185
181
|
bounds: PixelBounds,
|
|
186
182
|
layer_name: str,
|
|
187
183
|
group_idx: int,
|
|
188
|
-
layer_config:
|
|
184
|
+
layer_config: LayerConfig,
|
|
189
185
|
data_input: DataInput,
|
|
190
186
|
) -> torch.Tensor:
|
|
191
187
|
"""Read a raster layer for a DataInput.
|
|
@@ -246,9 +242,7 @@ def read_raster_layer_for_data_input(
|
|
|
246
242
|
)
|
|
247
243
|
if band_set.format is None:
|
|
248
244
|
raise ValueError(f"No format specified for {layer_name}")
|
|
249
|
-
raster_format =
|
|
250
|
-
RasterFormatConfig(band_set.format["name"], band_set.format)
|
|
251
|
-
)
|
|
245
|
+
raster_format = band_set.instantiate_raster_format()
|
|
252
246
|
raster_dir = window.get_raster_dir(
|
|
253
247
|
layer_name, band_set.bands, group_idx=group_idx
|
|
254
248
|
)
|
|
@@ -349,7 +343,6 @@ def read_data_input(
|
|
|
349
343
|
images: list[torch.Tensor] = []
|
|
350
344
|
for layer_name, group_idx in layers_to_read:
|
|
351
345
|
layer_config = dataset.layers[layer_name]
|
|
352
|
-
assert isinstance(layer_config, RasterLayerConfig)
|
|
353
346
|
images.append(
|
|
354
347
|
read_raster_layer_for_data_input(
|
|
355
348
|
window, bounds, layer_name, group_idx, layer_config, data_input
|
|
@@ -363,8 +356,7 @@ def read_data_input(
|
|
|
363
356
|
features: list[Feature] = []
|
|
364
357
|
for layer_name, group_idx in layers_to_read:
|
|
365
358
|
layer_config = dataset.layers[layer_name]
|
|
366
|
-
|
|
367
|
-
vector_format = load_vector_format(layer_config.format)
|
|
359
|
+
vector_format = layer_config.instantiate_vector_format()
|
|
368
360
|
layer_dir = window.get_layer_dir(layer_name, group_idx=group_idx)
|
|
369
361
|
cur_features = vector_format.decode_vector(
|
|
370
362
|
layer_dir, window.projection, window.bounds
|
|
@@ -12,10 +12,8 @@ from lightning.pytorch.callbacks import BasePredictionWriter
|
|
|
12
12
|
from upath import UPath
|
|
13
13
|
|
|
14
14
|
from rslearn.config import (
|
|
15
|
+
LayerConfig,
|
|
15
16
|
LayerType,
|
|
16
|
-
RasterFormatConfig,
|
|
17
|
-
RasterLayerConfig,
|
|
18
|
-
VectorLayerConfig,
|
|
19
17
|
)
|
|
20
18
|
from rslearn.dataset import Dataset, Window
|
|
21
19
|
from rslearn.log_utils import get_logger
|
|
@@ -25,9 +23,8 @@ from rslearn.utils.geometry import PixelBounds
|
|
|
25
23
|
from rslearn.utils.raster_format import (
|
|
26
24
|
RasterFormat,
|
|
27
25
|
adjust_projection_and_bounds_for_array,
|
|
28
|
-
load_raster_format,
|
|
29
26
|
)
|
|
30
|
-
from rslearn.utils.vector_format import VectorFormat
|
|
27
|
+
from rslearn.utils.vector_format import VectorFormat
|
|
31
28
|
|
|
32
29
|
from .lightning_module import RslearnLightningModule
|
|
33
30
|
from .tasks.task import Task
|
|
@@ -150,7 +147,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
150
147
|
selector: list[str] | None = None,
|
|
151
148
|
merger: PatchPredictionMerger | None = None,
|
|
152
149
|
output_path: str | Path | None = None,
|
|
153
|
-
layer_config:
|
|
150
|
+
layer_config: LayerConfig | None = None,
|
|
154
151
|
):
|
|
155
152
|
"""Create a new RslearnWriter.
|
|
156
153
|
|
|
@@ -178,43 +175,31 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
178
175
|
)
|
|
179
176
|
|
|
180
177
|
# Handle dataset and layer config
|
|
181
|
-
self.layer_config:
|
|
178
|
+
self.layer_config: LayerConfig
|
|
182
179
|
if layer_config:
|
|
183
180
|
self.layer_config = layer_config
|
|
184
|
-
self.dataset = None if self.output_path else Dataset(self.path)
|
|
185
181
|
else:
|
|
186
|
-
|
|
187
|
-
if self.output_layer not in
|
|
182
|
+
dataset = Dataset(self.path)
|
|
183
|
+
if self.output_layer not in dataset.layers:
|
|
188
184
|
raise KeyError(
|
|
189
185
|
f"Output layer '{self.output_layer}' not found in dataset layers."
|
|
190
186
|
)
|
|
191
|
-
|
|
192
|
-
# Type narrowing to ensure compatibility
|
|
193
|
-
if isinstance(raw_layer_config, (RasterLayerConfig | VectorLayerConfig)):
|
|
194
|
-
self.layer_config = raw_layer_config
|
|
195
|
-
else:
|
|
196
|
-
raise ValueError(
|
|
197
|
-
f"Layer config must be RasterLayerConfig or VectorLayerConfig, got {type(raw_layer_config)}"
|
|
198
|
-
)
|
|
187
|
+
self.layer_config = dataset.layers[self.output_layer]
|
|
199
188
|
|
|
200
189
|
self.format: RasterFormat | VectorFormat
|
|
201
|
-
if self.layer_config.
|
|
202
|
-
assert isinstance(self.layer_config, RasterLayerConfig)
|
|
190
|
+
if self.layer_config.type == LayerType.RASTER:
|
|
203
191
|
band_cfg = self.layer_config.band_sets[0]
|
|
204
|
-
self.format =
|
|
205
|
-
|
|
206
|
-
)
|
|
207
|
-
elif self.layer_config.layer_type == LayerType.VECTOR:
|
|
208
|
-
assert isinstance(self.layer_config, VectorLayerConfig)
|
|
209
|
-
self.format = load_vector_format(self.layer_config.format)
|
|
192
|
+
self.format = band_cfg.instantiate_raster_format()
|
|
193
|
+
elif self.layer_config.type == LayerType.VECTOR:
|
|
194
|
+
self.format = self.layer_config.instantiate_vector_format()
|
|
210
195
|
else:
|
|
211
|
-
raise ValueError(f"invalid layer type {self.layer_config.
|
|
196
|
+
raise ValueError(f"invalid layer type {self.layer_config.type}")
|
|
212
197
|
|
|
213
198
|
if merger is not None:
|
|
214
199
|
self.merger = merger
|
|
215
|
-
elif self.layer_config.
|
|
200
|
+
elif self.layer_config.type == LayerType.RASTER:
|
|
216
201
|
self.merger = RasterMerger()
|
|
217
|
-
elif self.layer_config.
|
|
202
|
+
elif self.layer_config.type == LayerType.VECTOR:
|
|
218
203
|
self.merger = VectorMerger()
|
|
219
204
|
|
|
220
205
|
# Map from window name to pending data to write.
|
|
@@ -337,8 +322,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
337
322
|
logger.debug(f"Merging and writing for window {window.name}")
|
|
338
323
|
merged_output = self.merger.merge(window, pending_output)
|
|
339
324
|
|
|
340
|
-
if self.layer_config.
|
|
341
|
-
assert isinstance(self.layer_config, RasterLayerConfig)
|
|
325
|
+
if self.layer_config.type == LayerType.RASTER:
|
|
342
326
|
raster_dir = window.get_raster_dir(
|
|
343
327
|
self.output_layer, self.layer_config.band_sets[0].bands
|
|
344
328
|
)
|
|
@@ -351,7 +335,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
351
335
|
)
|
|
352
336
|
self.format.encode_raster(raster_dir, projection, bounds, merged_output)
|
|
353
337
|
|
|
354
|
-
elif self.layer_config.
|
|
338
|
+
elif self.layer_config.type == LayerType.VECTOR:
|
|
355
339
|
layer_dir = window.get_layer_dir(self.output_layer)
|
|
356
340
|
assert isinstance(self.format, VectorFormat)
|
|
357
341
|
self.format.encode_vector(layer_dir, merged_output)
|
|
@@ -26,7 +26,7 @@ class ClassificationTask(BasicTask):
|
|
|
26
26
|
def __init__(
|
|
27
27
|
self,
|
|
28
28
|
property_name: str,
|
|
29
|
-
classes: list,
|
|
29
|
+
classes: list[str],
|
|
30
30
|
filters: list[tuple[str, str]] = [],
|
|
31
31
|
read_class_id: bool = False,
|
|
32
32
|
allow_invalid: bool = False,
|
|
@@ -176,6 +176,7 @@ class ClassificationTask(BasicTask):
|
|
|
176
176
|
# For multiclass classification or when using the default threshold
|
|
177
177
|
class_idx = probs.argmax().item()
|
|
178
178
|
|
|
179
|
+
value: str | int
|
|
179
180
|
if not self.read_class_id:
|
|
180
181
|
value = self.classes[class_idx] # type: ignore
|
|
181
182
|
else:
|
rslearn/utils/fsspec.py
CHANGED
|
@@ -156,3 +156,23 @@ def open_rasterio_upath_writer(
|
|
|
156
156
|
with path.open("wb") as f:
|
|
157
157
|
with rasterio.open(f, "w", **kwargs) as raster:
|
|
158
158
|
yield raster
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def get_relative_suffix(base_dir: UPath, fname: UPath) -> str:
|
|
162
|
+
"""Get the suffix of fname relative to base_dir.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
base_dir: the base directory.
|
|
166
|
+
fname: a filename within the base directory.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
the suffix on base_dir that would yield the given filename.
|
|
170
|
+
"""
|
|
171
|
+
if not fname.path.startswith(base_dir.path):
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"filename {fname.path} must start with base directory {base_dir.path}"
|
|
174
|
+
)
|
|
175
|
+
suffix = fname.path[len(base_dir.path) :]
|
|
176
|
+
if suffix.startswith("/"):
|
|
177
|
+
suffix = suffix[1:]
|
|
178
|
+
return suffix
|
rslearn/utils/jsonargparse.py
CHANGED
|
@@ -1,7 +1,18 @@
|
|
|
1
1
|
"""Custom serialization for jsonargparse."""
|
|
2
2
|
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
3
6
|
import jsonargparse
|
|
4
7
|
from rasterio.crs import CRS
|
|
8
|
+
from upath import UPath
|
|
9
|
+
|
|
10
|
+
from rslearn.config.dataset import LayerConfig
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from rslearn.data_sources.data_source import DataSourceContext
|
|
14
|
+
|
|
15
|
+
INITIALIZED = False
|
|
5
16
|
|
|
6
17
|
|
|
7
18
|
def crs_serializer(v: CRS) -> str:
|
|
@@ -28,6 +39,74 @@ def crs_deserializer(v: str) -> CRS:
|
|
|
28
39
|
return CRS.from_string(v)
|
|
29
40
|
|
|
30
41
|
|
|
42
|
+
def datetime_serializer(v: datetime) -> str:
|
|
43
|
+
"""Serialize datetime for jsonargparse.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
v: the datetime object.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
the datetime encoded to string
|
|
50
|
+
"""
|
|
51
|
+
return v.isoformat()
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def datetime_deserializer(v: str) -> datetime:
|
|
55
|
+
"""Deserialize datetime for jsonargparse.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
v: the encoded datetime.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
the decoded datetime object
|
|
62
|
+
"""
|
|
63
|
+
return datetime.fromisoformat(v)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def data_source_context_serializer(v: "DataSourceContext") -> dict[str, Any]:
|
|
67
|
+
"""Serialize DataSourceContext for jsonargparse."""
|
|
68
|
+
x = {
|
|
69
|
+
"ds_path": (str(v.ds_path) if v.ds_path is not None else None),
|
|
70
|
+
"layer_config": (
|
|
71
|
+
v.layer_config.model_dump(mode="json")
|
|
72
|
+
if v.layer_config is not None
|
|
73
|
+
else None
|
|
74
|
+
),
|
|
75
|
+
}
|
|
76
|
+
return x
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def data_source_context_deserializer(v: dict[str, Any]) -> "DataSourceContext":
|
|
80
|
+
"""Deserialize DataSourceContext for jsonargparse."""
|
|
81
|
+
# We lazily import these to avoid cyclic dependency.
|
|
82
|
+
from rslearn.data_sources.data_source import DataSourceContext
|
|
83
|
+
|
|
84
|
+
return DataSourceContext(
|
|
85
|
+
ds_path=(UPath(v["ds_path"]) if v["ds_path"] is not None else None),
|
|
86
|
+
layer_config=(
|
|
87
|
+
LayerConfig.model_validate(v["layer_config"])
|
|
88
|
+
if v["layer_config"] is not None
|
|
89
|
+
else None
|
|
90
|
+
),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
31
94
|
def init_jsonargparse() -> None:
|
|
32
95
|
"""Initialize custom jsonargparse serializers."""
|
|
96
|
+
global INITIALIZED
|
|
97
|
+
if INITIALIZED:
|
|
98
|
+
return
|
|
33
99
|
jsonargparse.typing.register_type(CRS, crs_serializer, crs_deserializer)
|
|
100
|
+
jsonargparse.typing.register_type(
|
|
101
|
+
datetime, datetime_serializer, datetime_deserializer
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
from rslearn.data_sources.data_source import DataSourceContext
|
|
105
|
+
|
|
106
|
+
jsonargparse.typing.register_type(
|
|
107
|
+
DataSourceContext,
|
|
108
|
+
data_source_context_serializer,
|
|
109
|
+
data_source_context_deserializer,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
INITIALIZED = True
|
rslearn/utils/raster_format.py
CHANGED
|
@@ -2,8 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import hashlib
|
|
4
4
|
import json
|
|
5
|
-
from
|
|
6
|
-
from typing import Any, BinaryIO, TypeVar
|
|
5
|
+
from typing import Any, BinaryIO
|
|
7
6
|
|
|
8
7
|
import affine
|
|
9
8
|
import numpy as np
|
|
@@ -14,34 +13,12 @@ from rasterio.crs import CRS
|
|
|
14
13
|
from rasterio.enums import Resampling
|
|
15
14
|
from upath import UPath
|
|
16
15
|
|
|
17
|
-
from rslearn.config import RasterFormatConfig
|
|
18
16
|
from rslearn.const import TILE_SIZE
|
|
19
17
|
from rslearn.log_utils import get_logger
|
|
20
18
|
from rslearn.utils.fsspec import open_rasterio_upath_reader, open_rasterio_upath_writer
|
|
21
19
|
|
|
22
20
|
from .geometry import PixelBounds, Projection
|
|
23
21
|
|
|
24
|
-
_RasterFormatT = TypeVar("_RasterFormatT", bound="RasterFormat")
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class _RasterFormatRegistry(dict[str, type["RasterFormat"]]):
|
|
28
|
-
"""Registry for RasterFormat classes."""
|
|
29
|
-
|
|
30
|
-
def register(
|
|
31
|
-
self, name: str
|
|
32
|
-
) -> Callable[[type[_RasterFormatT]], type[_RasterFormatT]]:
|
|
33
|
-
"""Decorator to register a raster format class."""
|
|
34
|
-
|
|
35
|
-
def decorator(cls: type[_RasterFormatT]) -> type[_RasterFormatT]:
|
|
36
|
-
self[name] = cls
|
|
37
|
-
return cls
|
|
38
|
-
|
|
39
|
-
return decorator
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
RasterFormats = _RasterFormatRegistry()
|
|
43
|
-
|
|
44
|
-
|
|
45
22
|
logger = get_logger(__name__)
|
|
46
23
|
|
|
47
24
|
|
|
@@ -219,7 +196,6 @@ class RasterFormat:
|
|
|
219
196
|
raise NotImplementedError
|
|
220
197
|
|
|
221
198
|
|
|
222
|
-
@RasterFormats.register("image_tile")
|
|
223
199
|
class ImageTileRasterFormat(RasterFormat):
|
|
224
200
|
"""A RasterFormat that stores data in image tiles corresponding to grid cells.
|
|
225
201
|
|
|
@@ -468,7 +444,6 @@ class ImageTileRasterFormat(RasterFormat):
|
|
|
468
444
|
)
|
|
469
445
|
|
|
470
446
|
|
|
471
|
-
@RasterFormats.register("geotiff")
|
|
472
447
|
class GeotiffRasterFormat(RasterFormat):
|
|
473
448
|
"""A raster format that uses one big, tiled GeoTIFF with small block size."""
|
|
474
449
|
|
|
@@ -623,7 +598,6 @@ class GeotiffRasterFormat(RasterFormat):
|
|
|
623
598
|
return GeotiffRasterFormat(**kwargs)
|
|
624
599
|
|
|
625
600
|
|
|
626
|
-
@RasterFormats.register("single_image")
|
|
627
601
|
class SingleImageRasterFormat(RasterFormat):
|
|
628
602
|
"""A raster format that produces a single image called image.png/jpg.
|
|
629
603
|
|
|
@@ -775,17 +749,3 @@ class SingleImageRasterFormat(RasterFormat):
|
|
|
775
749
|
if "format" in config:
|
|
776
750
|
kwargs["format"] = config["format"]
|
|
777
751
|
return SingleImageRasterFormat(**kwargs)
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
def load_raster_format(config: RasterFormatConfig) -> RasterFormat:
|
|
781
|
-
"""Loads a RasterFormat from a RasterFormatConfig.
|
|
782
|
-
|
|
783
|
-
Args:
|
|
784
|
-
config: the RasterFormatConfig configuration object specifying the
|
|
785
|
-
RasterFormat.
|
|
786
|
-
|
|
787
|
-
Returns:
|
|
788
|
-
the loaded RasterFormat implementation
|
|
789
|
-
"""
|
|
790
|
-
cls = RasterFormats[config.name]
|
|
791
|
-
return cls.from_config(config.name, config.config_dict)
|
rslearn/utils/vector_format.py
CHANGED
|
@@ -1,15 +1,13 @@
|
|
|
1
1
|
"""Classes for writing vector data to a UPath."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
-
from collections.abc import Callable
|
|
5
4
|
from enum import Enum
|
|
6
|
-
from typing import Any
|
|
5
|
+
from typing import Any
|
|
7
6
|
|
|
8
7
|
import shapely
|
|
9
8
|
from rasterio.crs import CRS
|
|
10
9
|
from upath import UPath
|
|
11
10
|
|
|
12
|
-
from rslearn.config import VectorFormatConfig
|
|
13
11
|
from rslearn.const import WGS84_PROJECTION
|
|
14
12
|
from rslearn.log_utils import get_logger
|
|
15
13
|
from rslearn.utils.fsspec import open_atomic
|
|
@@ -18,25 +16,6 @@ from .feature import Feature
|
|
|
18
16
|
from .geometry import PixelBounds, Projection, STGeometry, safely_reproject_and_clip
|
|
19
17
|
|
|
20
18
|
logger = get_logger(__name__)
|
|
21
|
-
_VectorFormatT = TypeVar("_VectorFormatT", bound="VectorFormat")
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class _VectorFormatRegistry(dict[str, type["VectorFormat"]]):
|
|
25
|
-
"""Registry for VectorFormat classes."""
|
|
26
|
-
|
|
27
|
-
def register(
|
|
28
|
-
self, name: str
|
|
29
|
-
) -> Callable[[type[_VectorFormatT]], type[_VectorFormatT]]:
|
|
30
|
-
"""Decorator to register a vector format class."""
|
|
31
|
-
|
|
32
|
-
def decorator(cls: type[_VectorFormatT]) -> type[_VectorFormatT]:
|
|
33
|
-
self[name] = cls
|
|
34
|
-
return cls
|
|
35
|
-
|
|
36
|
-
return decorator
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
VectorFormats = _VectorFormatRegistry()
|
|
40
19
|
|
|
41
20
|
|
|
42
21
|
class VectorFormat:
|
|
@@ -85,7 +64,6 @@ class VectorFormat:
|
|
|
85
64
|
raise NotImplementedError
|
|
86
65
|
|
|
87
66
|
|
|
88
|
-
@VectorFormats.register("tile")
|
|
89
67
|
class TileVectorFormat(VectorFormat):
|
|
90
68
|
"""TileVectorFormat stores data in GeoJSON files corresponding to grid cells.
|
|
91
69
|
|
|
@@ -275,7 +253,6 @@ class GeojsonCoordinateMode(Enum):
|
|
|
275
253
|
WGS84 = "wgs84"
|
|
276
254
|
|
|
277
255
|
|
|
278
|
-
@VectorFormats.register("geojson")
|
|
279
256
|
class GeojsonVectorFormat(VectorFormat):
|
|
280
257
|
"""A vector format that uses one big GeoJSON."""
|
|
281
258
|
|
|
@@ -429,17 +406,3 @@ class GeojsonVectorFormat(VectorFormat):
|
|
|
429
406
|
if "coordinate_mode" in config:
|
|
430
407
|
kwargs["coordinate_mode"] = GeojsonCoordinateMode(config["coordinate_mode"])
|
|
431
408
|
return GeojsonVectorFormat(**kwargs)
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
def load_vector_format(config: VectorFormatConfig) -> VectorFormat:
|
|
435
|
-
"""Loads a VectorFormat from a VectorFormatConfig.
|
|
436
|
-
|
|
437
|
-
Args:
|
|
438
|
-
config: the VectorFormatConfig configuration object specifying the
|
|
439
|
-
VectorFormat.
|
|
440
|
-
|
|
441
|
-
Returns:
|
|
442
|
-
the loaded VectorFormat implementation
|
|
443
|
-
"""
|
|
444
|
-
cls = VectorFormats[config.name]
|
|
445
|
-
return cls.from_config(config.name, config.config_dict)
|