rslearn 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,5 @@
1
1
  """Data on Planetary Computer."""
2
2
 
3
- import json
4
3
  import os
5
4
  import tempfile
6
5
  import xml.etree.ElementTree as ET
@@ -10,18 +9,14 @@ from typing import Any
10
9
  import affine
11
10
  import numpy.typing as npt
12
11
  import planetary_computer
13
- import pystac
14
- import pystac_client
15
12
  import rasterio
16
13
  import requests
17
- import shapely
18
14
  from rasterio.enums import Resampling
19
15
  from upath import UPath
20
16
 
21
- from rslearn.config import LayerConfig, QueryConfig
22
- from rslearn.const import WGS84_PROJECTION
23
- from rslearn.data_sources import DataSource, DataSourceContext, Item
24
- from rslearn.data_sources.utils import match_candidate_items_to_window
17
+ from rslearn.config import LayerConfig
18
+ from rslearn.data_sources import DataSourceContext
19
+ from rslearn.data_sources.stac import SourceItem, StacDataSource
25
20
  from rslearn.dataset import Window
26
21
  from rslearn.dataset.materialize import RasterMaterializer
27
22
  from rslearn.log_utils import get_logger
@@ -35,38 +30,7 @@ from .copernicus import get_harmonize_callback
35
30
  logger = get_logger(__name__)
36
31
 
37
32
 
38
- class PlanetaryComputerItem(Item):
39
- """An item in the PlanetaryComputer data source."""
40
-
41
- def __init__(self, name: str, geometry: STGeometry, asset_urls: dict[str, str]):
42
- """Creates a new PlanetaryComputerItem.
43
-
44
- Args:
45
- name: unique name of the item
46
- geometry: the spatial and temporal extent of the item
47
- asset_urls: map from asset key to the unsigned asset URL.
48
- """
49
- super().__init__(name, geometry)
50
- self.asset_urls = asset_urls
51
-
52
- def serialize(self) -> dict[str, Any]:
53
- """Serializes the item to a JSON-encodable dictionary."""
54
- d = super().serialize()
55
- d["asset_urls"] = self.asset_urls
56
- return d
57
-
58
- @staticmethod
59
- def deserialize(d: dict[str, Any]) -> "PlanetaryComputerItem":
60
- """Deserializes an item from a JSON-decoded dictionary."""
61
- item = super(PlanetaryComputerItem, PlanetaryComputerItem).deserialize(d)
62
- return PlanetaryComputerItem(
63
- name=item.name,
64
- geometry=item.geometry,
65
- asset_urls=d["asset_urls"],
66
- )
67
-
68
-
69
- class PlanetaryComputer(DataSource, TileStore):
33
+ class PlanetaryComputer(StacDataSource, TileStore):
70
34
  """Modality-agnostic data source for data on Microsoft Planetary Computer.
71
35
 
72
36
  If there is a subclass available for a modality, it is recommended to use the
@@ -83,10 +47,6 @@ class PlanetaryComputer(DataSource, TileStore):
83
47
 
84
48
  STAC_ENDPOINT = "https://planetarycomputer.microsoft.com/api/stac/v1"
85
49
 
86
- # Default threshold for recreating the STAC client to prevent memory leaks
87
- # from the pystac Catalog's resolved objects cache growing unbounded
88
- DEFAULT_MAX_ITEMS_PER_CLIENT = 1000
89
-
90
50
  def __init__(
91
51
  self,
92
52
  collection_name: str,
@@ -97,7 +57,6 @@ class PlanetaryComputer(DataSource, TileStore):
97
57
  timeout: timedelta = timedelta(seconds=10),
98
58
  skip_items_missing_assets: bool = False,
99
59
  cache_dir: str | None = None,
100
- max_items_per_client: int | None = None,
101
60
  context: DataSourceContext = DataSourceContext(),
102
61
  ):
103
62
  """Initialize a new PlanetaryComputer instance.
@@ -115,228 +74,40 @@ class PlanetaryComputer(DataSource, TileStore):
115
74
  cache_dir: optional directory to cache items by name, including asset URLs.
116
75
  If not set, there will be no cache and instead STAC requests will be
117
76
  needed each time.
118
- max_items_per_client: number of STAC items to process before recreating
119
- the client to prevent memory leaks from the resolved objects cache.
120
- Defaults to DEFAULT_MAX_ITEMS_PER_CLIENT.
121
77
  context: the data source context.
