rslearn 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/__init__.py +2 -10
- rslearn/config/dataset.py +420 -420
- rslearn/data_sources/__init__.py +8 -31
- rslearn/data_sources/aws_landsat.py +13 -24
- rslearn/data_sources/aws_open_data.py +21 -46
- rslearn/data_sources/aws_sentinel1.py +3 -14
- rslearn/data_sources/climate_data_store.py +21 -40
- rslearn/data_sources/copernicus.py +30 -91
- rslearn/data_sources/data_source.py +26 -0
- rslearn/data_sources/earthdaily.py +13 -38
- rslearn/data_sources/earthdata_srtm.py +14 -32
- rslearn/data_sources/eurocrops.py +5 -9
- rslearn/data_sources/gcp_public_data.py +46 -43
- rslearn/data_sources/google_earth_engine.py +31 -44
- rslearn/data_sources/local_files.py +91 -100
- rslearn/data_sources/openstreetmap.py +21 -51
- rslearn/data_sources/planet.py +12 -30
- rslearn/data_sources/planet_basemap.py +4 -25
- rslearn/data_sources/planetary_computer.py +58 -141
- rslearn/data_sources/usda_cdl.py +15 -26
- rslearn/data_sources/usgs_landsat.py +4 -29
- rslearn/data_sources/utils.py +9 -0
- rslearn/data_sources/worldcereal.py +47 -54
- rslearn/data_sources/worldcover.py +16 -14
- rslearn/data_sources/worldpop.py +15 -18
- rslearn/data_sources/xyz_tiles.py +11 -30
- rslearn/dataset/dataset.py +6 -6
- rslearn/dataset/manage.py +14 -20
- rslearn/dataset/materialize.py +9 -45
- rslearn/lightning_cli.py +377 -1
- rslearn/main.py +3 -3
- rslearn/models/concatenate_features.py +93 -0
- rslearn/models/olmoearth_pretrain/model.py +2 -5
- rslearn/tile_stores/__init__.py +0 -11
- rslearn/train/dataset.py +4 -12
- rslearn/train/prediction_writer.py +16 -32
- rslearn/train/tasks/classification.py +2 -1
- rslearn/utils/fsspec.py +20 -0
- rslearn/utils/jsonargparse.py +79 -0
- rslearn/utils/raster_format.py +1 -41
- rslearn/utils/vector_format.py +1 -38
- {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/METADATA +58 -25
- {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/RECORD +48 -49
- rslearn/data_sources/geotiff.py +0 -1
- rslearn/data_sources/raster_source.py +0 -23
- {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/WHEEL +0 -0
- {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
import functools
|
|
4
4
|
import json
|
|
5
|
-
from collections.abc import Callable
|
|
6
5
|
from typing import Any, Generic, TypeVar
|
|
7
6
|
|
|
8
7
|
import fiona
|
|
@@ -12,58 +11,43 @@ from rasterio.crs import CRS
|
|
|
12
11
|
from upath import UPath
|
|
13
12
|
|
|
14
13
|
import rslearn.data_sources.utils
|
|
15
|
-
from rslearn.config import
|
|
14
|
+
from rslearn.config import LayerType
|
|
16
15
|
from rslearn.const import SHAPEFILE_AUX_EXTENSIONS
|
|
17
16
|
from rslearn.log_utils import get_logger
|
|
18
17
|
from rslearn.tile_stores import TileStoreWithLayer
|
|
19
18
|
from rslearn.utils.feature import Feature
|
|
20
|
-
from rslearn.utils.fsspec import
|
|
19
|
+
from rslearn.utils.fsspec import (
|
|
20
|
+
get_relative_suffix,
|
|
21
|
+
get_upath_local,
|
|
22
|
+
join_upath,
|
|
23
|
+
open_rasterio_upath_reader,
|
|
24
|
+
)
|
|
21
25
|
from rslearn.utils.geometry import Projection, STGeometry, get_global_geometry
|
|
22
26
|
|
|
23
|
-
from .data_source import DataSource, Item, QueryConfig
|
|
27
|
+
from .data_source import DataSource, DataSourceContext, Item, QueryConfig
|
|
24
28
|
|
|
25
29
|
logger = get_logger("__name__")
|
|
26
|
-
_ImporterT = TypeVar("_ImporterT", bound="Importer")
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class _ImporterRegistry(dict[str, type["Importer"]]):
|
|
30
|
-
"""Registry for Importer classes."""
|
|
31
|
-
|
|
32
|
-
def register(self, name: str) -> Callable[[type[_ImporterT]], type[_ImporterT]]:
|
|
33
|
-
"""Decorator to register an importer class."""
|
|
34
|
-
|
|
35
|
-
def decorator(cls: type[_ImporterT]) -> type[_ImporterT]:
|
|
36
|
-
self[name] = cls
|
|
37
|
-
return cls
|
|
38
|
-
|
|
39
|
-
return decorator
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
Importers = _ImporterRegistry()
|
|
43
30
|
|
|
44
31
|
|
|
45
32
|
ItemType = TypeVar("ItemType", bound=Item)
|
|
46
|
-
LayerConfigType = TypeVar("LayerConfigType", bound=LayerConfig)
|
|
47
33
|
ImporterType = TypeVar("ImporterType", bound="Importer")
|
|
48
34
|
|
|
49
35
|
SOURCE_NAME = "rslearn.data_sources.local_files.LocalFiles"
|
|
50
36
|
|
|
51
37
|
|
|
52
|
-
class Importer(Generic[ItemType
|
|
38
|
+
class Importer(Generic[ItemType]):
|
|
53
39
|
"""An abstract base class for importing data from local files."""
|
|
54
40
|
|
|
55
|
-
def list_items(self,
|
|
41
|
+
def list_items(self, src_dir: UPath) -> list[ItemType]:
|
|
56
42
|
"""Extract a list of Items from the source directory.
|
|
57
43
|
|
|
58
44
|
Args:
|
|
59
|
-
config: the configuration of the layer.
|
|
60
45
|
src_dir: the source directory.
|
|
61
46
|
"""
|
|
62
47
|
raise NotImplementedError
|
|
63
48
|
|
|
64
49
|
def ingest_item(
|
|
65
50
|
self,
|
|
66
|
-
config: LayerConfigType,
|
|
67
51
|
tile_store: TileStoreWithLayer,
|
|
68
52
|
item: ItemType,
|
|
69
53
|
cur_geometries: list[STGeometry],
|
|
@@ -71,7 +55,6 @@ class Importer(Generic[ItemType, LayerConfigType]):
|
|
|
71
55
|
"""Ingest the specified local file item.
|
|
72
56
|
|
|
73
57
|
Args:
|
|
74
|
-
config: the configuration of the layer.
|
|
75
58
|
tile_store: the tile store to ingest the data into.
|
|
76
59
|
item: the Item to ingest
|
|
77
60
|
cur_geometries: the geometries where the item is needed.
|
|
@@ -84,7 +67,7 @@ class RasterItemSpec:
|
|
|
84
67
|
|
|
85
68
|
def __init__(
|
|
86
69
|
self,
|
|
87
|
-
fnames: list[
|
|
70
|
+
fnames: list[str],
|
|
88
71
|
bands: list[list[str]] | None = None,
|
|
89
72
|
name: str | None = None,
|
|
90
73
|
):
|
|
@@ -99,25 +82,6 @@ class RasterItemSpec:
|
|
|
99
82
|
self.bands = bands
|
|
100
83
|
self.name = name
|
|
101
84
|
|
|
102
|
-
@staticmethod
|
|
103
|
-
def from_config(src_dir: UPath, d: dict[str, Any]) -> "RasterItemSpec":
|
|
104
|
-
"""Decode a dict into a RasterItemSpec.
|
|
105
|
-
|
|
106
|
-
Args:
|
|
107
|
-
src_dir: the source directory.
|
|
108
|
-
d: the configuration dict.
|
|
109
|
-
|
|
110
|
-
Returns:
|
|
111
|
-
the RasterItemSpec.
|
|
112
|
-
"""
|
|
113
|
-
kwargs = dict(
|
|
114
|
-
fnames=[join_upath(src_dir, suffix) for suffix in d["fnames"]],
|
|
115
|
-
bands=d["bands"],
|
|
116
|
-
)
|
|
117
|
-
if "name" in d:
|
|
118
|
-
kwargs["name"] = d["name"]
|
|
119
|
-
return RasterItemSpec(**kwargs)
|
|
120
|
-
|
|
121
85
|
def serialize(self) -> dict[str, Any]:
|
|
122
86
|
"""Serializes the RasterItemSpec to a JSON-encodable dictionary."""
|
|
123
87
|
return {
|
|
@@ -130,7 +94,7 @@ class RasterItemSpec:
|
|
|
130
94
|
def deserialize(d: dict[str, Any]) -> "RasterItemSpec":
|
|
131
95
|
"""Deserializes a RasterItemSpec from a JSON-decoded dictionary."""
|
|
132
96
|
return RasterItemSpec(
|
|
133
|
-
fnames=[
|
|
97
|
+
fnames=[s for s in d["fnames"]],
|
|
134
98
|
bands=d["bands"],
|
|
135
99
|
name=d["name"],
|
|
136
100
|
)
|
|
@@ -139,20 +103,25 @@ class RasterItemSpec:
|
|
|
139
103
|
class RasterItem(Item):
|
|
140
104
|
"""An item corresponding to a local file."""
|
|
141
105
|
|
|
142
|
-
def __init__(
|
|
143
|
-
|
|
106
|
+
def __init__(
|
|
107
|
+
self, name: str, geometry: STGeometry, src_dir: str, spec: RasterItemSpec
|
|
108
|
+
):
|
|
109
|
+
"""Creates a new RasterItem.
|
|
144
110
|
|
|
145
111
|
Args:
|
|
146
112
|
name: unique name of the item
|
|
147
113
|
geometry: the spatial and temporal extent of the item
|
|
114
|
+
src_dir: the source directory.
|
|
148
115
|
spec: the RasterItemSpec that specifies the filename(s) and bands.
|
|
149
116
|
"""
|
|
150
117
|
super().__init__(name, geometry)
|
|
118
|
+
self.src_dir = src_dir
|
|
151
119
|
self.spec = spec
|
|
152
120
|
|
|
153
121
|
def serialize(self) -> dict:
|
|
154
122
|
"""Serializes the item to a JSON-encodable dictionary."""
|
|
155
123
|
d = super().serialize()
|
|
124
|
+
d["src_dir"] = str(self.src_dir)
|
|
156
125
|
d["spec"] = self.spec.serialize()
|
|
157
126
|
return d
|
|
158
127
|
|
|
@@ -160,8 +129,11 @@ class RasterItem(Item):
|
|
|
160
129
|
def deserialize(d: dict) -> "RasterItem":
|
|
161
130
|
"""Deserializes an item from a JSON-decoded dictionary."""
|
|
162
131
|
item = super(RasterItem, RasterItem).deserialize(d)
|
|
132
|
+
src_dir = d["src_dir"]
|
|
163
133
|
spec = RasterItemSpec.deserialize(d["spec"])
|
|
164
|
-
return RasterItem(
|
|
134
|
+
return RasterItem(
|
|
135
|
+
name=item.name, geometry=item.geometry, src_dir=src_dir, spec=spec
|
|
136
|
+
)
|
|
165
137
|
|
|
166
138
|
|
|
167
139
|
class VectorItem(Item):
|
|
@@ -193,29 +165,34 @@ class VectorItem(Item):
|
|
|
193
165
|
)
|
|
194
166
|
|
|
195
167
|
|
|
196
|
-
@Importers.register("raster")
|
|
197
168
|
class RasterImporter(Importer):
|
|
198
169
|
"""An Importer for raster data."""
|
|
199
170
|
|
|
200
|
-
def
|
|
171
|
+
def __init__(self, item_specs: list[RasterItemSpec] | None = None):
|
|
172
|
+
"""Create a new RasterImporter.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
item_specs: the specs to specify the raster items directly. If None, the
|
|
176
|
+
raster items are automatically detected from the files in the source
|
|
177
|
+
directory.
|
|
178
|
+
"""
|
|
179
|
+
self.item_specs = item_specs
|
|
180
|
+
|
|
181
|
+
def list_items(self, src_dir: UPath) -> list[Item]:
|
|
201
182
|
"""Extract a list of Items from the source directory.
|
|
202
183
|
|
|
203
184
|
Args:
|
|
204
|
-
config: the configuration of the layer.
|
|
205
185
|
src_dir: the source directory.
|
|
206
186
|
"""
|
|
207
|
-
item_specs: list[RasterItemSpec]
|
|
187
|
+
item_specs: list[RasterItemSpec]
|
|
188
|
+
|
|
208
189
|
# See if user has provided the item specs directly.
|
|
209
|
-
if
|
|
210
|
-
|
|
211
|
-
and "item_specs" in config.data_source.config_dict
|
|
212
|
-
):
|
|
213
|
-
for spec_dict in config.data_source.config_dict["item_specs"]:
|
|
214
|
-
spec = RasterItemSpec.from_config(src_dir, spec_dict)
|
|
215
|
-
item_specs.append(spec)
|
|
190
|
+
if self.item_specs is not None:
|
|
191
|
+
item_specs = self.item_specs
|
|
216
192
|
else:
|
|
217
193
|
# Otherwise we need to list files and assume each one is separate.
|
|
218
194
|
# And we'll need to autodetect the bands later.
|
|
195
|
+
item_specs = []
|
|
219
196
|
file_paths = src_dir.glob("**/*.*")
|
|
220
197
|
for path in file_paths:
|
|
221
198
|
# Ignore JSON files.
|
|
@@ -228,14 +205,17 @@ class RasterImporter(Importer):
|
|
|
228
205
|
if len(parts) >= 4 and parts[-2] == "tmp" and parts[-1].isdigit():
|
|
229
206
|
continue
|
|
230
207
|
|
|
231
|
-
spec = RasterItemSpec(
|
|
208
|
+
spec = RasterItemSpec(
|
|
209
|
+
fnames=[get_relative_suffix(src_dir, path)], bands=None
|
|
210
|
+
)
|
|
232
211
|
item_specs.append(spec)
|
|
233
212
|
|
|
234
|
-
items = []
|
|
213
|
+
items: list[Item] = []
|
|
235
214
|
for spec in item_specs:
|
|
236
215
|
# Get geometry from the first raster file.
|
|
237
216
|
# We assume files are readable with rasterio.
|
|
238
|
-
|
|
217
|
+
fname = join_upath(src_dir, spec.fnames[0])
|
|
218
|
+
with open_rasterio_upath_reader(fname) as src:
|
|
239
219
|
crs = src.crs
|
|
240
220
|
left = src.transform.c
|
|
241
221
|
top = src.transform.f
|
|
@@ -263,33 +243,33 @@ class RasterImporter(Importer):
|
|
|
263
243
|
if spec.name:
|
|
264
244
|
item_name = spec.name
|
|
265
245
|
else:
|
|
266
|
-
item_name =
|
|
246
|
+
item_name = fname.name.split(".")[0]
|
|
267
247
|
|
|
268
248
|
logger.debug(
|
|
269
249
|
"RasterImporter.list_items: got bounds of %s: %s", item_name, geometry
|
|
270
250
|
)
|
|
271
|
-
items.append(RasterItem(item_name, geometry, spec))
|
|
251
|
+
items.append(RasterItem(item_name, geometry, str(src_dir), spec))
|
|
272
252
|
|
|
273
253
|
logger.debug("RasterImporter.list_items: discovered %d items", len(items))
|
|
274
254
|
return items
|
|
275
255
|
|
|
276
256
|
def ingest_item(
|
|
277
257
|
self,
|
|
278
|
-
config: RasterLayerConfig,
|
|
279
258
|
tile_store: TileStoreWithLayer,
|
|
280
|
-
item:
|
|
259
|
+
item: Item,
|
|
281
260
|
cur_geometries: list[STGeometry],
|
|
282
261
|
) -> None:
|
|
283
262
|
"""Ingest the specified local file item.
|
|
284
263
|
|
|
285
264
|
Args:
|
|
286
|
-
config: the configuration of the layer.
|
|
287
265
|
tile_store: the tile store to ingest the data into.
|
|
288
266
|
item: the RasterItem to ingest
|
|
289
267
|
cur_geometries: the geometries where the item is needed.
|
|
290
268
|
"""
|
|
269
|
+
assert isinstance(item, RasterItem)
|
|
291
270
|
for file_idx, fname in enumerate(item.spec.fnames):
|
|
292
|
-
|
|
271
|
+
fname_upath = join_upath(UPath(item.src_dir), fname)
|
|
272
|
+
with open_rasterio_upath_reader(fname_upath) as src:
|
|
293
273
|
if item.spec.bands:
|
|
294
274
|
bands = item.spec.bands[file_idx]
|
|
295
275
|
else:
|
|
@@ -297,25 +277,23 @@ class RasterImporter(Importer):
|
|
|
297
277
|
|
|
298
278
|
if tile_store.is_raster_ready(item.name, bands):
|
|
299
279
|
continue
|
|
300
|
-
tile_store.write_raster_file(item.name, bands,
|
|
280
|
+
tile_store.write_raster_file(item.name, bands, fname_upath)
|
|
301
281
|
|
|
302
282
|
|
|
303
|
-
@Importers.register("vector")
|
|
304
283
|
class VectorImporter(Importer):
|
|
305
284
|
"""An Importer for vector data."""
|
|
306
285
|
|
|
307
286
|
# We need some buffer around GeoJSON bounds in case it just contains one point.
|
|
308
287
|
item_buffer_epsilon = 1e-4
|
|
309
288
|
|
|
310
|
-
def list_items(self,
|
|
289
|
+
def list_items(self, src_dir: UPath) -> list[Item]:
|
|
311
290
|
"""Extract a list of Items from the source directory.
|
|
312
291
|
|
|
313
292
|
Args:
|
|
314
|
-
config: the configuration of the layer.
|
|
315
293
|
src_dir: the source directory.
|
|
316
294
|
"""
|
|
317
295
|
file_paths = src_dir.glob("**/*.*")
|
|
318
|
-
items: list[
|
|
296
|
+
items: list[Item] = []
|
|
319
297
|
|
|
320
298
|
for path in file_paths:
|
|
321
299
|
# Ignore JSON files.
|
|
@@ -375,25 +353,21 @@ class VectorImporter(Importer):
|
|
|
375
353
|
|
|
376
354
|
def ingest_item(
|
|
377
355
|
self,
|
|
378
|
-
config: VectorLayerConfig,
|
|
379
356
|
tile_store: TileStoreWithLayer,
|
|
380
|
-
item:
|
|
357
|
+
item: Item,
|
|
381
358
|
cur_geometries: list[STGeometry],
|
|
382
359
|
) -> None:
|
|
383
360
|
"""Ingest the specified local file item.
|
|
384
361
|
|
|
385
362
|
Args:
|
|
386
|
-
config: the configuration of the layer.
|
|
387
363
|
tile_store: the TileStore to ingest the data into.
|
|
388
364
|
item: the Item to ingest
|
|
389
365
|
cur_geometries: the geometries where the item is needed.
|
|
390
366
|
"""
|
|
391
|
-
if not isinstance(config, VectorLayerConfig):
|
|
392
|
-
raise ValueError("VectorImporter requires a VectorLayerConfig")
|
|
393
|
-
|
|
394
367
|
if tile_store.is_vector_ready(item.name):
|
|
395
368
|
return
|
|
396
369
|
|
|
370
|
+
assert isinstance(item, VectorItem)
|
|
397
371
|
path = UPath(item.path_uri)
|
|
398
372
|
|
|
399
373
|
aux_files: list[UPath] = []
|
|
@@ -431,27 +405,44 @@ class LocalFiles(DataSource):
|
|
|
431
405
|
|
|
432
406
|
def __init__(
|
|
433
407
|
self,
|
|
434
|
-
|
|
435
|
-
|
|
408
|
+
src_dir: str,
|
|
409
|
+
raster_item_specs: list[RasterItemSpec] | None = None,
|
|
410
|
+
layer_type: LayerType | None = None,
|
|
411
|
+
context: DataSourceContext = DataSourceContext(),
|
|
436
412
|
) -> None:
|
|
437
413
|
"""Initialize a new LocalFiles instance.
|
|
438
414
|
|
|
439
415
|
Args:
|
|
440
|
-
config: configuration for this layer.
|
|
441
416
|
src_dir: source directory to ingest
|
|
417
|
+
raster_item_specs: the specs to specify the raster items directly. If None,
|
|
418
|
+
the raster items are automatically detected from the files in the
|
|
419
|
+
source directory.
|
|
420
|
+
layer_type: the layer type. It only needs to be set if the layer_config is
|
|
421
|
+
missing from the context.
|
|
422
|
+
context: the data source context. The layer config must be in the context.
|
|
442
423
|
"""
|
|
443
|
-
|
|
424
|
+
if context.ds_path is not None:
|
|
425
|
+
self.src_dir = join_upath(context.ds_path, src_dir)
|
|
426
|
+
else:
|
|
427
|
+
self.src_dir = UPath(src_dir)
|
|
444
428
|
|
|
445
|
-
|
|
446
|
-
|
|
429
|
+
# Determine layer type.
|
|
430
|
+
if context.layer_config is not None:
|
|
431
|
+
self.layer_type = context.layer_config.type
|
|
432
|
+
elif layer_type is not None:
|
|
433
|
+
self.layer_type = layer_type
|
|
434
|
+
else:
|
|
435
|
+
raise ValueError(
|
|
436
|
+
"layer type must be specified if the layer config is not in the context"
|
|
437
|
+
)
|
|
447
438
|
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
439
|
+
self.importer: Importer
|
|
440
|
+
if self.layer_type == LayerType.RASTER:
|
|
441
|
+
self.importer = RasterImporter(item_specs=raster_item_specs)
|
|
442
|
+
elif self.layer_type == LayerType.VECTOR:
|
|
443
|
+
self.importer = VectorImporter()
|
|
444
|
+
else:
|
|
445
|
+
raise ValueError(f"unknown layer type {self.layer_type}")
|
|
455
446
|
|
|
456
447
|
@functools.cache
|
|
457
448
|
def list_items(self) -> list[Item]:
|
|
@@ -459,7 +450,7 @@ class LocalFiles(DataSource):
|
|
|
459
450
|
cache_fname = self.src_dir / "summary.json"
|
|
460
451
|
if not cache_fname.exists():
|
|
461
452
|
logger.debug("cache at %s does not exist, listing items", cache_fname)
|
|
462
|
-
items = self.importer.list_items(self.
|
|
453
|
+
items = self.importer.list_items(self.src_dir)
|
|
463
454
|
serialized_items = [item.serialize() for item in items]
|
|
464
455
|
with cache_fname.open("w") as f:
|
|
465
456
|
json.dump(serialized_items, f)
|
|
@@ -501,12 +492,12 @@ class LocalFiles(DataSource):
|
|
|
501
492
|
|
|
502
493
|
def deserialize_item(self, serialized_item: Any) -> RasterItem | VectorItem:
|
|
503
494
|
"""Deserializes an item from JSON-decoded data."""
|
|
504
|
-
if self.
|
|
495
|
+
if self.layer_type == LayerType.RASTER:
|
|
505
496
|
return RasterItem.deserialize(serialized_item)
|
|
506
|
-
elif self.
|
|
497
|
+
elif self.layer_type == LayerType.VECTOR:
|
|
507
498
|
return VectorItem.deserialize(serialized_item)
|
|
508
499
|
else:
|
|
509
|
-
raise ValueError(f"Unknown layer type: {self.
|
|
500
|
+
raise ValueError(f"Unknown layer type: {self.layer_type}")
|
|
510
501
|
|
|
511
502
|
def ingest(
|
|
512
503
|
self,
|
|
@@ -522,4 +513,4 @@ class LocalFiles(DataSource):
|
|
|
522
513
|
geometries: a list of geometries needed for each item
|
|
523
514
|
"""
|
|
524
515
|
for item, cur_geometries in zip(items, geometries):
|
|
525
|
-
self.importer.ingest_item(
|
|
516
|
+
self.importer.ingest_item(tile_store, item, cur_geometries)
|
|
@@ -11,14 +11,17 @@ import osmium.osm.types
|
|
|
11
11
|
import shapely
|
|
12
12
|
from upath import UPath
|
|
13
13
|
|
|
14
|
-
from rslearn.config import QueryConfig
|
|
14
|
+
from rslearn.config import QueryConfig
|
|
15
15
|
from rslearn.const import WGS84_PROJECTION
|
|
16
|
-
from rslearn.data_sources import DataSource, Item
|
|
16
|
+
from rslearn.data_sources import DataSource, DataSourceContext, Item
|
|
17
17
|
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
18
|
+
from rslearn.log_utils import get_logger
|
|
18
19
|
from rslearn.tile_stores import TileStoreWithLayer
|
|
19
20
|
from rslearn.utils import Feature, GridIndex, STGeometry
|
|
20
21
|
from rslearn.utils.fsspec import get_upath_local, join_upath
|
|
21
22
|
|
|
23
|
+
logger = get_logger(__name__)
|
|
24
|
+
|
|
22
25
|
|
|
23
26
|
class FeatureType(Enum):
|
|
24
27
|
"""OpenStreetMap feature type."""
|
|
@@ -55,27 +58,6 @@ class Filter:
|
|
|
55
58
|
self.tag_properties = tag_properties
|
|
56
59
|
self.to_geometry = to_geometry
|
|
57
60
|
|
|
58
|
-
@staticmethod
|
|
59
|
-
def from_config(d: dict[str, Any]) -> "Filter":
|
|
60
|
-
"""Creates a Filter from a config dict.
|
|
61
|
-
|
|
62
|
-
Args:
|
|
63
|
-
d: the config dict
|
|
64
|
-
|
|
65
|
-
Returns:
|
|
66
|
-
the Filter object
|
|
67
|
-
"""
|
|
68
|
-
kwargs: dict[str, Any] = {}
|
|
69
|
-
if "feature_types" in d:
|
|
70
|
-
kwargs["feature_types"] = [FeatureType(el) for el in d["feature_types"]]
|
|
71
|
-
if "tag_conditions" in d:
|
|
72
|
-
kwargs["tag_conditions"] = d["tag_conditions"]
|
|
73
|
-
if "tag_properties" in d:
|
|
74
|
-
kwargs["tag_properties"] = d["tag_properties"]
|
|
75
|
-
if "to_geometry" in d:
|
|
76
|
-
kwargs["to_geometry"] = d["to_geometry"]
|
|
77
|
-
return Filter(**kwargs)
|
|
78
|
-
|
|
79
61
|
def match_tags(self, tags: dict[str, str]) -> bool:
|
|
80
62
|
"""Returns whether this filter matches based on the tags."""
|
|
81
63
|
if not self.tag_conditions:
|
|
@@ -387,10 +369,10 @@ class OpenStreetMap(DataSource[OsmItem]):
|
|
|
387
369
|
|
|
388
370
|
def __init__(
|
|
389
371
|
self,
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
bounds_fname: UPath,
|
|
372
|
+
pbf_fnames: list[str],
|
|
373
|
+
bounds_fname: str,
|
|
393
374
|
categories: dict[str, Filter],
|
|
375
|
+
context: DataSourceContext = DataSourceContext(),
|
|
394
376
|
):
|
|
395
377
|
"""Initialize a new OpenStreetMap instance.
|
|
396
378
|
|
|
@@ -402,14 +384,21 @@ class OpenStreetMap(DataSource[OsmItem]):
|
|
|
402
384
|
bounds_fname: filename where the bounds of the PBF are cached.
|
|
403
385
|
categories: dictionary of (category name, filter). Features that match the
|
|
404
386
|
filter will be emitted under the corresponding category.
|
|
387
|
+
context: the data source context.
|
|
405
388
|
"""
|
|
406
|
-
self.config = config
|
|
407
|
-
self.pbf_fnames = pbf_fnames
|
|
408
|
-
self.bounds_fname = bounds_fname
|
|
409
389
|
self.categories = categories
|
|
410
390
|
|
|
391
|
+
if context.ds_path is not None:
|
|
392
|
+
self.pbf_fnames = [
|
|
393
|
+
join_upath(context.ds_path, pbf_fname) for pbf_fname in pbf_fnames
|
|
394
|
+
]
|
|
395
|
+
self.bounds_fname = join_upath(context.ds_path, bounds_fname)
|
|
396
|
+
else:
|
|
397
|
+
self.pbf_fnames = [UPath(pbf_fname) for pbf_fname in pbf_fnames]
|
|
398
|
+
self.bounds_fname = UPath(bounds_fname)
|
|
399
|
+
|
|
411
400
|
if len(self.pbf_fnames) == 1 and not self.pbf_fnames[0].exists():
|
|
412
|
-
|
|
401
|
+
logger.info(
|
|
413
402
|
"Downloading planet.osm.pbf from "
|
|
414
403
|
+ f"{self.planet_pbf_url} to {self.pbf_fnames[0]}"
|
|
415
404
|
)
|
|
@@ -420,32 +409,13 @@ class OpenStreetMap(DataSource[OsmItem]):
|
|
|
420
409
|
# Detect bounds of each pbf file if needed.
|
|
421
410
|
self.pbf_bounds = self._get_pbf_bounds()
|
|
422
411
|
|
|
423
|
-
@staticmethod
|
|
424
|
-
def from_config(config: VectorLayerConfig, ds_path: UPath) -> "OpenStreetMap":
|
|
425
|
-
"""Creates a new OpenStreetMap instance from a configuration dictionary."""
|
|
426
|
-
if config.data_source is None:
|
|
427
|
-
raise ValueError("data_source is required")
|
|
428
|
-
d = config.data_source.config_dict
|
|
429
|
-
categories = {
|
|
430
|
-
category_name: Filter.from_config(filter_config_dict)
|
|
431
|
-
for category_name, filter_config_dict in d["categories"].items()
|
|
432
|
-
}
|
|
433
|
-
pbf_fnames = [join_upath(ds_path, pbf_fname) for pbf_fname in d["pbf_fnames"]]
|
|
434
|
-
bounds_fname = join_upath(ds_path, d["bounds_fname"])
|
|
435
|
-
return OpenStreetMap(
|
|
436
|
-
config=config,
|
|
437
|
-
pbf_fnames=pbf_fnames,
|
|
438
|
-
bounds_fname=bounds_fname,
|
|
439
|
-
categories=categories,
|
|
440
|
-
)
|
|
441
|
-
|
|
442
412
|
def _get_pbf_bounds(self) -> list[tuple[float, float, float, float]]:
|
|
443
413
|
# Determine WGS84 bounds of each PBF file by processing them through
|
|
444
414
|
# BoundsHandler.
|
|
445
415
|
if not self.bounds_fname.exists():
|
|
446
416
|
pbf_bounds = []
|
|
447
417
|
for pbf_fname in self.pbf_fnames:
|
|
448
|
-
|
|
418
|
+
logger.info(f"detecting bounds of {pbf_fname}")
|
|
449
419
|
handler = BoundsHandler()
|
|
450
420
|
with get_upath_local(pbf_fname) as local_fname:
|
|
451
421
|
handler.apply_file(local_fname)
|
|
@@ -512,7 +482,7 @@ class OpenStreetMap(DataSource[OsmItem]):
|
|
|
512
482
|
if tile_store.is_vector_ready(cur_item.name):
|
|
513
483
|
continue
|
|
514
484
|
|
|
515
|
-
|
|
485
|
+
logger.info(
|
|
516
486
|
f"ingesting osm item {cur_item.name} "
|
|
517
487
|
+ f"with {len(cur_geometries)} geometries"
|
|
518
488
|
)
|
rslearn/data_sources/planet.py
CHANGED
|
@@ -14,9 +14,9 @@ import shapely
|
|
|
14
14
|
from fsspec.implementations.local import LocalFileSystem
|
|
15
15
|
from upath import UPath
|
|
16
16
|
|
|
17
|
-
from rslearn.config import QueryConfig
|
|
17
|
+
from rslearn.config import QueryConfig
|
|
18
18
|
from rslearn.const import WGS84_PROJECTION
|
|
19
|
-
from rslearn.data_sources import DataSource, Item
|
|
19
|
+
from rslearn.data_sources import DataSource, DataSourceContext, Item
|
|
20
20
|
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
21
21
|
from rslearn.tile_stores import TileStoreWithLayer
|
|
22
22
|
from rslearn.utils import STGeometry
|
|
@@ -31,19 +31,18 @@ class Planet(DataSource):
|
|
|
31
31
|
|
|
32
32
|
def __init__(
|
|
33
33
|
self,
|
|
34
|
-
config: RasterLayerConfig,
|
|
35
34
|
item_type_id: str,
|
|
36
|
-
cache_dir:
|
|
35
|
+
cache_dir: str | None = None,
|
|
37
36
|
asset_type_id: str = "ortho_analytic_sr",
|
|
38
37
|
range_filters: dict[str, dict[str, Any]] = {},
|
|
39
38
|
use_permission_filter: bool = True,
|
|
40
39
|
sort_by: str | None = None,
|
|
41
40
|
bands: list[str] = ["b01", "b02", "b03", "b04"],
|
|
41
|
+
context: DataSourceContext = DataSourceContext(),
|
|
42
42
|
):
|
|
43
43
|
"""Initialize a new Planet instance.
|
|
44
44
|
|
|
45
45
|
Args:
|
|
46
|
-
config: the LayerConfig of the layer containing this data source
|
|
47
46
|
item_type_id: the item type ID, like "PSScene" or "SkySatCollect".
|
|
48
47
|
cache_dir: where to store downloaded assets, or None to just store it in
|
|
49
48
|
temporary directory before putting into tile store.
|
|
@@ -60,39 +59,22 @@ class Planet(DataSource):
|
|
|
60
59
|
"-clear_percent" or "cloud_cover" (if it starts with minus sign then we
|
|
61
60
|
sort descending.)
|
|
62
61
|
bands: what to call the bands in the asset.
|
|
62
|
+
context: the data source context.
|
|
63
63
|
"""
|
|
64
|
-
self.config = config
|
|
65
64
|
self.item_type_id = item_type_id
|
|
66
|
-
self.cache_dir = cache_dir
|
|
67
65
|
self.asset_type_id = asset_type_id
|
|
68
66
|
self.range_filters = range_filters
|
|
69
67
|
self.use_permission_filter = use_permission_filter
|
|
70
68
|
self.sort_by = sort_by
|
|
71
69
|
self.bands = bands
|
|
72
70
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
config=config,
|
|
81
|
-
item_type_id=d["item_type_id"],
|
|
82
|
-
)
|
|
83
|
-
optional_keys = [
|
|
84
|
-
"asset_type_id",
|
|
85
|
-
"range_filters",
|
|
86
|
-
"use_permission_filter",
|
|
87
|
-
"sort_by",
|
|
88
|
-
"bands",
|
|
89
|
-
]
|
|
90
|
-
for optional_key in optional_keys:
|
|
91
|
-
if optional_key in d:
|
|
92
|
-
kwargs[optional_key] = d[optional_key]
|
|
93
|
-
if "cache_dir" in d:
|
|
94
|
-
kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
|
|
95
|
-
return Planet(**kwargs)
|
|
71
|
+
if cache_dir is None:
|
|
72
|
+
self.cache_dir = None
|
|
73
|
+
else:
|
|
74
|
+
if context.ds_path is not None:
|
|
75
|
+
self.cache_dir = join_upath(context.ds_path, cache_dir)
|
|
76
|
+
else:
|
|
77
|
+
self.cache_dir = UPath(cache_dir)
|
|
96
78
|
|
|
97
79
|
async def _search_items(self, geometry: STGeometry) -> list[dict[str, Any]]:
|
|
98
80
|
wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
|