rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
rslearn/dataset/remap.py
CHANGED
|
@@ -1,18 +1,42 @@
|
|
|
1
1
|
"""Classes to remap raster values."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any, TypeVar
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
import numpy.typing as npt
|
|
7
|
-
from class_registry import ClassRegistry
|
|
8
8
|
|
|
9
|
-
|
|
9
|
+
_RemapperT = TypeVar("_RemapperT", bound="Remapper")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class _RemapperRegistry(dict[str, type["Remapper"]]):
|
|
13
|
+
"""Registry for Remapper classes."""
|
|
14
|
+
|
|
15
|
+
def register(self, name: str) -> Callable[[type[_RemapperT]], type[_RemapperT]]:
|
|
16
|
+
"""Decorator to register a remapper class."""
|
|
17
|
+
|
|
18
|
+
def decorator(cls: type[_RemapperT]) -> type[_RemapperT]:
|
|
19
|
+
self[name] = cls
|
|
20
|
+
return cls
|
|
21
|
+
|
|
22
|
+
return decorator
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
Remappers = _RemapperRegistry()
|
|
10
26
|
"""Registry of Remapper implementations."""
|
|
11
27
|
|
|
12
28
|
|
|
13
29
|
class Remapper:
|
|
14
30
|
"""An abstract class that remaps pixel values based on layer configuration."""
|
|
15
31
|
|
|
32
|
+
def __init__(self, config: dict[str, Any]) -> None:
|
|
33
|
+
"""Initialize a Remapper.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
config: the config dict for this remapper.
|
|
37
|
+
"""
|
|
38
|
+
pass
|
|
39
|
+
|
|
16
40
|
def __call__(
|
|
17
41
|
self, array: npt.NDArray[Any], dtype: npt.DTypeLike
|
|
18
42
|
) -> npt.NDArray[Any]:
|
|
@@ -67,4 +91,5 @@ class LinearRemapper(Remapper):
|
|
|
67
91
|
|
|
68
92
|
def load_remapper(config: dict[str, Any]) -> Remapper:
|
|
69
93
|
"""Load a remapper from a configuration dictionary."""
|
|
70
|
-
|
|
94
|
+
cls = Remappers[config["name"]]
|
|
95
|
+
return cls(config)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Storage backends for rslearn window metadata."""
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
"""The default file-based window storage backend."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import multiprocessing
|
|
5
|
+
|
|
6
|
+
import tqdm
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
from upath import UPath
|
|
9
|
+
|
|
10
|
+
from rslearn.dataset.window import (
|
|
11
|
+
LAYERS_DIRECTORY_NAME,
|
|
12
|
+
Window,
|
|
13
|
+
WindowLayerData,
|
|
14
|
+
get_layer_and_group_from_dir_name,
|
|
15
|
+
get_window_layer_dir,
|
|
16
|
+
)
|
|
17
|
+
from rslearn.log_utils import get_logger
|
|
18
|
+
from rslearn.utils.fsspec import open_atomic
|
|
19
|
+
from rslearn.utils.mp import star_imap_unordered
|
|
20
|
+
|
|
21
|
+
from .storage import WindowStorage, WindowStorageFactory
|
|
22
|
+
|
|
23
|
+
logger = get_logger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def load_window(storage: "FileWindowStorage", window_dir: UPath) -> Window:
|
|
27
|
+
"""Load the window from its directory by reading metadata.json.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
storage: the underlying FileWindowStorage.
|
|
31
|
+
window_dir: the path where the window is stored.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
the window object.
|
|
35
|
+
"""
|
|
36
|
+
metadata_fname = window_dir / "metadata.json"
|
|
37
|
+
with metadata_fname.open() as f:
|
|
38
|
+
metadata = json.load(f)
|
|
39
|
+
return Window.from_metadata(storage, metadata)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class FileWindowStorage(WindowStorage):
|
|
43
|
+
"""The default file-backed window storage."""
|
|
44
|
+
|
|
45
|
+
def __init__(self, path: UPath):
|
|
46
|
+
"""Create a new FileWindowStorage.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
path: the path to the dataset.
|
|
50
|
+
"""
|
|
51
|
+
self.path = path
|
|
52
|
+
|
|
53
|
+
@override
|
|
54
|
+
def get_window_root(self, group: str, name: str) -> UPath:
|
|
55
|
+
return Window.get_window_root(self.path, group, name)
|
|
56
|
+
|
|
57
|
+
@override
|
|
58
|
+
def get_windows(
|
|
59
|
+
self,
|
|
60
|
+
groups: list[str] | None = None,
|
|
61
|
+
names: list[str] | None = None,
|
|
62
|
+
show_progress: bool = False,
|
|
63
|
+
workers: int = 0,
|
|
64
|
+
) -> list["Window"]:
|
|
65
|
+
"""Load the windows in the dataset.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
groups: an optional list of groups to filter loading
|
|
69
|
+
names: an optional list of window names to filter loading
|
|
70
|
+
show_progress: whether to show tqdm progress bar
|
|
71
|
+
workers: number of parallel workers, default 0 (use main thread only to load windows)
|
|
72
|
+
"""
|
|
73
|
+
# Avoid directory does not exist errors later.
|
|
74
|
+
if not (self.path / "windows").exists():
|
|
75
|
+
return []
|
|
76
|
+
|
|
77
|
+
window_dirs = []
|
|
78
|
+
if not groups:
|
|
79
|
+
groups = []
|
|
80
|
+
for p in (self.path / "windows").iterdir():
|
|
81
|
+
groups.append(p.name)
|
|
82
|
+
for group in groups:
|
|
83
|
+
group_dir = self.path / "windows" / group
|
|
84
|
+
if not group_dir.exists():
|
|
85
|
+
logger.warning(
|
|
86
|
+
f"Skipping group directory {group_dir} since it does not exist"
|
|
87
|
+
)
|
|
88
|
+
continue
|
|
89
|
+
if names:
|
|
90
|
+
cur_names = names
|
|
91
|
+
else:
|
|
92
|
+
cur_names = []
|
|
93
|
+
for p in group_dir.iterdir():
|
|
94
|
+
cur_names.append(p.name)
|
|
95
|
+
|
|
96
|
+
for window_name in cur_names:
|
|
97
|
+
window_dir = group_dir / window_name
|
|
98
|
+
window_dirs.append(window_dir)
|
|
99
|
+
|
|
100
|
+
if workers == 0:
|
|
101
|
+
windows = [load_window(self, window_dir) for window_dir in window_dirs]
|
|
102
|
+
else:
|
|
103
|
+
p = multiprocessing.Pool(workers)
|
|
104
|
+
outputs = star_imap_unordered(
|
|
105
|
+
p,
|
|
106
|
+
load_window,
|
|
107
|
+
[
|
|
108
|
+
dict(storage=self, window_dir=window_dir)
|
|
109
|
+
for window_dir in window_dirs
|
|
110
|
+
],
|
|
111
|
+
)
|
|
112
|
+
if show_progress:
|
|
113
|
+
outputs = tqdm.tqdm(
|
|
114
|
+
outputs, total=len(window_dirs), desc="Loading windows"
|
|
115
|
+
)
|
|
116
|
+
windows = []
|
|
117
|
+
for window in outputs:
|
|
118
|
+
windows.append(window)
|
|
119
|
+
p.close()
|
|
120
|
+
|
|
121
|
+
return windows
|
|
122
|
+
|
|
123
|
+
@override
|
|
124
|
+
def create_or_update_window(self, window: Window) -> None:
|
|
125
|
+
window_path = self.get_window_root(window.group, window.name)
|
|
126
|
+
window_path.mkdir(parents=True, exist_ok=True)
|
|
127
|
+
metadata_path = window_path / "metadata.json"
|
|
128
|
+
logger.debug(f"Saving window metadata to {metadata_path}")
|
|
129
|
+
with open_atomic(metadata_path, "w") as f:
|
|
130
|
+
json.dump(window.get_metadata(), f)
|
|
131
|
+
|
|
132
|
+
@override
|
|
133
|
+
def get_layer_datas(self, group: str, name: str) -> dict[str, "WindowLayerData"]:
|
|
134
|
+
window_path = self.get_window_root(group, name)
|
|
135
|
+
items_fname = window_path / "items.json"
|
|
136
|
+
if not items_fname.exists():
|
|
137
|
+
return {}
|
|
138
|
+
|
|
139
|
+
with items_fname.open() as f:
|
|
140
|
+
layer_datas = [
|
|
141
|
+
WindowLayerData.deserialize(layer_data) for layer_data in json.load(f)
|
|
142
|
+
]
|
|
143
|
+
|
|
144
|
+
return {layer_data.layer_name: layer_data for layer_data in layer_datas}
|
|
145
|
+
|
|
146
|
+
@override
|
|
147
|
+
def save_layer_datas(
|
|
148
|
+
self, group: str, name: str, layer_datas: dict[str, "WindowLayerData"]
|
|
149
|
+
) -> None:
|
|
150
|
+
window_path = self.get_window_root(group, name)
|
|
151
|
+
json_data = [layer_data.serialize() for layer_data in layer_datas.values()]
|
|
152
|
+
items_fname = window_path / "items.json"
|
|
153
|
+
logger.info(f"Saving window items to {items_fname}")
|
|
154
|
+
with open_atomic(items_fname, "w") as f:
|
|
155
|
+
json.dump(json_data, f)
|
|
156
|
+
|
|
157
|
+
@override
|
|
158
|
+
def list_completed_layers(self, group: str, name: str) -> list[tuple[str, int]]:
|
|
159
|
+
window_path = self.get_window_root(group, name)
|
|
160
|
+
layers_directory = window_path / LAYERS_DIRECTORY_NAME
|
|
161
|
+
if not layers_directory.exists():
|
|
162
|
+
return []
|
|
163
|
+
|
|
164
|
+
completed_layers = []
|
|
165
|
+
for layer_dir in layers_directory.iterdir():
|
|
166
|
+
layer_name, group_idx = get_layer_and_group_from_dir_name(layer_dir.name)
|
|
167
|
+
if not self.is_layer_completed(group, name, layer_name, group_idx):
|
|
168
|
+
continue
|
|
169
|
+
completed_layers.append((layer_name, group_idx))
|
|
170
|
+
|
|
171
|
+
return completed_layers
|
|
172
|
+
|
|
173
|
+
@override
|
|
174
|
+
def is_layer_completed(
|
|
175
|
+
self, group: str, name: str, layer_name: str, group_idx: int = 0
|
|
176
|
+
) -> bool:
|
|
177
|
+
window_path = self.get_window_root(group, name)
|
|
178
|
+
layer_dir = get_window_layer_dir(
|
|
179
|
+
window_path,
|
|
180
|
+
layer_name,
|
|
181
|
+
group_idx,
|
|
182
|
+
)
|
|
183
|
+
return (layer_dir / "completed").exists()
|
|
184
|
+
|
|
185
|
+
@override
|
|
186
|
+
def mark_layer_completed(
|
|
187
|
+
self, group: str, name: str, layer_name: str, group_idx: int = 0
|
|
188
|
+
) -> None:
|
|
189
|
+
window_path = self.get_window_root(group, name)
|
|
190
|
+
layer_dir = get_window_layer_dir(window_path, layer_name, group_idx)
|
|
191
|
+
# We assume the directory exists because the layer should be materialized before
|
|
192
|
+
# being marked completed.
|
|
193
|
+
(layer_dir / "completed").touch()
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class FileWindowStorageFactory(WindowStorageFactory):
|
|
197
|
+
"""Factory class for FileWindowStorage."""
|
|
198
|
+
|
|
199
|
+
@override
|
|
200
|
+
def get_storage(self, ds_path: UPath) -> FileWindowStorage:
|
|
201
|
+
"""Get a FileWindowStorage for the given dataset path."""
|
|
202
|
+
return FileWindowStorage(ds_path)
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
"""Abstract classes for window metadata storage."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from upath import UPath
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from rslearn.dataset.window import Window, WindowLayerData
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class WindowStorage(abc.ABC):
|
|
13
|
+
"""An abstract class for the storage backend for window metadata.
|
|
14
|
+
|
|
15
|
+
This is instantiated by a WindowStorageFactory for a specific rslearn dataset.
|
|
16
|
+
|
|
17
|
+
Window metadata includes the location and time range of windows (metadata.json),
|
|
18
|
+
the window layer datas (items.json), and the completed (materialized) layers. It
|
|
19
|
+
excludes the actual materialized data. All operations involving window metadata go
|
|
20
|
+
through the WindowStorage, including enumerating windows, creating new windows, and
|
|
21
|
+
updating window layer datas during `rslearn dataset prepare` or the completed
|
|
22
|
+
layers during `rslearn dataset materialize`.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
@abc.abstractmethod
|
|
26
|
+
def get_window_root(self, group: str, name: str) -> UPath:
|
|
27
|
+
"""Get the path where the window should be stored."""
|
|
28
|
+
raise NotImplementedError
|
|
29
|
+
|
|
30
|
+
@abc.abstractmethod
|
|
31
|
+
def get_windows(
|
|
32
|
+
self,
|
|
33
|
+
groups: list[str] | None = None,
|
|
34
|
+
names: list[str] | None = None,
|
|
35
|
+
) -> list["Window"]:
|
|
36
|
+
"""Load the windows in the dataset.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
groups: an optional list of groups to filter loading
|
|
40
|
+
names: an optional list of window names to filter loading
|
|
41
|
+
"""
|
|
42
|
+
raise NotImplementedError
|
|
43
|
+
|
|
44
|
+
@abc.abstractmethod
|
|
45
|
+
def create_or_update_window(self, window: "Window") -> None:
|
|
46
|
+
"""Create or update the window.
|
|
47
|
+
|
|
48
|
+
An existing window is only updated if there is one with the same name and group.
|
|
49
|
+
|
|
50
|
+
If there is a window with the same name but a different group, the behavior is
|
|
51
|
+
undefined.
|
|
52
|
+
"""
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
@abc.abstractmethod
|
|
56
|
+
def get_layer_datas(self, group: str, name: str) -> dict[str, "WindowLayerData"]:
|
|
57
|
+
"""Get the window layer datas for the specified window.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
group: the window group.
|
|
61
|
+
name: the window name.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
a dict mapping from the layer name to the layer data for that layer, if one
|
|
65
|
+
was previously saved.
|
|
66
|
+
"""
|
|
67
|
+
raise NotImplementedError
|
|
68
|
+
|
|
69
|
+
@abc.abstractmethod
|
|
70
|
+
def save_layer_datas(
|
|
71
|
+
self, group: str, name: str, layer_datas: dict[str, "WindowLayerData"]
|
|
72
|
+
) -> None:
|
|
73
|
+
"""Set the window layer datas for the specified window."""
|
|
74
|
+
raise NotImplementedError
|
|
75
|
+
|
|
76
|
+
@abc.abstractmethod
|
|
77
|
+
def list_completed_layers(self, group: str, name: str) -> list[tuple[str, int]]:
|
|
78
|
+
"""List the layers available for this window that are completed.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
group: the window group.
|
|
82
|
+
name: the window name.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
a list of (layer_name, group_idx) completed layers.
|
|
86
|
+
"""
|
|
87
|
+
raise NotImplementedError
|
|
88
|
+
|
|
89
|
+
@abc.abstractmethod
|
|
90
|
+
def is_layer_completed(
|
|
91
|
+
self, group: str, name: str, layer_name: str, group_idx: int = 0
|
|
92
|
+
) -> bool:
|
|
93
|
+
"""Check whether the specified layer is completed in the given window.
|
|
94
|
+
|
|
95
|
+
Completed means there is data in the layer and the data has been written
|
|
96
|
+
(materialized).
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
group: the window group.
|
|
100
|
+
name: the window name.
|
|
101
|
+
layer_name: the layer name.
|
|
102
|
+
group_idx: the index of the group within the layer.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
whether the layer is completed.
|
|
106
|
+
"""
|
|
107
|
+
raise NotImplementedError
|
|
108
|
+
|
|
109
|
+
@abc.abstractmethod
|
|
110
|
+
def mark_layer_completed(
|
|
111
|
+
self, group: str, name: str, layer_name: str, group_idx: int = 0
|
|
112
|
+
) -> None:
|
|
113
|
+
"""Mark the specified layer completed for the given window.
|
|
114
|
+
|
|
115
|
+
This must be done after the contents of the layer have been written. If a layer
|
|
116
|
+
has multiple groups, the caller should wait until the contents of all groups
|
|
117
|
+
have been written before marking them completed; this is because, when
|
|
118
|
+
materializing a window, we skip materialization if the first group
|
|
119
|
+
(group_idx=0) is marked completed.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
group: the window group.
|
|
123
|
+
name: the window name.
|
|
124
|
+
layer_name: the layer name.
|
|
125
|
+
group_idx: the index of the group within the layer.
|
|
126
|
+
"""
|
|
127
|
+
raise NotImplementedError
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class WindowStorageFactory(abc.ABC):
|
|
131
|
+
"""An abstract class for a configurable storage backend for window metadata.
|
|
132
|
+
|
|
133
|
+
The dataset config includes a StorageConfig that configures a WindowStorageFactory,
|
|
134
|
+
which in turn creates a WindowStorage given a dataset path.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
@abc.abstractmethod
|
|
138
|
+
def get_storage(self, ds_path: UPath) -> WindowStorage:
|
|
139
|
+
"""Get a WindowStorage for the given dataset path."""
|
|
140
|
+
raise NotImplementedError
|
rslearn/dataset/window.py
CHANGED
|
@@ -1,14 +1,79 @@
|
|
|
1
1
|
"""rslearn windows."""
|
|
2
2
|
|
|
3
|
-
import json
|
|
4
3
|
from datetime import datetime
|
|
5
4
|
from typing import Any
|
|
6
5
|
|
|
7
6
|
import shapely
|
|
8
7
|
from upath import UPath
|
|
9
8
|
|
|
9
|
+
from rslearn.dataset.storage.storage import WindowStorage
|
|
10
|
+
from rslearn.log_utils import get_logger
|
|
10
11
|
from rslearn.utils import Projection, STGeometry
|
|
11
|
-
from rslearn.utils.
|
|
12
|
+
from rslearn.utils.raster_format import get_bandset_dirname
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
LAYERS_DIRECTORY_NAME = "layers"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_window_layer_dir(
|
|
20
|
+
window_path: UPath, layer_name: str, group_idx: int = 0
|
|
21
|
+
) -> UPath:
|
|
22
|
+
"""Get the directory containing materialized data for the specified layer.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
window_path: the window directory.
|
|
26
|
+
layer_name: the layer name.
|
|
27
|
+
group_idx: the index of the group within the layer to get the directory
|
|
28
|
+
for (default 0).
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
the path where data is or should be materialized.
|
|
32
|
+
"""
|
|
33
|
+
if group_idx == 0:
|
|
34
|
+
folder_name = layer_name
|
|
35
|
+
else:
|
|
36
|
+
folder_name = f"{layer_name}.{group_idx}"
|
|
37
|
+
return window_path / LAYERS_DIRECTORY_NAME / folder_name
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_layer_and_group_from_dir_name(layer_dir_name: str) -> tuple[str, int]:
|
|
41
|
+
"""Get the layer name and group index from the layer directory name.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
layer_dir_name: the name of the layer folder.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
a tuple (layer_name, group_idx)
|
|
48
|
+
"""
|
|
49
|
+
if "." in layer_dir_name:
|
|
50
|
+
parts = layer_dir_name.split(".")
|
|
51
|
+
if len(parts) != 2:
|
|
52
|
+
raise ValueError(
|
|
53
|
+
f"expected layer directory name {layer_dir_name} to only contain one '.'"
|
|
54
|
+
)
|
|
55
|
+
return (parts[0], int(parts[1]))
|
|
56
|
+
else:
|
|
57
|
+
return (layer_dir_name, 0)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def get_window_raster_dir(
|
|
61
|
+
window_path: UPath, layer_name: str, bands: list[str], group_idx: int = 0
|
|
62
|
+
) -> UPath:
|
|
63
|
+
"""Get the directory where the raster is materialized.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
window_path: the window directory.
|
|
67
|
+
layer_name: the layer name
|
|
68
|
+
bands: the bands in the raster. It should match a band set defined for this
|
|
69
|
+
layer.
|
|
70
|
+
group_idx: the index of the group within the layer.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
the directory containing the raster.
|
|
74
|
+
"""
|
|
75
|
+
dirname = get_bandset_dirname(bands)
|
|
76
|
+
return get_window_layer_dir(window_path, layer_name, group_idx) / dirname
|
|
12
77
|
|
|
13
78
|
|
|
14
79
|
class WindowLayerData:
|
|
@@ -69,7 +134,7 @@ class Window:
|
|
|
69
134
|
|
|
70
135
|
def __init__(
|
|
71
136
|
self,
|
|
72
|
-
|
|
137
|
+
storage: WindowStorage,
|
|
73
138
|
group: str,
|
|
74
139
|
name: str,
|
|
75
140
|
projection: Projection,
|
|
@@ -83,7 +148,7 @@ class Window:
|
|
|
83
148
|
stored in metadata.json.
|
|
84
149
|
|
|
85
150
|
Args:
|
|
86
|
-
|
|
151
|
+
storage: the dataset storage for the underlying rslearn dataset.
|
|
87
152
|
group: the group the window belongs to
|
|
88
153
|
name: the unique name for this window
|
|
89
154
|
projection: the projection of the window
|
|
@@ -91,7 +156,7 @@ class Window:
|
|
|
91
156
|
time_range: optional time range of the window
|
|
92
157
|
options: additional options (?)
|
|
93
158
|
"""
|
|
94
|
-
self.
|
|
159
|
+
self.storage = storage
|
|
95
160
|
self.group = group
|
|
96
161
|
self.name = name
|
|
97
162
|
self.projection = projection
|
|
@@ -99,25 +164,6 @@ class Window:
|
|
|
99
164
|
self.time_range = time_range
|
|
100
165
|
self.options = options
|
|
101
166
|
|
|
102
|
-
def save(self) -> None:
|
|
103
|
-
"""Save the window metadata to its root directory."""
|
|
104
|
-
self.path.mkdir(parents=True, exist_ok=True)
|
|
105
|
-
metadata = {
|
|
106
|
-
"group": self.group,
|
|
107
|
-
"name": self.name,
|
|
108
|
-
"projection": self.projection.serialize(),
|
|
109
|
-
"bounds": self.bounds,
|
|
110
|
-
"time_range": (
|
|
111
|
-
[self.time_range[0].isoformat(), self.time_range[1].isoformat()]
|
|
112
|
-
if self.time_range
|
|
113
|
-
else None
|
|
114
|
-
),
|
|
115
|
-
"options": self.options,
|
|
116
|
-
}
|
|
117
|
-
metadata_path = self.path / "metadata.json"
|
|
118
|
-
with open_atomic(metadata_path, "w") as f:
|
|
119
|
-
json.dump(metadata, f)
|
|
120
|
-
|
|
121
167
|
def get_geometry(self) -> STGeometry:
|
|
122
168
|
"""Computes the STGeometry corresponding to this window."""
|
|
123
169
|
return STGeometry(
|
|
@@ -128,41 +174,132 @@ class Window:
|
|
|
128
174
|
|
|
129
175
|
def load_layer_datas(self) -> dict[str, WindowLayerData]:
|
|
130
176
|
"""Load layer datas describing items in retrieved layers from items.json."""
|
|
131
|
-
|
|
132
|
-
if not items_fname.exists():
|
|
133
|
-
return {}
|
|
134
|
-
with items_fname.open("r") as f:
|
|
135
|
-
layer_datas = [
|
|
136
|
-
WindowLayerData.deserialize(layer_data) for layer_data in json.load(f)
|
|
137
|
-
]
|
|
138
|
-
return {layer_data.layer_name: layer_data for layer_data in layer_datas}
|
|
177
|
+
return self.storage.get_layer_datas(self.group, self.name)
|
|
139
178
|
|
|
140
179
|
def save_layer_datas(self, layer_datas: dict[str, WindowLayerData]) -> None:
|
|
141
180
|
"""Save layer datas to items.json."""
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
181
|
+
self.storage.save_layer_datas(self.group, self.name, layer_datas)
|
|
182
|
+
|
|
183
|
+
def list_completed_layers(self) -> list[tuple[str, int]]:
|
|
184
|
+
"""List the layers available for this window that are completed.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
a list of (layer_name, group_idx) completed layers.
|
|
188
|
+
"""
|
|
189
|
+
return self.storage.list_completed_layers(self.group, self.name)
|
|
190
|
+
|
|
191
|
+
def get_layer_dir(self, layer_name: str, group_idx: int = 0) -> UPath:
|
|
192
|
+
"""Get the directory containing materialized data for the specified layer.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
layer_name: the layer name.
|
|
196
|
+
group_idx: the index of the group within the layer to get the directory
|
|
197
|
+
for (default 0).
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
the path where data is or should be materialized.
|
|
201
|
+
"""
|
|
202
|
+
return get_window_layer_dir(
|
|
203
|
+
self.storage.get_window_root(self.group, self.name), layer_name, group_idx
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
def is_layer_completed(self, layer_name: str, group_idx: int = 0) -> bool:
|
|
207
|
+
"""Check whether the specified layer is completed.
|
|
208
|
+
|
|
209
|
+
Completed means there is data in the layer and the data has been written
|
|
210
|
+
(materialized).
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
layer_name: the layer name.
|
|
214
|
+
group_idx: the index of the group within the layer.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
whether the layer is completed
|
|
218
|
+
"""
|
|
219
|
+
return self.storage.is_layer_completed(
|
|
220
|
+
self.group, self.name, layer_name, group_idx
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
def mark_layer_completed(self, layer_name: str, group_idx: int = 0) -> None:
|
|
224
|
+
"""Mark the specified layer completed.
|
|
225
|
+
|
|
226
|
+
This must be done after the contents of the layer have been written. If a layer
|
|
227
|
+
has multiple groups, the caller should wait until the contents of all groups
|
|
228
|
+
have been written before marking them completed; this is because, when
|
|
229
|
+
materializing a window, we skip materialization if the first group
|
|
230
|
+
(group_idx=0) is marked completed.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
layer_name: the layer name.
|
|
234
|
+
group_idx: the index of the group within the layer.
|
|
235
|
+
"""
|
|
236
|
+
self.storage.mark_layer_completed(self.group, self.name, layer_name, group_idx)
|
|
237
|
+
|
|
238
|
+
def get_raster_dir(
|
|
239
|
+
self, layer_name: str, bands: list[str], group_idx: int = 0
|
|
240
|
+
) -> UPath:
|
|
241
|
+
"""Get the directory where the raster is materialized.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
layer_name: the layer name
|
|
245
|
+
bands: the bands in the raster. It should match a band set defined for this
|
|
246
|
+
layer.
|
|
247
|
+
group_idx: the index of the group within the layer.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
the directory containing the raster.
|
|
251
|
+
"""
|
|
252
|
+
return get_window_raster_dir(
|
|
253
|
+
self.storage.get_window_root(self.group, self.name),
|
|
254
|
+
layer_name,
|
|
255
|
+
bands,
|
|
256
|
+
group_idx,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
def get_metadata(self) -> dict[str, Any]:
|
|
260
|
+
"""Returns the window metadata dictionary."""
|
|
261
|
+
return {
|
|
262
|
+
"group": self.group,
|
|
263
|
+
"name": self.name,
|
|
264
|
+
"projection": self.projection.serialize(),
|
|
265
|
+
"bounds": self.bounds,
|
|
266
|
+
"time_range": (
|
|
267
|
+
[self.time_range[0].isoformat(), self.time_range[1].isoformat()]
|
|
268
|
+
if self.time_range
|
|
269
|
+
else None
|
|
270
|
+
),
|
|
271
|
+
"options": self.options,
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
def save(self) -> None:
|
|
275
|
+
"""Save the window metadata to its root directory."""
|
|
276
|
+
self.storage.create_or_update_window(self)
|
|
146
277
|
|
|
147
278
|
@staticmethod
|
|
148
|
-
def
|
|
149
|
-
"""
|
|
279
|
+
def from_metadata(storage: WindowStorage, metadata: dict[str, Any]) -> "Window":
|
|
280
|
+
"""Create a Window from the WindowStorage and the window's metadata dictionary.
|
|
150
281
|
|
|
151
282
|
Args:
|
|
152
|
-
|
|
283
|
+
storage: the WindowStorage for the underlying dataset.
|
|
284
|
+
metadata: the window metadata.
|
|
153
285
|
|
|
154
286
|
Returns:
|
|
155
287
|
the Window
|
|
156
288
|
"""
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
metadata
|
|
289
|
+
# Ensure bounds is converted from list to tuple.
|
|
290
|
+
bounds = (
|
|
291
|
+
metadata["bounds"][0],
|
|
292
|
+
metadata["bounds"][1],
|
|
293
|
+
metadata["bounds"][2],
|
|
294
|
+
metadata["bounds"][3],
|
|
295
|
+
)
|
|
296
|
+
|
|
160
297
|
return Window(
|
|
161
|
-
|
|
298
|
+
storage=storage,
|
|
162
299
|
group=metadata["group"],
|
|
163
300
|
name=metadata["name"],
|
|
164
301
|
projection=Projection.deserialize(metadata["projection"]),
|
|
165
|
-
bounds=
|
|
302
|
+
bounds=bounds,
|
|
166
303
|
time_range=(
|
|
167
304
|
(
|
|
168
305
|
datetime.fromisoformat(metadata["time_range"][0]),
|