rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -2,10 +2,11 @@
|
|
|
2
2
|
|
|
3
3
|
import io
|
|
4
4
|
import json
|
|
5
|
+
import os
|
|
5
6
|
import tempfile
|
|
6
7
|
import xml.etree.ElementTree as ET
|
|
7
8
|
from collections.abc import Callable, Generator
|
|
8
|
-
from datetime import
|
|
9
|
+
from datetime import UTC, datetime
|
|
9
10
|
from enum import Enum
|
|
10
11
|
from typing import Any, BinaryIO
|
|
11
12
|
|
|
@@ -14,7 +15,6 @@ import dateutil.parser
|
|
|
14
15
|
import fiona
|
|
15
16
|
import fiona.transform
|
|
16
17
|
import numpy.typing as npt
|
|
17
|
-
import pytimeparse
|
|
18
18
|
import rasterio
|
|
19
19
|
import shapely
|
|
20
20
|
import tqdm
|
|
@@ -22,27 +22,21 @@ from rasterio.crs import CRS
|
|
|
22
22
|
from upath import UPath
|
|
23
23
|
|
|
24
24
|
import rslearn.data_sources.utils
|
|
25
|
-
import rslearn.utils.mgrs
|
|
26
|
-
from rslearn.config import LayerConfig, RasterLayerConfig
|
|
27
25
|
from rslearn.const import SHAPEFILE_AUX_EXTENSIONS, WGS84_EPSG, WGS84_PROJECTION
|
|
28
|
-
from rslearn.tile_stores import
|
|
29
|
-
from rslearn.utils import
|
|
30
|
-
GridIndex,
|
|
31
|
-
Projection,
|
|
32
|
-
STGeometry,
|
|
33
|
-
daterange,
|
|
34
|
-
)
|
|
26
|
+
from rslearn.tile_stores import TileStoreWithLayer
|
|
27
|
+
from rslearn.utils import GridIndex, Projection, STGeometry, daterange
|
|
35
28
|
from rslearn.utils.fsspec import get_upath_local, join_upath, open_atomic
|
|
29
|
+
from rslearn.utils.raster_format import get_raster_projection_and_bounds
|
|
36
30
|
|
|
37
|
-
from .copernicus import get_harmonize_callback
|
|
31
|
+
from .copernicus import get_harmonize_callback, get_sentinel2_tiles
|
|
38
32
|
from .data_source import (
|
|
39
33
|
DataSource,
|
|
34
|
+
DataSourceContext,
|
|
40
35
|
Item,
|
|
41
36
|
ItemLookupDataSource,
|
|
42
37
|
QueryConfig,
|
|
43
38
|
RetrieveItemDataSource,
|
|
44
39
|
)
|
|
45
|
-
from .raster_source import get_needed_projections, ingest_raster
|
|
46
40
|
|
|
47
41
|
|
|
48
42
|
class NaipItem(Item):
|
|
@@ -66,7 +60,7 @@ class NaipItem(Item):
|
|
|
66
60
|
return d
|
|
67
61
|
|
|
68
62
|
@staticmethod
|
|
69
|
-
def deserialize(d: dict) ->
|
|
63
|
+
def deserialize(d: dict) -> "NaipItem":
|
|
70
64
|
"""Deserializes an item from a JSON-decoded dictionary."""
|
|
71
65
|
item = super(NaipItem, NaipItem).deserialize(d)
|
|
72
66
|
return NaipItem(
|
|
@@ -89,16 +83,15 @@ class Naip(DataSource):
|
|
|
89
83
|
|
|
90
84
|
def __init__(
|
|
91
85
|
self,
|
|
92
|
-
|
|
93
|
-
index_cache_dir: UPath,
|
|
86
|
+
index_cache_dir: str,
|
|
94
87
|
use_rtree_index: bool = False,
|
|
95
88
|
states: list[str] | None = None,
|
|
96
89
|
years: list[int] | None = None,
|
|
90
|
+
context: DataSourceContext = DataSourceContext(),
|
|
97
91
|
) -> None:
|
|
98
92
|
"""Initialize a new Naip instance.
|
|
99
93
|
|
|
100
94
|
Args:
|
|
101
|
-
config: the LayerConfig of the layer containing this data source.
|
|
102
95
|
index_cache_dir: directory to cache index shapefiles.
|
|
103
96
|
use_rtree_index: whether to create an rtree index to enable faster lookups
|
|
104
97
|
(default false)
|
|
@@ -106,40 +99,30 @@ class Naip(DataSource):
|
|
|
106
99
|
the search. If use_rtree_index is enabled, the rtree will only be
|
|
107
100
|
populated with data from these states.
|
|
108
101
|
years: optional list of years to restrict the search
|
|
102
|
+
context: the data source context.
|
|
109
103
|
"""
|
|
110
|
-
|
|
111
|
-
|
|
104
|
+
# If context is provided, we join the directory with the dataset path,
|
|
105
|
+
# otherwise we treat it directly as UPath.
|
|
106
|
+
if context.ds_path is not None:
|
|
107
|
+
self.index_cache_dir = join_upath(context.ds_path, index_cache_dir)
|
|
108
|
+
else:
|
|
109
|
+
self.index_cache_dir = UPath(index_cache_dir)
|
|
110
|
+
|
|
112
111
|
self.states = states
|
|
113
112
|
self.years = years
|
|
114
113
|
|
|
115
|
-
self.
|
|
114
|
+
self.index_cache_dir.mkdir(parents=True, exist_ok=True)
|
|
116
115
|
|
|
116
|
+
self.bucket = boto3.resource("s3").Bucket(self.bucket_name)
|
|
117
|
+
self.rtree_index: Any | None = None
|
|
117
118
|
if use_rtree_index:
|
|
118
119
|
from rslearn.utils.rtree_index import RtreeIndex, get_cached_rtree
|
|
119
120
|
|
|
120
|
-
def build_fn(index: RtreeIndex):
|
|
121
|
+
def build_fn(index: RtreeIndex) -> None:
|
|
121
122
|
for item in self._read_index_shapefiles(desc="Building rtree index"):
|
|
122
123
|
index.insert(item.geometry.shp.bounds, json.dumps(item.serialize()))
|
|
123
124
|
|
|
124
|
-
self.
|
|
125
|
-
self.rtree_index = get_cached_rtree(
|
|
126
|
-
self.index_cache_dir, self.rtree_tmp_dir.name, build_fn
|
|
127
|
-
)
|
|
128
|
-
else:
|
|
129
|
-
self.rtree_index = None
|
|
130
|
-
|
|
131
|
-
@staticmethod
|
|
132
|
-
def from_config(config: LayerConfig, ds_path: UPath) -> "Naip":
|
|
133
|
-
"""Creates a new Naip instance from a configuration dictionary."""
|
|
134
|
-
assert isinstance(config, RasterLayerConfig)
|
|
135
|
-
d = config.data_source.config_dict
|
|
136
|
-
kwargs = dict(
|
|
137
|
-
config=config,
|
|
138
|
-
index_cache_dir=join_upath(ds_path, d["index_cache_dir"]),
|
|
139
|
-
)
|
|
140
|
-
if "use_rtree_index" in d:
|
|
141
|
-
kwargs["use_rtree_index"] = d["use_rtree_index"]
|
|
142
|
-
return Naip(**kwargs)
|
|
125
|
+
self.rtree_index = get_cached_rtree(self.index_cache_dir, build_fn)
|
|
143
126
|
|
|
144
127
|
def _download_manifest(self) -> UPath:
|
|
145
128
|
"""Download the manifest that enumerates files in the bucket.
|
|
@@ -149,7 +132,7 @@ class Naip(DataSource):
|
|
|
149
132
|
"""
|
|
150
133
|
manifest_path = self.index_cache_dir / self.manifest_fname
|
|
151
134
|
if not manifest_path.exists():
|
|
152
|
-
with manifest_path
|
|
135
|
+
with open_atomic(manifest_path, "wb") as dst:
|
|
153
136
|
self.bucket.download_fileobj(
|
|
154
137
|
self.manifest_fname,
|
|
155
138
|
dst,
|
|
@@ -195,7 +178,9 @@ class Naip(DataSource):
|
|
|
195
178
|
blob_path, dst, ExtraArgs={"RequestPayer": "requester"}
|
|
196
179
|
)
|
|
197
180
|
|
|
198
|
-
def _read_index_shapefiles(
|
|
181
|
+
def _read_index_shapefiles(
|
|
182
|
+
self, desc: str | None = None
|
|
183
|
+
) -> Generator[NaipItem, None, None]:
|
|
199
184
|
"""Read the index shapefiles and yield NaipItems corresponding to each image."""
|
|
200
185
|
self._download_index_shapefiles()
|
|
201
186
|
|
|
@@ -275,7 +260,7 @@ class Naip(DataSource):
|
|
|
275
260
|
else:
|
|
276
261
|
src_img_date = fname_parts[5]
|
|
277
262
|
time = datetime.strptime(src_img_date, "%Y%m%d").replace(
|
|
278
|
-
tzinfo=
|
|
263
|
+
tzinfo=UTC
|
|
279
264
|
)
|
|
280
265
|
|
|
281
266
|
geometry = STGeometry(WGS84_PROJECTION, shp, (time, time))
|
|
@@ -288,7 +273,7 @@ class Naip(DataSource):
|
|
|
288
273
|
|
|
289
274
|
def get_items(
|
|
290
275
|
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
291
|
-
) -> list[list[list[
|
|
276
|
+
) -> list[list[list[NaipItem]]]:
|
|
292
277
|
"""Get a list of items in the data source intersecting the given geometries.
|
|
293
278
|
|
|
294
279
|
Args:
|
|
@@ -302,7 +287,7 @@ class Naip(DataSource):
|
|
|
302
287
|
geometry.to_projection(WGS84_PROJECTION) for geometry in geometries
|
|
303
288
|
]
|
|
304
289
|
|
|
305
|
-
items = [[] for _ in geometries]
|
|
290
|
+
items: list = [[] for _ in geometries]
|
|
306
291
|
if self.rtree_index:
|
|
307
292
|
for idx, geometry in enumerate(wgs84_geometries):
|
|
308
293
|
encoded_items = self.rtree_index.query(geometry.shp.bounds)
|
|
@@ -331,15 +316,15 @@ class Naip(DataSource):
|
|
|
331
316
|
groups.append(cur_groups)
|
|
332
317
|
return groups
|
|
333
318
|
|
|
334
|
-
def deserialize_item(self, serialized_item: Any) ->
|
|
319
|
+
def deserialize_item(self, serialized_item: Any) -> NaipItem:
|
|
335
320
|
"""Deserializes an item from JSON-decoded data."""
|
|
336
321
|
assert isinstance(serialized_item, dict)
|
|
337
322
|
return NaipItem.deserialize(serialized_item)
|
|
338
323
|
|
|
339
324
|
def ingest(
|
|
340
325
|
self,
|
|
341
|
-
tile_store:
|
|
342
|
-
items: list[
|
|
326
|
+
tile_store: TileStoreWithLayer,
|
|
327
|
+
items: list[NaipItem],
|
|
343
328
|
geometries: list[list[STGeometry]],
|
|
344
329
|
) -> None:
|
|
345
330
|
"""Ingest items into the given tile store.
|
|
@@ -349,29 +334,17 @@ class Naip(DataSource):
|
|
|
349
334
|
items: the items to ingest
|
|
350
335
|
geometries: a list of geometries needed for each item
|
|
351
336
|
"""
|
|
352
|
-
for item
|
|
337
|
+
for item in items:
|
|
353
338
|
bands = ["R", "G", "B", "IR"]
|
|
354
|
-
|
|
355
|
-
needed_projections = get_needed_projections(
|
|
356
|
-
cur_tile_store, bands, self.config.band_sets, cur_geometries
|
|
357
|
-
)
|
|
358
|
-
if not needed_projections:
|
|
339
|
+
if tile_store.is_raster_ready(item.name, bands):
|
|
359
340
|
continue
|
|
360
341
|
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
for projection in needed_projections:
|
|
368
|
-
ingest_raster(
|
|
369
|
-
tile_store=cur_tile_store,
|
|
370
|
-
raster=raster,
|
|
371
|
-
projection=projection,
|
|
372
|
-
time_range=item.geometry.time_range,
|
|
373
|
-
layer_config=self.config,
|
|
374
|
-
)
|
|
342
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
343
|
+
fname = os.path.join(tmp_dir, item.blob_path.split("/")[-1])
|
|
344
|
+
self.bucket.download_file(
|
|
345
|
+
item.blob_path, fname, ExtraArgs={"RequestPayer": "requester"}
|
|
346
|
+
)
|
|
347
|
+
tile_store.write_raster_file(item.name, bands, UPath(fname))
|
|
375
348
|
|
|
376
349
|
|
|
377
350
|
class Sentinel2Modality(Enum):
|
|
@@ -407,7 +380,7 @@ class Sentinel2Item(Item):
|
|
|
407
380
|
return d
|
|
408
381
|
|
|
409
382
|
@staticmethod
|
|
410
|
-
def deserialize(d: dict) ->
|
|
383
|
+
def deserialize(d: dict) -> "Sentinel2Item":
|
|
411
384
|
"""Deserializes an item from a JSON-decoded dictionary."""
|
|
412
385
|
if "name" not in d:
|
|
413
386
|
d["name"] = d["blob_path"].split("/")[-1].split(".tif")[0]
|
|
@@ -420,7 +393,9 @@ class Sentinel2Item(Item):
|
|
|
420
393
|
)
|
|
421
394
|
|
|
422
395
|
|
|
423
|
-
class Sentinel2(
|
|
396
|
+
class Sentinel2(
|
|
397
|
+
ItemLookupDataSource[Sentinel2Item], RetrieveItemDataSource[Sentinel2Item]
|
|
398
|
+
):
|
|
424
399
|
"""A data source for Sentinel-2 L1C and L2A imagery on AWS.
|
|
425
400
|
|
|
426
401
|
Specifically, uses the sentinel-s2-l1c and sentinel-s2-l2a S3 buckets maintained by
|
|
@@ -474,61 +449,39 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
|
|
|
474
449
|
|
|
475
450
|
def __init__(
|
|
476
451
|
self,
|
|
477
|
-
config: LayerConfig,
|
|
478
452
|
modality: Sentinel2Modality,
|
|
479
|
-
metadata_cache_dir:
|
|
480
|
-
max_time_delta: timedelta = timedelta(days=30),
|
|
453
|
+
metadata_cache_dir: str,
|
|
481
454
|
sort_by: str | None = None,
|
|
482
455
|
harmonize: bool = False,
|
|
456
|
+
context: DataSourceContext = DataSourceContext(),
|
|
483
457
|
) -> None:
|
|
484
458
|
"""Initialize a new Sentinel2 instance.
|
|
485
459
|
|
|
486
460
|
Args:
|
|
487
|
-
config: the LayerConfig of the layer containing this data source.
|
|
488
461
|
modality: L1C or L2A.
|
|
489
462
|
metadata_cache_dir: directory to cache product metadata files.
|
|
490
|
-
max_time_delta: maximum time before a query start time or after a
|
|
491
|
-
query end time to look for products. This is required due to the large
|
|
492
|
-
number of available products, and defaults to 30 days.
|
|
493
463
|
sort_by: can be "cloud_cover", default arbitrary order; only has effect for
|
|
494
464
|
SpaceMode.WITHIN.
|
|
495
465
|
harmonize: harmonize pixel values across different processing baselines,
|
|
496
466
|
see https://developers.google.com/earth-engine/datasets/catalog/COPERNICUS_S2_SR_HARMONIZED
|
|
467
|
+
context: the data source context.
|
|
497
468
|
""" # noqa: E501
|
|
498
|
-
|
|
469
|
+
# If context is provided, we join the directory with the dataset path,
|
|
470
|
+
# otherwise we treat it directly as UPath.
|
|
471
|
+
if context.ds_path is not None:
|
|
472
|
+
self.metadata_cache_dir = join_upath(context.ds_path, metadata_cache_dir)
|
|
473
|
+
else:
|
|
474
|
+
self.metadata_cache_dir = UPath(metadata_cache_dir)
|
|
475
|
+
|
|
499
476
|
self.modality = modality
|
|
500
|
-
self.metadata_cache_dir = metadata_cache_dir
|
|
501
|
-
self.max_time_delta = max_time_delta
|
|
502
477
|
self.sort_by = sort_by
|
|
503
478
|
self.harmonize = harmonize
|
|
504
479
|
|
|
505
480
|
bucket_name = self.bucket_names[modality]
|
|
506
481
|
self.bucket = boto3.resource("s3").Bucket(bucket_name)
|
|
507
482
|
|
|
508
|
-
@staticmethod
|
|
509
|
-
def from_config(config: LayerConfig, ds_path: UPath) -> "Sentinel2":
|
|
510
|
-
"""Creates a new Sentinel2 instance from a configuration dictionary."""
|
|
511
|
-
assert isinstance(config, RasterLayerConfig)
|
|
512
|
-
d = config.data_source.config_dict
|
|
513
|
-
kwargs = dict(
|
|
514
|
-
config=config,
|
|
515
|
-
modality=Sentinel2Modality(d["modality"]),
|
|
516
|
-
metadata_cache_dir=join_upath(ds_path, d["metadata_cache_dir"]),
|
|
517
|
-
)
|
|
518
|
-
|
|
519
|
-
if "max_time_delta" in d:
|
|
520
|
-
kwargs["max_time_delta"] = timedelta(
|
|
521
|
-
seconds=pytimeparse.parse(d["max_time_delta"])
|
|
522
|
-
)
|
|
523
|
-
simple_optionals = ["sort_by", "harmonize"]
|
|
524
|
-
for k in simple_optionals:
|
|
525
|
-
if k in d:
|
|
526
|
-
kwargs[k] = d[k]
|
|
527
|
-
|
|
528
|
-
return Sentinel2(**kwargs)
|
|
529
|
-
|
|
530
483
|
def _read_products(
|
|
531
|
-
self, needed_cell_months: set[tuple[str, int, int
|
|
484
|
+
self, needed_cell_months: set[tuple[str, int, int]]
|
|
532
485
|
) -> Generator[Sentinel2Item, None, None]:
|
|
533
486
|
"""Read productInfo.json files and yield relevant Sentinel2Items.
|
|
534
487
|
|
|
@@ -603,7 +556,7 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
|
|
|
603
556
|
|
|
604
557
|
def get_items(
|
|
605
558
|
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
606
|
-
) -> list[list[list[
|
|
559
|
+
) -> list[list[list[Sentinel2Item]]]:
|
|
607
560
|
"""Get a list of items in the data source intersecting the given geometries.
|
|
608
561
|
|
|
609
562
|
Args:
|
|
@@ -626,14 +579,14 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
|
|
|
626
579
|
raise ValueError(
|
|
627
580
|
"Sentinel2 on AWS requires geometry time ranges to be set"
|
|
628
581
|
)
|
|
629
|
-
for cell_id in
|
|
582
|
+
for cell_id in get_sentinel2_tiles(wgs84_geometry, self.metadata_cache_dir):
|
|
630
583
|
for ts in daterange(
|
|
631
|
-
wgs84_geometry.time_range[0]
|
|
632
|
-
wgs84_geometry.time_range[1]
|
|
584
|
+
wgs84_geometry.time_range[0],
|
|
585
|
+
wgs84_geometry.time_range[1],
|
|
633
586
|
):
|
|
634
587
|
needed_cell_months.add((cell_id, ts.year, ts.month))
|
|
635
588
|
|
|
636
|
-
items_by_cell = {}
|
|
589
|
+
items_by_cell: dict[str, list[Sentinel2Item]] = {}
|
|
637
590
|
for item in self._read_products(needed_cell_months):
|
|
638
591
|
cell_id = "".join(item.blob_path.split("/")[1:4])
|
|
639
592
|
if cell_id not in items_by_cell:
|
|
@@ -643,7 +596,7 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
|
|
|
643
596
|
groups = []
|
|
644
597
|
for geometry, wgs84_geometry in zip(geometries, wgs84_geometries):
|
|
645
598
|
items = []
|
|
646
|
-
for cell_id in
|
|
599
|
+
for cell_id in get_sentinel2_tiles(wgs84_geometry, self.metadata_cache_dir):
|
|
647
600
|
for item in items_by_cell.get(cell_id, []):
|
|
648
601
|
try:
|
|
649
602
|
item_geom = item.geometry.to_projection(geometry.projection)
|
|
@@ -666,7 +619,7 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
|
|
|
666
619
|
|
|
667
620
|
return groups
|
|
668
621
|
|
|
669
|
-
def get_item_by_name(self, name: str) ->
|
|
622
|
+
def get_item_by_name(self, name: str) -> Sentinel2Item:
|
|
670
623
|
"""Gets an item by name."""
|
|
671
624
|
# Product name is like:
|
|
672
625
|
# S2B_MSIL1C_20240201T230819_N0510_R015_T51CWM_20240202T012755.
|
|
@@ -685,12 +638,14 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
|
|
|
685
638
|
return item
|
|
686
639
|
raise ValueError(f"item {name} not found")
|
|
687
640
|
|
|
688
|
-
def deserialize_item(self, serialized_item: Any) ->
|
|
641
|
+
def deserialize_item(self, serialized_item: Any) -> Sentinel2Item:
|
|
689
642
|
"""Deserializes an item from JSON-decoded data."""
|
|
690
643
|
assert isinstance(serialized_item, dict)
|
|
691
644
|
return Sentinel2Item.deserialize(serialized_item)
|
|
692
645
|
|
|
693
|
-
def retrieve_item(
|
|
646
|
+
def retrieve_item(
|
|
647
|
+
self, item: Sentinel2Item
|
|
648
|
+
) -> Generator[tuple[str, BinaryIO], None, None]:
|
|
694
649
|
"""Retrieves the rasters corresponding to an item as file streams."""
|
|
695
650
|
for fname, _ in self.band_fnames[self.modality]:
|
|
696
651
|
buf = io.BytesIO()
|
|
@@ -701,7 +656,7 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
|
|
|
701
656
|
yield (fname, buf)
|
|
702
657
|
|
|
703
658
|
def _get_harmonize_callback(
|
|
704
|
-
self, item:
|
|
659
|
+
self, item: Sentinel2Item
|
|
705
660
|
) -> Callable[[npt.NDArray], npt.NDArray] | None:
|
|
706
661
|
"""Gets the harmonization callback for the given item.
|
|
707
662
|
|
|
@@ -715,6 +670,8 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
|
|
|
715
670
|
return None
|
|
716
671
|
# Search metadata XML for the RADIO_ADD_OFFSET tag.
|
|
717
672
|
# This contains the per-band offset, but we assume all bands have the same offset.
|
|
673
|
+
if item.geometry.time_range is None:
|
|
674
|
+
raise ValueError("Sentinel2 on AWS requires geometry time ranges to be set")
|
|
718
675
|
ts = item.geometry.time_range[0]
|
|
719
676
|
metadata_fname = (
|
|
720
677
|
f"products/{ts.year}/{ts.month}/{ts.day}/{item.name}/metadata.xml"
|
|
@@ -724,13 +681,15 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
|
|
|
724
681
|
metadata_fname, buf, ExtraArgs={"RequestPayer": "requester"}
|
|
725
682
|
)
|
|
726
683
|
buf.seek(0)
|
|
727
|
-
tree
|
|
684
|
+
tree: ET.ElementTree[ET.Element[str]] = ET.ElementTree(
|
|
685
|
+
ET.fromstring(buf.getvalue())
|
|
686
|
+
)
|
|
728
687
|
return get_harmonize_callback(tree)
|
|
729
688
|
|
|
730
689
|
def ingest(
|
|
731
690
|
self,
|
|
732
|
-
tile_store:
|
|
733
|
-
items: list[
|
|
691
|
+
tile_store: TileStoreWithLayer,
|
|
692
|
+
items: list[Sentinel2Item],
|
|
734
693
|
geometries: list[list[STGeometry]],
|
|
735
694
|
) -> None:
|
|
736
695
|
"""Ingest items into the given tile store.
|
|
@@ -740,42 +699,43 @@ class Sentinel2(ItemLookupDataSource, RetrieveItemDataSource):
|
|
|
740
699
|
items: the items to ingest
|
|
741
700
|
geometries: a list of geometries needed for each item
|
|
742
701
|
"""
|
|
743
|
-
for item
|
|
744
|
-
harmonize_callback = self._get_harmonize_callback(item)
|
|
745
|
-
|
|
702
|
+
for item in items:
|
|
746
703
|
for fname, band_names in self.band_fnames[self.modality]:
|
|
747
|
-
|
|
748
|
-
tile_store, (item.name, "_".join(band_names))
|
|
749
|
-
)
|
|
750
|
-
needed_projections = get_needed_projections(
|
|
751
|
-
cur_tile_store, band_names, self.config.band_sets, cur_geometries
|
|
752
|
-
)
|
|
753
|
-
if not needed_projections:
|
|
704
|
+
if tile_store.is_raster_ready(item.name, band_names):
|
|
754
705
|
continue
|
|
755
706
|
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
707
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
708
|
+
local_fname = os.path.join(tmp_dir, fname.split("/")[-1])
|
|
709
|
+
|
|
710
|
+
try:
|
|
711
|
+
self.bucket.download_file(
|
|
712
|
+
item.blob_path + fname,
|
|
713
|
+
local_fname,
|
|
714
|
+
ExtraArgs={"RequestPayer": "requester"},
|
|
715
|
+
)
|
|
716
|
+
except Exception as e:
|
|
717
|
+
# TODO: sometimes for some reason object doesn't exist
|
|
718
|
+
# we should probably investigate further why it happens
|
|
719
|
+
# and then should create the layer here and mark it completed
|
|
720
|
+
print(
|
|
721
|
+
f"warning: got error {e} downloading {item.blob_path + fname}"
|
|
722
|
+
)
|
|
723
|
+
continue
|
|
724
|
+
|
|
725
|
+
harmonize_callback = self._get_harmonize_callback(item)
|
|
726
|
+
|
|
727
|
+
if harmonize_callback is not None:
|
|
728
|
+
# In this case we need to read the array, convert the pixel
|
|
729
|
+
# values, and pass modified array directly to the TileStore.
|
|
730
|
+
with rasterio.open(local_fname) as src:
|
|
731
|
+
array = src.read()
|
|
732
|
+
projection, bounds = get_raster_projection_and_bounds(src)
|
|
733
|
+
array = harmonize_callback(array)
|
|
734
|
+
tile_store.write_raster(
|
|
735
|
+
item.name, band_names, projection, bounds, array
|
|
736
|
+
)
|
|
737
|
+
|
|
738
|
+
else:
|
|
739
|
+
tile_store.write_raster_file(
|
|
740
|
+
item.name, band_names, UPath(local_fname)
|
|
781
741
|
)
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""Data source for Sentinel-1 on AWS."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import tempfile
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import boto3
|
|
8
|
+
from upath import UPath
|
|
9
|
+
|
|
10
|
+
from rslearn.data_sources.copernicus import (
|
|
11
|
+
CopernicusItem,
|
|
12
|
+
Sentinel1OrbitDirection,
|
|
13
|
+
Sentinel1Polarisation,
|
|
14
|
+
Sentinel1ProductType,
|
|
15
|
+
)
|
|
16
|
+
from rslearn.data_sources.copernicus import Sentinel1 as CopernicusSentinel1
|
|
17
|
+
from rslearn.log_utils import get_logger
|
|
18
|
+
from rslearn.tile_stores import TileStore, TileStoreWithLayer
|
|
19
|
+
from rslearn.utils.geometry import STGeometry
|
|
20
|
+
|
|
21
|
+
from .data_source import DataSource, DataSourceContext, QueryConfig
|
|
22
|
+
|
|
23
|
+
WRS2_GRID_SIZE = 1.0
|
|
24
|
+
|
|
25
|
+
logger = get_logger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Sentinel1(DataSource, TileStore):
|
|
29
|
+
"""A data source for Sentinel-1 GRD imagery on AWS.
|
|
30
|
+
|
|
31
|
+
Specifically, uses the sentinel-s1-l1c S3 bucket maintained by Sinergise. See
|
|
32
|
+
https://aws.amazon.com/marketplace/pp/prodview-uxrsbvhd35ifw for details about the
|
|
33
|
+
bucket.
|
|
34
|
+
|
|
35
|
+
We use the Copernicus API for metadata search. So the bucket is only used for
|
|
36
|
+
downloading the images.
|
|
37
|
+
|
|
38
|
+
Currently, it only supports GRD IW DV scenes.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
bucket_name = "sentinel-s1-l1c"
|
|
42
|
+
bands = ["vv", "vh"]
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
orbit_direction: Sentinel1OrbitDirection | None = None,
|
|
47
|
+
context: DataSourceContext = DataSourceContext(),
|
|
48
|
+
) -> None:
|
|
49
|
+
"""Initialize a new Sentinel1 instance.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
orbit_direction: optional orbit direction to filter by.
|
|
53
|
+
context: the data source context.
|
|
54
|
+
"""
|
|
55
|
+
self.client = boto3.client("s3")
|
|
56
|
+
self.bucket = boto3.resource("s3").Bucket(self.bucket_name)
|
|
57
|
+
self.sentinel1 = CopernicusSentinel1(
|
|
58
|
+
product_type=Sentinel1ProductType.IW_GRDH,
|
|
59
|
+
polarisation=Sentinel1Polarisation.VV_VH,
|
|
60
|
+
orbit_direction=orbit_direction,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def get_items(
|
|
64
|
+
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
65
|
+
) -> list[list[list[CopernicusItem]]]:
|
|
66
|
+
"""Get a list of items in the data source intersecting the given geometries.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
geometries: the spatiotemporal geometries
|
|
70
|
+
query_config: the query configuration
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
List of groups of items that should be retrieved for each geometry.
|
|
74
|
+
"""
|
|
75
|
+
return self.sentinel1.get_items(geometries, query_config)
|
|
76
|
+
|
|
77
|
+
def get_item_by_name(self, name: str) -> CopernicusItem:
|
|
78
|
+
"""Gets an item by name."""
|
|
79
|
+
return self.sentinel1.get_item_by_name(name)
|
|
80
|
+
|
|
81
|
+
def deserialize_item(self, serialized_item: Any) -> CopernicusItem:
|
|
82
|
+
"""Deserializes an item from JSON-decoded data."""
|
|
83
|
+
assert isinstance(serialized_item, dict)
|
|
84
|
+
return CopernicusItem.deserialize(serialized_item)
|
|
85
|
+
|
|
86
|
+
def ingest(
|
|
87
|
+
self,
|
|
88
|
+
tile_store: TileStoreWithLayer,
|
|
89
|
+
items: list[CopernicusItem],
|
|
90
|
+
geometries: list[list[STGeometry]],
|
|
91
|
+
) -> None:
|
|
92
|
+
"""Ingest items into the given tile store.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
tile_store: the tile store to ingest into
|
|
96
|
+
items: the items to ingest
|
|
97
|
+
geometries: a list of geometries needed for each item
|
|
98
|
+
"""
|
|
99
|
+
for item in items:
|
|
100
|
+
for band in self.bands:
|
|
101
|
+
band_names = [band]
|
|
102
|
+
if tile_store.is_raster_ready(item.name, band_names):
|
|
103
|
+
continue
|
|
104
|
+
|
|
105
|
+
# Item name is like "S1C_IW_GRDH_1SDV_20250528T172106_20250528T172131_002534_00545C_B433.SAFE".
|
|
106
|
+
item_name_prefix = item.name.split(".")[0]
|
|
107
|
+
time_str = item_name_prefix.split("_")[4]
|
|
108
|
+
if len(time_str) != 15:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
f"expected 15-character time string but got {time_str}"
|
|
111
|
+
)
|
|
112
|
+
# We convert to int here since path in bucket isn't padded with leading 0s.
|
|
113
|
+
year = int(time_str[0:4])
|
|
114
|
+
month = int(time_str[4:6])
|
|
115
|
+
day = int(time_str[6:8])
|
|
116
|
+
blob_path = f"GRD/{year}/{month}/{day}/IW/DV/{item_name_prefix}/measurement/iw-{band}.tiff"
|
|
117
|
+
|
|
118
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
119
|
+
fname = os.path.join(tmp_dir, f"{band}.tif")
|
|
120
|
+
try:
|
|
121
|
+
self.bucket.download_file(
|
|
122
|
+
blob_path,
|
|
123
|
+
fname,
|
|
124
|
+
ExtraArgs={"RequestPayer": "requester"},
|
|
125
|
+
)
|
|
126
|
+
except:
|
|
127
|
+
logger.error(
|
|
128
|
+
f"encountered error while downloading s3://{self.bucket_name}/{blob_path}"
|
|
129
|
+
)
|
|
130
|
+
raise
|
|
131
|
+
tile_store.write_raster_file(item.name, band_names, UPath(fname))
|