122
78
  """
123
- self.collection_name = collection_name
124
- self.asset_bands = asset_bands
125
- self.query = query
126
- self.sort_by = sort_by
127
- self.sort_ascending = sort_ascending
128
- self.timeout = timeout
129
- self.skip_items_missing_assets = skip_items_missing_assets
130
- self.max_items_per_client = (
131
- max_items_per_client or self.DEFAULT_MAX_ITEMS_PER_CLIENT
132
- )
133
-
79
+ # Determine the cache_dir to use.
80
+ cache_upath: UPath | None = None
134
81
  if cache_dir is not None:
135
82
  if context.ds_path is not None:
136
- self.cache_dir = join_upath(context.ds_path, cache_dir)
83
+ cache_upath = join_upath(context.ds_path, cache_dir)
137
84
  else:
138
- self.cache_dir = UPath(cache_dir)
139
-
140
- self.cache_dir.mkdir(parents=True, exist_ok=True)
141
- else:
142
- self.cache_dir = None
143
-
144
- self.client: pystac_client.Client | None = None
145
- self._client_item_count = 0
85
+ cache_upath = UPath(cache_dir)
146
86
 
147
- def _load_client(
148
- self,
149
- ) -> pystac_client.Client:
150
- """Lazily load pystac client.
87
+ cache_upath.mkdir(parents=True, exist_ok=True)
151
88
 
152
- We don't load it when creating the data source because it takes time and caller
153
- may not be calling get_items. Additionally, loading it during the get_items
154
- call enables leveraging the retry loop functionality in
155
- prepare_dataset_windows.
89
+ # We pass required_assets to StacDataSource of skip_items_missing_assets is set.
90
+ required_assets: list[str] | None = None
91
+ if skip_items_missing_assets:
92
+ required_assets = list(asset_bands.keys())
156
93
 
