rslearn 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +22 -13
- rslearn/data_sources/__init__.py +8 -0
- rslearn/data_sources/aws_landsat.py +27 -18
- rslearn/data_sources/aws_open_data.py +41 -42
- rslearn/data_sources/copernicus.py +148 -2
- rslearn/data_sources/data_source.py +17 -10
- rslearn/data_sources/gcp_public_data.py +177 -100
- rslearn/data_sources/geotiff.py +1 -0
- rslearn/data_sources/google_earth_engine.py +17 -15
- rslearn/data_sources/local_files.py +59 -32
- rslearn/data_sources/openstreetmap.py +27 -23
- rslearn/data_sources/planet.py +10 -9
- rslearn/data_sources/planet_basemap.py +303 -0
- rslearn/data_sources/raster_source.py +23 -13
- rslearn/data_sources/usgs_landsat.py +56 -27
- rslearn/data_sources/utils.py +13 -6
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/xyz_tiles.py +8 -9
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +16 -5
- rslearn/dataset/manage.py +9 -4
- rslearn/dataset/materialize.py +26 -5
- rslearn/dataset/window.py +5 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +123 -59
- rslearn/models/clip.py +62 -0
- rslearn/models/conv.py +56 -0
- rslearn/models/faster_rcnn.py +2 -19
- rslearn/models/fpn.py +1 -1
- rslearn/models/module_wrapper.py +43 -0
- rslearn/models/molmo.py +65 -0
- rslearn/models/multitask.py +1 -1
- rslearn/models/pooling_decoder.py +4 -2
- rslearn/models/satlaspretrain.py +4 -7
- rslearn/models/simple_time_series.py +61 -55
- rslearn/models/ssl4eo_s12.py +9 -9
- rslearn/models/swin.py +22 -21
- rslearn/models/unet.py +4 -2
- rslearn/models/upsample.py +35 -0
- rslearn/tile_stores/file.py +6 -3
- rslearn/tile_stores/tile_store.py +19 -7
- rslearn/train/callbacks/freeze_unfreeze.py +3 -3
- rslearn/train/data_module.py +5 -4
- rslearn/train/dataset.py +79 -36
- rslearn/train/lightning_module.py +15 -11
- rslearn/train/prediction_writer.py +22 -11
- rslearn/train/tasks/classification.py +9 -8
- rslearn/train/tasks/detection.py +94 -37
- rslearn/train/tasks/multi_task.py +1 -1
- rslearn/train/tasks/regression.py +8 -4
- rslearn/train/tasks/segmentation.py +23 -19
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +6 -2
- rslearn/train/transforms/crop.py +6 -2
- rslearn/train/transforms/flip.py +5 -1
- rslearn/train/transforms/normalize.py +9 -5
- rslearn/train/transforms/pad.py +1 -1
- rslearn/train/transforms/transform.py +3 -3
- rslearn/utils/__init__.py +4 -5
- rslearn/utils/array.py +2 -2
- rslearn/utils/feature.py +1 -1
- rslearn/utils/fsspec.py +70 -1
- rslearn/utils/geometry.py +155 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +81 -73
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/utils.py +11 -3
- rslearn/utils/vector_format.py +113 -17
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/METADATA +32 -27
- rslearn-0.0.2.dist-info/RECORD +94 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/WHEEL +1 -1
- rslearn/utils/mgrs.py +0 -24
- rslearn-0.0.1.dist-info/RECORD +0 -88
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Data source for raster or vector data in local files."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any, Generic, TypeVar
|
|
4
4
|
|
|
5
5
|
import fiona
|
|
6
6
|
import rasterio
|
|
@@ -11,8 +11,8 @@ from rasterio.crs import CRS
|
|
|
11
11
|
from upath import UPath
|
|
12
12
|
|
|
13
13
|
import rslearn.data_sources.utils
|
|
14
|
-
from rslearn.config import LayerConfig, LayerType, VectorLayerConfig
|
|
15
|
-
from rslearn.const import SHAPEFILE_AUX_EXTENSIONS
|
|
14
|
+
from rslearn.config import LayerConfig, LayerType, RasterLayerConfig, VectorLayerConfig
|
|
15
|
+
from rslearn.const import SHAPEFILE_AUX_EXTENSIONS
|
|
16
16
|
from rslearn.tile_stores import LayerMetadata, PrefixedTileStore, TileStore
|
|
17
17
|
from rslearn.utils import Feature, Projection, STGeometry
|
|
18
18
|
from rslearn.utils.fsspec import get_upath_local, join_upath
|
|
@@ -22,11 +22,15 @@ from .raster_source import get_needed_projections, ingest_raster
|
|
|
22
22
|
|
|
23
23
|
Importers = ClassRegistry()
|
|
24
24
|
|
|
25
|
+
ItemType = TypeVar("ItemType", bound=Item)
|
|
26
|
+
LayerConfigType = TypeVar("LayerConfigType", bound=LayerConfig)
|
|
27
|
+
ImporterType = TypeVar("ImporterType", bound="Importer")
|
|
25
28
|
|
|
26
|
-
class Importer:
|
|
27
|
-
"""An abstract class for importing data from local files."""
|
|
28
29
|
|
|
29
|
-
|
|
30
|
+
class Importer(Generic[ItemType, LayerConfigType]):
|
|
31
|
+
"""An abstract base class for importing data from local files."""
|
|
32
|
+
|
|
33
|
+
def list_items(self, config: LayerConfigType, src_dir: UPath) -> list[ItemType]:
|
|
30
34
|
"""Extract a list of Items from the source directory.
|
|
31
35
|
|
|
32
36
|
Args:
|
|
@@ -37,11 +41,11 @@ class Importer:
|
|
|
37
41
|
|
|
38
42
|
def ingest_item(
|
|
39
43
|
self,
|
|
40
|
-
config:
|
|
44
|
+
config: LayerConfigType,
|
|
41
45
|
tile_store: TileStore,
|
|
42
|
-
item:
|
|
46
|
+
item: ItemType,
|
|
43
47
|
cur_geometries: list[STGeometry],
|
|
44
|
-
):
|
|
48
|
+
) -> None:
|
|
45
49
|
"""Ingest the specified local file item.
|
|
46
50
|
|
|
47
51
|
Args:
|
|
@@ -131,7 +135,7 @@ class RasterItem(Item):
|
|
|
131
135
|
return d
|
|
132
136
|
|
|
133
137
|
@staticmethod
|
|
134
|
-
def deserialize(d: dict) ->
|
|
138
|
+
def deserialize(d: dict) -> "RasterItem":
|
|
135
139
|
"""Deserializes an item from a JSON-decoded dictionary."""
|
|
136
140
|
item = super(RasterItem, RasterItem).deserialize(d)
|
|
137
141
|
spec = RasterItemSpec.deserialize(d["spec"])
|
|
@@ -159,7 +163,7 @@ class VectorItem(Item):
|
|
|
159
163
|
return d
|
|
160
164
|
|
|
161
165
|
@staticmethod
|
|
162
|
-
def deserialize(d: dict) ->
|
|
166
|
+
def deserialize(d: dict) -> "VectorItem":
|
|
163
167
|
"""Deserializes an item from a JSON-decoded dictionary."""
|
|
164
168
|
item = super(VectorItem, VectorItem).deserialize(d)
|
|
165
169
|
return VectorItem(
|
|
@@ -171,7 +175,7 @@ class VectorItem(Item):
|
|
|
171
175
|
class RasterImporter(Importer):
|
|
172
176
|
"""An Importer for raster data."""
|
|
173
177
|
|
|
174
|
-
def list_items(self, config: LayerConfig, src_dir: UPath) -> list[
|
|
178
|
+
def list_items(self, config: LayerConfig, src_dir: UPath) -> list[RasterItem]:
|
|
175
179
|
"""Extract a list of Items from the source directory.
|
|
176
180
|
|
|
177
181
|
Args:
|
|
@@ -179,6 +183,8 @@ class RasterImporter(Importer):
|
|
|
179
183
|
src_dir: the source directory.
|
|
180
184
|
"""
|
|
181
185
|
item_specs: list[RasterItemSpec] = []
|
|
186
|
+
if config.data_source is None:
|
|
187
|
+
raise ValueError("RasterImporter requires a data source config")
|
|
182
188
|
# See if user has provided the item specs directly.
|
|
183
189
|
if "item_specs" in config.data_source.config_dict:
|
|
184
190
|
for spec_dict in config.data_source.config_dict["item_specs"]:
|
|
@@ -192,7 +198,7 @@ class RasterImporter(Importer):
|
|
|
192
198
|
spec = RasterItemSpec(fnames=[path], bands=None)
|
|
193
199
|
item_specs.append(spec)
|
|
194
200
|
|
|
195
|
-
items
|
|
201
|
+
items = []
|
|
196
202
|
for spec in item_specs:
|
|
197
203
|
# Get geometry from the first raster file.
|
|
198
204
|
# We assume files are readable with rasterio.
|
|
@@ -222,20 +228,19 @@ class RasterImporter(Importer):
|
|
|
222
228
|
|
|
223
229
|
def ingest_item(
|
|
224
230
|
self,
|
|
225
|
-
config:
|
|
231
|
+
config: RasterLayerConfig,
|
|
226
232
|
tile_store: TileStore,
|
|
227
|
-
item:
|
|
233
|
+
item: RasterItem,
|
|
228
234
|
cur_geometries: list[STGeometry],
|
|
229
|
-
):
|
|
235
|
+
) -> None:
|
|
230
236
|
"""Ingest the specified local file item.
|
|
231
237
|
|
|
232
238
|
Args:
|
|
233
239
|
config: the configuration of the layer.
|
|
234
240
|
tile_store: the TileStore to ingest the data into.
|
|
235
|
-
item: the
|
|
241
|
+
item: the RasterItem to ingest
|
|
236
242
|
cur_geometries: the geometries where the item is needed.
|
|
237
243
|
"""
|
|
238
|
-
assert isinstance(item, RasterItem)
|
|
239
244
|
for file_idx, fname in enumerate(item.spec.fnames):
|
|
240
245
|
with fname.open("rb") as f:
|
|
241
246
|
with rasterio.open(f) as src:
|
|
@@ -264,7 +269,10 @@ class RasterImporter(Importer):
|
|
|
264
269
|
class VectorImporter(Importer):
|
|
265
270
|
"""An Importer for vector data."""
|
|
266
271
|
|
|
267
|
-
|
|
272
|
+
# We need some buffer around GeoJSON bounds in case it just contains one point.
|
|
273
|
+
item_buffer_epsilon = 1e-4
|
|
274
|
+
|
|
275
|
+
def list_items(self, config: LayerConfig, src_dir: UPath) -> list[VectorItem]:
|
|
268
276
|
"""Extract a list of Items from the source directory.
|
|
269
277
|
|
|
270
278
|
Args:
|
|
@@ -272,7 +280,7 @@ class VectorImporter(Importer):
|
|
|
272
280
|
src_dir: the source directory.
|
|
273
281
|
"""
|
|
274
282
|
file_paths = src_dir.glob("**/*.*")
|
|
275
|
-
items: list[
|
|
283
|
+
items: list[VectorItem] = []
|
|
276
284
|
|
|
277
285
|
for path in file_paths:
|
|
278
286
|
# Get the bounds of the features in the vector file, which we assume fiona can
|
|
@@ -299,8 +307,14 @@ class VectorImporter(Importer):
|
|
|
299
307
|
bounds[2] = max(bounds[2], cur_bounds[2])
|
|
300
308
|
bounds[3] = max(bounds[3], cur_bounds[3])
|
|
301
309
|
|
|
310
|
+
# Normal GeoJSON should have coordinates in CRS coordinates, i.e. it
|
|
311
|
+
# should be 1 projection unit/pixel.
|
|
302
312
|
projection = Projection(crs, 1, 1)
|
|
303
|
-
geometry = STGeometry(
|
|
313
|
+
geometry = STGeometry(
|
|
314
|
+
projection,
|
|
315
|
+
shapely.box(*bounds).buffer(self.item_buffer_epsilon),
|
|
316
|
+
None,
|
|
317
|
+
)
|
|
304
318
|
|
|
305
319
|
items.append(
|
|
306
320
|
VectorItem(path.name.split(".")[0], geometry, path.absolute().as_uri())
|
|
@@ -310,11 +324,11 @@ class VectorImporter(Importer):
|
|
|
310
324
|
|
|
311
325
|
def ingest_item(
|
|
312
326
|
self,
|
|
313
|
-
config:
|
|
327
|
+
config: VectorLayerConfig,
|
|
314
328
|
tile_store: TileStore,
|
|
315
|
-
item:
|
|
329
|
+
item: VectorItem,
|
|
316
330
|
cur_geometries: list[STGeometry],
|
|
317
|
-
):
|
|
331
|
+
) -> None:
|
|
318
332
|
"""Ingest the specified local file item.
|
|
319
333
|
|
|
320
334
|
Args:
|
|
@@ -323,7 +337,8 @@ class VectorImporter(Importer):
|
|
|
323
337
|
item: the Item to ingest
|
|
324
338
|
cur_geometries: the geometries where the item is needed.
|
|
325
339
|
"""
|
|
326
|
-
|
|
340
|
+
if not isinstance(config, VectorLayerConfig):
|
|
341
|
+
raise ValueError("VectorImporter requires a VectorLayerConfig")
|
|
327
342
|
|
|
328
343
|
needed_projections = set()
|
|
329
344
|
for geometry in cur_geometries:
|
|
@@ -347,14 +362,18 @@ class VectorImporter(Importer):
|
|
|
347
362
|
aux_files.append(path.parent / (prefix + ext))
|
|
348
363
|
|
|
349
364
|
# TODO: move converting fiona file to list[Feature] to utility function.
|
|
350
|
-
# TODO: don't assume WGS-84 projection here.
|
|
351
365
|
with get_upath_local(path, extra_paths=aux_files) as local_fname:
|
|
352
366
|
with fiona.open(local_fname) as src:
|
|
367
|
+
crs = CRS.from_wkt(src.crs.to_wkt())
|
|
368
|
+
# Normal GeoJSON should have coordinates in CRS coordinates, i.e. it
|
|
369
|
+
# should be 1 projection unit/pixel.
|
|
370
|
+
projection = Projection(crs, 1, 1)
|
|
371
|
+
|
|
353
372
|
features = []
|
|
354
373
|
for feat in src:
|
|
355
374
|
features.append(
|
|
356
375
|
Feature.from_geojson(
|
|
357
|
-
|
|
376
|
+
projection,
|
|
358
377
|
{
|
|
359
378
|
"type": "Feature",
|
|
360
379
|
"geometry": dict(feat.geometry),
|
|
@@ -376,7 +395,11 @@ class VectorImporter(Importer):
|
|
|
376
395
|
class LocalFiles(DataSource):
|
|
377
396
|
"""A data source for ingesting data from local files."""
|
|
378
397
|
|
|
379
|
-
def __init__(
|
|
398
|
+
def __init__(
|
|
399
|
+
self,
|
|
400
|
+
config: LayerConfig,
|
|
401
|
+
src_dir: UPath,
|
|
402
|
+
) -> None:
|
|
380
403
|
"""Initialize a new LocalFiles instance.
|
|
381
404
|
|
|
382
405
|
Args:
|
|
@@ -384,14 +407,17 @@ class LocalFiles(DataSource):
|
|
|
384
407
|
src_dir: source directory to ingest
|
|
385
408
|
"""
|
|
386
409
|
self.config = config
|
|
410
|
+
|
|
411
|
+
self.importer = Importers[config.layer_type.value]
|
|
387
412
|
self.src_dir = src_dir
|
|
388
413
|
|
|
389
|
-
self.
|
|
390
|
-
self.items = self.importer.list_items(config, src_dir)
|
|
414
|
+
self.items = self.importer.list_items(self.config, src_dir)
|
|
391
415
|
|
|
392
416
|
@staticmethod
|
|
393
417
|
def from_config(config: LayerConfig, ds_path: UPath) -> "LocalFiles":
|
|
394
418
|
"""Creates a new LocalFiles instance from a configuration dictionary."""
|
|
419
|
+
if config.data_source is None:
|
|
420
|
+
raise ValueError("LocalFiles data source requires a data source config")
|
|
395
421
|
d = config.data_source.config_dict
|
|
396
422
|
return LocalFiles(config=config, src_dir=join_upath(ds_path, d["src_dir"]))
|
|
397
423
|
|
|
@@ -421,13 +447,14 @@ class LocalFiles(DataSource):
|
|
|
421
447
|
groups.append(cur_groups)
|
|
422
448
|
return groups
|
|
423
449
|
|
|
424
|
-
def deserialize_item(self, serialized_item: Any) ->
|
|
450
|
+
def deserialize_item(self, serialized_item: Any) -> RasterItem | VectorItem:
|
|
425
451
|
"""Deserializes an item from JSON-decoded data."""
|
|
426
|
-
assert isinstance(serialized_item, dict)
|
|
427
452
|
if self.config.layer_type == LayerType.RASTER:
|
|
428
453
|
return RasterItem.deserialize(serialized_item)
|
|
429
454
|
elif self.config.layer_type == LayerType.VECTOR:
|
|
430
455
|
return VectorItem.deserialize(serialized_item)
|
|
456
|
+
else:
|
|
457
|
+
raise ValueError(f"Unknown layer type: {self.config.layer_type}")
|
|
431
458
|
|
|
432
459
|
def ingest(
|
|
433
460
|
self,
|
|
@@ -7,10 +7,11 @@ from enum import Enum
|
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
9
9
|
import osmium
|
|
10
|
+
import osmium.osm.types
|
|
10
11
|
import shapely
|
|
11
12
|
from upath import UPath
|
|
12
13
|
|
|
13
|
-
from rslearn.config import
|
|
14
|
+
from rslearn.config import QueryConfig, VectorLayerConfig
|
|
14
15
|
from rslearn.const import WGS84_PROJECTION
|
|
15
16
|
from rslearn.data_sources import DataSource, Item
|
|
16
17
|
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
@@ -36,7 +37,7 @@ class Filter:
|
|
|
36
37
|
tag_conditions: dict[str, list[str]] | None = None,
|
|
37
38
|
tag_properties: dict[str, str] | None = None,
|
|
38
39
|
to_geometry: str | None = None,
|
|
39
|
-
):
|
|
40
|
+
) -> None:
|
|
40
41
|
"""Create a new Filter instance.
|
|
41
42
|
|
|
42
43
|
Args:
|
|
@@ -64,7 +65,7 @@ class Filter:
|
|
|
64
65
|
Returns:
|
|
65
66
|
the Filter object
|
|
66
67
|
"""
|
|
67
|
-
kwargs = {}
|
|
68
|
+
kwargs: dict[str, Any] = {}
|
|
68
69
|
if "feature_types" in d:
|
|
69
70
|
kwargs["feature_types"] = [FeatureType(el) for el in d["feature_types"]]
|
|
70
71
|
if "tag_conditions" in d:
|
|
@@ -104,12 +105,12 @@ class Filter:
|
|
|
104
105
|
class BoundsHandler(osmium.SimpleHandler):
|
|
105
106
|
"""An osmium handler for computing the bounds of an input file."""
|
|
106
107
|
|
|
107
|
-
def __init__(self):
|
|
108
|
+
def __init__(self) -> None:
|
|
108
109
|
"""Initialize a new BoundsHandler."""
|
|
109
110
|
osmium.SimpleHandler.__init__(self)
|
|
110
|
-
self.bounds = (180, 90, -180, -90)
|
|
111
|
+
self.bounds: tuple[float, float, float, float] = (180, 90, -180, -90)
|
|
111
112
|
|
|
112
|
-
def node(self, n):
|
|
113
|
+
def node(self, n: osmium.osm.types.Node) -> None:
|
|
113
114
|
"""Handle nodes and update the computed bounds."""
|
|
114
115
|
lon = n.location.lon
|
|
115
116
|
lat = n.location.lat
|
|
@@ -130,7 +131,7 @@ class OsmHandler(osmium.SimpleHandler):
|
|
|
130
131
|
geometries: list[STGeometry],
|
|
131
132
|
grid_size: float = 0.03,
|
|
132
133
|
padding: float = 0.03,
|
|
133
|
-
):
|
|
134
|
+
) -> None:
|
|
134
135
|
"""Initialize a new OsmHandler.
|
|
135
136
|
|
|
136
137
|
Args:
|
|
@@ -163,12 +164,12 @@ class OsmHandler(osmium.SimpleHandler):
|
|
|
163
164
|
)
|
|
164
165
|
self.grid_index.insert(bounds, 1)
|
|
165
166
|
|
|
166
|
-
self.cached_nodes = {}
|
|
167
|
-
self.cached_ways = {}
|
|
167
|
+
self.cached_nodes: dict = {}
|
|
168
|
+
self.cached_ways: dict = {}
|
|
168
169
|
|
|
169
|
-
self.features = []
|
|
170
|
+
self.features: list[Feature] = []
|
|
170
171
|
|
|
171
|
-
def node(self, n):
|
|
172
|
+
def node(self, n: osmium.osm.types.Node) -> None:
|
|
172
173
|
"""Handle nodes."""
|
|
173
174
|
# Check if node is relevant to our geometries.
|
|
174
175
|
lon = n.location.lon
|
|
@@ -193,7 +194,7 @@ class OsmHandler(osmium.SimpleHandler):
|
|
|
193
194
|
)
|
|
194
195
|
self.features.append(feat)
|
|
195
196
|
|
|
196
|
-
def _get_way_coords(self, node_ids):
|
|
197
|
+
def _get_way_coords(self, node_ids: list[int]) -> list:
|
|
197
198
|
coords = []
|
|
198
199
|
for id in node_ids:
|
|
199
200
|
if id not in self.cached_nodes:
|
|
@@ -201,7 +202,7 @@ class OsmHandler(osmium.SimpleHandler):
|
|
|
201
202
|
coords.append(self.cached_nodes[id])
|
|
202
203
|
return coords
|
|
203
204
|
|
|
204
|
-
def way(self, w):
|
|
205
|
+
def way(self, w: osmium.osm.types.Way) -> None:
|
|
205
206
|
"""Handle ways."""
|
|
206
207
|
# Collect nodes, skip if too few.
|
|
207
208
|
node_ids = [member.ref for member in w.nodes]
|
|
@@ -235,7 +236,7 @@ class OsmHandler(osmium.SimpleHandler):
|
|
|
235
236
|
)
|
|
236
237
|
self.features.append(feat)
|
|
237
238
|
|
|
238
|
-
def match_relation(self, r):
|
|
239
|
+
def match_relation(self, r: osmium.osm.types.Relation) -> None:
|
|
239
240
|
"""Handle relations."""
|
|
240
241
|
# Collect ways and distinguish exterior vs holes, skip if none found.
|
|
241
242
|
exterior_ways = []
|
|
@@ -267,7 +268,7 @@ class OsmHandler(osmium.SimpleHandler):
|
|
|
267
268
|
# Merge the ways in case some exterior/interior polygons are split into
|
|
268
269
|
# multiple ways.
|
|
269
270
|
# And convert them from node IDs to coordinates.
|
|
270
|
-
def get_polygons(ways):
|
|
271
|
+
def get_polygons(ways: list) -> list:
|
|
271
272
|
polygons: list[list[int]] = []
|
|
272
273
|
for way in ways:
|
|
273
274
|
# Attempt to match the way to an existing polygon.
|
|
@@ -366,13 +367,13 @@ class OsmItem(Item):
|
|
|
366
367
|
return d
|
|
367
368
|
|
|
368
369
|
@staticmethod
|
|
369
|
-
def deserialize(d: dict) ->
|
|
370
|
+
def deserialize(d: dict) -> "OsmItem":
|
|
370
371
|
"""Deserializes an item from a JSON-decoded dictionary."""
|
|
371
372
|
item = super(OsmItem, OsmItem).deserialize(d)
|
|
372
373
|
return OsmItem(name=item.name, geometry=item.geometry, path_uri=d["path_uri"])
|
|
373
374
|
|
|
374
375
|
|
|
375
|
-
class OpenStreetMap(DataSource):
|
|
376
|
+
class OpenStreetMap(DataSource[OsmItem]):
|
|
376
377
|
"""A data source for OpenStreetMap data from PBF file.
|
|
377
378
|
|
|
378
379
|
An existing local PBF file can be used, or if the provided path doesn't exist, then
|
|
@@ -420,9 +421,10 @@ class OpenStreetMap(DataSource):
|
|
|
420
421
|
self.pbf_bounds = self._get_pbf_bounds()
|
|
421
422
|
|
|
422
423
|
@staticmethod
|
|
423
|
-
def from_config(config:
|
|
424
|
+
def from_config(config: VectorLayerConfig, ds_path: UPath) -> "OpenStreetMap":
|
|
424
425
|
"""Creates a new OpenStreetMap instance from a configuration dictionary."""
|
|
425
|
-
|
|
426
|
+
if config.data_source is None:
|
|
427
|
+
raise ValueError("data_source is required")
|
|
426
428
|
d = config.data_source.config_dict
|
|
427
429
|
categories = {
|
|
428
430
|
category_name: Filter.from_config(filter_config_dict)
|
|
@@ -437,7 +439,9 @@ class OpenStreetMap(DataSource):
|
|
|
437
439
|
categories=categories,
|
|
438
440
|
)
|
|
439
441
|
|
|
440
|
-
def _get_pbf_bounds(self):
|
|
442
|
+
def _get_pbf_bounds(self) -> list[tuple[float, float, float, float]]:
|
|
443
|
+
# Determine WGS84 bounds of each PBF file by processing them through
|
|
444
|
+
# BoundsHandler.
|
|
441
445
|
if not self.bounds_fname.exists():
|
|
442
446
|
pbf_bounds = []
|
|
443
447
|
for pbf_fname in self.pbf_fnames:
|
|
@@ -458,7 +462,7 @@ class OpenStreetMap(DataSource):
|
|
|
458
462
|
|
|
459
463
|
def get_items(
|
|
460
464
|
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
461
|
-
) -> list[list[list[
|
|
465
|
+
) -> list[list[list[OsmItem]]]:
|
|
462
466
|
"""Get a list of items in the data source intersecting the given geometries.
|
|
463
467
|
|
|
464
468
|
Args:
|
|
@@ -487,14 +491,14 @@ class OpenStreetMap(DataSource):
|
|
|
487
491
|
groups.append(cur_groups)
|
|
488
492
|
return groups
|
|
489
493
|
|
|
490
|
-
def deserialize_item(self, serialized_item: Any) ->
|
|
494
|
+
def deserialize_item(self, serialized_item: Any) -> OsmItem:
|
|
491
495
|
"""Deserializes an item from JSON-decoded data."""
|
|
492
496
|
return OsmItem.deserialize(serialized_item)
|
|
493
497
|
|
|
494
498
|
def ingest(
|
|
495
499
|
self,
|
|
496
500
|
tile_store: TileStore,
|
|
497
|
-
items: list[
|
|
501
|
+
items: list[OsmItem],
|
|
498
502
|
geometries: list[list[STGeometry]],
|
|
499
503
|
) -> None:
|
|
500
504
|
"""Ingest items into the given tile store.
|
rslearn/data_sources/planet.py
CHANGED
|
@@ -6,6 +6,7 @@ import pathlib
|
|
|
6
6
|
import shutil
|
|
7
7
|
import tempfile
|
|
8
8
|
from datetime import datetime
|
|
9
|
+
from pathlib import Path
|
|
9
10
|
from typing import Any
|
|
10
11
|
|
|
11
12
|
import planet
|
|
@@ -14,7 +15,7 @@ import shapely
|
|
|
14
15
|
from fsspec.implementations.local import LocalFileSystem
|
|
15
16
|
from upath import UPath
|
|
16
17
|
|
|
17
|
-
from rslearn.config import
|
|
18
|
+
from rslearn.config import QueryConfig, RasterLayerConfig
|
|
18
19
|
from rslearn.const import WGS84_PROJECTION
|
|
19
20
|
from rslearn.data_sources import DataSource, Item
|
|
20
21
|
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
@@ -33,7 +34,7 @@ class Planet(DataSource):
|
|
|
33
34
|
|
|
34
35
|
def __init__(
|
|
35
36
|
self,
|
|
36
|
-
config:
|
|
37
|
+
config: RasterLayerConfig,
|
|
37
38
|
item_type_id: str,
|
|
38
39
|
cache_dir: UPath | None = None,
|
|
39
40
|
asset_type_id: str = "ortho_analytic_sr",
|
|
@@ -73,9 +74,10 @@ class Planet(DataSource):
|
|
|
73
74
|
self.bands = bands
|
|
74
75
|
|
|
75
76
|
@staticmethod
|
|
76
|
-
def from_config(config:
|
|
77
|
+
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Planet":
|
|
77
78
|
"""Creates a new Planet instance from a configuration dictionary."""
|
|
78
|
-
|
|
79
|
+
if config.data_source is None:
|
|
80
|
+
raise ValueError("data_source is required")
|
|
79
81
|
d = config.data_source.config_dict
|
|
80
82
|
kwargs = dict(
|
|
81
83
|
config=config,
|
|
@@ -101,11 +103,10 @@ class Planet(DataSource):
|
|
|
101
103
|
|
|
102
104
|
async with planet.Session() as session:
|
|
103
105
|
client = session.client("data")
|
|
104
|
-
|
|
106
|
+
gte = geometry.time_range[0] if geometry.time_range is not None else None
|
|
107
|
+
lte = geometry.time_range[1] if geometry.time_range is not None else None
|
|
105
108
|
filter_list = [
|
|
106
|
-
planet.data_filter.date_range_filter(
|
|
107
|
-
"acquired", gte=geometry.time_range[0], lte=geometry.time_range[1]
|
|
108
|
-
),
|
|
109
|
+
planet.data_filter.date_range_filter("acquired", gte=gte, lte=lte),
|
|
109
110
|
planet.data_filter.geometry_filter(geojson_data),
|
|
110
111
|
planet.data_filter.asset_filter([self.asset_type_id]),
|
|
111
112
|
]
|
|
@@ -265,7 +266,7 @@ class Planet(DataSource):
|
|
|
265
266
|
if not needed_projections:
|
|
266
267
|
continue
|
|
267
268
|
|
|
268
|
-
asset_path = asyncio.run(self._download_asset(item, tmp_dir))
|
|
269
|
+
asset_path = asyncio.run(self._download_asset(item, Path(tmp_dir)))
|
|
269
270
|
with asset_path.open("rb") as f:
|
|
270
271
|
with rasterio.open(f) as raster:
|
|
271
272
|
for projection in needed_projections:
|