rslearn 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/__init__.py +2 -10
- rslearn/config/dataset.py +414 -420
- rslearn/data_sources/__init__.py +8 -31
- rslearn/data_sources/aws_landsat.py +13 -24
- rslearn/data_sources/aws_open_data.py +21 -46
- rslearn/data_sources/aws_sentinel1.py +3 -14
- rslearn/data_sources/climate_data_store.py +21 -40
- rslearn/data_sources/copernicus.py +30 -91
- rslearn/data_sources/data_source.py +26 -0
- rslearn/data_sources/earthdaily.py +13 -38
- rslearn/data_sources/earthdata_srtm.py +14 -32
- rslearn/data_sources/eurocrops.py +5 -9
- rslearn/data_sources/gcp_public_data.py +46 -43
- rslearn/data_sources/google_earth_engine.py +31 -44
- rslearn/data_sources/local_files.py +91 -100
- rslearn/data_sources/openstreetmap.py +21 -51
- rslearn/data_sources/planet.py +12 -30
- rslearn/data_sources/planet_basemap.py +4 -25
- rslearn/data_sources/planetary_computer.py +58 -141
- rslearn/data_sources/usda_cdl.py +15 -26
- rslearn/data_sources/usgs_landsat.py +4 -29
- rslearn/data_sources/utils.py +9 -0
- rslearn/data_sources/worldcereal.py +47 -54
- rslearn/data_sources/worldcover.py +16 -14
- rslearn/data_sources/worldpop.py +15 -18
- rslearn/data_sources/xyz_tiles.py +11 -30
- rslearn/dataset/dataset.py +6 -6
- rslearn/dataset/manage.py +28 -26
- rslearn/dataset/materialize.py +9 -45
- rslearn/lightning_cli.py +370 -1
- rslearn/main.py +3 -3
- rslearn/models/clay/clay.py +14 -1
- rslearn/models/concatenate_features.py +93 -0
- rslearn/models/croma.py +26 -3
- rslearn/models/satlaspretrain.py +18 -4
- rslearn/models/terramind.py +19 -0
- rslearn/tile_stores/__init__.py +0 -11
- rslearn/train/dataset.py +4 -12
- rslearn/train/prediction_writer.py +16 -32
- rslearn/train/tasks/classification.py +2 -1
- rslearn/utils/fsspec.py +20 -0
- rslearn/utils/jsonargparse.py +79 -0
- rslearn/utils/raster_format.py +1 -41
- rslearn/utils/vector_format.py +1 -38
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/METADATA +1 -1
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/RECORD +51 -52
- rslearn/data_sources/geotiff.py +0 -1
- rslearn/data_sources/raster_source.py +0 -23
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/WHEEL +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/top_level.txt +0 -0
|
@@ -9,7 +9,8 @@ import requests
|
|
|
9
9
|
from fsspec.implementations.local import LocalFileSystem
|
|
10
10
|
from upath import UPath
|
|
11
11
|
|
|
12
|
-
from rslearn.config import
|
|
12
|
+
from rslearn.config import LayerType
|
|
13
|
+
from rslearn.data_sources import DataSourceContext
|
|
13
14
|
from rslearn.data_sources.local_files import LocalFiles
|
|
14
15
|
from rslearn.log_utils import get_logger
|
|
15
16
|
from rslearn.utils.fsspec import get_upath_local, join_upath, open_atomic
|
|
@@ -52,8 +53,8 @@ class WorldCover(LocalFiles):
|
|
|
52
53
|
|
|
53
54
|
def __init__(
|
|
54
55
|
self,
|
|
55
|
-
|
|
56
|
-
|
|
56
|
+
worldcover_dir: str,
|
|
57
|
+
context: DataSourceContext = DataSourceContext(),
|
|
57
58
|
) -> None:
|
|
58
59
|
"""Create a new WorldCover.
|
|
59
60
|
|
|
@@ -64,18 +65,19 @@ class WorldCover(LocalFiles):
|
|
|
64
65
|
high performance, this should be a local directory; if the dataset is
|
|
65
66
|
remote, prefix with a protocol ("file://") to use a local directory
|
|
66
67
|
instead of a path relative to the dataset path.
|
|
68
|
+
context: the data source context.
|
|
67
69
|
"""
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
70
|
+
if context.ds_path is not None:
|
|
71
|
+
worldcover_upath = join_upath(context.ds_path, worldcover_dir)
|
|
72
|
+
else:
|
|
73
|
+
worldcover_upath = UPath(worldcover_dir)
|
|
74
|
+
|
|
75
|
+
tif_dir = self.download_worldcover_data(worldcover_upath)
|
|
76
|
+
|
|
77
|
+
super().__init__(
|
|
78
|
+
src_dir=tif_dir,
|
|
79
|
+
layer_type=LayerType.RASTER,
|
|
80
|
+
context=context,
|
|
79
81
|
)
|
|
80
82
|
|
|
81
83
|
def download_worldcover_data(self, worldcover_dir: UPath) -> UPath:
|
rslearn/data_sources/worldpop.py
CHANGED
|
@@ -6,10 +6,10 @@ from html.parser import HTMLParser
|
|
|
6
6
|
from urllib.parse import urljoin
|
|
7
7
|
|
|
8
8
|
import requests
|
|
9
|
-
import requests.auth
|
|
10
9
|
from upath import UPath
|
|
11
10
|
|
|
12
|
-
from rslearn.config import
|
|
11
|
+
from rslearn.config import LayerType
|
|
12
|
+
from rslearn.data_sources import DataSourceContext
|
|
13
13
|
from rslearn.data_sources.local_files import LocalFiles
|
|
14
14
|
from rslearn.log_utils import get_logger
|
|
15
15
|
from rslearn.utils.fsspec import join_upath, open_atomic
|
|
@@ -59,33 +59,30 @@ class WorldPop(LocalFiles):
|
|
|
59
59
|
|
|
60
60
|
def __init__(
|
|
61
61
|
self,
|
|
62
|
-
|
|
63
|
-
worldpop_dir: UPath,
|
|
62
|
+
worldpop_dir: str,
|
|
64
63
|
timeout: timedelta = timedelta(seconds=30),
|
|
64
|
+
context: DataSourceContext = DataSourceContext(),
|
|
65
65
|
):
|
|
66
66
|
"""Create a new WorldPop.
|
|
67
67
|
|
|
68
68
|
Args:
|
|
69
|
-
config: configuration for this layer. It should specify a single band
|
|
70
|
-
called B1 which will contain the population counts.
|
|
71
69
|
worldpop_dir: the directory to extract the WorldPop GeoTIFF files. For
|
|
72
70
|
high performance, this should be a local directory; if the dataset is
|
|
73
71
|
remote, prefix with a protocol ("file://") to use a local directory
|
|
74
72
|
instead of a path relative to the dataset path.
|
|
75
73
|
timeout: timeout for HTTP requests.
|
|
74
|
+
context: the data source context.
|
|
76
75
|
"""
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
return WorldPop(
|
|
88
|
-
config=config, worldpop_dir=join_upath(ds_path, d["worldpop_dir"])
|
|
76
|
+
if context.ds_path is not None:
|
|
77
|
+
worldpop_upath = join_upath(context.ds_path, worldpop_dir)
|
|
78
|
+
else:
|
|
79
|
+
worldpop_upath = UPath(worldpop_dir)
|
|
80
|
+
worldpop_upath.mkdir(parents=True, exist_ok=True)
|
|
81
|
+
self.download_worldpop_data(worldpop_upath, timeout)
|
|
82
|
+
super().__init__(
|
|
83
|
+
src_dir=worldpop_upath,
|
|
84
|
+
layer_type=LayerType.RASTER,
|
|
85
|
+
context=context,
|
|
89
86
|
)
|
|
90
87
|
|
|
91
88
|
def download_worldpop_data(self, worldpop_dir: UPath, timeout: timedelta) -> None:
|
|
@@ -14,9 +14,8 @@ import shapely
|
|
|
14
14
|
from PIL import Image
|
|
15
15
|
from rasterio.crs import CRS
|
|
16
16
|
from rasterio.enums import Resampling
|
|
17
|
-
from upath import UPath
|
|
18
17
|
|
|
19
|
-
from rslearn.config import LayerConfig, QueryConfig
|
|
18
|
+
from rslearn.config import LayerConfig, QueryConfig
|
|
20
19
|
from rslearn.dataset import Window
|
|
21
20
|
from rslearn.dataset.materialize import RasterMaterializer
|
|
22
21
|
from rslearn.tile_stores import TileStore, TileStoreWithLayer
|
|
@@ -24,7 +23,7 @@ from rslearn.utils import PixelBounds, Projection, STGeometry
|
|
|
24
23
|
from rslearn.utils.array import copy_spatial_array
|
|
25
24
|
from rslearn.utils.raster_format import get_transform_from_projection_and_bounds
|
|
26
25
|
|
|
27
|
-
from .data_source import DataSource, Item
|
|
26
|
+
from .data_source import DataSource, DataSourceContext, Item
|
|
28
27
|
from .utils import match_candidate_items_to_window
|
|
29
28
|
|
|
30
29
|
WEB_MERCATOR_EPSG = 3857
|
|
@@ -96,11 +95,12 @@ class XyzTiles(DataSource, TileStore):
|
|
|
96
95
|
url_templates: list[str],
|
|
97
96
|
time_ranges: list[tuple[datetime, datetime]],
|
|
98
97
|
zoom: int,
|
|
99
|
-
crs: CRS = CRS.from_epsg(WEB_MERCATOR_EPSG),
|
|
98
|
+
crs: str | CRS = CRS.from_epsg(WEB_MERCATOR_EPSG),
|
|
100
99
|
total_units: float = WEB_MERCATOR_UNITS,
|
|
101
100
|
offset: float = WEB_MERCATOR_UNITS / 2,
|
|
102
101
|
tile_size: int = 256,
|
|
103
102
|
band_names: list[str] = ["R", "G", "B"],
|
|
103
|
+
context: DataSourceContext = DataSourceContext(),
|
|
104
104
|
):
|
|
105
105
|
"""Initialize an XyzTiles instance.
|
|
106
106
|
|
|
@@ -121,16 +121,22 @@ class XyzTiles(DataSource, TileStore):
|
|
|
121
121
|
offset: offset added to projection units when converting to tile positions.
|
|
122
122
|
tile_size: size in pixels of each tile. Tiles must be square.
|
|
123
123
|
band_names: what to name the bands that we read.
|
|
124
|
+
context: the data source context.
|
|
124
125
|
"""
|
|
125
126
|
self.url_templates = url_templates
|
|
126
127
|
self.time_ranges = time_ranges
|
|
127
128
|
self.zoom = zoom
|
|
128
|
-
self.crs = crs
|
|
129
129
|
self.total_units = total_units
|
|
130
130
|
self.offset = offset
|
|
131
131
|
self.tile_size = tile_size
|
|
132
132
|
self.band_names = band_names
|
|
133
133
|
|
|
134
|
+
# Convert to CRS if needed.
|
|
135
|
+
if isinstance(crs, str):
|
|
136
|
+
self.crs = CRS.from_string(crs)
|
|
137
|
+
else:
|
|
138
|
+
self.crs = crs
|
|
139
|
+
|
|
134
140
|
# Compute total number of pixels (a function of the zoom level and tile size).
|
|
135
141
|
self.total_pixels = tile_size * (2**zoom)
|
|
136
142
|
# Compute pixel size (resolution).
|
|
@@ -153,30 +159,6 @@ class XyzTiles(DataSource, TileStore):
|
|
|
153
159
|
item = Item(url_template, geometry)
|
|
154
160
|
self.items.append(item)
|
|
155
161
|
|
|
156
|
-
@staticmethod
|
|
157
|
-
def from_config(config: LayerConfig, ds_path: UPath) -> "XyzTiles":
|
|
158
|
-
"""Creates a new XyzTiles instance from a configuration dictionary."""
|
|
159
|
-
if config.data_source is None:
|
|
160
|
-
raise ValueError("data_source is required")
|
|
161
|
-
d = config.data_source.config_dict
|
|
162
|
-
time_ranges = []
|
|
163
|
-
for str1, str2 in d["time_ranges"]:
|
|
164
|
-
time1 = datetime.fromisoformat(str1)
|
|
165
|
-
time2 = datetime.fromisoformat(str2)
|
|
166
|
-
time_ranges.append((time1, time2))
|
|
167
|
-
kwargs = dict(
|
|
168
|
-
url_templates=d["url_templates"], zoom=d["zoom"], time_ranges=time_ranges
|
|
169
|
-
)
|
|
170
|
-
if "crs" in d:
|
|
171
|
-
kwargs["crs"] = CRS.from_string(d["crs"])
|
|
172
|
-
if "total_units" in d:
|
|
173
|
-
kwargs["total_units"] = d["total_units"]
|
|
174
|
-
if "offset" in d:
|
|
175
|
-
kwargs["offset"] = d["offset"]
|
|
176
|
-
if "tile_size" in d:
|
|
177
|
-
kwargs["tile_size"] = d["tile_size"]
|
|
178
|
-
return XyzTiles(**kwargs)
|
|
179
|
-
|
|
180
162
|
def get_items(
|
|
181
163
|
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
182
164
|
) -> list[list[list[Item]]]:
|
|
@@ -381,7 +363,6 @@ class XyzTiles(DataSource, TileStore):
|
|
|
381
363
|
layer_name: the name of this layer
|
|
382
364
|
layer_cfg: the config of this layer
|
|
383
365
|
"""
|
|
384
|
-
assert isinstance(layer_cfg, RasterLayerConfig)
|
|
385
366
|
RasterMaterializer().materialize(
|
|
386
367
|
TileStoreWithLayer(self, layer_name),
|
|
387
368
|
window,
|
rslearn/dataset/dataset.py
CHANGED
|
@@ -6,7 +6,7 @@ import multiprocessing
|
|
|
6
6
|
import tqdm
|
|
7
7
|
from upath import UPath
|
|
8
8
|
|
|
9
|
-
from rslearn.config import
|
|
9
|
+
from rslearn.config import DatasetConfig
|
|
10
10
|
from rslearn.log_utils import get_logger
|
|
11
11
|
from rslearn.template_params import substitute_env_vars_in_string
|
|
12
12
|
from rslearn.tile_stores import TileStore, load_tile_store
|
|
@@ -55,19 +55,19 @@ class Dataset:
|
|
|
55
55
|
with (self.path / "config.json").open("r") as f:
|
|
56
56
|
config_content = f.read()
|
|
57
57
|
config_content = substitute_env_vars_in_string(config_content)
|
|
58
|
-
config = json.loads(config_content)
|
|
58
|
+
config = DatasetConfig.model_validate(json.loads(config_content))
|
|
59
|
+
|
|
59
60
|
self.layers = {}
|
|
60
|
-
for layer_name,
|
|
61
|
+
for layer_name, layer_config in config.layers.items():
|
|
61
62
|
# Layer names must not contain period, since we use period to
|
|
62
63
|
# distinguish different materialized groups within a layer.
|
|
63
64
|
assert "." not in layer_name, "layer names must not contain periods"
|
|
64
65
|
if layer_name in disabled_layers:
|
|
65
66
|
logger.warning(f"Layer {layer_name} is disabled")
|
|
66
67
|
continue
|
|
67
|
-
self.layers[layer_name] =
|
|
68
|
+
self.layers[layer_name] = layer_config
|
|
68
69
|
|
|
69
|
-
self.tile_store_config = config.
|
|
70
|
-
self.materializer_name = config.get("materialize")
|
|
70
|
+
self.tile_store_config = config.tile_store
|
|
71
71
|
|
|
72
72
|
def _get_index(self) -> DatasetIndex | None:
|
|
73
73
|
index_fname = self.path / DatasetIndex.FNAME
|
rslearn/dataset/manage.py
CHANGED
|
@@ -6,11 +6,9 @@ from collections.abc import Callable
|
|
|
6
6
|
from datetime import timedelta
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
9
|
-
import rslearn.data_sources
|
|
10
9
|
from rslearn.config import (
|
|
11
10
|
LayerConfig,
|
|
12
11
|
LayerType,
|
|
13
|
-
RasterLayerConfig,
|
|
14
12
|
)
|
|
15
13
|
from rslearn.data_sources import DataSource, Item
|
|
16
14
|
from rslearn.dataset.handler_summaries import (
|
|
@@ -24,7 +22,7 @@ from rslearn.log_utils import get_logger
|
|
|
24
22
|
from rslearn.tile_stores import TileStore, get_tile_store_with_layer
|
|
25
23
|
|
|
26
24
|
from .dataset import Dataset
|
|
27
|
-
from .materialize import
|
|
25
|
+
from .materialize import Materializer, RasterMaterializer, VectorMaterializer
|
|
28
26
|
from .window import Window, WindowLayerData
|
|
29
27
|
|
|
30
28
|
logger = get_logger(__name__)
|
|
@@ -124,12 +122,24 @@ def prepare_dataset_windows(
|
|
|
124
122
|
)
|
|
125
123
|
continue
|
|
126
124
|
data_source_cfg = layer_cfg.data_source
|
|
125
|
+
min_matches = data_source_cfg.query_config.min_matches
|
|
127
126
|
|
|
128
127
|
# Get windows that need to be prepared for this layer.
|
|
128
|
+
# Also track which windows are skipped vs previously rejected.
|
|
129
129
|
needed_windows = []
|
|
130
|
+
windows_skipped = 0
|
|
131
|
+
windows_rejected = 0
|
|
130
132
|
for window in windows:
|
|
131
133
|
layer_datas = window.load_layer_datas()
|
|
132
134
|
if layer_name in layer_datas and not force:
|
|
135
|
+
# Window already has layer data - check if it was previously rejected
|
|
136
|
+
layer_data = layer_datas[layer_name]
|
|
137
|
+
if len(layer_data.serialized_item_groups) == 0 and min_matches > 0:
|
|
138
|
+
# Previously rejected due to min_matches
|
|
139
|
+
windows_rejected += 1
|
|
140
|
+
else:
|
|
141
|
+
# Successfully prepared previously
|
|
142
|
+
windows_skipped += 1
|
|
133
143
|
continue
|
|
134
144
|
needed_windows.append(window)
|
|
135
145
|
logger.info(f"Preparing {len(needed_windows)} windows for layer {layer_name}")
|
|
@@ -138,11 +148,11 @@ def prepare_dataset_windows(
|
|
|
138
148
|
layer_summaries.append(
|
|
139
149
|
LayerPrepareSummary(
|
|
140
150
|
layer_name=layer_name,
|
|
141
|
-
data_source_name=data_source_cfg.
|
|
151
|
+
data_source_name=data_source_cfg.class_path,
|
|
142
152
|
duration_seconds=time.monotonic() - layer_start_time,
|
|
143
153
|
windows_prepared=0,
|
|
144
|
-
windows_skipped=
|
|
145
|
-
windows_rejected=
|
|
154
|
+
windows_skipped=windows_skipped,
|
|
155
|
+
windows_rejected=windows_rejected,
|
|
146
156
|
get_items_attempts=0,
|
|
147
157
|
)
|
|
148
158
|
)
|
|
@@ -150,9 +160,7 @@ def prepare_dataset_windows(
|
|
|
150
160
|
|
|
151
161
|
# Create data source after checking for at least one window so it can be fast
|
|
152
162
|
# if there are no windows to prepare.
|
|
153
|
-
data_source =
|
|
154
|
-
layer_cfg, dataset.path
|
|
155
|
-
)
|
|
163
|
+
data_source = layer_cfg.instantiate_data_source(dataset.path)
|
|
156
164
|
|
|
157
165
|
# Get STGeometry for each window.
|
|
158
166
|
geometries = []
|
|
@@ -184,8 +192,6 @@ def prepare_dataset_windows(
|
|
|
184
192
|
)
|
|
185
193
|
|
|
186
194
|
windows_prepared = 0
|
|
187
|
-
windows_rejected = 0
|
|
188
|
-
min_matches = data_source_cfg.query_config.min_matches
|
|
189
195
|
for window, result in zip(needed_windows, results):
|
|
190
196
|
layer_datas = window.load_layer_datas()
|
|
191
197
|
layer_datas[layer_name] = WindowLayerData(
|
|
@@ -202,12 +208,10 @@ def prepare_dataset_windows(
|
|
|
202
208
|
else:
|
|
203
209
|
windows_prepared += 1
|
|
204
210
|
|
|
205
|
-
windows_skipped = len(windows) - len(needed_windows)
|
|
206
|
-
|
|
207
211
|
layer_summaries.append(
|
|
208
212
|
LayerPrepareSummary(
|
|
209
213
|
layer_name=layer_name,
|
|
210
|
-
data_source_name=data_source_cfg.
|
|
214
|
+
data_source_name=data_source_cfg.class_path,
|
|
211
215
|
duration_seconds=time.monotonic() - layer_start_time,
|
|
212
216
|
windows_prepared=windows_prepared,
|
|
213
217
|
windows_skipped=windows_skipped,
|
|
@@ -250,9 +254,7 @@ def ingest_dataset_windows(
|
|
|
250
254
|
if not layer_cfg.data_source.ingest:
|
|
251
255
|
continue
|
|
252
256
|
|
|
253
|
-
data_source =
|
|
254
|
-
layer_cfg, dataset.path
|
|
255
|
-
)
|
|
257
|
+
data_source = layer_cfg.instantiate_data_source(dataset.path)
|
|
256
258
|
|
|
257
259
|
geometries_by_item: dict = {}
|
|
258
260
|
for window in windows:
|
|
@@ -314,8 +316,7 @@ def is_window_ingested(
|
|
|
314
316
|
for serialized_item in group:
|
|
315
317
|
item = Item.deserialize(serialized_item)
|
|
316
318
|
|
|
317
|
-
if layer_cfg.
|
|
318
|
-
assert isinstance(layer_cfg, RasterLayerConfig)
|
|
319
|
+
if layer_cfg.type == LayerType.RASTER:
|
|
319
320
|
for band_set in layer_cfg.band_sets:
|
|
320
321
|
# Make sure that layers exist containing each configured band.
|
|
321
322
|
# And that those layers are marked completed.
|
|
@@ -409,10 +410,13 @@ def materialize_window(
|
|
|
409
410
|
f"Materializing {len(item_groups)} item groups in layer {layer_name} from tile store"
|
|
410
411
|
)
|
|
411
412
|
|
|
412
|
-
|
|
413
|
-
|
|
413
|
+
materializer: Materializer
|
|
414
|
+
if layer_cfg.type == LayerType.RASTER:
|
|
415
|
+
materializer = RasterMaterializer()
|
|
416
|
+
elif layer_cfg.type == LayerType.VECTOR:
|
|
417
|
+
materializer = VectorMaterializer()
|
|
414
418
|
else:
|
|
415
|
-
|
|
419
|
+
raise ValueError(f"unknown layer type {layer_cfg.type}")
|
|
416
420
|
|
|
417
421
|
retry(
|
|
418
422
|
fn=lambda: materializer.materialize(
|
|
@@ -483,10 +487,8 @@ def materialize_dataset_windows(
|
|
|
483
487
|
if not layer_cfg.data_source:
|
|
484
488
|
total_skipped = len(windows)
|
|
485
489
|
else:
|
|
486
|
-
data_source_name = layer_cfg.data_source.
|
|
487
|
-
data_source =
|
|
488
|
-
layer_cfg, dataset.path
|
|
489
|
-
)
|
|
490
|
+
data_source_name = layer_cfg.data_source.class_path
|
|
491
|
+
data_source = layer_cfg.instantiate_data_source(dataset.path)
|
|
490
492
|
|
|
491
493
|
for window in windows:
|
|
492
494
|
window_summary = materialize_window(
|
rslearn/dataset/materialize.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
"""Classes to implement dataset materialization."""
|
|
2
2
|
|
|
3
|
-
from
|
|
4
|
-
from typing import Any, Generic, TypeVar
|
|
3
|
+
from typing import Any
|
|
5
4
|
|
|
6
5
|
import numpy as np
|
|
7
6
|
import numpy.typing as npt
|
|
@@ -11,45 +10,17 @@ from rslearn.config import (
|
|
|
11
10
|
BandSetConfig,
|
|
12
11
|
CompositingMethod,
|
|
13
12
|
LayerConfig,
|
|
14
|
-
RasterFormatConfig,
|
|
15
|
-
RasterLayerConfig,
|
|
16
|
-
VectorLayerConfig,
|
|
17
13
|
)
|
|
18
14
|
from rslearn.data_sources.data_source import ItemType
|
|
19
15
|
from rslearn.tile_stores import TileStoreWithLayer
|
|
20
16
|
from rslearn.utils.feature import Feature
|
|
21
17
|
from rslearn.utils.geometry import PixelBounds, Projection
|
|
22
|
-
from rslearn.utils.raster_format import load_raster_format
|
|
23
|
-
from rslearn.utils.vector_format import load_vector_format
|
|
24
18
|
|
|
25
19
|
from .remap import Remapper, load_remapper
|
|
26
20
|
from .window import Window
|
|
27
21
|
|
|
28
|
-
_MaterializerT = TypeVar("_MaterializerT", bound="Materializer")
|
|
29
22
|
|
|
30
|
-
|
|
31
|
-
class _MaterializerRegistry(dict[str, type["Materializer"]]):
|
|
32
|
-
"""Registry for Materializer classes."""
|
|
33
|
-
|
|
34
|
-
def register(
|
|
35
|
-
self, name: str
|
|
36
|
-
) -> Callable[[type[_MaterializerT]], type[_MaterializerT]]:
|
|
37
|
-
"""Decorator to register a materializer class."""
|
|
38
|
-
|
|
39
|
-
def decorator(cls: type[_MaterializerT]) -> type[_MaterializerT]:
|
|
40
|
-
self[name] = cls
|
|
41
|
-
return cls
|
|
42
|
-
|
|
43
|
-
return decorator
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
Materializers = _MaterializerRegistry()
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
LayerConfigType = TypeVar("LayerConfigType", bound=LayerConfig)
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
class Materializer(Generic[LayerConfigType]):
|
|
23
|
+
class Materializer:
|
|
53
24
|
"""An abstract class that materializes data from a tile store."""
|
|
54
25
|
|
|
55
26
|
def materialize(
|
|
@@ -57,7 +28,7 @@ class Materializer(Generic[LayerConfigType]):
|
|
|
57
28
|
tile_store: TileStoreWithLayer,
|
|
58
29
|
window: Window,
|
|
59
30
|
layer_name: str,
|
|
60
|
-
layer_cfg:
|
|
31
|
+
layer_cfg: LayerConfig,
|
|
61
32
|
item_groups: list[list[ItemType]],
|
|
62
33
|
) -> None:
|
|
63
34
|
"""Materialize portions of items corresponding to this window into the dataset.
|
|
@@ -473,7 +444,7 @@ def build_composite(
|
|
|
473
444
|
group: list[ItemType],
|
|
474
445
|
compositing_method: CompositingMethod,
|
|
475
446
|
tile_store: TileStoreWithLayer,
|
|
476
|
-
layer_cfg:
|
|
447
|
+
layer_cfg: LayerConfig,
|
|
477
448
|
band_cfg: BandSetConfig,
|
|
478
449
|
projection: Projection,
|
|
479
450
|
bounds: PixelBounds,
|
|
@@ -503,13 +474,12 @@ def build_composite(
|
|
|
503
474
|
band_dtype=band_cfg.dtype.value,
|
|
504
475
|
tile_store=tile_store,
|
|
505
476
|
projection=projection,
|
|
506
|
-
resampling_method=layer_cfg.resampling_method,
|
|
477
|
+
resampling_method=layer_cfg.resampling_method.get_rasterio_resampling(),
|
|
507
478
|
remapper=remapper,
|
|
508
479
|
)
|
|
509
480
|
|
|
510
481
|
|
|
511
|
-
|
|
512
|
-
class RasterMaterializer(Materializer[RasterLayerConfig]):
|
|
482
|
+
class RasterMaterializer(Materializer):
|
|
513
483
|
"""A Materializer for raster data."""
|
|
514
484
|
|
|
515
485
|
def materialize(
|
|
@@ -517,7 +487,7 @@ class RasterMaterializer(Materializer[RasterLayerConfig]):
|
|
|
517
487
|
tile_store: TileStoreWithLayer,
|
|
518
488
|
window: Window,
|
|
519
489
|
layer_name: str,
|
|
520
|
-
layer_cfg:
|
|
490
|
+
layer_cfg: LayerConfig,
|
|
521
491
|
item_groups: list[list[ItemType]],
|
|
522
492
|
) -> None:
|
|
523
493
|
"""Materialize portions of items corresponding to this window into the dataset.
|
|
@@ -529,8 +499,6 @@ class RasterMaterializer(Materializer[RasterLayerConfig]):
|
|
|
529
499
|
layer_cfg: the configuration of the layer to materialize
|
|
530
500
|
item_groups: the items associated with this window and layer
|
|
531
501
|
"""
|
|
532
|
-
assert isinstance(layer_cfg, RasterLayerConfig)
|
|
533
|
-
|
|
534
502
|
for band_cfg in layer_cfg.band_sets:
|
|
535
503
|
# band_cfg could specify zoom_offset and maybe other parameters that affect
|
|
536
504
|
# projection/bounds, so use the corrected projection/bounds.
|
|
@@ -543,9 +511,7 @@ class RasterMaterializer(Materializer[RasterLayerConfig]):
|
|
|
543
511
|
if band_cfg.remap:
|
|
544
512
|
remapper = load_remapper(band_cfg.remap)
|
|
545
513
|
|
|
546
|
-
raster_format =
|
|
547
|
-
RasterFormatConfig(band_cfg.format["name"], band_cfg.format)
|
|
548
|
-
)
|
|
514
|
+
raster_format = band_cfg.instantiate_raster_format()
|
|
549
515
|
|
|
550
516
|
for group_id, group in enumerate(item_groups):
|
|
551
517
|
composite = build_composite(
|
|
@@ -569,7 +535,6 @@ class RasterMaterializer(Materializer[RasterLayerConfig]):
|
|
|
569
535
|
window.mark_layer_completed(layer_name, group_id)
|
|
570
536
|
|
|
571
537
|
|
|
572
|
-
@Materializers.register("vector")
|
|
573
538
|
class VectorMaterializer(Materializer):
|
|
574
539
|
"""A Materializer for vector data."""
|
|
575
540
|
|
|
@@ -590,8 +555,7 @@ class VectorMaterializer(Materializer):
|
|
|
590
555
|
layer_cfg: the configuration of the layer to materialize
|
|
591
556
|
item_groups: the items associated with this window and layer
|
|
592
557
|
"""
|
|
593
|
-
|
|
594
|
-
vector_format = load_vector_format(layer_cfg.format)
|
|
558
|
+
vector_format = layer_cfg.instantiate_vector_format()
|
|
595
559
|
|
|
596
560
|
for group_id, group in enumerate(item_groups):
|
|
597
561
|
features: list[Feature] = []
|