rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -6,10 +6,12 @@ import json
|
|
|
6
6
|
import os
|
|
7
7
|
import tempfile
|
|
8
8
|
import time
|
|
9
|
-
from datetime import
|
|
9
|
+
from datetime import UTC, datetime
|
|
10
10
|
from typing import Any
|
|
11
11
|
|
|
12
12
|
import ee
|
|
13
|
+
import numpy as np
|
|
14
|
+
import numpy.typing as npt
|
|
13
15
|
import rasterio
|
|
14
16
|
import rasterio.merge
|
|
15
17
|
import shapely
|
|
@@ -18,53 +20,96 @@ from google.cloud import storage
|
|
|
18
20
|
from upath import UPath
|
|
19
21
|
|
|
20
22
|
import rslearn.data_sources.utils
|
|
21
|
-
|
|
22
|
-
from rslearn.config import DType, LayerConfig, RasterLayerConfig
|
|
23
|
+
from rslearn.config import DType, LayerConfig
|
|
23
24
|
from rslearn.const import WGS84_PROJECTION
|
|
24
|
-
from rslearn.
|
|
25
|
-
from rslearn.
|
|
25
|
+
from rslearn.dataset.materialize import RasterMaterializer
|
|
26
|
+
from rslearn.dataset.window import Window
|
|
27
|
+
from rslearn.log_utils import get_logger
|
|
28
|
+
from rslearn.tile_stores import TileStore, TileStoreWithLayer
|
|
29
|
+
from rslearn.utils.array import copy_spatial_array
|
|
26
30
|
from rslearn.utils.fsspec import join_upath
|
|
27
|
-
from rslearn.utils.
|
|
31
|
+
from rslearn.utils.geometry import PixelBounds, Projection, STGeometry
|
|
32
|
+
from rslearn.utils.raster_format import (
|
|
33
|
+
Resampling,
|
|
34
|
+
get_raster_projection_and_bounds_from_transform,
|
|
35
|
+
get_transform_from_projection_and_bounds,
|
|
36
|
+
)
|
|
37
|
+
from rslearn.utils.rtree_index import RtreeIndex, get_cached_rtree
|
|
28
38
|
|
|
29
|
-
from .data_source import DataSource, Item, QueryConfig
|
|
30
|
-
from .raster_source import ArrayWithTransform, get_needed_projections, ingest_raster
|
|
39
|
+
from .data_source import DataSource, DataSourceContext, Item, QueryConfig
|
|
31
40
|
|
|
41
|
+
logger = get_logger(__name__)
|
|
32
42
|
|
|
33
|
-
|
|
43
|
+
|
|
44
|
+
class NoValidPixelsException(Exception):
|
|
45
|
+
"""Exception when GEE API reports that export failed due to no valid pixels."""
|
|
46
|
+
|
|
47
|
+
# Expected GEE error_message when the task fails.
|
|
48
|
+
GEE_MESSAGE = "No valid (un-masked) pixels in export region."
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class ExportException(Exception):
|
|
52
|
+
"""GEE API export error."""
|
|
53
|
+
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class GEE(DataSource, TileStore):
|
|
34
58
|
"""A data source for ingesting images from Google Earth Engine."""
|
|
35
59
|
|
|
36
60
|
def __init__(
|
|
37
61
|
self,
|
|
38
|
-
config: LayerConfig,
|
|
39
62
|
collection_name: str,
|
|
40
63
|
gcs_bucket_name: str,
|
|
41
|
-
index_cache_dir:
|
|
64
|
+
index_cache_dir: str,
|
|
42
65
|
service_account_name: str,
|
|
43
66
|
service_account_credentials: str,
|
|
67
|
+
bands: list[str] | None = None,
|
|
44
68
|
filters: list[tuple[str, Any]] | None = None,
|
|
45
69
|
dtype: DType | None = None,
|
|
70
|
+
context: DataSourceContext = DataSourceContext(),
|
|
46
71
|
) -> None:
|
|
47
72
|
"""Initialize a new GEE instance.
|
|
48
73
|
|
|
49
74
|
Args:
|
|
50
|
-
|
|
51
|
-
collection_name: the Earth Engine collection to ingest images from
|
|
75
|
+
collection_name: the Earth Engine ImageCollection to ingest images from
|
|
52
76
|
gcs_bucket_name: the Cloud Storage bucket to export GEE images to
|
|
53
77
|
index_cache_dir: cache directory to store rtree index
|
|
54
78
|
service_account_name: name of the service account to use for authentication
|
|
55
79
|
service_account_credentials: service account credentials filename
|
|
80
|
+
bands: the list of bands to ingest, in case the layer config is not present
|
|
81
|
+
in the context.
|
|
56
82
|
filters: optional list of tuples (property_name, property_value) to filter
|
|
57
83
|
images (using ee.Filter.eq)
|
|
58
84
|
dtype: optional desired array data type. If the data obtained from GEE does
|
|
59
85
|
not match this type, then it is converted.
|
|
86
|
+
context: the data source context.
|
|
60
87
|
"""
|
|
61
|
-
self.config = config
|
|
62
88
|
self.collection_name = collection_name
|
|
63
89
|
self.gcs_bucket_name = gcs_bucket_name
|
|
64
|
-
self.index_cache_dir = index_cache_dir
|
|
65
90
|
self.filters = filters
|
|
66
91
|
self.dtype = dtype
|
|
67
92
|
|
|
93
|
+
# Get index cache dir depending on dataset path.
|
|
94
|
+
if context.ds_path is not None:
|
|
95
|
+
self.index_cache_dir = join_upath(context.ds_path, index_cache_dir)
|
|
96
|
+
else:
|
|
97
|
+
self.index_cache_dir = UPath(index_cache_dir)
|
|
98
|
+
|
|
99
|
+
# Get bands we need to export.
|
|
100
|
+
if context.layer_config is not None:
|
|
101
|
+
self.bands = [
|
|
102
|
+
band
|
|
103
|
+
for band_set in context.layer_config.band_sets
|
|
104
|
+
for band in band_set.bands
|
|
105
|
+
]
|
|
106
|
+
elif bands is not None:
|
|
107
|
+
self.bands = bands
|
|
108
|
+
else:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
"bands must be specified if layer_config is not present in the context"
|
|
111
|
+
)
|
|
112
|
+
|
|
68
113
|
self.bucket = storage.Client().bucket(self.gcs_bucket_name)
|
|
69
114
|
|
|
70
115
|
credentials = ee.ServiceAccountCredentials(
|
|
@@ -72,44 +117,27 @@ class GEE(DataSource):
|
|
|
72
117
|
)
|
|
73
118
|
ee.Initialize(credentials)
|
|
74
119
|
|
|
75
|
-
self.
|
|
76
|
-
self.rtree_index = get_cached_rtree(
|
|
77
|
-
self.index_cache_dir, self.rtree_tmp_dir.name, self._build_index
|
|
78
|
-
)
|
|
120
|
+
self.index_cache_dir.mkdir(parents=True, exist_ok=True)
|
|
121
|
+
self.rtree_index = get_cached_rtree(self.index_cache_dir, self._build_index)
|
|
79
122
|
|
|
80
|
-
|
|
81
|
-
def from_config(config: LayerConfig, ds_path: UPath) -> "GEE":
|
|
82
|
-
"""Creates a new GEE instance from a configuration dictionary."""
|
|
83
|
-
d = config.data_source.config_dict
|
|
84
|
-
kwargs = {
|
|
85
|
-
"config": config,
|
|
86
|
-
"collection_name": d["collection_name"],
|
|
87
|
-
"gcs_bucket_name": d["gcs_bucket_name"],
|
|
88
|
-
"service_account_name": d["service_account_name"],
|
|
89
|
-
"service_account_credentials": d["service_account_credentials"],
|
|
90
|
-
"filters": d.get("filters"),
|
|
91
|
-
"index_cache_dir": join_upath(ds_path, d["index_cache_dir"]),
|
|
92
|
-
}
|
|
93
|
-
if "dtype" in d:
|
|
94
|
-
kwargs["dtype"] = DType(d["dtype"])
|
|
95
|
-
|
|
96
|
-
return GEE(**kwargs)
|
|
97
|
-
|
|
98
|
-
def get_collection(self):
|
|
123
|
+
def get_collection(self) -> ee.ImageCollection:
|
|
99
124
|
"""Returns the Earth Engine image collection for this data source."""
|
|
100
125
|
image_collection = ee.ImageCollection(self.collection_name)
|
|
126
|
+
if self.filters is None:
|
|
127
|
+
return image_collection
|
|
128
|
+
|
|
101
129
|
for k, v in self.filters:
|
|
102
130
|
cur_filter = ee.Filter.eq(k, v)
|
|
103
131
|
image_collection = image_collection.filter(cur_filter)
|
|
104
132
|
return image_collection
|
|
105
133
|
|
|
106
|
-
def _build_index(self, rtree_index):
|
|
134
|
+
def _build_index(self, rtree_index: RtreeIndex) -> None:
|
|
107
135
|
csv_blob = self.bucket.blob(f"{self.collection_name}/index.csv")
|
|
108
136
|
|
|
109
137
|
if not csv_blob.exists():
|
|
110
138
|
# Export feature collection of image metadata to GCS.
|
|
111
|
-
def image_to_feature(image):
|
|
112
|
-
geometry = image.geometry().transform(proj="EPSG:4326")
|
|
139
|
+
def image_to_feature(image: ee.Image) -> ee.Feature:
|
|
140
|
+
geometry = image.geometry().transform(proj="EPSG:4326", maxError=0.001)
|
|
113
141
|
return ee.Feature(geometry, {"time": image.date().format()})
|
|
114
142
|
|
|
115
143
|
fc = self.get_collection().map(image_to_feature)
|
|
@@ -121,17 +149,23 @@ class GEE(DataSource):
|
|
|
121
149
|
fileFormat="CSV",
|
|
122
150
|
)
|
|
123
151
|
task.start()
|
|
124
|
-
|
|
125
|
-
"
|
|
126
|
-
|
|
152
|
+
logger.info(
|
|
153
|
+
"Started task to export GEE index for image collection %s",
|
|
154
|
+
self.collection_name,
|
|
127
155
|
)
|
|
128
156
|
while True:
|
|
129
157
|
time.sleep(10)
|
|
130
158
|
status_dict = task.status()
|
|
131
|
-
|
|
159
|
+
logger.debug(
|
|
160
|
+
"Waiting for export task to complete, current status is %s",
|
|
161
|
+
status_dict,
|
|
162
|
+
)
|
|
132
163
|
if status_dict["state"] in ["UNSUBMITTED", "READY", "RUNNING"]:
|
|
133
164
|
continue
|
|
134
|
-
|
|
165
|
+
elif status_dict["state"] != "COMPLETED":
|
|
166
|
+
raise ValueError(
|
|
167
|
+
f"got unexpected GEE task state {status_dict['state']}"
|
|
168
|
+
)
|
|
135
169
|
break
|
|
136
170
|
|
|
137
171
|
# Read the CSV and add rows into the rtree index.
|
|
@@ -141,15 +175,31 @@ class GEE(DataSource):
|
|
|
141
175
|
shp = shapely.geometry.shape(json.loads(row[".geo"]))
|
|
142
176
|
if "E" in row["time"]:
|
|
143
177
|
unix_time = float(row["time"]) / 1000
|
|
144
|
-
ts = datetime.fromtimestamp(unix_time, tz=
|
|
178
|
+
ts = datetime.fromtimestamp(unix_time, tz=UTC)
|
|
145
179
|
else:
|
|
146
|
-
ts = datetime.fromisoformat(row["time"]).replace(
|
|
147
|
-
tzinfo=timezone.utc
|
|
148
|
-
)
|
|
180
|
+
ts = datetime.fromisoformat(row["time"]).replace(tzinfo=UTC)
|
|
149
181
|
geometry = STGeometry(WGS84_PROJECTION, shp, (ts, ts))
|
|
150
182
|
item = Item(row["system:index"], geometry)
|
|
151
183
|
rtree_index.insert(shp.bounds, json.dumps(item.serialize()))
|
|
152
184
|
|
|
185
|
+
def get_item_by_name(self, name: str) -> Item:
|
|
186
|
+
"""Gets an item by name.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
name: the name of the item to get
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
the item object
|
|
193
|
+
"""
|
|
194
|
+
filtered = self.get_collection().filter(ee.Filter.eq("system:index", name))
|
|
195
|
+
image = filtered.first()
|
|
196
|
+
shp = shapely.geometry.shape(
|
|
197
|
+
image.geometry().transform(proj="EPSG:4326", maxError=0.001).getInfo()
|
|
198
|
+
)
|
|
199
|
+
ts = datetime.fromisoformat(image.date().format().getInfo()).replace(tzinfo=UTC)
|
|
200
|
+
geometry = STGeometry(WGS84_PROJECTION, shp, (ts, ts))
|
|
201
|
+
return Item(name, geometry)
|
|
202
|
+
|
|
153
203
|
def get_items(
|
|
154
204
|
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
155
205
|
) -> list[list[list[Item]]]:
|
|
@@ -176,7 +226,7 @@ class GEE(DataSource):
|
|
|
176
226
|
continue
|
|
177
227
|
cur_items.append(item)
|
|
178
228
|
|
|
179
|
-
cur_items.sort(key=lambda item: item.geometry.time_range[0])
|
|
229
|
+
cur_items.sort(key=lambda item: item.geometry.time_range[0]) # type: ignore
|
|
180
230
|
|
|
181
231
|
cur_groups = rslearn.data_sources.utils.match_candidate_items_to_window(
|
|
182
232
|
geometry, cur_items, query_config
|
|
@@ -190,9 +240,143 @@ class GEE(DataSource):
|
|
|
190
240
|
assert isinstance(serialized_item, dict)
|
|
191
241
|
return Item.deserialize(serialized_item)
|
|
192
242
|
|
|
243
|
+
def item_to_image(self, item: Item) -> ee.image.Image:
|
|
244
|
+
"""Get the Image corresponding to the Item.
|
|
245
|
+
|
|
246
|
+
This function is separated so it can be overriden if subclasses want to add
|
|
247
|
+
modifications to the image.
|
|
248
|
+
"""
|
|
249
|
+
filtered = self.get_collection().filter(ee.Filter.eq("system:index", item.name))
|
|
250
|
+
image = filtered.first()
|
|
251
|
+
image = image.select(self.bands)
|
|
252
|
+
return image
|
|
253
|
+
|
|
254
|
+
def export_item(
|
|
255
|
+
self,
|
|
256
|
+
item: Item,
|
|
257
|
+
blob_prefix: str,
|
|
258
|
+
projection_and_bounds: tuple[Projection, PixelBounds] | None = None,
|
|
259
|
+
) -> None:
|
|
260
|
+
"""Export the item to the specified folder.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
item: the item to export.
|
|
264
|
+
blob_prefix: the prefix (folder) to use.
|
|
265
|
+
projection_and_bounds: optionally use this projection and bounds instead of
|
|
266
|
+
the extent of the image.
|
|
267
|
+
"""
|
|
268
|
+
image = self.item_to_image(item)
|
|
269
|
+
projection = image.select(self.bands[0]).projection().getInfo()
|
|
270
|
+
logger.info("Starting task to retrieve image %s", item.name)
|
|
271
|
+
|
|
272
|
+
extent_kwargs: dict[str, Any]
|
|
273
|
+
if projection_and_bounds is not None:
|
|
274
|
+
projection, bounds = projection_and_bounds
|
|
275
|
+
transform = get_transform_from_projection_and_bounds(projection, bounds)
|
|
276
|
+
width = bounds[2] - bounds[0]
|
|
277
|
+
height = bounds[3] - bounds[1]
|
|
278
|
+
extent_kwargs = dict(
|
|
279
|
+
crs=str(projection.crs),
|
|
280
|
+
crsTransform=[
|
|
281
|
+
transform.a,
|
|
282
|
+
transform.b,
|
|
283
|
+
transform.c,
|
|
284
|
+
transform.d,
|
|
285
|
+
transform.e,
|
|
286
|
+
transform.f,
|
|
287
|
+
],
|
|
288
|
+
dimensions=f"{width}x{height}",
|
|
289
|
+
)
|
|
290
|
+
else:
|
|
291
|
+
# Use the native projection of the image.
|
|
292
|
+
# We pass scale instead of crsTransform since some images have positive y
|
|
293
|
+
# resolution which means they are upside down and rasterio cannot merge
|
|
294
|
+
# them.
|
|
295
|
+
extent_kwargs = dict(
|
|
296
|
+
crs=projection["crs"],
|
|
297
|
+
scale=projection["transform"][0],
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
task = ee.batch.Export.image.toCloudStorage(
|
|
301
|
+
image=image,
|
|
302
|
+
description=item.name,
|
|
303
|
+
bucket=self.gcs_bucket_name,
|
|
304
|
+
fileNamePrefix=blob_prefix,
|
|
305
|
+
maxPixels=10000000000,
|
|
306
|
+
fileFormat="GeoTIFF",
|
|
307
|
+
skipEmptyTiles=True,
|
|
308
|
+
**extent_kwargs,
|
|
309
|
+
)
|
|
310
|
+
task.start()
|
|
311
|
+
while True:
|
|
312
|
+
time.sleep(10)
|
|
313
|
+
status_dict = task.status()
|
|
314
|
+
if status_dict["state"] in ["UNSUBMITTED", "READY", "RUNNING"]:
|
|
315
|
+
continue
|
|
316
|
+
if status_dict["state"] == "COMPLETED":
|
|
317
|
+
break
|
|
318
|
+
if status_dict["state"] != "FAILED":
|
|
319
|
+
raise ValueError(
|
|
320
|
+
f"got unexpected GEE task state {status_dict['state']}"
|
|
321
|
+
)
|
|
322
|
+
# The task failed. We see if it is an okay failure case or if we need to
|
|
323
|
+
# raise exception.
|
|
324
|
+
if status_dict["error_message"] == NoValidPixelsException.GEE_MESSAGE:
|
|
325
|
+
raise NoValidPixelsException()
|
|
326
|
+
raise ExportException(f"GEE task failed: {status_dict['error_message']}")
|
|
327
|
+
|
|
328
|
+
def _merge_rasters(
|
|
329
|
+
self,
|
|
330
|
+
blobs: list[storage.Blob],
|
|
331
|
+
crs_bounds: tuple[float, float, float, float] | None = None,
|
|
332
|
+
res: float | None = None,
|
|
333
|
+
) -> tuple[npt.NDArray, Projection, PixelBounds]:
|
|
334
|
+
"""Merge multiple rasters split up during export by GEE.
|
|
335
|
+
|
|
336
|
+
GEE can produce multiple rasters if it determines the file size exceeds its
|
|
337
|
+
internal limit. So in this case we stitch them back together.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
blobs: the list of GCS blobs where the rasters were written.
|
|
341
|
+
crs_bounds: generate merged output under this bounds, in CRS coordinates
|
|
342
|
+
(not pixel units).
|
|
343
|
+
res: generate merged output under this resolution.
|
|
344
|
+
|
|
345
|
+
Returns:
|
|
346
|
+
a tuple (array, projection, bounds) where the projection and bounds
|
|
347
|
+
indicate the extent of the array.
|
|
348
|
+
"""
|
|
349
|
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
350
|
+
rasterio_datasets = []
|
|
351
|
+
for blob in blobs:
|
|
352
|
+
local_fname = os.path.join(tmp_dir_name, blob.name.split("/")[-1])
|
|
353
|
+
blob.download_to_filename(local_fname)
|
|
354
|
+
src = rasterio.open(local_fname)
|
|
355
|
+
rasterio_datasets.append(src)
|
|
356
|
+
|
|
357
|
+
merge_kwargs: dict[str, Any] = dict(
|
|
358
|
+
sources=rasterio_datasets,
|
|
359
|
+
bounds=crs_bounds,
|
|
360
|
+
res=res,
|
|
361
|
+
)
|
|
362
|
+
if self.dtype:
|
|
363
|
+
merge_kwargs["dtype"] = self.dtype.value
|
|
364
|
+
array, transform = rasterio.merge.merge(**merge_kwargs)
|
|
365
|
+
projection, bounds = get_raster_projection_and_bounds_from_transform(
|
|
366
|
+
rasterio_datasets[0].crs,
|
|
367
|
+
transform,
|
|
368
|
+
array.shape[2],
|
|
369
|
+
array.shape[1],
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
for ds in rasterio_datasets:
|
|
373
|
+
ds.close()
|
|
374
|
+
|
|
375
|
+
return array, projection, bounds
|
|
376
|
+
|
|
193
377
|
def ingest(
|
|
194
378
|
self,
|
|
195
|
-
tile_store:
|
|
379
|
+
tile_store: TileStoreWithLayer,
|
|
196
380
|
items: list[Item],
|
|
197
381
|
geometries: list[list[STGeometry]],
|
|
198
382
|
) -> None:
|
|
@@ -203,96 +387,238 @@ class GEE(DataSource):
|
|
|
203
387
|
items: the items to ingest
|
|
204
388
|
geometries: a list of geometries needed for each item
|
|
205
389
|
"""
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
for band_set in self.config.band_sets:
|
|
209
|
-
for band in band_set.bands:
|
|
210
|
-
if band in bands:
|
|
211
|
-
continue
|
|
212
|
-
bands.append(band)
|
|
213
|
-
|
|
214
|
-
for item, cur_geometries in zip(items, geometries):
|
|
215
|
-
cur_tile_store = PrefixedTileStore(tile_store, (item.name, "_".join(bands)))
|
|
216
|
-
needed_projections = get_needed_projections(
|
|
217
|
-
cur_tile_store, bands, self.config.band_sets, cur_geometries
|
|
218
|
-
)
|
|
219
|
-
if not needed_projections:
|
|
390
|
+
for item in items:
|
|
391
|
+
if tile_store.is_raster_ready(item.name, self.bands):
|
|
220
392
|
continue
|
|
221
393
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
)
|
|
225
|
-
image = filtered.first()
|
|
226
|
-
image = image.select(bands)
|
|
227
|
-
|
|
228
|
-
# Use the native projection of the image to obtain the raster.
|
|
229
|
-
projection = image.select(bands[0]).projection().getInfo()
|
|
230
|
-
print(f"starting task to retrieve image {item.name}")
|
|
231
|
-
blob_path = f"{self.collection_name}/{item.name}.{os.getpid()}/"
|
|
232
|
-
task = ee.batch.Export.image.toCloudStorage(
|
|
233
|
-
image=image,
|
|
234
|
-
description=item.name,
|
|
235
|
-
bucket=self.gcs_bucket_name,
|
|
236
|
-
fileNamePrefix=blob_path,
|
|
237
|
-
fileFormat="GeoTIFF",
|
|
238
|
-
crs=projection["crs"],
|
|
239
|
-
crsTransform=projection["transform"],
|
|
240
|
-
maxPixels=10000000000,
|
|
241
|
-
)
|
|
242
|
-
task.start()
|
|
243
|
-
while True:
|
|
244
|
-
time.sleep(10)
|
|
245
|
-
status_dict = task.status()
|
|
246
|
-
if status_dict["state"] in ["UNSUBMITTED", "READY", "RUNNING"]:
|
|
247
|
-
continue
|
|
248
|
-
assert status_dict["state"] == "COMPLETED"
|
|
249
|
-
break
|
|
394
|
+
# Export the item to GCS.
|
|
395
|
+
blob_prefix = f"{self.collection_name}/{item.name}.{os.getpid()}/"
|
|
396
|
+
self.export_item(item, blob_prefix)
|
|
250
397
|
|
|
251
398
|
# See what files the export produced.
|
|
252
399
|
# If there are multiple, then we merge them into one file since that's the
|
|
253
400
|
# simplest way to handle it.
|
|
254
|
-
blobs = self.bucket.list_blobs(prefix=
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
266
|
-
rasterio_datasets = []
|
|
267
|
-
for blob in blobs:
|
|
268
|
-
local_fname = os.path.join(
|
|
269
|
-
tmp_dir_name, blob.name.split("/")[-1]
|
|
270
|
-
)
|
|
271
|
-
blob.download_to_filename(local_fname)
|
|
272
|
-
src = rasterio.open(local_fname)
|
|
273
|
-
rasterio_datasets.append(src)
|
|
274
|
-
|
|
275
|
-
merge_kwargs = {"datasets": rasterio_datasets}
|
|
276
|
-
if self.dtype:
|
|
277
|
-
merge_kwargs["dtype"] = self.dtype.value
|
|
278
|
-
array, transform = rasterio.merge.merge(**merge_kwargs)
|
|
279
|
-
crs = rasterio_datasets[0].crs
|
|
280
|
-
|
|
281
|
-
for ds in rasterio_datasets:
|
|
282
|
-
ds.close()
|
|
283
|
-
|
|
284
|
-
raster = ArrayWithTransform(array, crs, transform)
|
|
285
|
-
|
|
286
|
-
for projection in needed_projections:
|
|
287
|
-
ingest_raster(
|
|
288
|
-
tile_store=cur_tile_store,
|
|
289
|
-
raster=raster,
|
|
290
|
-
projection=projection,
|
|
291
|
-
time_range=item.geometry.time_range,
|
|
292
|
-
layer_config=self.config,
|
|
293
|
-
)
|
|
401
|
+
blobs = list(self.bucket.list_blobs(prefix=blob_prefix))
|
|
402
|
+
|
|
403
|
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
404
|
+
if len(blobs) == 1:
|
|
405
|
+
local_fname = os.path.join(
|
|
406
|
+
tmp_dir_name, blobs[0].name.split("/")[-1]
|
|
407
|
+
)
|
|
408
|
+
blobs[0].download_to_filename(local_fname)
|
|
409
|
+
tile_store.write_raster_file(
|
|
410
|
+
item.name, self.bands, UPath(local_fname)
|
|
411
|
+
)
|
|
294
412
|
|
|
295
|
-
|
|
413
|
+
else:
|
|
414
|
+
array, projection, bounds = self._merge_rasters(blobs)
|
|
415
|
+
tile_store.write_raster(
|
|
416
|
+
item.name, self.bands, projection, bounds, array
|
|
417
|
+
)
|
|
296
418
|
|
|
297
419
|
for blob in blobs:
|
|
298
420
|
blob.delete()
|
|
421
|
+
|
|
422
|
+
def is_raster_ready(
|
|
423
|
+
self, layer_name: str, item_name: str, bands: list[str]
|
|
424
|
+
) -> bool:
|
|
425
|
+
"""Checks if this raster has been written to the store.
|
|
426
|
+
|
|
427
|
+
Args:
|
|
428
|
+
layer_name: the layer name or alias.
|
|
429
|
+
item_name: the item.
|
|
430
|
+
bands: the list of bands identifying which specific raster to read.
|
|
431
|
+
|
|
432
|
+
Returns:
|
|
433
|
+
whether there is a raster in the store matching the source, item, and
|
|
434
|
+
bands.
|
|
435
|
+
"""
|
|
436
|
+
# Always ready since we wrap accesses to Planetary Computer.
|
|
437
|
+
return True
|
|
438
|
+
|
|
439
|
+
def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]:
|
|
440
|
+
"""Get the sets of bands that have been stored for the specified item.
|
|
441
|
+
|
|
442
|
+
Args:
|
|
443
|
+
layer_name: the layer name or alias.
|
|
444
|
+
item_name: the item.
|
|
445
|
+
|
|
446
|
+
Returns:
|
|
447
|
+
a list of lists of bands that are in the tile store (with one raster
|
|
448
|
+
stored corresponding to each inner list). If no rasters are ready for
|
|
449
|
+
this item, returns empty list.
|
|
450
|
+
"""
|
|
451
|
+
return [self.bands]
|
|
452
|
+
|
|
453
|
+
def get_raster_bounds(
|
|
454
|
+
self, layer_name: str, item_name: str, bands: list[str], projection: Projection
|
|
455
|
+
) -> PixelBounds:
|
|
456
|
+
"""Get the bounds of the raster in the specified projection.
|
|
457
|
+
|
|
458
|
+
Args:
|
|
459
|
+
layer_name: the layer name or alias.
|
|
460
|
+
item_name: the item to check.
|
|
461
|
+
bands: the list of bands identifying which specific raster to read. These
|
|
462
|
+
bands must match the bands of a stored raster.
|
|
463
|
+
projection: the projection to get the raster's bounds in.
|
|
464
|
+
|
|
465
|
+
Returns:
|
|
466
|
+
the bounds of the raster in the projection.
|
|
467
|
+
"""
|
|
468
|
+
item = self.get_item_by_name(item_name)
|
|
469
|
+
geom = item.geometry.to_projection(projection)
|
|
470
|
+
return (
|
|
471
|
+
int(geom.shp.bounds[0]),
|
|
472
|
+
int(geom.shp.bounds[1]),
|
|
473
|
+
int(geom.shp.bounds[2]),
|
|
474
|
+
int(geom.shp.bounds[3]),
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
def read_raster(
|
|
478
|
+
self,
|
|
479
|
+
layer_name: str,
|
|
480
|
+
item_name: str,
|
|
481
|
+
bands: list[str],
|
|
482
|
+
projection: Projection,
|
|
483
|
+
bounds: PixelBounds,
|
|
484
|
+
resampling: Resampling = Resampling.bilinear,
|
|
485
|
+
) -> npt.NDArray[Any]:
|
|
486
|
+
"""Read raster data from the store.
|
|
487
|
+
|
|
488
|
+
Args:
|
|
489
|
+
layer_name: the layer name or alias.
|
|
490
|
+
item_name: the item to read.
|
|
491
|
+
bands: the list of bands identifying which specific raster to read. These
|
|
492
|
+
bands must match the bands of a stored raster.
|
|
493
|
+
projection: the projection to read in.
|
|
494
|
+
bounds: the bounds to read.
|
|
495
|
+
resampling: the resampling method to use in case reprojection is needed.
|
|
496
|
+
|
|
497
|
+
Returns:
|
|
498
|
+
the raster data
|
|
499
|
+
"""
|
|
500
|
+
# Extract the requested extent and export to GCS.
|
|
501
|
+
bounds_str = f"{bounds[0]}_{bounds[1]}_{bounds[2]}_{bounds[3]}"
|
|
502
|
+
item = self.get_item_by_name(item_name)
|
|
503
|
+
blob_prefix = f"{self.collection_name}/{item.name}.{bounds_str}.{os.getpid()}/"
|
|
504
|
+
|
|
505
|
+
try:
|
|
506
|
+
self.export_item(
|
|
507
|
+
item, blob_prefix, projection_and_bounds=(projection, bounds)
|
|
508
|
+
)
|
|
509
|
+
except NoValidPixelsException:
|
|
510
|
+
# No valid pixels means the result should be empty.
|
|
511
|
+
logger.info(
|
|
512
|
+
f"No valid pixels in item {item.name} with projection={projection}, bounds={bounds}, returning empty image"
|
|
513
|
+
)
|
|
514
|
+
return np.zeros(
|
|
515
|
+
(len(bands), bounds[3] - bounds[1], bounds[2] - bounds[0]),
|
|
516
|
+
dtype=np.float32,
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
wanted_transform = get_transform_from_projection_and_bounds(projection, bounds)
|
|
520
|
+
crs_bounds = (
|
|
521
|
+
bounds[0] * projection.x_resolution,
|
|
522
|
+
bounds[3] * projection.y_resolution,
|
|
523
|
+
bounds[2] * projection.x_resolution,
|
|
524
|
+
bounds[1] * projection.y_resolution,
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
blobs = list(self.bucket.list_blobs(prefix=blob_prefix))
|
|
528
|
+
|
|
529
|
+
if len(blobs) == 1:
|
|
530
|
+
# With a single output, we can simply read it with vrt.
|
|
531
|
+
buf = io.BytesIO()
|
|
532
|
+
blobs[0].download_to_file(buf)
|
|
533
|
+
buf.seek(0)
|
|
534
|
+
with rasterio.open(buf) as src:
|
|
535
|
+
with rasterio.vrt.WarpedVRT(
|
|
536
|
+
src,
|
|
537
|
+
crs=projection.crs,
|
|
538
|
+
transform=wanted_transform,
|
|
539
|
+
width=bounds[2] - bounds[0],
|
|
540
|
+
height=bounds[3] - bounds[1],
|
|
541
|
+
resampling=resampling,
|
|
542
|
+
) as vrt:
|
|
543
|
+
return vrt.read()
|
|
544
|
+
|
|
545
|
+
else:
|
|
546
|
+
# With multiple outputs, we need to merge them together.
|
|
547
|
+
# We can set the bounds in CRS coordinates when we do the merging.
|
|
548
|
+
if projection.x_resolution != -projection.y_resolution:
|
|
549
|
+
raise NotImplementedError(
|
|
550
|
+
"Only projection with x_res=-y_res is supported for GEE direct materialization"
|
|
551
|
+
)
|
|
552
|
+
src_array, _, src_bounds = self._merge_rasters(
|
|
553
|
+
blobs, crs_bounds=crs_bounds, res=projection.x_resolution
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
# We copy the array if its bounds don't match exactly.
|
|
557
|
+
if src_bounds == bounds:
|
|
558
|
+
return src_array
|
|
559
|
+
dst_array = np.zeros(
|
|
560
|
+
(src_array.shape[0], bounds[3] - bounds[1], bounds[2] - bounds[0]),
|
|
561
|
+
dtype=src_array.dtype,
|
|
562
|
+
)
|
|
563
|
+
copy_spatial_array(src_array, dst_array, src_bounds[0:2], bounds[0:2])
|
|
564
|
+
return dst_array
|
|
565
|
+
|
|
566
|
+
def materialize(
|
|
567
|
+
self,
|
|
568
|
+
window: Window,
|
|
569
|
+
item_groups: list[list[Item]],
|
|
570
|
+
layer_name: str,
|
|
571
|
+
layer_cfg: LayerConfig,
|
|
572
|
+
) -> None:
|
|
573
|
+
"""Materialize data for the window.
|
|
574
|
+
|
|
575
|
+
Args:
|
|
576
|
+
window: the window to materialize
|
|
577
|
+
item_groups: the items from get_items
|
|
578
|
+
layer_name: the name of this layer
|
|
579
|
+
layer_cfg: the config of this layer
|
|
580
|
+
"""
|
|
581
|
+
RasterMaterializer().materialize(
|
|
582
|
+
TileStoreWithLayer(self, layer_name),
|
|
583
|
+
window,
|
|
584
|
+
layer_name,
|
|
585
|
+
layer_cfg,
|
|
586
|
+
item_groups,
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
class GoogleSatelliteEmbeddings(GEE):
|
|
591
|
+
"""GEE data source for the Google Satellite Embeddings.
|
|
592
|
+
|
|
593
|
+
See here for details:
|
|
594
|
+
https://developers.google.com/earth-engine/datasets/catalog/GOOGLE_SATELLITE_EMBEDDING_V1_ANNUAL
|
|
595
|
+
"""
|
|
596
|
+
|
|
597
|
+
COLLECTION_NAME = "GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL"
|
|
598
|
+
|
|
599
|
+
def __init__(
|
|
600
|
+
self,
|
|
601
|
+
gcs_bucket_name: str,
|
|
602
|
+
index_cache_dir: str,
|
|
603
|
+
service_account_name: str,
|
|
604
|
+
service_account_credentials: str,
|
|
605
|
+
context: DataSourceContext = DataSourceContext(),
|
|
606
|
+
):
|
|
607
|
+
"""Create a new GoogleSatelliteEmbeddings. See GEE for the arguments."""
|
|
608
|
+
super().__init__(
|
|
609
|
+
bands=[f"A{idx:02d}" for idx in range(64)],
|
|
610
|
+
collection_name=self.COLLECTION_NAME,
|
|
611
|
+
gcs_bucket_name=gcs_bucket_name,
|
|
612
|
+
index_cache_dir=index_cache_dir,
|
|
613
|
+
service_account_name=service_account_name,
|
|
614
|
+
service_account_credentials=service_account_credentials,
|
|
615
|
+
context=context,
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
# Override to add conversion to uint16.
|
|
619
|
+
def item_to_image(self, item: Item) -> ee.image.Image:
|
|
620
|
+
"""Get the Image corresponding to the Item."""
|
|
621
|
+
filtered = self.get_collection().filter(ee.Filter.eq("system:index", item.name))
|
|
622
|
+
image = filtered.first()
|
|
623
|
+
image = image.select(self.bands)
|
|
624
|
+
return image.multiply(8192).add(8192).toUint16()
|