rslearn 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,6 +11,8 @@ from enum import Enum
11
11
  from typing import Any, BinaryIO
12
12
 
13
13
  import boto3
14
+ import botocore
15
+ import botocore.client
14
16
  import dateutil.parser
15
17
  import fiona
16
18
  import fiona.transform
@@ -405,7 +407,8 @@ class Sentinel2(
405
407
  See https://aws.amazon.com/marketplace/pp/prodview-2ostsvrguftb2 for details about
406
408
  the buckets.
407
409
 
408
- AWS credentials must be configured for use with boto3.
410
+ The buckets were previously requester pays but now appear to allow anonymous/free
411
+ access, so AWS credentials are not needed.
409
412
  """
410
413
 
411
414
  bucket_names = {
@@ -478,7 +481,9 @@ class Sentinel2(
478
481
  self.harmonize = harmonize
479
482
 
480
483
  bucket_name = self.bucket_names[modality]
481
- self.bucket = boto3.resource("s3").Bucket(bucket_name)
484
+ self.bucket = boto3.resource(
485
+ "s3", config=botocore.client.Config(signature_version=botocore.UNSIGNED)
486
+ ).Bucket(bucket_name)
482
487
 
483
488
  def _read_products(
484
489
  self, needed_cell_months: set[tuple[str, int, int]]
@@ -505,15 +510,11 @@ class Sentinel2(
505
510
  )
506
511
 
507
512
  products = []
508
- for obj in self.bucket.objects.filter(
509
- Prefix=prefix, RequestPayer="requester"
510
- ):
513
+ for obj in self.bucket.objects.filter(Prefix=prefix):
511
514
  if not obj.key.endswith("tileInfo.json"):
512
515
  continue
513
516
  buf = io.BytesIO()
514
- self.bucket.download_fileobj(
515
- obj.key, buf, ExtraArgs={"RequestPayer": "requester"}
516
- )
517
+ self.bucket.download_fileobj(obj.key, buf)
517
518
  buf.seek(0)
518
519
  product = json.load(buf)
519
520
  if "tileDataGeometry" not in product:
@@ -649,9 +650,7 @@ class Sentinel2(
649
650
  """Retrieves the rasters corresponding to an item as file streams."""
650
651
  for fname, _ in self.band_fnames[self.modality]:
651
652
  buf = io.BytesIO()
652
- self.bucket.download_fileobj(
653
- item.blob_path + fname, buf, ExtraArgs={"RequestPayer": "requester"}
654
- )
653
+ self.bucket.download_fileobj(item.blob_path + fname, buf)
655
654
  buf.seek(0)
656
655
  yield (fname, buf)
657
656
 
@@ -677,9 +676,7 @@ class Sentinel2(
677
676
  f"products/{ts.year}/{ts.month}/{ts.day}/{item.name}/metadata.xml"
678
677
  )
679
678
  buf = io.BytesIO()
680
- self.bucket.download_fileobj(
681
- metadata_fname, buf, ExtraArgs={"RequestPayer": "requester"}
682
- )
679
+ self.bucket.download_fileobj(metadata_fname, buf)
683
680
  buf.seek(0)
684
681
  tree: ET.ElementTree[ET.Element[str]] = ET.ElementTree(
685
682
  ET.fromstring(buf.getvalue())
@@ -711,7 +708,6 @@ class Sentinel2(
711
708
  self.bucket.download_file(
712
709
  item.blob_path + fname,
713
710
  local_fname,
714
- ExtraArgs={"RequestPayer": "requester"},
715
711
  )
716
712
  except Exception as e:
717
713
  # TODO: sometimes for some reason object doesn't exist
@@ -0,0 +1,374 @@
1
+ """Data source for Sentinel-2 from public AWS bucket maintained by Element 84."""
2
+
3
+ import os
4
+ import tempfile
5
+ from collections.abc import Callable
6
+ from datetime import timedelta
7
+ from typing import Any
8
+
9
+ import affine
10
+ import numpy as np
11
+ import numpy.typing as npt
12
+ import rasterio
13
+ import requests
14
+ from rasterio.enums import Resampling
15
+ from upath import UPath
16
+
17
+ from rslearn.config import LayerConfig
18
+ from rslearn.data_sources.stac import SourceItem, StacDataSource
19
+ from rslearn.dataset import Window
20
+ from rslearn.dataset.manage import RasterMaterializer
21
+ from rslearn.log_utils import get_logger
22
+ from rslearn.tile_stores import TileStore, TileStoreWithLayer
23
+ from rslearn.utils import Projection, STGeometry
24
+ from rslearn.utils.fsspec import join_upath
25
+ from rslearn.utils.geometry import PixelBounds
26
+ from rslearn.utils.raster_format import get_raster_projection_and_bounds
27
+
28
+ from .data_source import (
29
+ DataSourceContext,
30
+ )
31
+
32
+ logger = get_logger(__name__)
33
+
34
+
35
+ class Sentinel2(StacDataSource, TileStore):
36
+ """A data source for Sentinel-2 L2A imagery on AWS from s3://sentinel-cogs.
37
+
38
+ The S3 bucket has COGs so this data source supports direct materialization. It also
39
+ allows anonymous free access, so no credentials are needed.
40
+
41
+ See https://aws.amazon.com/marketplace/pp/prodview-ykj5gyumkzlme for details.
42
+ """
43
+
44
+ STAC_ENDPOINT = "https://earth-search.aws.element84.com/v1"
45
+ COLLECTION_NAME = "sentinel-2-l2a"
46
+ ASSET_BANDS = {
47
+ "coastal": ["B01"],
48
+ "blue": ["B02"],
49
+ "green": ["B03"],
50
+ "red": ["B04"],
51
+ "rededge1": ["B05"],
52
+ "rededge2": ["B06"],
53
+ "rededge3": ["B07"],
54
+ "nir": ["B08"],
55
+ "nir09": ["B09"],
56
+ "swir16": ["B11"],
57
+ "swir22": ["B12"],
58
+ "nir08": ["B8A"],
59
+ "visual": ["R", "G", "B"],
60
+ }
61
+ HARMONIZE_OFFSET = -1000
62
+ HARMONIZE_PROPERTY_NAME = "earthsearch:boa_offset_applied"
63
+
64
+ def __init__(
65
+ self,
66
+ assets: list[str] | None = None,
67
+ query: dict[str, Any] | None = None,
68
+ sort_by: str | None = None,
69
+ sort_ascending: bool = True,
70
+ cache_dir: str | None = None,
71
+ harmonize: bool = False,
72
+ timeout: timedelta = timedelta(seconds=10),
73
+ context: DataSourceContext = DataSourceContext(),
74
+ ) -> None:
75
+ """Initialize a new Sentinel2 instance.
76
+
77
+ Args:
78
+ assets: only ingest these asset names. This is only used if context.layer_config is not set.
79
+ If neither assets nor context.layer_config is set, then all assets are ingested.
80
+ query: optional STAC query filter to use.
81
+ sort_by: STAC item property to sort by. For example, use "eo:cloud_cover" to sort by cloud cover.
82
+ sort_ascending: whether to sort ascending or descending.
83
+ cache_dir: directory to cache discovered items.
84
+ harmonize: harmonize pixel values across different processing baselines,
85
+ see https://developers.google.com/earth-engine/datasets/catalog/COPERNICUS_S2_SR_HARMONIZED
86
+ timeout: timeout to use for requests.
87
+ context: the data source context.
88
+ """ # noqa: E501
89
+ # Determine the cache_upath to use.
90
+ cache_upath: UPath | None = None
91
+ if cache_dir is not None:
92
+ if context.ds_path is not None:
93
+ cache_upath = join_upath(context.ds_path, cache_dir)
94
+ else:
95
+ cache_upath = UPath(cache_dir)
96
+
97
+ cache_upath.mkdir(parents=True, exist_ok=True)
98
+
99
+ # Determine which assets we need based on the bands in the layer config.
100
+ self.asset_bands: dict[str, list[str]]
101
+ if context.layer_config is not None:
102
+ self.asset_bands = {}
103
+ for asset_key, band_names in self.ASSET_BANDS.items():
104
+ # See if the bands provided by this asset intersect with the bands in
105
+ # at least one configured band set.
106
+ for band_set in context.layer_config.band_sets:
107
+ if not set(band_set.bands).intersection(set(band_names)):
108
+ continue
109
+ self.asset_bands[asset_key] = band_names
110
+ break
111
+ elif assets is not None:
112
+ self.asset_bands = {
113
+ asset_key: self.ASSET_BANDS[asset_key] for asset_key in assets
114
+ }
115
+ else:
116
+ self.asset_bands = self.ASSET_BANDS
117
+
118
+ super().__init__(
119
+ endpoint=self.STAC_ENDPOINT,
120
+ collection_name=self.COLLECTION_NAME,
121
+ query=query,
122
+ sort_by=sort_by,
123
+ sort_ascending=sort_ascending,
124
+ required_assets=list(self.asset_bands.keys()),
125
+ cache_dir=cache_upath,
126
+ properties_to_record=[self.HARMONIZE_PROPERTY_NAME],
127
+ )
128
+
129
+ self.harmonize = harmonize
130
+ self.timeout = timeout
131
+
132
+ def _get_harmonize_callback(
133
+ self, item: SourceItem
134
+ ) -> Callable[[npt.NDArray], npt.NDArray] | None:
135
+ """Get the harmonization callback to remove offset for newly processed scenes.
136
+
137
+ We do not use copernicus.get_harmonize_callback here because the S3 bucket does
138
+ not seem to provide the product metadata XML file. So instead we check the
139
+ earthsearch:boa_offset_applied property on the item.
140
+ """
141
+ if not item.properties[self.HARMONIZE_PROPERTY_NAME]:
142
+ # This means no offset was applied so we don't need to subtract it.
143
+ return None
144
+
145
+ def harmonize_callback(array: npt.NDArray) -> npt.NDArray:
146
+ # We assume the offset is -1000 since that is the standard.
147
+ # To work with uint16 array, we clip to 1000+ and then subtract 1000.
148
+ assert array.shape[0] == 1 and array.dtype == np.uint16
149
+ return np.clip(array, -self.HARMONIZE_OFFSET, None) - (
150
+ -self.HARMONIZE_OFFSET
151
+ )
152
+
153
+ return harmonize_callback
154
+
155
+ def ingest(
156
+ self,
157
+ tile_store: TileStoreWithLayer,
158
+ items: list[SourceItem],
159
+ geometries: list[list[STGeometry]],
160
+ ) -> None:
161
+ """Ingest items into the given tile store.
162
+
163
+ Args:
164
+ tile_store: the tile store to ingest into
165
+ items: the items to ingest
166
+ geometries: a list of geometries needed for each item
167
+ """
168
+ for item in items:
169
+ for asset_key, band_names in self.asset_bands.items():
170
+ if asset_key not in item.asset_urls:
171
+ continue
172
+ if tile_store.is_raster_ready(item.name, band_names):
173
+ continue
174
+
175
+ asset_url = item.asset_urls[asset_key]
176
+
177
+ with tempfile.TemporaryDirectory() as tmp_dir:
178
+ local_fname = os.path.join(tmp_dir, f"{asset_key}.tif")
179
+ logger.debug(
180
+ "Download item %s asset %s to %s",
181
+ item.name,
182
+ asset_key,
183
+ local_fname,
184
+ )
185
+ with requests.get(
186
+ asset_url, stream=True, timeout=self.timeout.total_seconds()
187
+ ) as r:
188
+ r.raise_for_status()
189
+ with open(local_fname, "wb") as f:
190
+ for chunk in r.iter_content(chunk_size=8192):
191
+ f.write(chunk)
192
+
193
+ logger.debug(
194
+ "Ingest item %s asset %s",
195
+ item.name,
196
+ asset_key,
197
+ )
198
+
199
+ # Harmonize values if needed.
200
+ # TCI does not need harmonization.
201
+ harmonize_callback = None
202
+ if self.harmonize and asset_key != "visual":
203
+ harmonize_callback = self._get_harmonize_callback(item)
204
+
205
+ if harmonize_callback is not None:
206
+ # In this case we need to read the array, convert the pixel
207
+ # values, and pass modified array directly to the TileStore.
208
+ with rasterio.open(local_fname) as src:
209
+ array = src.read()
210
+ projection, bounds = get_raster_projection_and_bounds(src)
211
+ array = harmonize_callback(array)
212
+ tile_store.write_raster(
213
+ item.name, band_names, projection, bounds, array
214
+ )
215
+
216
+ else:
217
+ tile_store.write_raster_file(
218
+ item.name, band_names, UPath(local_fname)
219
+ )
220
+
221
+ logger.debug(
222
+ "Done ingesting item %s asset %s",
223
+ item.name,
224
+ asset_key,
225
+ )
226
+
227
+ def is_raster_ready(
228
+ self, layer_name: str, item_name: str, bands: list[str]
229
+ ) -> bool:
230
+ """Checks if this raster has been written to the store.
231
+
232
+ Args:
233
+ layer_name: the layer name or alias.
234
+ item_name: the item.
235
+ bands: the list of bands identifying which specific raster to read.
236
+
237
+ Returns:
238
+ whether there is a raster in the store matching the source, item, and
239
+ bands.
240
+ """
241
+ # Always ready since we wrap accesses to underlying API.
242
+ return True
243
+
244
+ def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]:
245
+ """Get the sets of bands that have been stored for the specified item.
246
+
247
+ Args:
248
+ layer_name: the layer name or alias.
249
+ item_name: the item.
250
+
251
+ Returns:
252
+ a list of lists of bands that are in the tile store (with one raster
253
+ stored corresponding to each inner list). If no rasters are ready for
254
+ this item, returns empty list.
255
+ """
256
+ return list(self.asset_bands.values())
257
+
258
+ def _get_asset_by_band(self, bands: list[str]) -> str:
259
+ """Get the name of the asset based on the band names."""
260
+ for asset_key, asset_bands in self.asset_bands.items():
261
+ if bands == asset_bands:
262
+ return asset_key
263
+
264
+ raise ValueError(f"no known asset with bands {bands}")
265
+
266
+ def get_raster_bounds(
267
+ self, layer_name: str, item_name: str, bands: list[str], projection: Projection
268
+ ) -> PixelBounds:
269
+ """Get the bounds of the raster in the specified projection.
270
+
271
+ Args:
272
+ layer_name: the layer name or alias.
273
+ item_name: the item to check.
274
+ bands: the list of bands identifying which specific raster to read. These
275
+ bands must match the bands of a stored raster.
276
+ projection: the projection to get the raster's bounds in.
277
+
278
+ Returns:
279
+ the bounds of the raster in the projection.
280
+ """
281
+ item = self.get_item_by_name(item_name)
282
+ geom = item.geometry.to_projection(projection)
283
+ return (
284
+ int(geom.shp.bounds[0]),
285
+ int(geom.shp.bounds[1]),
286
+ int(geom.shp.bounds[2]),
287
+ int(geom.shp.bounds[3]),
288
+ )
289
+
290
+ def read_raster(
291
+ self,
292
+ layer_name: str,
293
+ item_name: str,
294
+ bands: list[str],
295
+ projection: Projection,
296
+ bounds: PixelBounds,
297
+ resampling: Resampling = Resampling.bilinear,
298
+ ) -> npt.NDArray[Any]:
299
+ """Read raster data from the store.
300
+
301
+ Args:
302
+ layer_name: the layer name or alias.
303
+ item_name: the item to read.
304
+ bands: the list of bands identifying which specific raster to read. These
305
+ bands must match the bands of a stored raster.
306
+ projection: the projection to read in.
307
+ bounds: the bounds to read.
308
+ resampling: the resampling method to use in case reprojection is needed.
309
+
310
+ Returns:
311
+ the raster data
312
+ """
313
+ asset_key = self._get_asset_by_band(bands)
314
+ item = self.get_item_by_name(item_name)
315
+ asset_url = item.asset_urls[asset_key]
316
+
317
+ # Construct the transform to use for the warped dataset.
318
+ wanted_transform = affine.Affine(
319
+ projection.x_resolution,
320
+ 0,
321
+ bounds[0] * projection.x_resolution,
322
+ 0,
323
+ projection.y_resolution,
324
+ bounds[1] * projection.y_resolution,
325
+ )
326
+
327
+ # Read from the raster under the specified projection/bounds.
328
+ with rasterio.open(asset_url) as src:
329
+ with rasterio.vrt.WarpedVRT(
330
+ src,
331
+ crs=projection.crs,
332
+ transform=wanted_transform,
333
+ width=bounds[2] - bounds[0],
334
+ height=bounds[3] - bounds[1],
335
+ resampling=resampling,
336
+ ) as vrt:
337
+ raw_data = vrt.read()
338
+
339
+ # We can return the data now if harmonization is not needed.
340
+ if not self.harmonize or bands == self.ASSET_BANDS["visual"]:
341
+ return raw_data
342
+
343
+ # Otherwise we apply the harmonize_callback.
344
+ item = self.get_item_by_name(item_name)
345
+ harmonize_callback = self._get_harmonize_callback(item)
346
+
347
+ if harmonize_callback is None:
348
+ return raw_data
349
+
350
+ array = harmonize_callback(raw_data)
351
+ return array
352
+
353
+ def materialize(
354
+ self,
355
+ window: Window,
356
+ item_groups: list[list[SourceItem]],
357
+ layer_name: str,
358
+ layer_cfg: LayerConfig,
359
+ ) -> None:
360
+ """Materialize data for the window.
361
+
362
+ Args:
363
+ window: the window to materialize
364
+ item_groups: the items from get_items
365
+ layer_name: the name of this layer
366
+ layer_cfg: the config of this layer
367
+ """
368
+ RasterMaterializer().materialize(
369
+ TileStoreWithLayer(self, layer_name),
370
+ window,
371
+ layer_name,
372
+ layer_cfg,
373
+ item_groups,
374
+ )
@@ -765,6 +765,22 @@ class Sentinel2(DataSource):
765
765
  for item in self._read_bigquery(
766
766
  time_range=geometry.time_range, wgs84_bbox=wgs84_bbox
767
767
  ):
768
+ # Get the item from XML to get its exact geometry (BigQuery only knows
769
+ # the bounding box of the item).
770
+ try:
771
+ item = self.get_item_by_name(item.name)
772
+ except CorruptItemException as e:
773
+ logger.warning("skipping corrupt item %s: %s", item.name, e.message)
774
+ continue
775
+ except MissingXMLException:
776
+ # Sometimes a scene that appears in the BigQuery index does not
777
+ # actually have an XML file on GCS. Since we know this happens
778
+ # occasionally, we ignore the error here.
779
+ logger.warning(
780
+ "skipping item %s that is missing XML file", item.name
781
+ )
782
+ continue
783
+
768
784
  candidates[idx].append(item)
769
785
 
770
786
  return candidates