157
- Note: We periodically recreate the client to prevent memory leaks from the
158
- pystac Catalog's resolved objects cache, which grows unbounded as STAC items
159
- are deserialized and cached. The cache cannot be cleared or disabled.
160
- """
161
- if self.client is None:
162
- logger.info("Creating initial STAC client")
163
- self.client = pystac_client.Client.open(self.STAC_ENDPOINT)
164
- return self.client
165
-
166
- if self._client_item_count < self.max_items_per_client:
167
- return self.client
168
-
169
- # Recreate client to clear the resolved objects cache
170
- current_client = self.client
171
- logger.debug(
172
- "Recreating STAC client after processing %d items (threshold: %d)",
173
- self._client_item_count,
174
- self.max_items_per_client,
94
+ super().__init__(
95
+ endpoint=self.STAC_ENDPOINT,
96
+ collection_name=collection_name,
97
+ query=query,
98
+ sort_by=sort_by,
99
+ sort_ascending=sort_ascending,
100
+ required_assets=required_assets,
101
+ cache_dir=cache_upath,
175
102
  )
176
- client_root = current_client.get_root()
177
- client_root.clear_links()
178
- client_root.clear_items()
179
- client_root.clear_children()
180
- self._client_item_count = 0
181
- self.client = pystac_client.Client.open(self.STAC_ENDPOINT)
182
- return self.client
183
-
184
- def _stac_item_to_item(self, stac_item: pystac.Item) -> PlanetaryComputerItem:
185
- shp = shapely.geometry.shape(stac_item.geometry)
186
-
187
- # Get time range.
188
- metadata = stac_item.common_metadata
189
- if metadata.start_datetime is not None and metadata.end_datetime is not None:
190
- time_range = (
191
- metadata.start_datetime,
192
- metadata.end_datetime,
193
- )
194
- elif stac_item.datetime is not None:
195
- time_range = (stac_item.datetime, stac_item.datetime)
196
- else:
197
- raise ValueError(
198
- f"item {stac_item.id} unexpectedly missing start_datetime, end_datetime, and datetime"
199
- )
200
-
201
- geom = STGeometry(WGS84_PROJECTION, shp, time_range)
202
- asset_urls = {
203
- asset_key: asset_obj.href
204
- for asset_key, asset_obj in stac_item.assets.items()
205
- }
206
- return PlanetaryComputerItem(stac_item.id, geom, asset_urls)
207
-
208
- def get_item_by_name(self, name: str) -> PlanetaryComputerItem:
209
- """Gets an item by name.
210
-
211
- Args:
212
- name: the name of the item to get
213
-
214
- Returns:
215
- the item object
216
- """
217
- # If cache_dir is set, we cache the item. First here we check if it is already
218
- # in the cache.
219
- cache_fname: UPath | None = None
220
- if self.cache_dir:
221
- cache_fname = self.cache_dir / f"{name}.json"
222
- if cache_fname is not None and cache_fname.exists():
223
- with cache_fname.open() as f:
224
- return PlanetaryComputerItem.deserialize(json.load(f))
225
-
226
- # No cache or not in cache, so we need to make the STAC request.
227
- logger.debug("Getting STAC item {name}")
228
- client = self._load_client()
229
-
230
- search_result = client.search(ids=[name], collections=[self.collection_name])
231
- stac_items = list(search_result.items())
232
-
233
- if not stac_items:
234
- raise ValueError(
235
- f"Item {name} not found in collection {self.collection_name}"
236
- )
237
- if len(stac_items) > 1:
238
- raise ValueError(
239
- f"Multiple items found for ID {name} in collection {self.collection_name}"
240
- )
241
-
242
- stac_item = stac_items[0]
243
- item = self._stac_item_to_item(stac_item)
244
-
245
- # Track items processed for client recreation threshold (after deserialization)
246
- self._client_item_count += 1
247
-
248
- # Finally we cache it if cache_dir is set.
249
- if cache_fname is not None:
250
- with cache_fname.open("w") as f:
251
- json.dump(item.serialize(), f)
252
-
253
- return item
254
-
255
- def get_items(
256
- self, geometries: list[STGeometry], query_config: QueryConfig
257
- ) -> list[list[list[PlanetaryComputerItem]]]:
258
- """Get a list of items in the data source intersecting the given geometries.
259
-
260
- Args:
261
- geometries: the spatiotemporal geometries
262
- query_config: the query configuration
263
-
264
- Returns:
265
- List of groups of items that should be retrieved for each geometry.
266
- """
267
- client = self._load_client()
268
-
269
- groups = []
270
- for geometry in geometries:
271
- # Get potentially relevant items from the collection by performing one search
272
- # for each requested geometry.
273
- wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
274
- logger.debug("performing STAC search for geometry %s", wgs84_geometry)
275
- result = client.search(
276
- collections=[self.collection_name],
277
- intersects=shapely.to_geojson(wgs84_geometry.shp),
278
- datetime=wgs84_geometry.time_range,
279
- query=self.query,
280
- )
281
- stac_items = [item for item in result.items()]
282
- # Track items processed for client recreation threshold (after deserialization)
283
- self._client_item_count += len(stac_items)
284
- logger.debug("STAC search yielded %d items", len(stac_items))
285
-
286
- if self.skip_items_missing_assets:
287
- # Filter out items that are missing any of the assets in self.asset_bands.
288
- good_stac_items = []
289
- for stac_item in stac_items:
290
- good = True
291
- for asset_key in self.asset_bands.keys():
292
- if asset_key in stac_item.assets:
293
- continue
294
- good = False
295
- break
296
- if good:
297
- good_stac_items.append(stac_item)
298
- logger.debug(
299
- "skip_items_missing_assets filter from %d to %d items",
300
- len(stac_items),
301
- len(good_stac_items),
302
- )
303
- stac_items = good_stac_items
304
-
305
- if self.sort_by is not None:
306
- stac_items.sort(
307
- key=lambda stac_item: stac_item.properties[self.sort_by],
308
- reverse=not self.sort_ascending,
309
- )
310
-
311
- candidate_items = [
312
- self._stac_item_to_item(stac_item) for stac_item in stac_items
313
- ]
314
-
315
- # Since we made the STAC request, might as well save these to the cache.
316
- if self.cache_dir is not None:
317
- for item in candidate_items:
318
- cache_fname = self.cache_dir / f"{item.name}.json"
319
- if cache_fname.exists():
320
- continue
321
- with cache_fname.open("w") as f:
322
- json.dump(item.serialize(), f)
323
-
324
- cur_groups = match_candidate_items_to_window(
325
- geometry, candidate_items, query_config
326
- )
327
- groups.append(cur_groups)
328
-
329
- return groups
330
-
331
- def deserialize_item(self, serialized_item: Any) -> PlanetaryComputerItem:
332
- """Deserializes an item from JSON-decoded data."""
333
- assert isinstance(serialized_item, dict)
334
- return PlanetaryComputerItem.deserialize(serialized_item)
103
+ self.asset_bands = asset_bands
104
+ self.timeout = timeout
105
+ self.skip_items_missing_assets = skip_items_missing_assets
335
106
 
336
107
  def ingest(
337
108
  self,
338
109
  tile_store: TileStoreWithLayer,
339
- items: list[PlanetaryComputerItem],
110
+ items: list[SourceItem],
340
111
  geometries: list[list[STGeometry]],
341
112
  ) -> None:
342
113
  """Ingest items into the given tile store.
