rslearn 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +2 -9
- rslearn/config/__init__.py +2 -0
- rslearn/config/dataset.py +64 -20
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +34 -84
- rslearn/dataset/materialize.py +5 -5
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +26 -80
- rslearn/lightning_cli.py +22 -11
- rslearn/main.py +12 -37
- rslearn/models/anysat.py +11 -9
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +8 -9
- rslearn/models/clip.py +18 -15
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +21 -11
- rslearn/models/conv.py +15 -8
- rslearn/models/croma.py +13 -8
- rslearn/models/detr/detr.py +25 -14
- rslearn/models/dinov3.py +11 -6
- rslearn/models/faster_rcnn.py +19 -9
- rslearn/models/feature_center_crop.py +12 -9
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/galileo.py +23 -18
- rslearn/models/module_wrapper.py +26 -57
- rslearn/models/molmo.py +16 -14
- rslearn/models/multitask.py +102 -73
- rslearn/models/olmoearth_pretrain/model.py +135 -38
- rslearn/models/panopticon.py +8 -7
- rslearn/models/pick_features.py +18 -24
- rslearn/models/pooling_decoder.py +22 -14
- rslearn/models/presto/presto.py +16 -10
- rslearn/models/presto/single_file_presto.py +4 -10
- rslearn/models/prithvi.py +12 -8
- rslearn/models/resize_features.py +21 -7
- rslearn/models/sam2_enc.py +11 -9
- rslearn/models/satlaspretrain.py +15 -9
- rslearn/models/simple_time_series.py +37 -17
- rslearn/models/singletask.py +24 -17
- rslearn/models/ssl4eo_s12.py +15 -10
- rslearn/models/swin.py +22 -13
- rslearn/models/terramind.py +24 -7
- rslearn/models/trunk.py +6 -3
- rslearn/models/unet.py +18 -9
- rslearn/models/upsample.py +22 -9
- rslearn/train/all_patches_dataset.py +89 -37
- rslearn/train/dataset.py +105 -97
- rslearn/train/lightning_module.py +51 -32
- rslearn/train/model_context.py +54 -0
- rslearn/train/prediction_writer.py +111 -41
- rslearn/train/scheduler.py +15 -0
- rslearn/train/tasks/classification.py +34 -15
- rslearn/train/tasks/detection.py +24 -31
- rslearn/train/tasks/embedding.py +33 -29
- rslearn/train/tasks/multi_task.py +7 -7
- rslearn/train/tasks/per_pixel_regression.py +41 -19
- rslearn/train/tasks/regression.py +38 -21
- rslearn/train/tasks/segmentation.py +33 -15
- rslearn/train/tasks/task.py +3 -2
- rslearn/train/transforms/resize.py +74 -0
- rslearn/utils/geometry.py +73 -0
- rslearn/utils/jsonargparse.py +66 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/RECORD +71 -66
- rslearn/dataset/index.py +0 -173
- rslearn/models/registry.py +0 -22
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
rslearn/arg_parser.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
"""Custom Lightning ArgumentParser with environment variable substitution support."""
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
3
|
from typing import Any
|
|
5
4
|
|
|
6
5
|
from jsonargparse import Namespace
|
|
@@ -21,11 +20,7 @@ class RslearnArgumentParser(LightningArgumentParser):
|
|
|
21
20
|
def parse_string(
|
|
22
21
|
self,
|
|
23
22
|
cfg_str: str,
|
|
24
|
-
|
|
25
|
-
ext_vars: dict | None = None,
|
|
26
|
-
env: bool | None = None,
|
|
27
|
-
defaults: bool = True,
|
|
28
|
-
with_meta: bool | None = None,
|
|
23
|
+
*args: Any,
|
|
29
24
|
**kwargs: Any,
|
|
30
25
|
) -> Namespace:
|
|
31
26
|
"""Pre-processes string for environment variable substitution before parsing."""
|
|
@@ -33,6 +28,4 @@ class RslearnArgumentParser(LightningArgumentParser):
|
|
|
33
28
|
substituted_cfg_str = substitute_env_vars_in_string(cfg_str)
|
|
34
29
|
|
|
35
30
|
# Call the parent method with the substituted config
|
|
36
|
-
return super().parse_string(
|
|
37
|
-
substituted_cfg_str, cfg_path, ext_vars, env, defaults, with_meta, **kwargs
|
|
38
|
-
)
|
|
31
|
+
return super().parse_string(substituted_cfg_str, *args, **kwargs)
|
rslearn/config/__init__.py
CHANGED
rslearn/config/dataset.py
CHANGED
|
@@ -25,12 +25,13 @@ from rasterio.enums import Resampling
|
|
|
25
25
|
from upath import UPath
|
|
26
26
|
|
|
27
27
|
from rslearn.log_utils import get_logger
|
|
28
|
-
from rslearn.utils import PixelBounds, Projection
|
|
28
|
+
from rslearn.utils.geometry import PixelBounds, Projection, ResolutionFactor
|
|
29
29
|
from rslearn.utils.raster_format import RasterFormat
|
|
30
30
|
from rslearn.utils.vector_format import VectorFormat
|
|
31
31
|
|
|
32
32
|
if TYPE_CHECKING:
|
|
33
33
|
from rslearn.data_sources.data_source import DataSource
|
|
34
|
+
from rslearn.dataset.storage.storage import WindowStorageFactory
|
|
34
35
|
|
|
35
36
|
logger = get_logger("__name__")
|
|
36
37
|
|
|
@@ -132,7 +133,11 @@ class BandSetConfig(BaseModel):
|
|
|
132
133
|
bands.
|
|
133
134
|
"""
|
|
134
135
|
|
|
135
|
-
|
|
136
|
+
model_config = ConfigDict(extra="forbid")
|
|
137
|
+
|
|
138
|
+
dtype: DType = Field(
|
|
139
|
+
description="Pixel value type to store the data under. This is used during dataset materialize and model predict."
|
|
140
|
+
)
|
|
136
141
|
bands: list[str] = Field(
|
|
137
142
|
default_factory=lambda: [],
|
|
138
143
|
description="List of band names in this BandSetConfig. One of bands or num_bands must be set.",
|
|
@@ -210,22 +215,12 @@ class BandSetConfig(BaseModel):
|
|
|
210
215
|
Returns:
|
|
211
216
|
tuple of updated projection and bounds with zoom offset applied
|
|
212
217
|
"""
|
|
213
|
-
if self.zoom_offset
|
|
214
|
-
|
|
215
|
-
projection = Projection(
|
|
216
|
-
projection.crs,
|
|
217
|
-
projection.x_resolution / (2**self.zoom_offset),
|
|
218
|
-
projection.y_resolution / (2**self.zoom_offset),
|
|
219
|
-
)
|
|
220
|
-
if self.zoom_offset > 0:
|
|
221
|
-
zoom_factor = 2**self.zoom_offset
|
|
222
|
-
bounds = tuple(x * zoom_factor for x in bounds) # type: ignore
|
|
218
|
+
if self.zoom_offset >= 0:
|
|
219
|
+
factor = ResolutionFactor(numerator=2**self.zoom_offset)
|
|
223
220
|
else:
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
)
|
|
228
|
-
return projection, bounds
|
|
221
|
+
factor = ResolutionFactor(denominator=2 ** (-self.zoom_offset))
|
|
222
|
+
|
|
223
|
+
return (factor.multiply_projection(projection), factor.multiply_bounds(bounds))
|
|
229
224
|
|
|
230
225
|
@field_validator("format", mode="before")
|
|
231
226
|
@classmethod
|
|
@@ -329,7 +324,7 @@ class TimeMode(StrEnum):
|
|
|
329
324
|
class QueryConfig(BaseModel):
|
|
330
325
|
"""A configuration for querying items in a data source."""
|
|
331
326
|
|
|
332
|
-
model_config = ConfigDict(frozen=True)
|
|
327
|
+
model_config = ConfigDict(frozen=True, extra="forbid")
|
|
333
328
|
|
|
334
329
|
space_mode: SpaceMode = Field(
|
|
335
330
|
default=SpaceMode.MOSAIC,
|
|
@@ -363,7 +358,7 @@ class QueryConfig(BaseModel):
|
|
|
363
358
|
class DataSourceConfig(BaseModel):
|
|
364
359
|
"""Configuration for a DataSource in a dataset layer."""
|
|
365
360
|
|
|
366
|
-
model_config = ConfigDict(frozen=True)
|
|
361
|
+
model_config = ConfigDict(frozen=True, extra="forbid")
|
|
367
362
|
|
|
368
363
|
class_path: str = Field(description="Class path for the data source.")
|
|
369
364
|
init_args: dict[str, Any] = Field(
|
|
@@ -469,7 +464,7 @@ class CompositingMethod(StrEnum):
|
|
|
469
464
|
class LayerConfig(BaseModel):
|
|
470
465
|
"""Configuration of a layer in a dataset."""
|
|
471
466
|
|
|
472
|
-
model_config = ConfigDict(frozen=True)
|
|
467
|
+
model_config = ConfigDict(frozen=True, extra="forbid")
|
|
473
468
|
|
|
474
469
|
type: LayerType = Field(description="The LayerType (raster or vector).")
|
|
475
470
|
data_source: DataSourceConfig | None = Field(
|
|
@@ -592,11 +587,60 @@ class LayerConfig(BaseModel):
|
|
|
592
587
|
return vector_format
|
|
593
588
|
|
|
594
589
|
|
|
590
|
+
class StorageConfig(BaseModel):
|
|
591
|
+
"""Configuration for the WindowStorageFactory (window metadata storage backend)."""
|
|
592
|
+
|
|
593
|
+
model_config = ConfigDict(frozen=True, extra="forbid")
|
|
594
|
+
|
|
595
|
+
class_path: str = Field(
|
|
596
|
+
default="rslearn.dataset.storage.file.FileWindowStorageFactory",
|
|
597
|
+
description="Class path for the WindowStorageFactory.",
|
|
598
|
+
)
|
|
599
|
+
init_args: dict[str, Any] = Field(
|
|
600
|
+
default_factory=lambda: {},
|
|
601
|
+
description="jsonargparse init args for the WindowStorageFactory.",
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
def instantiate_window_storage_factory(self) -> "WindowStorageFactory":
|
|
605
|
+
"""Instantiate the WindowStorageFactory specified by this config."""
|
|
606
|
+
from rslearn.dataset.storage.storage import WindowStorageFactory
|
|
607
|
+
from rslearn.utils.jsonargparse import init_jsonargparse
|
|
608
|
+
|
|
609
|
+
init_jsonargparse()
|
|
610
|
+
parser = jsonargparse.ArgumentParser()
|
|
611
|
+
parser.add_argument("--wsf", type=WindowStorageFactory)
|
|
612
|
+
cfg = parser.parse_object(
|
|
613
|
+
{
|
|
614
|
+
"wsf": dict(
|
|
615
|
+
class_path=self.class_path,
|
|
616
|
+
init_args=self.init_args,
|
|
617
|
+
)
|
|
618
|
+
}
|
|
619
|
+
)
|
|
620
|
+
wsf = parser.instantiate_classes(cfg).wsf
|
|
621
|
+
return wsf
|
|
622
|
+
|
|
623
|
+
|
|
595
624
|
class DatasetConfig(BaseModel):
|
|
596
625
|
"""Overall dataset configuration."""
|
|
597
626
|
|
|
627
|
+
model_config = ConfigDict(extra="forbid")
|
|
628
|
+
|
|
598
629
|
layers: dict[str, LayerConfig] = Field(description="Layers in the dataset.")
|
|
599
630
|
tile_store: dict[str, Any] = Field(
|
|
600
631
|
default={"class_path": "rslearn.tile_stores.default.DefaultTileStore"},
|
|
601
632
|
description="jsonargparse configuration for the TileStore.",
|
|
602
633
|
)
|
|
634
|
+
storage: StorageConfig = Field(
|
|
635
|
+
default_factory=lambda: StorageConfig(),
|
|
636
|
+
description="jsonargparse configuration for the WindowStorageFactory.",
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
@field_validator("layers", mode="after")
|
|
640
|
+
@classmethod
|
|
641
|
+
def layer_names_validator(cls, v: dict[str, LayerConfig]) -> dict[str, LayerConfig]:
|
|
642
|
+
"""Ensure layer names don't contain periods, since we use periods to distinguish different materialized groups within a layer."""
|
|
643
|
+
for layer_name in v.keys():
|
|
644
|
+
if "." in layer_name:
|
|
645
|
+
raise ValueError(f"layer names must not contain periods: {layer_name}")
|
|
646
|
+
return v
|
rslearn/dataset/add_windows.py
CHANGED
|
@@ -131,7 +131,7 @@ def add_windows_from_geometries(
|
|
|
131
131
|
f"_{time_range[0].isoformat()}_{time_range[1].isoformat()}"
|
|
132
132
|
)
|
|
133
133
|
window = Window(
|
|
134
|
-
|
|
134
|
+
storage=dataset.storage,
|
|
135
135
|
group=group,
|
|
136
136
|
name=cur_window_name,
|
|
137
137
|
projection=cur_projection,
|
rslearn/dataset/dataset.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
"""rslearn dataset class."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
-
import
|
|
4
|
+
from typing import Any
|
|
5
5
|
|
|
6
|
-
import tqdm
|
|
7
6
|
from upath import UPath
|
|
8
7
|
|
|
9
8
|
from rslearn.config import DatasetConfig
|
|
@@ -11,7 +10,6 @@ from rslearn.log_utils import get_logger
|
|
|
11
10
|
from rslearn.template_params import substitute_env_vars_in_string
|
|
12
11
|
from rslearn.tile_stores import TileStore, load_tile_store
|
|
13
12
|
|
|
14
|
-
from .index import DatasetIndex
|
|
15
13
|
from .window import Window
|
|
16
14
|
|
|
17
15
|
logger = get_logger(__name__)
|
|
@@ -25,7 +23,7 @@ class Dataset:
|
|
|
25
23
|
.. code-block:: none
|
|
26
24
|
|
|
27
25
|
dataset/
|
|
28
|
-
config.json
|
|
26
|
+
config.json # optional, if config provided as runtime object
|
|
29
27
|
windows/
|
|
30
28
|
group1/
|
|
31
29
|
epsg:3857_10_623565_1528020/
|
|
@@ -42,106 +40,58 @@ class Dataset:
|
|
|
42
40
|
materialize.
|
|
43
41
|
"""
|
|
44
42
|
|
|
45
|
-
def __init__(
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
path: UPath,
|
|
46
|
+
disabled_layers: list[str] = [],
|
|
47
|
+
dataset_config: DatasetConfig | None = None,
|
|
48
|
+
) -> None:
|
|
46
49
|
"""Initializes a new Dataset.
|
|
47
50
|
|
|
48
51
|
Args:
|
|
49
52
|
path: the root directory of the dataset
|
|
50
53
|
disabled_layers: list of layers to disable
|
|
54
|
+
dataset_config: optional dataset configuration to use instead of loading from the dataset directory
|
|
51
55
|
"""
|
|
52
56
|
self.path = path
|
|
53
57
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
return DatasetIndex.load_index(self.path)
|
|
58
|
+
if dataset_config is None:
|
|
59
|
+
# Load dataset configuration from the dataset directory.
|
|
60
|
+
with (self.path / "config.json").open("r") as f:
|
|
61
|
+
config_content = f.read()
|
|
62
|
+
config_content = substitute_env_vars_in_string(config_content)
|
|
63
|
+
dataset_config = DatasetConfig.model_validate(
|
|
64
|
+
json.loads(config_content)
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
self.layers = {}
|
|
68
|
+
for layer_name, layer_config in dataset_config.layers.items():
|
|
69
|
+
if layer_name in disabled_layers:
|
|
70
|
+
logger.warning(f"Layer {layer_name} is disabled")
|
|
71
|
+
continue
|
|
72
|
+
self.layers[layer_name] = layer_config
|
|
73
|
+
|
|
74
|
+
self.tile_store_config = dataset_config.tile_store
|
|
75
|
+
self.storage = (
|
|
76
|
+
dataset_config.storage.instantiate_window_storage_factory().get_storage(
|
|
77
|
+
self.path
|
|
78
|
+
)
|
|
79
|
+
)
|
|
77
80
|
|
|
78
81
|
def load_windows(
|
|
79
82
|
self,
|
|
80
83
|
groups: list[str] | None = None,
|
|
81
84
|
names: list[str] | None = None,
|
|
82
|
-
|
|
83
|
-
workers: int = 0,
|
|
84
|
-
no_index: bool = False,
|
|
85
|
+
**kwargs: Any,
|
|
85
86
|
) -> list[Window]:
|
|
86
87
|
"""Load the windows in the dataset.
|
|
87
88
|
|
|
88
89
|
Args:
|
|
89
90
|
groups: an optional list of groups to filter loading
|
|
90
91
|
names: an optional list of window names to filter loading
|
|
91
|
-
|
|
92
|
-
workers: number of parallel workers, default 0 (use main thread only to load windows)
|
|
93
|
-
no_index: don't use the dataset index even if it exists.
|
|
92
|
+
kwargs: optional keyword arguments to pass to WindowStorage.get_windows.
|
|
94
93
|
"""
|
|
95
|
-
|
|
96
|
-
# We never use the index if names is set since loading the index will likely be
|
|
97
|
-
# slower than loading a few windows.
|
|
98
|
-
if not no_index and names is None:
|
|
99
|
-
dataset_index = self._get_index()
|
|
100
|
-
if dataset_index is not None:
|
|
101
|
-
return dataset_index.get_windows(groups=groups, names=names)
|
|
102
|
-
|
|
103
|
-
# Avoid directory does not exist errors later.
|
|
104
|
-
if not (self.path / "windows").exists():
|
|
105
|
-
return []
|
|
106
|
-
|
|
107
|
-
window_dirs = []
|
|
108
|
-
if not groups:
|
|
109
|
-
groups = []
|
|
110
|
-
for p in (self.path / "windows").iterdir():
|
|
111
|
-
groups.append(p.name)
|
|
112
|
-
for group in groups:
|
|
113
|
-
group_dir = self.path / "windows" / group
|
|
114
|
-
if not group_dir.exists():
|
|
115
|
-
logger.warning(
|
|
116
|
-
f"Skipping group directory {group_dir} since it does not exist"
|
|
117
|
-
)
|
|
118
|
-
continue
|
|
119
|
-
if names:
|
|
120
|
-
cur_names = names
|
|
121
|
-
else:
|
|
122
|
-
cur_names = []
|
|
123
|
-
for p in group_dir.iterdir():
|
|
124
|
-
cur_names.append(p.name)
|
|
125
|
-
|
|
126
|
-
for window_name in cur_names:
|
|
127
|
-
window_dir = group_dir / window_name
|
|
128
|
-
window_dirs.append(window_dir)
|
|
129
|
-
|
|
130
|
-
if workers == 0:
|
|
131
|
-
windows = [Window.load(window_dir) for window_dir in window_dirs]
|
|
132
|
-
else:
|
|
133
|
-
p = multiprocessing.Pool(workers)
|
|
134
|
-
outputs = p.imap_unordered(Window.load, window_dirs)
|
|
135
|
-
if show_progress:
|
|
136
|
-
outputs = tqdm.tqdm(
|
|
137
|
-
outputs, total=len(window_dirs), desc="Loading windows"
|
|
138
|
-
)
|
|
139
|
-
windows = []
|
|
140
|
-
for window in outputs:
|
|
141
|
-
windows.append(window)
|
|
142
|
-
p.close()
|
|
143
|
-
|
|
144
|
-
return windows
|
|
94
|
+
return self.storage.get_windows(groups=groups, names=names, **kwargs)
|
|
145
95
|
|
|
146
96
|
def get_tile_store(self) -> TileStore:
|
|
147
97
|
"""Get the tile store associated with this dataset.
|
rslearn/dataset/materialize.py
CHANGED
|
@@ -161,7 +161,7 @@ def build_first_valid_composite(
|
|
|
161
161
|
nodata_vals: list[Any],
|
|
162
162
|
bands: list[str],
|
|
163
163
|
bounds: PixelBounds,
|
|
164
|
-
band_dtype:
|
|
164
|
+
band_dtype: npt.DTypeLike,
|
|
165
165
|
tile_store: TileStoreWithLayer,
|
|
166
166
|
projection: Projection,
|
|
167
167
|
remapper: Remapper | None,
|
|
@@ -233,7 +233,7 @@ def read_and_stack_raster_windows(
|
|
|
233
233
|
projection: Projection,
|
|
234
234
|
nodata_vals: list[Any],
|
|
235
235
|
remapper: Remapper | None,
|
|
236
|
-
band_dtype:
|
|
236
|
+
band_dtype: npt.DTypeLike,
|
|
237
237
|
resampling_method: Resampling = Resampling.bilinear,
|
|
238
238
|
) -> npt.NDArray[np.generic]:
|
|
239
239
|
"""Create a stack of extent aligned raster windows.
|
|
@@ -326,7 +326,7 @@ def build_mean_composite(
|
|
|
326
326
|
nodata_vals: list[Any],
|
|
327
327
|
bands: list[str],
|
|
328
328
|
bounds: PixelBounds,
|
|
329
|
-
band_dtype:
|
|
329
|
+
band_dtype: npt.DTypeLike,
|
|
330
330
|
tile_store: TileStoreWithLayer,
|
|
331
331
|
projection: Projection,
|
|
332
332
|
remapper: Remapper | None,
|
|
@@ -383,7 +383,7 @@ def build_median_composite(
|
|
|
383
383
|
nodata_vals: list[Any],
|
|
384
384
|
bands: list[str],
|
|
385
385
|
bounds: PixelBounds,
|
|
386
|
-
band_dtype:
|
|
386
|
+
band_dtype: npt.DTypeLike,
|
|
387
387
|
tile_store: TileStoreWithLayer,
|
|
388
388
|
projection: Projection,
|
|
389
389
|
remapper: Remapper | None,
|
|
@@ -471,7 +471,7 @@ def build_composite(
|
|
|
471
471
|
nodata_vals=nodata_vals,
|
|
472
472
|
bands=band_cfg.bands,
|
|
473
473
|
bounds=bounds,
|
|
474
|
-
band_dtype=band_cfg.dtype.
|
|
474
|
+
band_dtype=band_cfg.dtype.get_numpy_dtype(),
|
|
475
475
|
tile_store=tile_store,
|
|
476
476
|
projection=projection,
|
|
477
477
|
resampling_method=layer_cfg.resampling_method.get_rasterio_resampling(),
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Storage backends for rslearn window metadata."""
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
"""The default file-based window storage backend."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import multiprocessing
|
|
5
|
+
|
|
6
|
+
import tqdm
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
from upath import UPath
|
|
9
|
+
|
|
10
|
+
from rslearn.dataset.window import (
|
|
11
|
+
LAYERS_DIRECTORY_NAME,
|
|
12
|
+
Window,
|
|
13
|
+
WindowLayerData,
|
|
14
|
+
get_layer_and_group_from_dir_name,
|
|
15
|
+
get_window_layer_dir,
|
|
16
|
+
)
|
|
17
|
+
from rslearn.log_utils import get_logger
|
|
18
|
+
from rslearn.utils.fsspec import open_atomic
|
|
19
|
+
from rslearn.utils.mp import star_imap_unordered
|
|
20
|
+
|
|
21
|
+
from .storage import WindowStorage, WindowStorageFactory
|
|
22
|
+
|
|
23
|
+
logger = get_logger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def load_window(storage: "FileWindowStorage", window_dir: UPath) -> Window:
|
|
27
|
+
"""Load the window from its directory by reading metadata.json.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
storage: the underlying FileWindowStorage.
|
|
31
|
+
window_dir: the path where the window is stored.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
the window object.
|
|
35
|
+
"""
|
|
36
|
+
metadata_fname = window_dir / "metadata.json"
|
|
37
|
+
with metadata_fname.open() as f:
|
|
38
|
+
metadata = json.load(f)
|
|
39
|
+
return Window.from_metadata(storage, metadata)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class FileWindowStorage(WindowStorage):
|
|
43
|
+
"""The default file-backed window storage."""
|
|
44
|
+
|
|
45
|
+
def __init__(self, path: UPath):
|
|
46
|
+
"""Create a new FileWindowStorage.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
path: the path to the dataset.
|
|
50
|
+
"""
|
|
51
|
+
self.path = path
|
|
52
|
+
|
|
53
|
+
@override
|
|
54
|
+
def get_window_root(self, group: str, name: str) -> UPath:
|
|
55
|
+
return Window.get_window_root(self.path, group, name)
|
|
56
|
+
|
|
57
|
+
@override
|
|
58
|
+
def get_windows(
|
|
59
|
+
self,
|
|
60
|
+
groups: list[str] | None = None,
|
|
61
|
+
names: list[str] | None = None,
|
|
62
|
+
show_progress: bool = False,
|
|
63
|
+
workers: int = 0,
|
|
64
|
+
) -> list["Window"]:
|
|
65
|
+
"""Load the windows in the dataset.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
groups: an optional list of groups to filter loading
|
|
69
|
+
names: an optional list of window names to filter loading
|
|
70
|
+
show_progress: whether to show tqdm progress bar
|
|
71
|
+
workers: number of parallel workers, default 0 (use main thread only to load windows)
|
|
72
|
+
"""
|
|
73
|
+
# Avoid directory does not exist errors later.
|
|
74
|
+
if not (self.path / "windows").exists():
|
|
75
|
+
return []
|
|
76
|
+
|
|
77
|
+
window_dirs = []
|
|
78
|
+
if not groups:
|
|
79
|
+
groups = []
|
|
80
|
+
for p in (self.path / "windows").iterdir():
|
|
81
|
+
groups.append(p.name)
|
|
82
|
+
for group in groups:
|
|
83
|
+
group_dir = self.path / "windows" / group
|
|
84
|
+
if not group_dir.exists():
|
|
85
|
+
logger.warning(
|
|
86
|
+
f"Skipping group directory {group_dir} since it does not exist"
|
|
87
|
+
)
|
|
88
|
+
continue
|
|
89
|
+
if names:
|
|
90
|
+
cur_names = names
|
|
91
|
+
else:
|
|
92
|
+
cur_names = []
|
|
93
|
+
for p in group_dir.iterdir():
|
|
94
|
+
cur_names.append(p.name)
|
|
95
|
+
|
|
96
|
+
for window_name in cur_names:
|
|
97
|
+
window_dir = group_dir / window_name
|
|
98
|
+
window_dirs.append(window_dir)
|
|
99
|
+
|
|
100
|
+
if workers == 0:
|
|
101
|
+
windows = [load_window(self, window_dir) for window_dir in window_dirs]
|
|
102
|
+
else:
|
|
103
|
+
p = multiprocessing.Pool(workers)
|
|
104
|
+
outputs = star_imap_unordered(
|
|
105
|
+
p,
|
|
106
|
+
load_window,
|
|
107
|
+
[
|
|
108
|
+
dict(storage=self, window_dir=window_dir)
|
|
109
|
+
for window_dir in window_dirs
|
|
110
|
+
],
|
|
111
|
+
)
|
|
112
|
+
if show_progress:
|
|
113
|
+
outputs = tqdm.tqdm(
|
|
114
|
+
outputs, total=len(window_dirs), desc="Loading windows"
|
|
115
|
+
)
|
|
116
|
+
windows = []
|
|
117
|
+
for window in outputs:
|
|
118
|
+
windows.append(window)
|
|
119
|
+
p.close()
|
|
120
|
+
|
|
121
|
+
return windows
|
|
122
|
+
|
|
123
|
+
@override
|
|
124
|
+
def create_or_update_window(self, window: Window) -> None:
|
|
125
|
+
window_path = self.get_window_root(window.group, window.name)
|
|
126
|
+
window_path.mkdir(parents=True, exist_ok=True)
|
|
127
|
+
metadata_path = window_path / "metadata.json"
|
|
128
|
+
logger.debug(f"Saving window metadata to {metadata_path}")
|
|
129
|
+
with open_atomic(metadata_path, "w") as f:
|
|
130
|
+
json.dump(window.get_metadata(), f)
|
|
131
|
+
|
|
132
|
+
@override
|
|
133
|
+
def get_layer_datas(self, group: str, name: str) -> dict[str, "WindowLayerData"]:
|
|
134
|
+
window_path = self.get_window_root(group, name)
|
|
135
|
+
items_fname = window_path / "items.json"
|
|
136
|
+
if not items_fname.exists():
|
|
137
|
+
return {}
|
|
138
|
+
|
|
139
|
+
with items_fname.open() as f:
|
|
140
|
+
layer_datas = [
|
|
141
|
+
WindowLayerData.deserialize(layer_data) for layer_data in json.load(f)
|
|
142
|
+
]
|
|
143
|
+
|
|
144
|
+
return {layer_data.layer_name: layer_data for layer_data in layer_datas}
|
|
145
|
+
|
|
146
|
+
@override
|
|
147
|
+
def save_layer_datas(
|
|
148
|
+
self, group: str, name: str, layer_datas: dict[str, "WindowLayerData"]
|
|
149
|
+
) -> None:
|
|
150
|
+
window_path = self.get_window_root(group, name)
|
|
151
|
+
json_data = [layer_data.serialize() for layer_data in layer_datas.values()]
|
|
152
|
+
items_fname = window_path / "items.json"
|
|
153
|
+
logger.info(f"Saving window items to {items_fname}")
|
|
154
|
+
with open_atomic(items_fname, "w") as f:
|
|
155
|
+
json.dump(json_data, f)
|
|
156
|
+
|
|
157
|
+
@override
|
|
158
|
+
def list_completed_layers(self, group: str, name: str) -> list[tuple[str, int]]:
|
|
159
|
+
window_path = self.get_window_root(group, name)
|
|
160
|
+
layers_directory = window_path / LAYERS_DIRECTORY_NAME
|
|
161
|
+
if not layers_directory.exists():
|
|
162
|
+
return []
|
|
163
|
+
|
|
164
|
+
completed_layers = []
|
|
165
|
+
for layer_dir in layers_directory.iterdir():
|
|
166
|
+
layer_name, group_idx = get_layer_and_group_from_dir_name(layer_dir.name)
|
|
167
|
+
if not self.is_layer_completed(group, name, layer_name, group_idx):
|
|
168
|
+
continue
|
|
169
|
+
completed_layers.append((layer_name, group_idx))
|
|
170
|
+
|
|
171
|
+
return completed_layers
|
|
172
|
+
|
|
173
|
+
@override
|
|
174
|
+
def is_layer_completed(
|
|
175
|
+
self, group: str, name: str, layer_name: str, group_idx: int = 0
|
|
176
|
+
) -> bool:
|
|
177
|
+
window_path = self.get_window_root(group, name)
|
|
178
|
+
layer_dir = get_window_layer_dir(
|
|
179
|
+
window_path,
|
|
180
|
+
layer_name,
|
|
181
|
+
group_idx,
|
|
182
|
+
)
|
|
183
|
+
return (layer_dir / "completed").exists()
|
|
184
|
+
|
|
185
|
+
@override
|
|
186
|
+
def mark_layer_completed(
|
|
187
|
+
self, group: str, name: str, layer_name: str, group_idx: int = 0
|
|
188
|
+
) -> None:
|
|
189
|
+
window_path = self.get_window_root(group, name)
|
|
190
|
+
layer_dir = get_window_layer_dir(window_path, layer_name, group_idx)
|
|
191
|
+
# We assume the directory exists because the layer should be materialized before
|
|
192
|
+
# being marked completed.
|
|
193
|
+
(layer_dir / "completed").touch()
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class FileWindowStorageFactory(WindowStorageFactory):
|
|
197
|
+
"""Factory class for FileWindowStorage."""
|
|
198
|
+
|
|
199
|
+
@override
|
|
200
|
+
def get_storage(self, ds_path: UPath) -> FileWindowStorage:
|
|
201
|
+
"""Get a FileWindowStorage for the given dataset path."""
|
|
202
|
+
return FileWindowStorage(ds_path)
|