rslearn 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. rslearn/config/dataset.py +30 -23
  2. rslearn/data_sources/__init__.py +2 -0
  3. rslearn/data_sources/aws_landsat.py +44 -161
  4. rslearn/data_sources/aws_open_data.py +2 -4
  5. rslearn/data_sources/aws_sentinel1.py +1 -3
  6. rslearn/data_sources/aws_sentinel2_element84.py +54 -165
  7. rslearn/data_sources/climate_data_store.py +1 -3
  8. rslearn/data_sources/copernicus.py +1 -2
  9. rslearn/data_sources/data_source.py +1 -1
  10. rslearn/data_sources/direct_materialize_data_source.py +336 -0
  11. rslearn/data_sources/earthdaily.py +52 -155
  12. rslearn/data_sources/earthdatahub.py +425 -0
  13. rslearn/data_sources/eurocrops.py +1 -2
  14. rslearn/data_sources/gcp_public_data.py +1 -2
  15. rslearn/data_sources/google_earth_engine.py +1 -2
  16. rslearn/data_sources/hf_srtm.py +595 -0
  17. rslearn/data_sources/local_files.py +3 -3
  18. rslearn/data_sources/openstreetmap.py +1 -1
  19. rslearn/data_sources/planet.py +1 -2
  20. rslearn/data_sources/planet_basemap.py +1 -2
  21. rslearn/data_sources/planetary_computer.py +183 -186
  22. rslearn/data_sources/soilgrids.py +3 -3
  23. rslearn/data_sources/stac.py +1 -2
  24. rslearn/data_sources/usda_cdl.py +1 -3
  25. rslearn/data_sources/usgs_landsat.py +7 -254
  26. rslearn/data_sources/utils.py +204 -64
  27. rslearn/data_sources/worldcereal.py +1 -1
  28. rslearn/data_sources/worldcover.py +1 -1
  29. rslearn/data_sources/worldpop.py +1 -1
  30. rslearn/data_sources/xyz_tiles.py +5 -9
  31. rslearn/dataset/materialize.py +5 -1
  32. rslearn/models/clay/clay.py +3 -3
  33. rslearn/models/concatenate_features.py +6 -1
  34. rslearn/models/detr/detr.py +4 -1
  35. rslearn/models/dinov3.py +0 -1
  36. rslearn/models/olmoearth_pretrain/model.py +3 -1
  37. rslearn/models/pooling_decoder.py +1 -1
  38. rslearn/models/prithvi.py +0 -1
  39. rslearn/models/simple_time_series.py +97 -35
  40. rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
  41. rslearn/train/data_module.py +32 -27
  42. rslearn/train/dataset.py +260 -117
  43. rslearn/train/dataset_index.py +156 -0
  44. rslearn/train/lightning_module.py +1 -1
  45. rslearn/train/model_context.py +19 -3
  46. rslearn/train/prediction_writer.py +69 -41
  47. rslearn/train/tasks/classification.py +1 -1
  48. rslearn/train/tasks/detection.py +5 -5
  49. rslearn/train/tasks/per_pixel_regression.py +13 -13
  50. rslearn/train/tasks/regression.py +1 -1
  51. rslearn/train/tasks/segmentation.py +26 -13
  52. rslearn/train/transforms/concatenate.py +17 -27
  53. rslearn/train/transforms/crop.py +8 -19
  54. rslearn/train/transforms/flip.py +4 -10
  55. rslearn/train/transforms/mask.py +9 -15
  56. rslearn/train/transforms/normalize.py +31 -82
  57. rslearn/train/transforms/pad.py +7 -13
  58. rslearn/train/transforms/resize.py +5 -22
  59. rslearn/train/transforms/select_bands.py +16 -36
  60. rslearn/train/transforms/sentinel1.py +4 -16
  61. rslearn/utils/__init__.py +2 -0
  62. rslearn/utils/geometry.py +21 -0
  63. rslearn/utils/m2m_api.py +251 -0
  64. rslearn/utils/retry_session.py +43 -0
  65. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
  66. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/RECORD +71 -66
  67. rslearn/data_sources/earthdata_srtm.py +0 -282
  68. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
  69. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
  70. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
  71. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
  72. {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
@@ -4,12 +4,9 @@
4
4
  """
5
5
 
6
6
  import io
7
- import json
8
7
  import os
9
8
  import shutil
10
9
  import tempfile
11
- import time
12
- import uuid
13
10
  from collections.abc import Generator
14
11
  from datetime import UTC, datetime, timedelta
15
12
  from typing import Any, BinaryIO
@@ -24,246 +21,7 @@ from rslearn.data_sources import DataSource, DataSourceContext, Item
24
21
  from rslearn.data_sources.utils import match_candidate_items_to_window
25
22
  from rslearn.tile_stores import TileStoreWithLayer
26
23
  from rslearn.utils import STGeometry
27
-
28
-
29
- class APIException(Exception):
30
- """Exception raised for M2M API errors."""
31
-
32
- pass
33
-
34
-
35
- class M2MAPIClient:
36
- """An API client for interacting with the USGS M2M API."""
37
-
38
- api_url = "https://m2m.cr.usgs.gov/api/api/json/stable/"
39
- pagination_size = 1000
40
-
41
- def __init__(
42
- self,
43
- username: str,
44
- password: str | None = None,
45
- token: str | None = None,
46
- timeout: timedelta = timedelta(seconds=120),
47
- ) -> None:
48
- """Initialize a new M2MAPIClient.
49
-
50
- Args:
51
- username: the EROS username
52
- password: the EROS password
53
- token: the application token. One of password or token must be specified.
54
- timeout: timeout for requests.
55
- """
56
- self.username = username
57
- self.timeout = timeout
58
-
59
- if password is not None and token is not None:
60
- raise ValueError("only one of password or token can be specified")
61
-
62
- if password is not None:
63
- json_data = json.dumps({"username": self.username, "password": password})
64
- response = requests.post(
65
- self.api_url + "login",
66
- data=json_data,
67
- timeout=self.timeout.total_seconds(),
68
- )
69
-
70
- elif token is not None:
71
- json_data = json.dumps({"username": username, "token": token})
72
- response = requests.post(
73
- self.api_url + "login-token",
74
- data=json_data,
75
- timeout=self.timeout.total_seconds(),
76
- )
77
-
78
- else:
79
- raise ValueError("one of password or token must be specified")
80
-
81
- response.raise_for_status()
82
- self.auth_token = response.json()["data"]
83
-
84
- def request(
85
- self, endpoint: str, data: dict[str, Any] | None = None
86
- ) -> dict[str, Any] | None:
87
- """Make a request to the API.
88
-
89
- Args:
90
- endpoint: the endpoint to call
91
- data: POST data to pass
92
-
93
- Returns:
94
- JSON response data if any
95
- """
96
- response = requests.post(
97
- self.api_url + endpoint,
98
- headers={"X-Auth-Token": self.auth_token},
99
- data=json.dumps(data),
100
- timeout=self.timeout.total_seconds(),
101
- )
102
- response.raise_for_status()
103
- if response.text:
104
- response_dict = response.json()
105
-
106
- if response_dict["errorMessage"]:
107
- raise APIException(response_dict["errorMessage"])
108
- return response_dict
109
- return None
110
-
111
- def close(self) -> None:
112
- """Logout from the API."""
113
- self.request("logout")
114
-
115
- def __enter__(self) -> "M2MAPIClient":
116
- """Enter function to provide with semantics."""
117
- return self
118
-
119
- def __exit__(self) -> None:
120
- """Exit function to provide with semantics.
121
-
122
- Logs out the API.
123
- """
124
- self.close()
125
-
126
- def get_filters(self, dataset_name: str) -> list[dict[str, Any]]:
127
- """Returns filters available for the given dataset.
128
-
129
- Args:
130
- dataset_name: the dataset name e.g. landsat_ot_c2_l1
131
-
132
- Returns:
133
- list of filter objects
134
- """
135
- response_dict = self.request("dataset-filters", {"datasetName": dataset_name})
136
- if response_dict is None:
137
- raise APIException("No response from API")
138
- return response_dict["data"]
139
-
140
- def scene_search(
141
- self,
142
- dataset_name: str,
143
- acquisition_time_range: tuple[datetime, datetime] | None = None,
144
- cloud_cover_range: tuple[int, int] | None = None,
145
- bbox: tuple[float, float, float, float] | None = None,
146
- metadata_filter: dict[str, Any] | None = None,
147
- ) -> list[dict[str, Any]]:
148
- """Search for scenes matching the arguments.
149
-
150
- Args:
151
- dataset_name: the dataset name e.g. landsat_ot_c2_l1
152
- acquisition_time_range: optional filter on the acquisition time
153
- cloud_cover_range: optional filter on the cloud cover
154
- bbox: optional spatial filter
155
- metadata_filter: optional metadata filter dict
156
- """
157
- base_data: dict[str, Any] = {"datasetName": dataset_name, "sceneFilter": {}}
158
- if acquisition_time_range:
159
- base_data["sceneFilter"]["acquisitionFilter"] = {
160
- "start": acquisition_time_range[0].isoformat(),
161
- "end": acquisition_time_range[1].isoformat(),
162
- }
163
- if cloud_cover_range:
164
- base_data["sceneFilter"]["cloudCoverFilter"] = {
165
- "min": cloud_cover_range[0],
166
- "max": cloud_cover_range[1],
167
- "includeUnknown": False,
168
- }
169
- if bbox:
170
- base_data["sceneFilter"]["spatialFilter"] = {
171
- "filterType": "mbr",
172
- "lowerLeft": {"longitude": bbox[0], "latitude": bbox[1]},
173
- "upperRight": {"longitude": bbox[2], "latitude": bbox[3]},
174
- }
175
- if metadata_filter:
176
- base_data["sceneFilter"]["metadataFilter"] = metadata_filter
177
-
178
- starting_number = 1
179
- results = []
180
- while True:
181
- cur_data = base_data.copy()
182
- cur_data["startingNumber"] = starting_number
183
- cur_data["maxResults"] = self.pagination_size
184
- response_dict = self.request("scene-search", cur_data)
185
- if response_dict is None:
186
- raise APIException("No response from API")
187
- data = response_dict["data"]
188
- results.extend(data["results"])
189
- if data["recordsReturned"] < self.pagination_size:
190
- break
191
- starting_number += self.pagination_size
192
-
193
- return results
194
-
195
- def get_scene_metadata(self, dataset_name: str, entity_id: str) -> dict[str, Any]:
196
- """Get detailed metadata for a scene.
197
-
198
- Args:
199
- dataset_name: the dataset name in which to search
200
- entity_id: the entity ID of the scene
201
-
202
- Returns:
203
- full scene metadata
204
- """
205
- response_dict = self.request(
206
- "scene-metadata",
207
- {
208
- "datasetName": dataset_name,
209
- "entityId": entity_id,
210
- "metadataType": "full",
211
- },
212
- )
213
- if response_dict is None:
214
- raise APIException("No response from API")
215
- return response_dict["data"]
216
-
217
- def get_downloadable_products(
218
- self, dataset_name: str, entity_id: str
219
- ) -> list[dict[str, Any]]:
220
- """Get the downloadable products for a given scene.
221
-
222
- Args:
223
- dataset_name: the dataset name
224
- entity_id: the entity ID of the scene
225
-
226
- Returns:
227
- list of downloadable products
228
- """
229
- data = {"datasetName": dataset_name, "entityIds": [entity_id]}
230
- response_dict = self.request("download-options", data)
231
- if response_dict is None:
232
- raise APIException("No response from API")
233
- return response_dict["data"]
234
-
235
- def get_download_url(self, entity_id: str, product_id: str) -> str:
236
- """Get the download URL for a given product.
237
-
238
- Args:
239
- entity_id: the entity ID of the product
240
- product_id: the product ID of the product
241
-
242
- Returns:
243
- the download URL
244
- """
245
- label = str(uuid.uuid4())
246
- data = {
247
- "downloads": [
248
- {"label": label, "entityId": entity_id, "productId": product_id}
249
- ]
250
- }
251
- response_dict = self.request("download-request", data)
252
- if response_dict is None:
253
- raise APIException("No response from API")
254
- response = response_dict["data"]
255
- while True:
256
- response_dict = self.request("download-retrieve", {"label": label})
257
- if response_dict is None:
258
- raise APIException("No response from API")
259
- response = response_dict["data"]
260
- if len(response["available"]) > 0:
261
- return response["available"][0]["url"]
262
- if len(response["requested"]) == 0:
263
- raise Exception("Did not get download URL")
264
- if response["requested"][0].get("url"):
265
- return response["requested"][0]["url"]
266
- time.sleep(10)
24
+ from rslearn.utils.m2m_api import APIException, M2MAPIClient
267
25
 
268
26
 
269
27
  class LandsatOliTirsItem(Item):
@@ -314,30 +72,26 @@ class LandsatOliTirs(DataSource):
314
72
 
315
73
  def __init__(
316
74
  self,
317
- username: str,
318
- sort_by: str | None = None,
319
- password: str | None = None,
75
+ username: str | None = None,
320
76
  token: str | None = None,
77
+ sort_by: str | None = None,
321
78
  timeout: timedelta = timedelta(seconds=10),
322
79
  context: DataSourceContext = DataSourceContext(),
323
80
  ):
324
81
  """Initialize a new LandsatOliTirs instance.
325
82
 
326
83
  Args:
327
- username: EROS username
84
+ username: EROS username (see M2MAPIClient).
85
+ token: EROS application token (see M2MAPIClient).
328
86
  sort_by: can be "cloud_cover", default arbitrary order; only has effect for
329
87
  SpaceMode.WITHIN.
330
- password: EROS password (see M2MAPIClient).
331
- token: EROS application token (see M2MAPIClient).
332
88
  timeout: timeout for requests.
333
89
  context: the data source context.
334
90
  """
335
91
  self.sort_by = sort_by
336
92
  self.timeout = timeout
337
93
 
338
- self.client = M2MAPIClient(
339
- username, password=password, token=token, timeout=timeout
340
- )
94
+ self.client = M2MAPIClient(username=username, token=token, timeout=timeout)
341
95
 
342
96
  def _scene_metadata_to_item(self, result: dict[str, Any]) -> LandsatOliTirsItem:
343
97
  """Convert scene metadata from the API to a LandsatOliTirsItem."""
@@ -429,9 +183,8 @@ class LandsatOliTirs(DataSource):
429
183
  )
430
184
  return self._scene_metadata_to_item(scene_metadata)
431
185
 
432
- def deserialize_item(self, serialized_item: Any) -> Item:
186
+ def deserialize_item(self, serialized_item: dict) -> Item:
433
187
  """Deserializes an item from JSON-decoded data."""
434
- assert isinstance(serialized_item, dict)
435
188
  return LandsatOliTirsItem.deserialize(serialized_item)
436
189
 
437
190
  def _get_download_urls(self, item: Item) -> dict[str, tuple[str, str]]:
@@ -1,7 +1,9 @@
1
1
  """Utilities shared by data sources."""
2
2
 
3
+ import warnings
4
+ from collections.abc import Callable
3
5
  from dataclasses import dataclass
4
- from datetime import UTC, datetime, timedelta
6
+ from datetime import UTC, datetime
5
7
  from typing import TypeVar
6
8
 
7
9
  import shapely
@@ -40,13 +42,13 @@ class PendingMosaic:
40
42
  completed: bool = False
41
43
 
42
44
 
43
- def mosaic_matching(
45
+ def _create_single_coverage_mosaics(
44
46
  window_geometry: STGeometry,
45
47
  items: list[ItemType],
46
48
  item_shps: list[shapely.Geometry],
47
- max_matches: int,
49
+ max_mosaics: int,
48
50
  ) -> list[list[ItemType]]:
49
- """Spatial item matching for mosaic space mode.
51
+ """Create mosaics where each mosaic covers the window geometry once.
50
52
 
51
53
  This attempts to piece together items into mosaics that fully cover the window
52
54
  geometry. If there are items leftover that only partially cover the window
@@ -56,15 +58,16 @@ def mosaic_matching(
56
58
  window_geometry: the geometry of the window.
57
59
  items: list of items.
58
60
  item_shps: the item shapes projected to the window's projection.
59
- max_matches: the maximum number of matches (mosaics) to create.
61
+ max_mosaics: the maximum number of mosaics to create.
60
62
 
61
63
  Returns:
62
- list of item groups, each one corresponding to a different mosaic.
64
+ list of item groups, each one corresponding to a different single-coverage
65
+ mosaic.
63
66
  """
64
67
  # To create mosaics, we iterate over the items in order, and add each item to
65
68
  # the first mosaic that the new item adds coverage to.
66
69
 
67
- # max_matches could be very high if the user just wants us to create as many
70
+ # max_mosaics could be very high if the user just wants us to create as many
68
71
  # mosaics as possible, so we initialize the list here as empty and just add
69
72
  # more pending mosaics when it is necessary.
70
73
  pending_mosaics: list[PendingMosaic] = []
@@ -108,7 +111,7 @@ def mosaic_matching(
108
111
 
109
112
  # See if we can add a new mosaic based on this item. There must be room for
110
113
  # more mosaics, but the item must also intersect the requested geometry.
111
- if len(pending_mosaics) >= max_matches:
114
+ if len(pending_mosaics) >= max_mosaics:
112
115
  continue
113
116
  intersect_area = item_shp.intersection(window_geometry.shp).area
114
117
  if (
@@ -127,18 +130,148 @@ def mosaic_matching(
127
130
  return [pending_mosaic.items for pending_mosaic in pending_mosaics]
128
131
 
129
132
 
130
- def per_period_mosaic_matching(
131
- window_geometry: STGeometry,
132
- item_list: list[ItemType],
133
- period_duration: timedelta,
134
- max_matches: int,
133
+ def _consolidate_mosaics_by_overlaps(
134
+ mosaics: list[list[ItemType]],
135
+ overlaps: int,
136
+ max_groups: int,
137
+ ) -> list[list[ItemType]]:
138
+ """Consolidate single-coverage mosaics into groups based on desired overlaps.
139
+
140
+ Args:
141
+ mosaics: list of single-coverage mosaics (each mosaic is a list of items).
142
+ overlaps: the number of overlapping coverages wanted per group.
143
+ max_groups: the maximum number of groups to return.
144
+
145
+ Returns:
146
+ list of item groups, where each group contains items from multiple mosaics
147
+ to achieve the desired number of overlapping coverages.
148
+ """
149
+ if overlaps <= 0:
150
+ overlaps = 1
151
+
152
+ groups: list[list[ItemType]] = []
153
+ for i in range(0, len(mosaics), overlaps):
154
+ if len(groups) >= max_groups:
155
+ break
156
+ # Combine overlaps consecutive mosaics into one group
157
+ combined_items: list[ItemType] = []
158
+ for mosaic in mosaics[i : i + overlaps]:
159
+ combined_items.extend(mosaic)
160
+ if combined_items:
161
+ groups.append(combined_items)
162
+
163
+ return groups
164
+
165
+
166
+ def match_with_space_mode_contains(
167
+ geometry: STGeometry,
168
+ items: list[ItemType],
169
+ item_shps: list[shapely.Geometry],
170
+ query_config: QueryConfig,
171
+ ) -> list[list[ItemType]]:
172
+ """Match items that fully contain the window geometry.
173
+
174
+ Args:
175
+ geometry: the window's geometry.
176
+ items: list of items.
177
+ item_shps: the item shapes projected to the window's projection.
178
+ query_config: the query configuration.
179
+
180
+ Returns:
181
+ list of matched item groups, where each group contains a single item.
182
+ """
183
+ groups: list[list[ItemType]] = []
184
+ for item, item_shp in zip(items, item_shps):
185
+ if not item_shp.contains(geometry.shp):
186
+ continue
187
+ groups.append([item])
188
+ if len(groups) >= query_config.max_matches:
189
+ break
190
+ return groups
191
+
192
+
193
+ def match_with_space_mode_intersects(
194
+ geometry: STGeometry,
195
+ items: list[ItemType],
196
+ item_shps: list[shapely.Geometry],
197
+ query_config: QueryConfig,
198
+ ) -> list[list[ItemType]]:
199
+ """Match items that intersect any portion of the window geometry.
200
+
201
+ Args:
202
+ geometry: the window's geometry.
203
+ items: list of items.
204
+ item_shps: the item shapes projected to the window's projection.
205
+ query_config: the query configuration.
206
+
207
+ Returns:
208
+ list of matched item groups, where each group contains a single item.
209
+ """
210
+ groups: list[list[ItemType]] = []
211
+ for item, item_shp in zip(items, item_shps):
212
+ if not shp_intersects(item_shp, geometry.shp):
213
+ continue
214
+ groups.append([item])
215
+ if len(groups) >= query_config.max_matches:
216
+ break
217
+ return groups
218
+
219
+
220
+ def match_with_space_mode_mosaic(
221
+ geometry: STGeometry,
222
+ items: list[ItemType],
223
+ item_shps: list[shapely.Geometry],
224
+ query_config: QueryConfig,
225
+ ) -> list[list[ItemType]]:
226
+ """Match items into mosaic groups that cover the window geometry.
227
+
228
+ Creates groups of items that together cover the window geometry. The number of
229
+ overlapping coverages in each group is controlled by mosaic_compositing_overlaps.
230
+
231
+ Args:
232
+ geometry: the window's geometry.
233
+ items: list of items.
234
+ item_shps: the item shapes projected to the window's projection.
235
+ query_config: the query configuration.
236
+
237
+ Returns:
238
+ list of matched item groups, where each group forms a mosaic covering the
239
+ window.
240
+ """
241
+ overlaps = query_config.mosaic_compositing_overlaps
242
+
243
+ # Calculate how many single-coverage mosaics we need to create.
244
+ # We need enough mosaics to consolidate into max_matches groups with the
245
+ # desired number of overlaps per group.
246
+ max_single_mosaics = query_config.max_matches * overlaps
247
+
248
+ # Create single-coverage mosaics
249
+ single_mosaics = _create_single_coverage_mosaics(
250
+ geometry, items, item_shps, max_single_mosaics
251
+ )
252
+
253
+ # Consolidate into groups based on overlaps
254
+ return _consolidate_mosaics_by_overlaps(
255
+ single_mosaics, overlaps, query_config.max_matches
256
+ )
257
+
258
+
259
+ def match_with_space_mode_per_period_mosaic(
260
+ geometry: STGeometry,
261
+ items: list[ItemType],
262
+ item_shps: list[shapely.Geometry],
263
+ query_config: QueryConfig,
135
264
  ) -> list[list[ItemType]]:
136
265
  """Match items to the geometry with one mosaic per period.
137
266
 
138
267
  We divide the time range of the geometry into shorter periods. Within each period,
139
268
  we use the items corresponding to that period to create a mosaic. The returned item
140
- groups include one group per period, starting from the most recent periods, up to
141
- the provided max_matches.
269
+ groups include one group per period, up to the provided max_matches.
270
+
271
+ By default (reverse_time_order=True), groups are returned starting from the most
272
+ recent periods. When reverse_time_order=False, groups are returned in chronological
273
+ order (oldest first). reverse_time_order should always be set False, and
274
+ FutureWarning will be warned if it is not.
142
275
 
143
276
  The periods are also bounded to the window's time range, and aligned with the end
144
277
  of that time range, i.e. the most recent window is
@@ -159,42 +292,59 @@ def per_period_mosaic_matching(
159
292
  max_matches*period_duration is not equivalent to a longer window duration.
160
293
 
161
294
  Args:
162
- window_geometry: the window geometry to match items to.
163
- item_list: the list of items.
164
- period_duration: the duration of one period.
165
- max_matches: the number of per-period mosaics to create.
295
+ geometry: the window's geometry.
296
+ items: list of items.
297
+ item_shps: the item shapes projected to the window's projection (unused here)
298
+ query_config: the query configuration.
166
299
 
167
300
  Returns:
168
- the matched item groups, where each group contains items that yield a
169
- per-period mosaic.
301
+ list of matched item groups, where each group contains items that yield a
302
+ per-period mosaic.
170
303
  """
171
- if window_geometry.time_range is None:
304
+ if geometry.time_range is None:
172
305
  raise ValueError(
173
306
  "all windows must have time range for per period mosaic matching"
174
307
  )
175
308
 
309
+ # Emit warning if per_period_mosaic_reverse_time_order is True (the default).
310
+ if query_config.per_period_mosaic_reverse_time_order:
311
+ warnings.warn(
312
+ "QueryConfig.per_period_mosaic_reverse_time_order defaults to True, which "
313
+ "returns item groups in reverse temporal order (most recent first) for "
314
+ "PER_PERIOD_MOSAIC mode. This default will change to False (chronological "
315
+ "order) after 2026-04-01. To silence this warning, explicitly set "
316
+ "per_period_mosaic_reverse_time_order=False.",
317
+ FutureWarning,
318
+ stacklevel=3,
319
+ )
320
+
321
+ period_duration = query_config.period_duration
322
+
176
323
  # For each period, we create an STGeometry with modified time range matching that
177
324
  # period, and use it with match_candidate_items_to_window to get a mosaic.
178
325
  cur_groups: list[list[ItemType]] = []
179
- period_start = window_geometry.time_range[1] - period_duration
326
+ period_start = geometry.time_range[1] - period_duration
180
327
  while (
181
- period_start >= window_geometry.time_range[0] and len(cur_groups) < max_matches
328
+ period_start >= geometry.time_range[0]
329
+ and len(cur_groups) < query_config.max_matches
182
330
  ):
183
331
  period_time_range = (
184
332
  period_start,
185
333
  period_start + period_duration,
186
334
  )
187
335
  period_start -= period_duration
188
- period_geom = STGeometry(
189
- window_geometry.projection, window_geometry.shp, period_time_range
190
- )
336
+ period_geom = STGeometry(geometry.projection, geometry.shp, period_time_range)
191
337
 
192
338
  # We modify the QueryConfig here since caller should be asking for
193
339
  # multiple mosaics, but we just want one mosaic per period.
194
340
  period_groups = match_candidate_items_to_window(
195
341
  period_geom,
196
- item_list,
197
- QueryConfig(space_mode=SpaceMode.MOSAIC, max_matches=1),
342
+ items,
343
+ QueryConfig(
344
+ space_mode=SpaceMode.MOSAIC,
345
+ max_matches=1,
346
+ mosaic_compositing_overlaps=query_config.mosaic_compositing_overlaps,
347
+ ),
198
348
  )
199
349
 
200
350
  # There should be zero or one group depending on whether there were
@@ -204,9 +354,29 @@ def per_period_mosaic_matching(
204
354
  continue
205
355
  cur_groups.append(period_groups[0])
206
356
 
357
+ # Currently the item groups are in reverse chronologic order.
358
+ # Reverse it to correct chronological order if requested.
359
+ if not query_config.per_period_mosaic_reverse_time_order:
360
+ cur_groups.reverse()
361
+
207
362
  return cur_groups
208
363
 
209
364
 
365
+ # Type alias for space mode handler functions
366
+ SpaceModeHandler = Callable[
367
+ [STGeometry, list[ItemType], list[shapely.Geometry], QueryConfig],
368
+ list[list[ItemType]],
369
+ ]
370
+
371
+ # Dict mapping SpaceMode values to their handler functions
372
+ space_mode_handlers: dict[SpaceMode, SpaceModeHandler] = {
373
+ SpaceMode.CONTAINS: match_with_space_mode_contains,
374
+ SpaceMode.INTERSECTS: match_with_space_mode_intersects,
375
+ SpaceMode.MOSAIC: match_with_space_mode_mosaic,
376
+ SpaceMode.PER_PERIOD_MOSAIC: match_with_space_mode_per_period_mosaic,
377
+ }
378
+
379
+
210
380
  def match_candidate_items_to_window(
211
381
  geometry: STGeometry, items: list[ItemType], query_config: QueryConfig
212
382
  ) -> list[list[ItemType]]:
@@ -262,43 +432,13 @@ def match_candidate_items_to_window(
262
432
  item_geom = item_geom.to_projection(geometry.projection)
263
433
  item_shps.append(item_geom.shp)
264
434
 
265
- if query_config.space_mode == SpaceMode.CONTAINS:
266
- groups = []
267
- for item, item_shp in zip(items, item_shps):
268
- if not item_shp.contains(geometry.shp):
269
- continue
270
- groups.append([item])
271
- if len(groups) >= query_config.max_matches:
272
- break
273
-
274
- elif query_config.space_mode == SpaceMode.INTERSECTS:
275
- groups = []
276
- for item, item_shp in zip(items, item_shps):
277
- if not shp_intersects(item_shp, geometry.shp):
278
- continue
279
- groups.append([item])
280
- if len(groups) >= query_config.max_matches:
281
- break
282
-
283
- elif query_config.space_mode == SpaceMode.MOSAIC:
284
- groups = mosaic_matching(geometry, items, item_shps, query_config.max_matches)
285
-
286
- elif query_config.space_mode == SpaceMode.PER_PERIOD_MOSAIC:
287
- groups = per_period_mosaic_matching(
288
- geometry, items, query_config.period_duration, query_config.max_matches
289
- )
290
-
291
- elif query_config.space_mode == SpaceMode.COMPOSITE:
292
- group = []
293
- for item, item_shp in zip(items, item_shps):
294
- if not shp_intersects(item_shp, geometry.shp):
295
- continue
296
- group.append(item)
297
- groups = [group]
298
-
299
- else:
435
+ # Dispatch to the appropriate space mode handler
436
+ handler = space_mode_handlers.get(query_config.space_mode)
437
+ if handler is None:
300
438
  raise ValueError(f"invalid space mode {query_config.space_mode}")
301
439
 
440
+ groups = handler(geometry, items, item_shps, query_config)
441
+
302
442
  # Enforce minimum matches if set.
303
443
  if len(groups) < query_config.min_matches:
304
444
  logger.warning(
@@ -291,7 +291,7 @@ class WorldCereal(LocalFiles):
291
291
  raise ValueError(f"No AEZ files found for {self.band}")
292
292
 
293
293
  super().__init__(
294
- src_dir=tif_dir,
294
+ src_dir=tif_dir.absolute().as_uri(),
295
295
  raster_item_specs=item_specs,
296
296
  layer_type=LayerType.RASTER,
297
297
  context=context,
@@ -75,7 +75,7 @@ class WorldCover(LocalFiles):
75
75
  tif_dir = self.download_worldcover_data(worldcover_upath)
76
76
 
77
77
  super().__init__(
78
- src_dir=tif_dir,
78
+ src_dir=tif_dir.absolute().as_uri(),
79
79
  layer_type=LayerType.RASTER,
80
80
  context=context,
81
81
  )