rslearn 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +30 -23
- rslearn/data_sources/__init__.py +2 -0
- rslearn/data_sources/aws_landsat.py +44 -161
- rslearn/data_sources/aws_open_data.py +2 -4
- rslearn/data_sources/aws_sentinel1.py +1 -3
- rslearn/data_sources/aws_sentinel2_element84.py +54 -165
- rslearn/data_sources/climate_data_store.py +1 -3
- rslearn/data_sources/copernicus.py +1 -2
- rslearn/data_sources/data_source.py +1 -1
- rslearn/data_sources/direct_materialize_data_source.py +336 -0
- rslearn/data_sources/earthdaily.py +52 -155
- rslearn/data_sources/earthdatahub.py +425 -0
- rslearn/data_sources/eurocrops.py +1 -2
- rslearn/data_sources/gcp_public_data.py +1 -2
- rslearn/data_sources/google_earth_engine.py +1 -2
- rslearn/data_sources/hf_srtm.py +595 -0
- rslearn/data_sources/local_files.py +3 -3
- rslearn/data_sources/openstreetmap.py +1 -1
- rslearn/data_sources/planet.py +1 -2
- rslearn/data_sources/planet_basemap.py +1 -2
- rslearn/data_sources/planetary_computer.py +183 -186
- rslearn/data_sources/soilgrids.py +3 -3
- rslearn/data_sources/stac.py +1 -2
- rslearn/data_sources/usda_cdl.py +1 -3
- rslearn/data_sources/usgs_landsat.py +7 -254
- rslearn/data_sources/utils.py +204 -64
- rslearn/data_sources/worldcereal.py +1 -1
- rslearn/data_sources/worldcover.py +1 -1
- rslearn/data_sources/worldpop.py +1 -1
- rslearn/data_sources/xyz_tiles.py +5 -9
- rslearn/dataset/materialize.py +5 -1
- rslearn/models/clay/clay.py +3 -3
- rslearn/models/concatenate_features.py +6 -1
- rslearn/models/detr/detr.py +4 -1
- rslearn/models/dinov3.py +0 -1
- rslearn/models/olmoearth_pretrain/model.py +3 -1
- rslearn/models/pooling_decoder.py +1 -1
- rslearn/models/prithvi.py +0 -1
- rslearn/models/simple_time_series.py +97 -35
- rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
- rslearn/train/data_module.py +32 -27
- rslearn/train/dataset.py +260 -117
- rslearn/train/dataset_index.py +156 -0
- rslearn/train/lightning_module.py +1 -1
- rslearn/train/model_context.py +19 -3
- rslearn/train/prediction_writer.py +69 -41
- rslearn/train/tasks/classification.py +1 -1
- rslearn/train/tasks/detection.py +5 -5
- rslearn/train/tasks/per_pixel_regression.py +13 -13
- rslearn/train/tasks/regression.py +1 -1
- rslearn/train/tasks/segmentation.py +26 -13
- rslearn/train/transforms/concatenate.py +17 -27
- rslearn/train/transforms/crop.py +8 -19
- rslearn/train/transforms/flip.py +4 -10
- rslearn/train/transforms/mask.py +9 -15
- rslearn/train/transforms/normalize.py +31 -82
- rslearn/train/transforms/pad.py +7 -13
- rslearn/train/transforms/resize.py +5 -22
- rslearn/train/transforms/select_bands.py +16 -36
- rslearn/train/transforms/sentinel1.py +4 -16
- rslearn/utils/__init__.py +2 -0
- rslearn/utils/geometry.py +21 -0
- rslearn/utils/m2m_api.py +251 -0
- rslearn/utils/retry_session.py +43 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/RECORD +71 -66
- rslearn/data_sources/earthdata_srtm.py +0 -282
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
|
@@ -4,12 +4,9 @@
|
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
import io
|
|
7
|
-
import json
|
|
8
7
|
import os
|
|
9
8
|
import shutil
|
|
10
9
|
import tempfile
|
|
11
|
-
import time
|
|
12
|
-
import uuid
|
|
13
10
|
from collections.abc import Generator
|
|
14
11
|
from datetime import UTC, datetime, timedelta
|
|
15
12
|
from typing import Any, BinaryIO
|
|
@@ -24,246 +21,7 @@ from rslearn.data_sources import DataSource, DataSourceContext, Item
|
|
|
24
21
|
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
25
22
|
from rslearn.tile_stores import TileStoreWithLayer
|
|
26
23
|
from rslearn.utils import STGeometry
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class APIException(Exception):
|
|
30
|
-
"""Exception raised for M2M API errors."""
|
|
31
|
-
|
|
32
|
-
pass
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
class M2MAPIClient:
|
|
36
|
-
"""An API client for interacting with the USGS M2M API."""
|
|
37
|
-
|
|
38
|
-
api_url = "https://m2m.cr.usgs.gov/api/api/json/stable/"
|
|
39
|
-
pagination_size = 1000
|
|
40
|
-
|
|
41
|
-
def __init__(
|
|
42
|
-
self,
|
|
43
|
-
username: str,
|
|
44
|
-
password: str | None = None,
|
|
45
|
-
token: str | None = None,
|
|
46
|
-
timeout: timedelta = timedelta(seconds=120),
|
|
47
|
-
) -> None:
|
|
48
|
-
"""Initialize a new M2MAPIClient.
|
|
49
|
-
|
|
50
|
-
Args:
|
|
51
|
-
username: the EROS username
|
|
52
|
-
password: the EROS password
|
|
53
|
-
token: the application token. One of password or token must be specified.
|
|
54
|
-
timeout: timeout for requests.
|
|
55
|
-
"""
|
|
56
|
-
self.username = username
|
|
57
|
-
self.timeout = timeout
|
|
58
|
-
|
|
59
|
-
if password is not None and token is not None:
|
|
60
|
-
raise ValueError("only one of password or token can be specified")
|
|
61
|
-
|
|
62
|
-
if password is not None:
|
|
63
|
-
json_data = json.dumps({"username": self.username, "password": password})
|
|
64
|
-
response = requests.post(
|
|
65
|
-
self.api_url + "login",
|
|
66
|
-
data=json_data,
|
|
67
|
-
timeout=self.timeout.total_seconds(),
|
|
68
|
-
)
|
|
69
|
-
|
|
70
|
-
elif token is not None:
|
|
71
|
-
json_data = json.dumps({"username": username, "token": token})
|
|
72
|
-
response = requests.post(
|
|
73
|
-
self.api_url + "login-token",
|
|
74
|
-
data=json_data,
|
|
75
|
-
timeout=self.timeout.total_seconds(),
|
|
76
|
-
)
|
|
77
|
-
|
|
78
|
-
else:
|
|
79
|
-
raise ValueError("one of password or token must be specified")
|
|
80
|
-
|
|
81
|
-
response.raise_for_status()
|
|
82
|
-
self.auth_token = response.json()["data"]
|
|
83
|
-
|
|
84
|
-
def request(
|
|
85
|
-
self, endpoint: str, data: dict[str, Any] | None = None
|
|
86
|
-
) -> dict[str, Any] | None:
|
|
87
|
-
"""Make a request to the API.
|
|
88
|
-
|
|
89
|
-
Args:
|
|
90
|
-
endpoint: the endpoint to call
|
|
91
|
-
data: POST data to pass
|
|
92
|
-
|
|
93
|
-
Returns:
|
|
94
|
-
JSON response data if any
|
|
95
|
-
"""
|
|
96
|
-
response = requests.post(
|
|
97
|
-
self.api_url + endpoint,
|
|
98
|
-
headers={"X-Auth-Token": self.auth_token},
|
|
99
|
-
data=json.dumps(data),
|
|
100
|
-
timeout=self.timeout.total_seconds(),
|
|
101
|
-
)
|
|
102
|
-
response.raise_for_status()
|
|
103
|
-
if response.text:
|
|
104
|
-
response_dict = response.json()
|
|
105
|
-
|
|
106
|
-
if response_dict["errorMessage"]:
|
|
107
|
-
raise APIException(response_dict["errorMessage"])
|
|
108
|
-
return response_dict
|
|
109
|
-
return None
|
|
110
|
-
|
|
111
|
-
def close(self) -> None:
|
|
112
|
-
"""Logout from the API."""
|
|
113
|
-
self.request("logout")
|
|
114
|
-
|
|
115
|
-
def __enter__(self) -> "M2MAPIClient":
|
|
116
|
-
"""Enter function to provide with semantics."""
|
|
117
|
-
return self
|
|
118
|
-
|
|
119
|
-
def __exit__(self) -> None:
|
|
120
|
-
"""Exit function to provide with semantics.
|
|
121
|
-
|
|
122
|
-
Logs out the API.
|
|
123
|
-
"""
|
|
124
|
-
self.close()
|
|
125
|
-
|
|
126
|
-
def get_filters(self, dataset_name: str) -> list[dict[str, Any]]:
|
|
127
|
-
"""Returns filters available for the given dataset.
|
|
128
|
-
|
|
129
|
-
Args:
|
|
130
|
-
dataset_name: the dataset name e.g. landsat_ot_c2_l1
|
|
131
|
-
|
|
132
|
-
Returns:
|
|
133
|
-
list of filter objects
|
|
134
|
-
"""
|
|
135
|
-
response_dict = self.request("dataset-filters", {"datasetName": dataset_name})
|
|
136
|
-
if response_dict is None:
|
|
137
|
-
raise APIException("No response from API")
|
|
138
|
-
return response_dict["data"]
|
|
139
|
-
|
|
140
|
-
def scene_search(
|
|
141
|
-
self,
|
|
142
|
-
dataset_name: str,
|
|
143
|
-
acquisition_time_range: tuple[datetime, datetime] | None = None,
|
|
144
|
-
cloud_cover_range: tuple[int, int] | None = None,
|
|
145
|
-
bbox: tuple[float, float, float, float] | None = None,
|
|
146
|
-
metadata_filter: dict[str, Any] | None = None,
|
|
147
|
-
) -> list[dict[str, Any]]:
|
|
148
|
-
"""Search for scenes matching the arguments.
|
|
149
|
-
|
|
150
|
-
Args:
|
|
151
|
-
dataset_name: the dataset name e.g. landsat_ot_c2_l1
|
|
152
|
-
acquisition_time_range: optional filter on the acquisition time
|
|
153
|
-
cloud_cover_range: optional filter on the cloud cover
|
|
154
|
-
bbox: optional spatial filter
|
|
155
|
-
metadata_filter: optional metadata filter dict
|
|
156
|
-
"""
|
|
157
|
-
base_data: dict[str, Any] = {"datasetName": dataset_name, "sceneFilter": {}}
|
|
158
|
-
if acquisition_time_range:
|
|
159
|
-
base_data["sceneFilter"]["acquisitionFilter"] = {
|
|
160
|
-
"start": acquisition_time_range[0].isoformat(),
|
|
161
|
-
"end": acquisition_time_range[1].isoformat(),
|
|
162
|
-
}
|
|
163
|
-
if cloud_cover_range:
|
|
164
|
-
base_data["sceneFilter"]["cloudCoverFilter"] = {
|
|
165
|
-
"min": cloud_cover_range[0],
|
|
166
|
-
"max": cloud_cover_range[1],
|
|
167
|
-
"includeUnknown": False,
|
|
168
|
-
}
|
|
169
|
-
if bbox:
|
|
170
|
-
base_data["sceneFilter"]["spatialFilter"] = {
|
|
171
|
-
"filterType": "mbr",
|
|
172
|
-
"lowerLeft": {"longitude": bbox[0], "latitude": bbox[1]},
|
|
173
|
-
"upperRight": {"longitude": bbox[2], "latitude": bbox[3]},
|
|
174
|
-
}
|
|
175
|
-
if metadata_filter:
|
|
176
|
-
base_data["sceneFilter"]["metadataFilter"] = metadata_filter
|
|
177
|
-
|
|
178
|
-
starting_number = 1
|
|
179
|
-
results = []
|
|
180
|
-
while True:
|
|
181
|
-
cur_data = base_data.copy()
|
|
182
|
-
cur_data["startingNumber"] = starting_number
|
|
183
|
-
cur_data["maxResults"] = self.pagination_size
|
|
184
|
-
response_dict = self.request("scene-search", cur_data)
|
|
185
|
-
if response_dict is None:
|
|
186
|
-
raise APIException("No response from API")
|
|
187
|
-
data = response_dict["data"]
|
|
188
|
-
results.extend(data["results"])
|
|
189
|
-
if data["recordsReturned"] < self.pagination_size:
|
|
190
|
-
break
|
|
191
|
-
starting_number += self.pagination_size
|
|
192
|
-
|
|
193
|
-
return results
|
|
194
|
-
|
|
195
|
-
def get_scene_metadata(self, dataset_name: str, entity_id: str) -> dict[str, Any]:
|
|
196
|
-
"""Get detailed metadata for a scene.
|
|
197
|
-
|
|
198
|
-
Args:
|
|
199
|
-
dataset_name: the dataset name in which to search
|
|
200
|
-
entity_id: the entity ID of the scene
|
|
201
|
-
|
|
202
|
-
Returns:
|
|
203
|
-
full scene metadata
|
|
204
|
-
"""
|
|
205
|
-
response_dict = self.request(
|
|
206
|
-
"scene-metadata",
|
|
207
|
-
{
|
|
208
|
-
"datasetName": dataset_name,
|
|
209
|
-
"entityId": entity_id,
|
|
210
|
-
"metadataType": "full",
|
|
211
|
-
},
|
|
212
|
-
)
|
|
213
|
-
if response_dict is None:
|
|
214
|
-
raise APIException("No response from API")
|
|
215
|
-
return response_dict["data"]
|
|
216
|
-
|
|
217
|
-
def get_downloadable_products(
|
|
218
|
-
self, dataset_name: str, entity_id: str
|
|
219
|
-
) -> list[dict[str, Any]]:
|
|
220
|
-
"""Get the downloadable products for a given scene.
|
|
221
|
-
|
|
222
|
-
Args:
|
|
223
|
-
dataset_name: the dataset name
|
|
224
|
-
entity_id: the entity ID of the scene
|
|
225
|
-
|
|
226
|
-
Returns:
|
|
227
|
-
list of downloadable products
|
|
228
|
-
"""
|
|
229
|
-
data = {"datasetName": dataset_name, "entityIds": [entity_id]}
|
|
230
|
-
response_dict = self.request("download-options", data)
|
|
231
|
-
if response_dict is None:
|
|
232
|
-
raise APIException("No response from API")
|
|
233
|
-
return response_dict["data"]
|
|
234
|
-
|
|
235
|
-
def get_download_url(self, entity_id: str, product_id: str) -> str:
|
|
236
|
-
"""Get the download URL for a given product.
|
|
237
|
-
|
|
238
|
-
Args:
|
|
239
|
-
entity_id: the entity ID of the product
|
|
240
|
-
product_id: the product ID of the product
|
|
241
|
-
|
|
242
|
-
Returns:
|
|
243
|
-
the download URL
|
|
244
|
-
"""
|
|
245
|
-
label = str(uuid.uuid4())
|
|
246
|
-
data = {
|
|
247
|
-
"downloads": [
|
|
248
|
-
{"label": label, "entityId": entity_id, "productId": product_id}
|
|
249
|
-
]
|
|
250
|
-
}
|
|
251
|
-
response_dict = self.request("download-request", data)
|
|
252
|
-
if response_dict is None:
|
|
253
|
-
raise APIException("No response from API")
|
|
254
|
-
response = response_dict["data"]
|
|
255
|
-
while True:
|
|
256
|
-
response_dict = self.request("download-retrieve", {"label": label})
|
|
257
|
-
if response_dict is None:
|
|
258
|
-
raise APIException("No response from API")
|
|
259
|
-
response = response_dict["data"]
|
|
260
|
-
if len(response["available"]) > 0:
|
|
261
|
-
return response["available"][0]["url"]
|
|
262
|
-
if len(response["requested"]) == 0:
|
|
263
|
-
raise Exception("Did not get download URL")
|
|
264
|
-
if response["requested"][0].get("url"):
|
|
265
|
-
return response["requested"][0]["url"]
|
|
266
|
-
time.sleep(10)
|
|
24
|
+
from rslearn.utils.m2m_api import APIException, M2MAPIClient
|
|
267
25
|
|
|
268
26
|
|
|
269
27
|
class LandsatOliTirsItem(Item):
|
|
@@ -314,30 +72,26 @@ class LandsatOliTirs(DataSource):
|
|
|
314
72
|
|
|
315
73
|
def __init__(
|
|
316
74
|
self,
|
|
317
|
-
username: str,
|
|
318
|
-
sort_by: str | None = None,
|
|
319
|
-
password: str | None = None,
|
|
75
|
+
username: str | None = None,
|
|
320
76
|
token: str | None = None,
|
|
77
|
+
sort_by: str | None = None,
|
|
321
78
|
timeout: timedelta = timedelta(seconds=10),
|
|
322
79
|
context: DataSourceContext = DataSourceContext(),
|
|
323
80
|
):
|
|
324
81
|
"""Initialize a new LandsatOliTirs instance.
|
|
325
82
|
|
|
326
83
|
Args:
|
|
327
|
-
username: EROS username
|
|
84
|
+
username: EROS username (see M2MAPIClient).
|
|
85
|
+
token: EROS application token (see M2MAPIClient).
|
|
328
86
|
sort_by: can be "cloud_cover", default arbitrary order; only has effect for
|
|
329
87
|
SpaceMode.WITHIN.
|
|
330
|
-
password: EROS password (see M2MAPIClient).
|
|
331
|
-
token: EROS application token (see M2MAPIClient).
|
|
332
88
|
timeout: timeout for requests.
|
|
333
89
|
context: the data source context.
|
|
334
90
|
"""
|
|
335
91
|
self.sort_by = sort_by
|
|
336
92
|
self.timeout = timeout
|
|
337
93
|
|
|
338
|
-
self.client = M2MAPIClient(
|
|
339
|
-
username, password=password, token=token, timeout=timeout
|
|
340
|
-
)
|
|
94
|
+
self.client = M2MAPIClient(username=username, token=token, timeout=timeout)
|
|
341
95
|
|
|
342
96
|
def _scene_metadata_to_item(self, result: dict[str, Any]) -> LandsatOliTirsItem:
|
|
343
97
|
"""Convert scene metadata from the API to a LandsatOliTirsItem."""
|
|
@@ -429,9 +183,8 @@ class LandsatOliTirs(DataSource):
|
|
|
429
183
|
)
|
|
430
184
|
return self._scene_metadata_to_item(scene_metadata)
|
|
431
185
|
|
|
432
|
-
def deserialize_item(self, serialized_item:
|
|
186
|
+
def deserialize_item(self, serialized_item: dict) -> Item:
|
|
433
187
|
"""Deserializes an item from JSON-decoded data."""
|
|
434
|
-
assert isinstance(serialized_item, dict)
|
|
435
188
|
return LandsatOliTirsItem.deserialize(serialized_item)
|
|
436
189
|
|
|
437
190
|
def _get_download_urls(self, item: Item) -> dict[str, tuple[str, str]]:
|
rslearn/data_sources/utils.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
"""Utilities shared by data sources."""
|
|
2
2
|
|
|
3
|
+
import warnings
|
|
4
|
+
from collections.abc import Callable
|
|
3
5
|
from dataclasses import dataclass
|
|
4
|
-
from datetime import UTC, datetime
|
|
6
|
+
from datetime import UTC, datetime
|
|
5
7
|
from typing import TypeVar
|
|
6
8
|
|
|
7
9
|
import shapely
|
|
@@ -40,13 +42,13 @@ class PendingMosaic:
|
|
|
40
42
|
completed: bool = False
|
|
41
43
|
|
|
42
44
|
|
|
43
|
-
def
|
|
45
|
+
def _create_single_coverage_mosaics(
|
|
44
46
|
window_geometry: STGeometry,
|
|
45
47
|
items: list[ItemType],
|
|
46
48
|
item_shps: list[shapely.Geometry],
|
|
47
|
-
|
|
49
|
+
max_mosaics: int,
|
|
48
50
|
) -> list[list[ItemType]]:
|
|
49
|
-
"""
|
|
51
|
+
"""Create mosaics where each mosaic covers the window geometry once.
|
|
50
52
|
|
|
51
53
|
This attempts to piece together items into mosaics that fully cover the window
|
|
52
54
|
geometry. If there are items leftover that only partially cover the window
|
|
@@ -56,15 +58,16 @@ def mosaic_matching(
|
|
|
56
58
|
window_geometry: the geometry of the window.
|
|
57
59
|
items: list of items.
|
|
58
60
|
item_shps: the item shapes projected to the window's projection.
|
|
59
|
-
|
|
61
|
+
max_mosaics: the maximum number of mosaics to create.
|
|
60
62
|
|
|
61
63
|
Returns:
|
|
62
|
-
list of item groups, each one corresponding to a different
|
|
64
|
+
list of item groups, each one corresponding to a different single-coverage
|
|
65
|
+
mosaic.
|
|
63
66
|
"""
|
|
64
67
|
# To create mosaics, we iterate over the items in order, and add each item to
|
|
65
68
|
# the first mosaic that the new item adds coverage to.
|
|
66
69
|
|
|
67
|
-
#
|
|
70
|
+
# max_mosaics could be very high if the user just wants us to create as many
|
|
68
71
|
# mosaics as possible, so we initialize the list here as empty and just add
|
|
69
72
|
# more pending mosaics when it is necessary.
|
|
70
73
|
pending_mosaics: list[PendingMosaic] = []
|
|
@@ -108,7 +111,7 @@ def mosaic_matching(
|
|
|
108
111
|
|
|
109
112
|
# See if we can add a new mosaic based on this item. There must be room for
|
|
110
113
|
# more mosaics, but the item must also intersect the requested geometry.
|
|
111
|
-
if len(pending_mosaics) >=
|
|
114
|
+
if len(pending_mosaics) >= max_mosaics:
|
|
112
115
|
continue
|
|
113
116
|
intersect_area = item_shp.intersection(window_geometry.shp).area
|
|
114
117
|
if (
|
|
@@ -127,18 +130,148 @@ def mosaic_matching(
|
|
|
127
130
|
return [pending_mosaic.items for pending_mosaic in pending_mosaics]
|
|
128
131
|
|
|
129
132
|
|
|
130
|
-
def
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
133
|
+
def _consolidate_mosaics_by_overlaps(
|
|
134
|
+
mosaics: list[list[ItemType]],
|
|
135
|
+
overlaps: int,
|
|
136
|
+
max_groups: int,
|
|
137
|
+
) -> list[list[ItemType]]:
|
|
138
|
+
"""Consolidate single-coverage mosaics into groups based on desired overlaps.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
mosaics: list of single-coverage mosaics (each mosaic is a list of items).
|
|
142
|
+
overlaps: the number of overlapping coverages wanted per group.
|
|
143
|
+
max_groups: the maximum number of groups to return.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
list of item groups, where each group contains items from multiple mosaics
|
|
147
|
+
to achieve the desired number of overlapping coverages.
|
|
148
|
+
"""
|
|
149
|
+
if overlaps <= 0:
|
|
150
|
+
overlaps = 1
|
|
151
|
+
|
|
152
|
+
groups: list[list[ItemType]] = []
|
|
153
|
+
for i in range(0, len(mosaics), overlaps):
|
|
154
|
+
if len(groups) >= max_groups:
|
|
155
|
+
break
|
|
156
|
+
# Combine overlaps consecutive mosaics into one group
|
|
157
|
+
combined_items: list[ItemType] = []
|
|
158
|
+
for mosaic in mosaics[i : i + overlaps]:
|
|
159
|
+
combined_items.extend(mosaic)
|
|
160
|
+
if combined_items:
|
|
161
|
+
groups.append(combined_items)
|
|
162
|
+
|
|
163
|
+
return groups
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def match_with_space_mode_contains(
|
|
167
|
+
geometry: STGeometry,
|
|
168
|
+
items: list[ItemType],
|
|
169
|
+
item_shps: list[shapely.Geometry],
|
|
170
|
+
query_config: QueryConfig,
|
|
171
|
+
) -> list[list[ItemType]]:
|
|
172
|
+
"""Match items that fully contain the window geometry.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
geometry: the window's geometry.
|
|
176
|
+
items: list of items.
|
|
177
|
+
item_shps: the item shapes projected to the window's projection.
|
|
178
|
+
query_config: the query configuration.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
list of matched item groups, where each group contains a single item.
|
|
182
|
+
"""
|
|
183
|
+
groups: list[list[ItemType]] = []
|
|
184
|
+
for item, item_shp in zip(items, item_shps):
|
|
185
|
+
if not item_shp.contains(geometry.shp):
|
|
186
|
+
continue
|
|
187
|
+
groups.append([item])
|
|
188
|
+
if len(groups) >= query_config.max_matches:
|
|
189
|
+
break
|
|
190
|
+
return groups
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def match_with_space_mode_intersects(
|
|
194
|
+
geometry: STGeometry,
|
|
195
|
+
items: list[ItemType],
|
|
196
|
+
item_shps: list[shapely.Geometry],
|
|
197
|
+
query_config: QueryConfig,
|
|
198
|
+
) -> list[list[ItemType]]:
|
|
199
|
+
"""Match items that intersect any portion of the window geometry.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
geometry: the window's geometry.
|
|
203
|
+
items: list of items.
|
|
204
|
+
item_shps: the item shapes projected to the window's projection.
|
|
205
|
+
query_config: the query configuration.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
list of matched item groups, where each group contains a single item.
|
|
209
|
+
"""
|
|
210
|
+
groups: list[list[ItemType]] = []
|
|
211
|
+
for item, item_shp in zip(items, item_shps):
|
|
212
|
+
if not shp_intersects(item_shp, geometry.shp):
|
|
213
|
+
continue
|
|
214
|
+
groups.append([item])
|
|
215
|
+
if len(groups) >= query_config.max_matches:
|
|
216
|
+
break
|
|
217
|
+
return groups
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def match_with_space_mode_mosaic(
|
|
221
|
+
geometry: STGeometry,
|
|
222
|
+
items: list[ItemType],
|
|
223
|
+
item_shps: list[shapely.Geometry],
|
|
224
|
+
query_config: QueryConfig,
|
|
225
|
+
) -> list[list[ItemType]]:
|
|
226
|
+
"""Match items into mosaic groups that cover the window geometry.
|
|
227
|
+
|
|
228
|
+
Creates groups of items that together cover the window geometry. The number of
|
|
229
|
+
overlapping coverages in each group is controlled by mosaic_compositing_overlaps.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
geometry: the window's geometry.
|
|
233
|
+
items: list of items.
|
|
234
|
+
item_shps: the item shapes projected to the window's projection.
|
|
235
|
+
query_config: the query configuration.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
list of matched item groups, where each group forms a mosaic covering the
|
|
239
|
+
window.
|
|
240
|
+
"""
|
|
241
|
+
overlaps = query_config.mosaic_compositing_overlaps
|
|
242
|
+
|
|
243
|
+
# Calculate how many single-coverage mosaics we need to create.
|
|
244
|
+
# We need enough mosaics to consolidate into max_matches groups with the
|
|
245
|
+
# desired number of overlaps per group.
|
|
246
|
+
max_single_mosaics = query_config.max_matches * overlaps
|
|
247
|
+
|
|
248
|
+
# Create single-coverage mosaics
|
|
249
|
+
single_mosaics = _create_single_coverage_mosaics(
|
|
250
|
+
geometry, items, item_shps, max_single_mosaics
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
# Consolidate into groups based on overlaps
|
|
254
|
+
return _consolidate_mosaics_by_overlaps(
|
|
255
|
+
single_mosaics, overlaps, query_config.max_matches
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def match_with_space_mode_per_period_mosaic(
|
|
260
|
+
geometry: STGeometry,
|
|
261
|
+
items: list[ItemType],
|
|
262
|
+
item_shps: list[shapely.Geometry],
|
|
263
|
+
query_config: QueryConfig,
|
|
135
264
|
) -> list[list[ItemType]]:
|
|
136
265
|
"""Match items to the geometry with one mosaic per period.
|
|
137
266
|
|
|
138
267
|
We divide the time range of the geometry into shorter periods. Within each period,
|
|
139
268
|
we use the items corresponding to that period to create a mosaic. The returned item
|
|
140
|
-
groups include one group per period,
|
|
141
|
-
|
|
269
|
+
groups include one group per period, up to the provided max_matches.
|
|
270
|
+
|
|
271
|
+
By default (reverse_time_order=True), groups are returned starting from the most
|
|
272
|
+
recent periods. When reverse_time_order=False, groups are returned in chronological
|
|
273
|
+
order (oldest first). reverse_time_order should always be set False, and
|
|
274
|
+
FutureWarning will be warned if it is not.
|
|
142
275
|
|
|
143
276
|
The periods are also bounded to the window's time range, and aligned with the end
|
|
144
277
|
of that time range, i.e. the most recent window is
|
|
@@ -159,42 +292,59 @@ def per_period_mosaic_matching(
|
|
|
159
292
|
max_matches*period_duration is not equivalent to a longer window duration.
|
|
160
293
|
|
|
161
294
|
Args:
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
295
|
+
geometry: the window's geometry.
|
|
296
|
+
items: list of items.
|
|
297
|
+
item_shps: the item shapes projected to the window's projection (unused here)
|
|
298
|
+
query_config: the query configuration.
|
|
166
299
|
|
|
167
300
|
Returns:
|
|
168
|
-
|
|
169
|
-
|
|
301
|
+
list of matched item groups, where each group contains items that yield a
|
|
302
|
+
per-period mosaic.
|
|
170
303
|
"""
|
|
171
|
-
if
|
|
304
|
+
if geometry.time_range is None:
|
|
172
305
|
raise ValueError(
|
|
173
306
|
"all windows must have time range for per period mosaic matching"
|
|
174
307
|
)
|
|
175
308
|
|
|
309
|
+
# Emit warning if per_period_mosaic_reverse_time_order is True (the default).
|
|
310
|
+
if query_config.per_period_mosaic_reverse_time_order:
|
|
311
|
+
warnings.warn(
|
|
312
|
+
"QueryConfig.per_period_mosaic_reverse_time_order defaults to True, which "
|
|
313
|
+
"returns item groups in reverse temporal order (most recent first) for "
|
|
314
|
+
"PER_PERIOD_MOSAIC mode. This default will change to False (chronological "
|
|
315
|
+
"order) after 2026-04-01. To silence this warning, explicitly set "
|
|
316
|
+
"per_period_mosaic_reverse_time_order=False.",
|
|
317
|
+
FutureWarning,
|
|
318
|
+
stacklevel=3,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
period_duration = query_config.period_duration
|
|
322
|
+
|
|
176
323
|
# For each period, we create an STGeometry with modified time range matching that
|
|
177
324
|
# period, and use it with match_candidate_items_to_window to get a mosaic.
|
|
178
325
|
cur_groups: list[list[ItemType]] = []
|
|
179
|
-
period_start =
|
|
326
|
+
period_start = geometry.time_range[1] - period_duration
|
|
180
327
|
while (
|
|
181
|
-
period_start >=
|
|
328
|
+
period_start >= geometry.time_range[0]
|
|
329
|
+
and len(cur_groups) < query_config.max_matches
|
|
182
330
|
):
|
|
183
331
|
period_time_range = (
|
|
184
332
|
period_start,
|
|
185
333
|
period_start + period_duration,
|
|
186
334
|
)
|
|
187
335
|
period_start -= period_duration
|
|
188
|
-
period_geom = STGeometry(
|
|
189
|
-
window_geometry.projection, window_geometry.shp, period_time_range
|
|
190
|
-
)
|
|
336
|
+
period_geom = STGeometry(geometry.projection, geometry.shp, period_time_range)
|
|
191
337
|
|
|
192
338
|
# We modify the QueryConfig here since caller should be asking for
|
|
193
339
|
# multiple mosaics, but we just want one mosaic per period.
|
|
194
340
|
period_groups = match_candidate_items_to_window(
|
|
195
341
|
period_geom,
|
|
196
|
-
|
|
197
|
-
QueryConfig(
|
|
342
|
+
items,
|
|
343
|
+
QueryConfig(
|
|
344
|
+
space_mode=SpaceMode.MOSAIC,
|
|
345
|
+
max_matches=1,
|
|
346
|
+
mosaic_compositing_overlaps=query_config.mosaic_compositing_overlaps,
|
|
347
|
+
),
|
|
198
348
|
)
|
|
199
349
|
|
|
200
350
|
# There should be zero or one group depending on whether there were
|
|
@@ -204,9 +354,29 @@ def per_period_mosaic_matching(
|
|
|
204
354
|
continue
|
|
205
355
|
cur_groups.append(period_groups[0])
|
|
206
356
|
|
|
357
|
+
# Currently the item groups are in reverse chronologic order.
|
|
358
|
+
# Reverse it to correct chronological order if requested.
|
|
359
|
+
if not query_config.per_period_mosaic_reverse_time_order:
|
|
360
|
+
cur_groups.reverse()
|
|
361
|
+
|
|
207
362
|
return cur_groups
|
|
208
363
|
|
|
209
364
|
|
|
365
|
+
# Type alias for space mode handler functions
|
|
366
|
+
SpaceModeHandler = Callable[
|
|
367
|
+
[STGeometry, list[ItemType], list[shapely.Geometry], QueryConfig],
|
|
368
|
+
list[list[ItemType]],
|
|
369
|
+
]
|
|
370
|
+
|
|
371
|
+
# Dict mapping SpaceMode values to their handler functions
|
|
372
|
+
space_mode_handlers: dict[SpaceMode, SpaceModeHandler] = {
|
|
373
|
+
SpaceMode.CONTAINS: match_with_space_mode_contains,
|
|
374
|
+
SpaceMode.INTERSECTS: match_with_space_mode_intersects,
|
|
375
|
+
SpaceMode.MOSAIC: match_with_space_mode_mosaic,
|
|
376
|
+
SpaceMode.PER_PERIOD_MOSAIC: match_with_space_mode_per_period_mosaic,
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
|
|
210
380
|
def match_candidate_items_to_window(
|
|
211
381
|
geometry: STGeometry, items: list[ItemType], query_config: QueryConfig
|
|
212
382
|
) -> list[list[ItemType]]:
|
|
@@ -262,43 +432,13 @@ def match_candidate_items_to_window(
|
|
|
262
432
|
item_geom = item_geom.to_projection(geometry.projection)
|
|
263
433
|
item_shps.append(item_geom.shp)
|
|
264
434
|
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
if not item_shp.contains(geometry.shp):
|
|
269
|
-
continue
|
|
270
|
-
groups.append([item])
|
|
271
|
-
if len(groups) >= query_config.max_matches:
|
|
272
|
-
break
|
|
273
|
-
|
|
274
|
-
elif query_config.space_mode == SpaceMode.INTERSECTS:
|
|
275
|
-
groups = []
|
|
276
|
-
for item, item_shp in zip(items, item_shps):
|
|
277
|
-
if not shp_intersects(item_shp, geometry.shp):
|
|
278
|
-
continue
|
|
279
|
-
groups.append([item])
|
|
280
|
-
if len(groups) >= query_config.max_matches:
|
|
281
|
-
break
|
|
282
|
-
|
|
283
|
-
elif query_config.space_mode == SpaceMode.MOSAIC:
|
|
284
|
-
groups = mosaic_matching(geometry, items, item_shps, query_config.max_matches)
|
|
285
|
-
|
|
286
|
-
elif query_config.space_mode == SpaceMode.PER_PERIOD_MOSAIC:
|
|
287
|
-
groups = per_period_mosaic_matching(
|
|
288
|
-
geometry, items, query_config.period_duration, query_config.max_matches
|
|
289
|
-
)
|
|
290
|
-
|
|
291
|
-
elif query_config.space_mode == SpaceMode.COMPOSITE:
|
|
292
|
-
group = []
|
|
293
|
-
for item, item_shp in zip(items, item_shps):
|
|
294
|
-
if not shp_intersects(item_shp, geometry.shp):
|
|
295
|
-
continue
|
|
296
|
-
group.append(item)
|
|
297
|
-
groups = [group]
|
|
298
|
-
|
|
299
|
-
else:
|
|
435
|
+
# Dispatch to the appropriate space mode handler
|
|
436
|
+
handler = space_mode_handlers.get(query_config.space_mode)
|
|
437
|
+
if handler is None:
|
|
300
438
|
raise ValueError(f"invalid space mode {query_config.space_mode}")
|
|
301
439
|
|
|
440
|
+
groups = handler(geometry, items, item_shps, query_config)
|
|
441
|
+
|
|
302
442
|
# Enforce minimum matches if set.
|
|
303
443
|
if len(groups) < query_config.min_matches:
|
|
304
444
|
logger.warning(
|
|
@@ -291,7 +291,7 @@ class WorldCereal(LocalFiles):
|
|
|
291
291
|
raise ValueError(f"No AEZ files found for {self.band}")
|
|
292
292
|
|
|
293
293
|
super().__init__(
|
|
294
|
-
src_dir=tif_dir,
|
|
294
|
+
src_dir=tif_dir.absolute().as_uri(),
|
|
295
295
|
raster_item_specs=item_specs,
|
|
296
296
|
layer_type=LayerType.RASTER,
|
|
297
297
|
context=context,
|