@@ -512,7 +283,7 @@ class PlanetaryComputer(DataSource, TileStore):
512
283
  def materialize(
513
284
  self,
514
285
  window: Window,
515
- item_groups: list[list[Item]],
286
+ item_groups: list[list[SourceItem]],
516
287
  layer_name: str,
517
288
  layer_cfg: LayerConfig,
518
289
  ) -> None:
@@ -601,7 +372,7 @@ class Sentinel2(PlanetaryComputer):
601
372
  **kwargs,
602
373
  )
603
374
 
604
- def _get_product_xml(self, item: PlanetaryComputerItem) -> ET.Element:
375
+ def _get_product_xml(self, item: SourceItem) -> ET.Element:
605
376
  asset_url = planetary_computer.sign(item.asset_urls["product-metadata"])
606
377
  response = requests.get(asset_url, timeout=self.timeout.total_seconds())
607
378
  response.raise_for_status()
@@ -610,7 +381,7 @@ class Sentinel2(PlanetaryComputer):
610
381
  def ingest(
611
382
  self,
612
383
  tile_store: TileStoreWithLayer,
613
- items: list[PlanetaryComputerItem],
384
+ items: list[SourceItem],
614
385
  geometries: list[list[STGeometry]],
615
386
  ) -> None:
616
387
  """Ingest items into the given tile store.
@@ -796,3 +567,53 @@ class Naip(PlanetaryComputer):
796
567
  context=context,
797
568
  **kwargs,
798
569
  )
570
+
571
+
572
+ class CopDemGlo30(PlanetaryComputer):
573
+ """A data source for Copernicus DEM GLO-30 (30m) on Microsoft Planetary Computer.
574
+
575
+ See https://planetarycomputer.microsoft.com/dataset/cop-dem-glo-30.
576
+ """
577
+
578
+ COLLECTION_NAME = "cop-dem-glo-30"
579
+ DATA_ASSET = "data"
580
+
581
+ def __init__(
582
+ self,
583
+ band_name: str = "DEM",
584
+ context: DataSourceContext = DataSourceContext(),
585
+ **kwargs: Any,
586
+ ):
587
+ """Initialize a new CopDemGlo30 instance.
588
+
589
+ Args:
590
+ band_name: band name to use if the layer config is missing from the
591
+ context.
592
+ context: the data source context.
593
+ kwargs: additional arguments to pass to PlanetaryComputer.
594
+ """
595
+ if context.layer_config is not None:
596
+ if len(context.layer_config.band_sets) != 1:
597
+ raise ValueError("expected a single band set")
598
+ if len(context.layer_config.band_sets[0].bands) != 1:
599
+ raise ValueError("expected band set to have a single band")
600
+ band_name = context.layer_config.band_sets[0].bands[0]
601
+
602
+ super().__init__(
603
+ collection_name=self.COLLECTION_NAME,
604
+ asset_bands={self.DATA_ASSET: [band_name]},
605
+ # Skip since all items should have the same asset(s).
606
+ skip_items_missing_assets=True,
607
+ context=context,
608
+ **kwargs,
609
+ )
610
+
611
+ def _stac_item_to_item(self, stac_item: Any) -> SourceItem:
612
+ # Copernicus DEM is static; ignore item timestamps so it matches any window.
613
+ item = super()._stac_item_to_item(stac_item)
614
+ item.geometry = STGeometry(item.geometry.projection, item.geometry.shp, None)
615
+ return item
616
+
617
+ def _get_search_time_range(self, geometry: STGeometry) -> None:
618
+ # Copernicus DEM is static; do not filter STAC searches by time.
619
+ return None