rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""Data source for
|
|
1
|
+
"""Data source for OpenStreetMap vector features."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
4
|
import shutil
|
|
@@ -7,17 +7,21 @@ from enum import Enum
|
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
9
9
|
import osmium
|
|
10
|
+
import osmium.osm.types
|
|
10
11
|
import shapely
|
|
11
12
|
from upath import UPath
|
|
12
13
|
|
|
13
|
-
from rslearn.config import
|
|
14
|
+
from rslearn.config import QueryConfig
|
|
14
15
|
from rslearn.const import WGS84_PROJECTION
|
|
15
|
-
from rslearn.data_sources import DataSource, Item
|
|
16
|
+
from rslearn.data_sources import DataSource, DataSourceContext, Item
|
|
16
17
|
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
17
|
-
from rslearn.
|
|
18
|
+
from rslearn.log_utils import get_logger
|
|
19
|
+
from rslearn.tile_stores import TileStoreWithLayer
|
|
18
20
|
from rslearn.utils import Feature, GridIndex, STGeometry
|
|
19
21
|
from rslearn.utils.fsspec import get_upath_local, join_upath
|
|
20
22
|
|
|
23
|
+
logger = get_logger(__name__)
|
|
24
|
+
|
|
21
25
|
|
|
22
26
|
class FeatureType(Enum):
|
|
23
27
|
"""OpenStreetMap feature type."""
|
|
@@ -36,7 +40,7 @@ class Filter:
|
|
|
36
40
|
tag_conditions: dict[str, list[str]] | None = None,
|
|
37
41
|
tag_properties: dict[str, str] | None = None,
|
|
38
42
|
to_geometry: str | None = None,
|
|
39
|
-
):
|
|
43
|
+
) -> None:
|
|
40
44
|
"""Create a new Filter instance.
|
|
41
45
|
|
|
42
46
|
Args:
|
|
@@ -54,27 +58,6 @@ class Filter:
|
|
|
54
58
|
self.tag_properties = tag_properties
|
|
55
59
|
self.to_geometry = to_geometry
|
|
56
60
|
|
|
57
|
-
@staticmethod
|
|
58
|
-
def from_config(d: dict[str, Any]) -> "Filter":
|
|
59
|
-
"""Creates a Filter from a config dict.
|
|
60
|
-
|
|
61
|
-
Args:
|
|
62
|
-
d: the config dict
|
|
63
|
-
|
|
64
|
-
Returns:
|
|
65
|
-
the Filter object
|
|
66
|
-
"""
|
|
67
|
-
kwargs = {}
|
|
68
|
-
if "feature_types" in d:
|
|
69
|
-
kwargs["feature_types"] = [FeatureType(el) for el in d["feature_types"]]
|
|
70
|
-
if "tag_conditions" in d:
|
|
71
|
-
kwargs["tag_conditions"] = d["tag_conditions"]
|
|
72
|
-
if "tag_properties" in d:
|
|
73
|
-
kwargs["tag_properties"] = d["tag_properties"]
|
|
74
|
-
if "to_geometry" in d:
|
|
75
|
-
kwargs["to_geometry"] = d["to_geometry"]
|
|
76
|
-
return Filter(**kwargs)
|
|
77
|
-
|
|
78
61
|
def match_tags(self, tags: dict[str, str]) -> bool:
|
|
79
62
|
"""Returns whether this filter matches based on the tags."""
|
|
80
63
|
if not self.tag_conditions:
|
|
@@ -104,12 +87,12 @@ class Filter:
|
|
|
104
87
|
class BoundsHandler(osmium.SimpleHandler):
|
|
105
88
|
"""An osmium handler for computing the bounds of an input file."""
|
|
106
89
|
|
|
107
|
-
def __init__(self):
|
|
90
|
+
def __init__(self) -> None:
|
|
108
91
|
"""Initialize a new BoundsHandler."""
|
|
109
92
|
osmium.SimpleHandler.__init__(self)
|
|
110
|
-
self.bounds = (180, 90, -180, -90)
|
|
93
|
+
self.bounds: tuple[float, float, float, float] = (180, 90, -180, -90)
|
|
111
94
|
|
|
112
|
-
def node(self, n):
|
|
95
|
+
def node(self, n: osmium.osm.types.Node) -> None:
|
|
113
96
|
"""Handle nodes and update the computed bounds."""
|
|
114
97
|
lon = n.location.lon
|
|
115
98
|
lat = n.location.lat
|
|
@@ -130,7 +113,7 @@ class OsmHandler(osmium.SimpleHandler):
|
|
|
130
113
|
geometries: list[STGeometry],
|
|
131
114
|
grid_size: float = 0.03,
|
|
132
115
|
padding: float = 0.03,
|
|
133
|
-
):
|
|
116
|
+
) -> None:
|
|
134
117
|
"""Initialize a new OsmHandler.
|
|
135
118
|
|
|
136
119
|
Args:
|
|
@@ -163,12 +146,12 @@ class OsmHandler(osmium.SimpleHandler):
|
|
|
163
146
|
)
|
|
164
147
|
self.grid_index.insert(bounds, 1)
|
|
165
148
|
|
|
166
|
-
self.cached_nodes = {}
|
|
167
|
-
self.cached_ways = {}
|
|
149
|
+
self.cached_nodes: dict = {}
|
|
150
|
+
self.cached_ways: dict = {}
|
|
168
151
|
|
|
169
|
-
self.features = []
|
|
152
|
+
self.features: list[Feature] = []
|
|
170
153
|
|
|
171
|
-
def node(self, n):
|
|
154
|
+
def node(self, n: osmium.osm.types.Node) -> None:
|
|
172
155
|
"""Handle nodes."""
|
|
173
156
|
# Check if node is relevant to our geometries.
|
|
174
157
|
lon = n.location.lon
|
|
@@ -193,7 +176,7 @@ class OsmHandler(osmium.SimpleHandler):
|
|
|
193
176
|
)
|
|
194
177
|
self.features.append(feat)
|
|
195
178
|
|
|
196
|
-
def _get_way_coords(self, node_ids):
|
|
179
|
+
def _get_way_coords(self, node_ids: list[int]) -> list:
|
|
197
180
|
coords = []
|
|
198
181
|
for id in node_ids:
|
|
199
182
|
if id not in self.cached_nodes:
|
|
@@ -201,7 +184,7 @@ class OsmHandler(osmium.SimpleHandler):
|
|
|
201
184
|
coords.append(self.cached_nodes[id])
|
|
202
185
|
return coords
|
|
203
186
|
|
|
204
|
-
def way(self, w):
|
|
187
|
+
def way(self, w: osmium.osm.types.Way) -> None:
|
|
205
188
|
"""Handle ways."""
|
|
206
189
|
# Collect nodes, skip if too few.
|
|
207
190
|
node_ids = [member.ref for member in w.nodes]
|
|
@@ -235,7 +218,7 @@ class OsmHandler(osmium.SimpleHandler):
|
|
|
235
218
|
)
|
|
236
219
|
self.features.append(feat)
|
|
237
220
|
|
|
238
|
-
def match_relation(self, r):
|
|
221
|
+
def match_relation(self, r: osmium.osm.types.Relation) -> None:
|
|
239
222
|
"""Handle relations."""
|
|
240
223
|
# Collect ways and distinguish exterior vs holes, skip if none found.
|
|
241
224
|
exterior_ways = []
|
|
@@ -267,7 +250,7 @@ class OsmHandler(osmium.SimpleHandler):
|
|
|
267
250
|
# Merge the ways in case some exterior/interior polygons are split into
|
|
268
251
|
# multiple ways.
|
|
269
252
|
# And convert them from node IDs to coordinates.
|
|
270
|
-
def get_polygons(ways):
|
|
253
|
+
def get_polygons(ways: list) -> list:
|
|
271
254
|
polygons: list[list[int]] = []
|
|
272
255
|
for way in ways:
|
|
273
256
|
# Attempt to match the way to an existing polygon.
|
|
@@ -366,13 +349,13 @@ class OsmItem(Item):
|
|
|
366
349
|
return d
|
|
367
350
|
|
|
368
351
|
@staticmethod
|
|
369
|
-
def deserialize(d: dict) ->
|
|
352
|
+
def deserialize(d: dict) -> "OsmItem":
|
|
370
353
|
"""Deserializes an item from a JSON-decoded dictionary."""
|
|
371
354
|
item = super(OsmItem, OsmItem).deserialize(d)
|
|
372
355
|
return OsmItem(name=item.name, geometry=item.geometry, path_uri=d["path_uri"])
|
|
373
356
|
|
|
374
357
|
|
|
375
|
-
class OpenStreetMap(DataSource):
|
|
358
|
+
class OpenStreetMap(DataSource[OsmItem]):
|
|
376
359
|
"""A data source for OpenStreetMap data from PBF file.
|
|
377
360
|
|
|
378
361
|
An existing local PBF file can be used, or if the provided path doesn't exist, then
|
|
@@ -386,12 +369,12 @@ class OpenStreetMap(DataSource):
|
|
|
386
369
|
|
|
387
370
|
def __init__(
|
|
388
371
|
self,
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
bounds_fname: UPath,
|
|
372
|
+
pbf_fnames: list[str],
|
|
373
|
+
bounds_fname: str,
|
|
392
374
|
categories: dict[str, Filter],
|
|
375
|
+
context: DataSourceContext = DataSourceContext(),
|
|
393
376
|
):
|
|
394
|
-
"""Initialize a new
|
|
377
|
+
"""Initialize a new OpenStreetMap instance.
|
|
395
378
|
|
|
396
379
|
Args:
|
|
397
380
|
config: the configuration of this layer.
|
|
@@ -401,14 +384,21 @@ class OpenStreetMap(DataSource):
|
|
|
401
384
|
bounds_fname: filename where the bounds of the PBF are cached.
|
|
402
385
|
categories: dictionary of (category name, filter). Features that match the
|
|
403
386
|
filter will be emitted under the corresponding category.
|
|
387
|
+
context: the data source context.
|
|
404
388
|
"""
|
|
405
|
-
self.config = config
|
|
406
|
-
self.pbf_fnames = pbf_fnames
|
|
407
|
-
self.bounds_fname = bounds_fname
|
|
408
389
|
self.categories = categories
|
|
409
390
|
|
|
391
|
+
if context.ds_path is not None:
|
|
392
|
+
self.pbf_fnames = [
|
|
393
|
+
join_upath(context.ds_path, pbf_fname) for pbf_fname in pbf_fnames
|
|
394
|
+
]
|
|
395
|
+
self.bounds_fname = join_upath(context.ds_path, bounds_fname)
|
|
396
|
+
else:
|
|
397
|
+
self.pbf_fnames = [UPath(pbf_fname) for pbf_fname in pbf_fnames]
|
|
398
|
+
self.bounds_fname = UPath(bounds_fname)
|
|
399
|
+
|
|
410
400
|
if len(self.pbf_fnames) == 1 and not self.pbf_fnames[0].exists():
|
|
411
|
-
|
|
401
|
+
logger.info(
|
|
412
402
|
"Downloading planet.osm.pbf from "
|
|
413
403
|
+ f"{self.planet_pbf_url} to {self.pbf_fnames[0]}"
|
|
414
404
|
)
|
|
@@ -419,29 +409,13 @@ class OpenStreetMap(DataSource):
|
|
|
419
409
|
# Detect bounds of each pbf file if needed.
|
|
420
410
|
self.pbf_bounds = self._get_pbf_bounds()
|
|
421
411
|
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
assert isinstance(config, VectorLayerConfig)
|
|
426
|
-
d = config.data_source.config_dict
|
|
427
|
-
categories = {
|
|
428
|
-
category_name: Filter.from_config(filter_config_dict)
|
|
429
|
-
for category_name, filter_config_dict in d["categories"].items()
|
|
430
|
-
}
|
|
431
|
-
pbf_fnames = [join_upath(ds_path, pbf_fname) for pbf_fname in d["pbf_fnames"]]
|
|
432
|
-
bounds_fname = join_upath(ds_path, d["bounds_fname"])
|
|
433
|
-
return OpenStreetMap(
|
|
434
|
-
config=config,
|
|
435
|
-
pbf_fnames=pbf_fnames,
|
|
436
|
-
bounds_fname=bounds_fname,
|
|
437
|
-
categories=categories,
|
|
438
|
-
)
|
|
439
|
-
|
|
440
|
-
def _get_pbf_bounds(self):
|
|
412
|
+
def _get_pbf_bounds(self) -> list[tuple[float, float, float, float]]:
|
|
413
|
+
# Determine WGS84 bounds of each PBF file by processing them through
|
|
414
|
+
# BoundsHandler.
|
|
441
415
|
if not self.bounds_fname.exists():
|
|
442
416
|
pbf_bounds = []
|
|
443
417
|
for pbf_fname in self.pbf_fnames:
|
|
444
|
-
|
|
418
|
+
logger.info(f"detecting bounds of {pbf_fname}")
|
|
445
419
|
handler = BoundsHandler()
|
|
446
420
|
with get_upath_local(pbf_fname) as local_fname:
|
|
447
421
|
handler.apply_file(local_fname)
|
|
@@ -458,7 +432,7 @@ class OpenStreetMap(DataSource):
|
|
|
458
432
|
|
|
459
433
|
def get_items(
|
|
460
434
|
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
461
|
-
) -> list[list[list[
|
|
435
|
+
) -> list[list[list[OsmItem]]]:
|
|
462
436
|
"""Get a list of items in the data source intersecting the given geometries.
|
|
463
437
|
|
|
464
438
|
Args:
|
|
@@ -487,14 +461,14 @@ class OpenStreetMap(DataSource):
|
|
|
487
461
|
groups.append(cur_groups)
|
|
488
462
|
return groups
|
|
489
463
|
|
|
490
|
-
def deserialize_item(self, serialized_item: Any) ->
|
|
464
|
+
def deserialize_item(self, serialized_item: Any) -> OsmItem:
|
|
491
465
|
"""Deserializes an item from JSON-decoded data."""
|
|
492
466
|
return OsmItem.deserialize(serialized_item)
|
|
493
467
|
|
|
494
468
|
def ingest(
|
|
495
469
|
self,
|
|
496
|
-
tile_store:
|
|
497
|
-
items: list[
|
|
470
|
+
tile_store: TileStoreWithLayer,
|
|
471
|
+
items: list[OsmItem],
|
|
498
472
|
geometries: list[list[STGeometry]],
|
|
499
473
|
) -> None:
|
|
500
474
|
"""Ingest items into the given tile store.
|
|
@@ -504,10 +478,11 @@ class OpenStreetMap(DataSource):
|
|
|
504
478
|
items: the items to ingest
|
|
505
479
|
geometries: a list of geometries needed for each item
|
|
506
480
|
"""
|
|
507
|
-
item_names = [item.name for item in items]
|
|
508
|
-
item_names.sort()
|
|
509
481
|
for cur_item, cur_geometries in zip(items, geometries):
|
|
510
|
-
|
|
482
|
+
if tile_store.is_vector_ready(cur_item.name):
|
|
483
|
+
continue
|
|
484
|
+
|
|
485
|
+
logger.info(
|
|
511
486
|
f"ingesting osm item {cur_item.name} "
|
|
512
487
|
+ f"with {len(cur_geometries)} geometries"
|
|
513
488
|
)
|
|
@@ -515,17 +490,4 @@ class OpenStreetMap(DataSource):
|
|
|
515
490
|
with get_upath_local(UPath(cur_item.path_uri)) as local_fname:
|
|
516
491
|
handler.apply_file(local_fname)
|
|
517
492
|
|
|
518
|
-
|
|
519
|
-
for geometry in cur_geometries:
|
|
520
|
-
projection, _ = self.config.get_final_projection_and_bounds(
|
|
521
|
-
geometry.projection, None
|
|
522
|
-
)
|
|
523
|
-
projections.add(projection)
|
|
524
|
-
|
|
525
|
-
for projection in projections:
|
|
526
|
-
features = [feat.to_projection(projection) for feat in handler.features]
|
|
527
|
-
layer = tile_store.create_layer(
|
|
528
|
-
(cur_item.name, str(projection)),
|
|
529
|
-
LayerMetadata(projection, None, {}),
|
|
530
|
-
)
|
|
531
|
-
layer.write_vector(features)
|
|
493
|
+
tile_store.write_vector(cur_item.name, handler.features)
|
rslearn/data_sources/planet.py
CHANGED
|
@@ -6,24 +6,22 @@ import pathlib
|
|
|
6
6
|
import shutil
|
|
7
7
|
import tempfile
|
|
8
8
|
from datetime import datetime
|
|
9
|
+
from pathlib import Path
|
|
9
10
|
from typing import Any
|
|
10
11
|
|
|
11
12
|
import planet
|
|
12
|
-
import rasterio
|
|
13
13
|
import shapely
|
|
14
14
|
from fsspec.implementations.local import LocalFileSystem
|
|
15
15
|
from upath import UPath
|
|
16
16
|
|
|
17
|
-
from rslearn.config import
|
|
17
|
+
from rslearn.config import QueryConfig
|
|
18
18
|
from rslearn.const import WGS84_PROJECTION
|
|
19
|
-
from rslearn.data_sources import DataSource, Item
|
|
19
|
+
from rslearn.data_sources import DataSource, DataSourceContext, Item
|
|
20
20
|
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
21
|
-
from rslearn.tile_stores import
|
|
21
|
+
from rslearn.tile_stores import TileStoreWithLayer
|
|
22
22
|
from rslearn.utils import STGeometry
|
|
23
23
|
from rslearn.utils.fsspec import join_upath
|
|
24
24
|
|
|
25
|
-
from .raster_source import get_needed_projections, ingest_raster
|
|
26
|
-
|
|
27
25
|
|
|
28
26
|
class Planet(DataSource):
|
|
29
27
|
"""A data source for Planet Labs API.
|
|
@@ -33,19 +31,18 @@ class Planet(DataSource):
|
|
|
33
31
|
|
|
34
32
|
def __init__(
|
|
35
33
|
self,
|
|
36
|
-
config: LayerConfig,
|
|
37
34
|
item_type_id: str,
|
|
38
|
-
cache_dir:
|
|
35
|
+
cache_dir: str | None = None,
|
|
39
36
|
asset_type_id: str = "ortho_analytic_sr",
|
|
40
37
|
range_filters: dict[str, dict[str, Any]] = {},
|
|
41
38
|
use_permission_filter: bool = True,
|
|
42
39
|
sort_by: str | None = None,
|
|
43
40
|
bands: list[str] = ["b01", "b02", "b03", "b04"],
|
|
41
|
+
context: DataSourceContext = DataSourceContext(),
|
|
44
42
|
):
|
|
45
43
|
"""Initialize a new Planet instance.
|
|
46
44
|
|
|
47
45
|
Args:
|
|
48
|
-
config: the LayerConfig of the layer containing this data source
|
|
49
46
|
item_type_id: the item type ID, like "PSScene" or "SkySatCollect".
|
|
50
47
|
cache_dir: where to store downloaded assets, or None to just store it in
|
|
51
48
|
temporary directory before putting into tile store.
|
|
@@ -62,38 +59,22 @@ class Planet(DataSource):
|
|
|
62
59
|
"-clear_percent" or "cloud_cover" (if it starts with minus sign then we
|
|
63
60
|
sort descending.)
|
|
64
61
|
bands: what to call the bands in the asset.
|
|
62
|
+
context: the data source context.
|
|
65
63
|
"""
|
|
66
|
-
self.config = config
|
|
67
64
|
self.item_type_id = item_type_id
|
|
68
|
-
self.cache_dir = cache_dir
|
|
69
65
|
self.asset_type_id = asset_type_id
|
|
70
66
|
self.range_filters = range_filters
|
|
71
67
|
self.use_permission_filter = use_permission_filter
|
|
72
68
|
self.sort_by = sort_by
|
|
73
69
|
self.bands = bands
|
|
74
70
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
item_type_id=d["item_type_id"],
|
|
83
|
-
)
|
|
84
|
-
optional_keys = [
|
|
85
|
-
"asset_type_id",
|
|
86
|
-
"range_filters",
|
|
87
|
-
"use_permission_filter",
|
|
88
|
-
"sort_by",
|
|
89
|
-
"bands",
|
|
90
|
-
]
|
|
91
|
-
for optional_key in optional_keys:
|
|
92
|
-
if optional_key in d:
|
|
93
|
-
kwargs[optional_key] = d[optional_key]
|
|
94
|
-
if "cache_dir" in d:
|
|
95
|
-
kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
|
|
96
|
-
return Planet(**kwargs)
|
|
71
|
+
if cache_dir is None:
|
|
72
|
+
self.cache_dir = None
|
|
73
|
+
else:
|
|
74
|
+
if context.ds_path is not None:
|
|
75
|
+
self.cache_dir = join_upath(context.ds_path, cache_dir)
|
|
76
|
+
else:
|
|
77
|
+
self.cache_dir = UPath(cache_dir)
|
|
97
78
|
|
|
98
79
|
async def _search_items(self, geometry: STGeometry) -> list[dict[str, Any]]:
|
|
99
80
|
wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
|
|
@@ -101,11 +82,10 @@ class Planet(DataSource):
|
|
|
101
82
|
|
|
102
83
|
async with planet.Session() as session:
|
|
103
84
|
client = session.client("data")
|
|
104
|
-
|
|
85
|
+
gte = geometry.time_range[0] if geometry.time_range is not None else None
|
|
86
|
+
lte = geometry.time_range[1] if geometry.time_range is not None else None
|
|
105
87
|
filter_list = [
|
|
106
|
-
planet.data_filter.date_range_filter(
|
|
107
|
-
"acquired", gte=geometry.time_range[0], lte=geometry.time_range[1]
|
|
108
|
-
),
|
|
88
|
+
planet.data_filter.date_range_filter("acquired", gte=gte, lte=lte),
|
|
109
89
|
planet.data_filter.geometry_filter(geojson_data),
|
|
110
90
|
planet.data_filter.asset_filter([self.asset_type_id]),
|
|
111
91
|
]
|
|
@@ -242,7 +222,7 @@ class Planet(DataSource):
|
|
|
242
222
|
|
|
243
223
|
def ingest(
|
|
244
224
|
self,
|
|
245
|
-
tile_store:
|
|
225
|
+
tile_store: TileStoreWithLayer,
|
|
246
226
|
items: list[Item],
|
|
247
227
|
geometries: list[list[STGeometry]],
|
|
248
228
|
) -> None:
|
|
@@ -253,26 +233,10 @@ class Planet(DataSource):
|
|
|
253
233
|
items: the items to ingest
|
|
254
234
|
geometries: a list of geometries needed for each item
|
|
255
235
|
"""
|
|
256
|
-
for item
|
|
236
|
+
for item in items:
|
|
237
|
+
if tile_store.is_raster_ready(item.name, self.bands):
|
|
238
|
+
continue
|
|
239
|
+
|
|
257
240
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
tile_store, (item.name, "_".join(band_names))
|
|
261
|
-
)
|
|
262
|
-
needed_projections = get_needed_projections(
|
|
263
|
-
cur_tile_store, band_names, self.config.band_sets, cur_geometries
|
|
264
|
-
)
|
|
265
|
-
if not needed_projections:
|
|
266
|
-
continue
|
|
267
|
-
|
|
268
|
-
asset_path = asyncio.run(self._download_asset(item, tmp_dir))
|
|
269
|
-
with asset_path.open("rb") as f:
|
|
270
|
-
with rasterio.open(f) as raster:
|
|
271
|
-
for projection in needed_projections:
|
|
272
|
-
ingest_raster(
|
|
273
|
-
tile_store=cur_tile_store,
|
|
274
|
-
raster=raster,
|
|
275
|
-
projection=projection,
|
|
276
|
-
time_range=item.geometry.time_range,
|
|
277
|
-
layer_config=self.config,
|
|
278
|
-
)
|
|
241
|
+
asset_path = asyncio.run(self._download_asset(item, Path(tmp_dir)))
|
|
242
|
+
tile_store.write_raster_file(item.name, self.bands, asset_path)
|