rslearn 0.0.26__py3-none-any.whl → 0.0.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. rslearn/data_sources/__init__.py +2 -0
  2. rslearn/data_sources/aws_landsat.py +44 -161
  3. rslearn/data_sources/aws_open_data.py +2 -4
  4. rslearn/data_sources/aws_sentinel1.py +1 -3
  5. rslearn/data_sources/aws_sentinel2_element84.py +54 -165
  6. rslearn/data_sources/climate_data_store.py +1 -3
  7. rslearn/data_sources/copernicus.py +1 -2
  8. rslearn/data_sources/data_source.py +1 -1
  9. rslearn/data_sources/direct_materialize_data_source.py +336 -0
  10. rslearn/data_sources/earthdaily.py +52 -155
  11. rslearn/data_sources/earthdatahub.py +425 -0
  12. rslearn/data_sources/eurocrops.py +1 -2
  13. rslearn/data_sources/gcp_public_data.py +1 -2
  14. rslearn/data_sources/google_earth_engine.py +1 -2
  15. rslearn/data_sources/hf_srtm.py +595 -0
  16. rslearn/data_sources/local_files.py +1 -1
  17. rslearn/data_sources/openstreetmap.py +1 -1
  18. rslearn/data_sources/planet.py +1 -2
  19. rslearn/data_sources/planet_basemap.py +1 -2
  20. rslearn/data_sources/planetary_computer.py +183 -186
  21. rslearn/data_sources/soilgrids.py +3 -3
  22. rslearn/data_sources/stac.py +1 -2
  23. rslearn/data_sources/usda_cdl.py +1 -3
  24. rslearn/data_sources/usgs_landsat.py +7 -254
  25. rslearn/data_sources/worldcereal.py +1 -1
  26. rslearn/data_sources/worldcover.py +1 -1
  27. rslearn/data_sources/worldpop.py +1 -1
  28. rslearn/data_sources/xyz_tiles.py +5 -9
  29. rslearn/dataset/storage/file.py +16 -12
  30. rslearn/models/concatenate_features.py +6 -1
  31. rslearn/tile_stores/default.py +4 -2
  32. rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
  33. rslearn/train/data_module.py +36 -33
  34. rslearn/train/dataset.py +159 -68
  35. rslearn/train/lightning_module.py +60 -4
  36. rslearn/train/metrics.py +162 -0
  37. rslearn/train/model_context.py +3 -3
  38. rslearn/train/prediction_writer.py +69 -41
  39. rslearn/train/tasks/classification.py +14 -1
  40. rslearn/train/tasks/detection.py +5 -5
  41. rslearn/train/tasks/per_pixel_regression.py +19 -6
  42. rslearn/train/tasks/regression.py +19 -3
  43. rslearn/train/tasks/segmentation.py +17 -0
  44. rslearn/utils/__init__.py +2 -0
  45. rslearn/utils/fsspec.py +51 -1
  46. rslearn/utils/geometry.py +21 -0
  47. rslearn/utils/m2m_api.py +251 -0
  48. rslearn/utils/retry_session.py +43 -0
  49. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/METADATA +6 -3
  50. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/RECORD +55 -50
  51. rslearn/data_sources/earthdata_srtm.py +0 -282
  52. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/WHEEL +0 -0
  53. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/entry_points.txt +0 -0
  54. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/licenses/LICENSE +0 -0
  55. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/licenses/NOTICE +0 -0
  56. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ from .data_source import (
17
17
  ItemLookupDataSource,
18
18
  RetrieveItemDataSource,
19
19
  )
20
+ from .direct_materialize_data_source import DirectMaterializeDataSource
20
21
 
21
22
  __all__ = (
22
23
  "DataSource",
@@ -24,5 +25,6 @@ __all__ = (
24
25
  "Item",
25
26
  "ItemLookupDataSource",
26
27
  "RetrieveItemDataSource",
28
+ "DirectMaterializeDataSource",
27
29
  "data_source_from_config",
28
30
  )
@@ -9,32 +9,28 @@ import urllib.request
9
9
  import zipfile
10
10
  from collections.abc import Generator
11
11
  from datetime import datetime
12
- from typing import Any, BinaryIO
12
+ from typing import BinaryIO
13
13
 
14
- import affine
15
14
  import boto3
16
15
  import dateutil.parser
17
16
  import fiona
18
17
  import fiona.transform
19
- import numpy.typing as npt
20
- import rasterio
21
18
  import shapely
22
19
  import shapely.geometry
23
20
  import tqdm
24
- from rasterio.enums import Resampling
25
21
  from upath import UPath
26
22
 
27
23
  import rslearn.data_sources.utils
28
- from rslearn.config import LayerConfig
29
24
  from rslearn.const import SHAPEFILE_AUX_EXTENSIONS, WGS84_PROJECTION
30
- from rslearn.dataset import Window
31
- from rslearn.dataset.materialize import RasterMaterializer
32
- from rslearn.tile_stores import TileStore, TileStoreWithLayer
25
+ from rslearn.data_sources.direct_materialize_data_source import (
26
+ DirectMaterializeDataSource,
27
+ )
28
+ from rslearn.tile_stores import TileStoreWithLayer
33
29
  from rslearn.utils.fsspec import get_upath_local, join_upath, open_atomic
34
- from rslearn.utils.geometry import PixelBounds, Projection, STGeometry
30
+ from rslearn.utils.geometry import STGeometry
35
31
  from rslearn.utils.grid_index import GridIndex
36
32
 
37
- from .data_source import DataSource, DataSourceContext, Item, QueryConfig
33
+ from .data_source import DataSourceContext, Item, QueryConfig
38
34
 
39
35
  WRS2_GRID_SIZE = 1.0
40
36
 
@@ -79,7 +75,7 @@ class LandsatOliTirsItem(Item):
79
75
  )
80
76
 
81
77
 
82
- class LandsatOliTirs(DataSource, TileStore):
78
+ class LandsatOliTirs(DirectMaterializeDataSource[LandsatOliTirsItem]):
83
79
  """A data source for Landsat 8/9 OLI-TIRS imagery on AWS.
84
80
 
85
81
  Specifically, uses the usgs-landsat S3 bucket maintained by USGS. The data includes
@@ -91,7 +87,7 @@ class LandsatOliTirs(DataSource, TileStore):
91
87
 
92
88
  bucket_name = "usgs-landsat"
93
89
  bucket_prefix = "collection02/level-1/standard/oli-tirs"
94
- bands = ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B9", "B10", "B11"]
90
+ BANDS = ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B9", "B10", "B11"]
95
91
 
96
92
  wrs2_url = "https://d9-wret.s3.us-west-2.amazonaws.com/assets/palladium/production/s3fs-public/atoms/files/WRS2_descending_0.zip" # noqa
97
93
  """URL to download shapefile specifying polygon of each (path, row)."""
@@ -110,6 +106,10 @@ class LandsatOliTirs(DataSource, TileStore):
110
106
  SpaceMode.WITHIN.
111
107
  context: the data source context.
112
108
  """
109
+ # Each band is a separate single-band asset.
110
+ asset_bands = {band: [band] for band in self.BANDS}
111
+ super().__init__(asset_bands=asset_bands)
112
+
113
113
  # If context is provided, we join the directory with the dataset path,
114
114
  # otherwise we treat it directly as UPath.
115
115
  if context.ds_path is not None:
@@ -342,16 +342,43 @@ class LandsatOliTirs(DataSource, TileStore):
342
342
  return item
343
343
  raise ValueError(f"item {name} not found")
344
344
 
345
- def deserialize_item(self, serialized_item: Any) -> LandsatOliTirsItem:
345
+ # --- DirectMaterializeDataSource implementation ---
346
+
347
+ def get_asset_url(self, item_name: str, asset_key: str) -> str:
348
+ """Get the presigned URL to read the asset for the given item and asset key.
349
+
350
+ Args:
351
+ item_name: the name of the item.
352
+ asset_key: the key identifying which asset to get (the band name).
353
+
354
+ Returns:
355
+ the presigned URL to read the asset from.
356
+ """
357
+ # Get the item since it has the blob path.
358
+ item = self.get_item_by_name(item_name)
359
+
360
+ # For Landsat, the asset_key is the band name (e.g., "B1", "B2", etc.).
361
+ blob_key = item.blob_path + f"{asset_key}.TIF"
362
+ return self.client.generate_presigned_url(
363
+ "get_object",
364
+ Params={
365
+ "Bucket": self.bucket_name,
366
+ "Key": blob_key,
367
+ "RequestPayer": "requester",
368
+ },
369
+ )
370
+
371
+ # --- DataSource implementation ---
372
+
373
+ def deserialize_item(self, serialized_item: dict) -> LandsatOliTirsItem:
346
374
  """Deserializes an item from JSON-decoded data."""
347
- assert isinstance(serialized_item, dict)
348
375
  return LandsatOliTirsItem.deserialize(serialized_item)
349
376
 
350
377
  def retrieve_item(
351
378
  self, item: LandsatOliTirsItem
352
379
  ) -> Generator[tuple[str, BinaryIO], None, None]:
353
380
  """Retrieves the rasters corresponding to an item as file streams."""
354
- for band in self.bands:
381
+ for band in self.BANDS:
355
382
  buf = io.BytesIO()
356
383
  self.bucket.download_fileobj(
357
384
  item.blob_path + f"{band}.TIF",
@@ -376,7 +403,7 @@ class LandsatOliTirs(DataSource, TileStore):
376
403
  geometries: a list of geometries needed for each item
377
404
  """
378
405
  for item, cur_geometries in zip(items, geometries):
379
- for band in self.bands:
406
+ for band in self.BANDS:
380
407
  band_names = [band]
381
408
  if tile_store.is_raster_ready(item.name, band_names):
382
409
  continue
@@ -389,147 +416,3 @@ class LandsatOliTirs(DataSource, TileStore):
389
416
  ExtraArgs={"RequestPayer": "requester"},
390
417
  )
391
418
  tile_store.write_raster_file(item.name, band_names, UPath(fname))
392
-
393
- # The functions below are to emulate TileStore functionality so we can easily
394
- # support materialization directly from the COGs.
395
- def is_raster_ready(
396
- self, layer_name: str, item_name: str, bands: list[str]
397
- ) -> bool:
398
- """Checks if this raster has been written to the store.
399
-
400
- Args:
401
- layer_name: the layer name or alias.
402
- item_name: the item.
403
- bands: the list of bands identifying which specific raster to read.
404
-
405
- Returns:
406
- whether there is a raster in the store matching the source, item, and
407
- bands.
408
- """
409
- # Always ready since we access it on AWS bucket.
410
- return True
411
-
412
- def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]:
413
- """Get the sets of bands that have been stored for the specified item.
414
-
415
- Args:
416
- layer_name: the layer name or alias.
417
- item_name: the item.
418
-
419
- Returns:
420
- a list of lists of bands that are in the tile store (with one raster
421
- stored corresponding to each inner list). If no rasters are ready for
422
- this item, returns empty list.
423
- """
424
- return [[band] for band in self.bands]
425
-
426
- def get_raster_bounds(
427
- self, layer_name: str, item_name: str, bands: list[str], projection: Projection
428
- ) -> PixelBounds:
429
- """Get the bounds of the raster in the specified projection.
430
-
431
- Args:
432
- layer_name: the layer name or alias.
433
- item_name: the item to check.
434
- bands: the list of bands identifying which specific raster to read. These
435
- bands must match the bands of a stored raster.
436
- projection: the projection to get the raster's bounds in.
437
-
438
- Returns:
439
- the bounds of the raster in the projection.
440
- """
441
- item = self.get_item_by_name(item_name)
442
- geom = item.geometry.to_projection(projection)
443
- return (
444
- int(geom.shp.bounds[0]),
445
- int(geom.shp.bounds[1]),
446
- int(geom.shp.bounds[2]),
447
- int(geom.shp.bounds[3]),
448
- )
449
-
450
- def read_raster(
451
- self,
452
- layer_name: str,
453
- item_name: str,
454
- bands: list[str],
455
- projection: Projection,
456
- bounds: PixelBounds,
457
- resampling: Resampling = Resampling.bilinear,
458
- ) -> npt.NDArray[Any]:
459
- """Read raster data from the store.
460
-
461
- Args:
462
- layer_name: the layer name or alias.
463
- item_name: the item to read.
464
- bands: the list of bands identifying which specific raster to read. These
465
- bands must match the bands of a stored raster.
466
- projection: the projection to read in.
467
- bounds: the bounds to read.
468
- resampling: the resampling method to use in case reprojection is needed.
469
-
470
- Returns:
471
- the raster data
472
- """
473
- # Landsat assets have single band per asset.
474
- assert len(bands) == 1
475
- band = bands[0]
476
-
477
- # Get the item since it has the blob path.
478
- item = self.get_item_by_name(item_name)
479
-
480
- # Create pre-signed URL for rasterio access.
481
- # We do this because accessing via URL is much faster since rasterio can use
482
- # the URL directly.
483
- blob_key = item.blob_path + f"{band}.TIF"
484
- url = self.client.generate_presigned_url(
485
- "get_object",
486
- Params={
487
- "Bucket": self.bucket_name,
488
- "Key": blob_key,
489
- "RequestPayer": "requester",
490
- },
491
- )
492
-
493
- # Construct the transform to use for the warped dataset.
494
- wanted_transform = affine.Affine(
495
- projection.x_resolution,
496
- 0,
497
- bounds[0] * projection.x_resolution,
498
- 0,
499
- projection.y_resolution,
500
- bounds[1] * projection.y_resolution,
501
- )
502
-
503
- with rasterio.open(url) as src:
504
- with rasterio.vrt.WarpedVRT(
505
- src,
506
- crs=projection.crs,
507
- transform=wanted_transform,
508
- width=bounds[2] - bounds[0],
509
- height=bounds[3] - bounds[1],
510
- resampling=resampling,
511
- ) as vrt:
512
- return vrt.read()
513
-
514
- def materialize(
515
- self,
516
- window: Window,
517
- item_groups: list[list[LandsatOliTirsItem]],
518
- layer_name: str,
519
- layer_cfg: LayerConfig,
520
- ) -> None:
521
- """Materialize data for the window.
522
-
523
- Args:
524
- window: the window to materialize
525
- item_groups: the items from get_items
526
- layer_name: the name of this layer
527
- layer_cfg: the config of this layer
528
- """
529
- RasterMaterializer().materialize(
530
- TileStoreWithLayer(self, layer_name),
531
- window,
532
- layer_name,
533
- layer_cfg,
534
- item_groups,
535
- )
@@ -318,9 +318,8 @@ class Naip(DataSource):
318
318
  groups.append(cur_groups)
319
319
  return groups
320
320
 
321
- def deserialize_item(self, serialized_item: Any) -> NaipItem:
321
+ def deserialize_item(self, serialized_item: dict) -> NaipItem:
322
322
  """Deserializes an item from JSON-decoded data."""
323
- assert isinstance(serialized_item, dict)
324
323
  return NaipItem.deserialize(serialized_item)
325
324
 
326
325
  def ingest(
@@ -639,9 +638,8 @@ class Sentinel2(
639
638
  return item
640
639
  raise ValueError(f"item {name} not found")
641
640
 
642
- def deserialize_item(self, serialized_item: Any) -> Sentinel2Item:
641
+ def deserialize_item(self, serialized_item: dict) -> Sentinel2Item:
643
642
  """Deserializes an item from JSON-decoded data."""
644
- assert isinstance(serialized_item, dict)
645
643
  return Sentinel2Item.deserialize(serialized_item)
646
644
 
647
645
  def retrieve_item(
@@ -2,7 +2,6 @@
2
2
 
3
3
  import os
4
4
  import tempfile
5
- from typing import Any
6
5
 
7
6
  import boto3
8
7
  from upath import UPath
@@ -78,9 +77,8 @@ class Sentinel1(DataSource, TileStore):
78
77
  """Gets an item by name."""
79
78
  return self.sentinel1.get_item_by_name(name)
80
79
 
81
- def deserialize_item(self, serialized_item: Any) -> CopernicusItem:
80
+ def deserialize_item(self, serialized_item: dict) -> CopernicusItem:
82
81
  """Deserializes an item from JSON-decoded data."""
83
- assert isinstance(serialized_item, dict)
84
82
  return CopernicusItem.deserialize(serialized_item)
85
83
 
86
84
  def ingest(
@@ -6,23 +6,20 @@ from collections.abc import Callable
6
6
  from datetime import timedelta
7
7
  from typing import Any
8
8
 
9
- import affine
10
9
  import numpy as np
11
10
  import numpy.typing as npt
12
11
  import rasterio
13
12
  import requests
14
- from rasterio.enums import Resampling
15
13
  from upath import UPath
16
14
 
17
- from rslearn.config import LayerConfig
15
+ from rslearn.data_sources.direct_materialize_data_source import (
16
+ DirectMaterializeDataSource,
17
+ )
18
18
  from rslearn.data_sources.stac import SourceItem, StacDataSource
19
- from rslearn.dataset import Window
20
- from rslearn.dataset.manage import RasterMaterializer
21
19
  from rslearn.log_utils import get_logger
22
- from rslearn.tile_stores import TileStore, TileStoreWithLayer
23
- from rslearn.utils import Projection, STGeometry
20
+ from rslearn.tile_stores import TileStoreWithLayer
21
+ from rslearn.utils import STGeometry
24
22
  from rslearn.utils.fsspec import join_upath
25
- from rslearn.utils.geometry import PixelBounds
26
23
  from rslearn.utils.raster_format import get_raster_projection_and_bounds
27
24
 
28
25
  from .data_source import (
@@ -32,7 +29,7 @@ from .data_source import (
32
29
  logger = get_logger(__name__)
33
30
 
34
31
 
35
- class Sentinel2(StacDataSource, TileStore):
32
+ class Sentinel2(DirectMaterializeDataSource[SourceItem], StacDataSource):
36
33
  """A data source for Sentinel-2 L2A imagery on AWS from s3://sentinel-cogs.
37
34
 
38
35
  The S3 bucket has COGs so this data source supports direct materialization. It also
@@ -97,31 +94,36 @@ class Sentinel2(StacDataSource, TileStore):
97
94
  cache_upath.mkdir(parents=True, exist_ok=True)
98
95
 
99
96
  # Determine which assets we need based on the bands in the layer config.
100
- self.asset_bands: dict[str, list[str]]
97
+ asset_bands: dict[str, list[str]]
101
98
  if context.layer_config is not None:
102
- self.asset_bands = {}
99
+ asset_bands = {}
103
100
  for asset_key, band_names in self.ASSET_BANDS.items():
104
101
  # See if the bands provided by this asset intersect with the bands in
105
102
  # at least one configured band set.
106
103
  for band_set in context.layer_config.band_sets:
107
104
  if not set(band_set.bands).intersection(set(band_names)):
108
105
  continue
109
- self.asset_bands[asset_key] = band_names
106
+ asset_bands[asset_key] = band_names
110
107
  break
111
108
  elif assets is not None:
112
- self.asset_bands = {
109
+ asset_bands = {
113
110
  asset_key: self.ASSET_BANDS[asset_key] for asset_key in assets
114
111
  }
115
112
  else:
116
- self.asset_bands = self.ASSET_BANDS
113
+ asset_bands = dict(self.ASSET_BANDS)
114
+
115
+ # Initialize DirectMaterializeDataSource with asset_bands
116
+ DirectMaterializeDataSource.__init__(self, asset_bands=asset_bands)
117
117
 
118
- super().__init__(
118
+ # Initialize StacDataSource
119
+ StacDataSource.__init__(
120
+ self,
119
121
  endpoint=self.STAC_ENDPOINT,
120
122
  collection_name=self.COLLECTION_NAME,
121
123
  query=query,
122
124
  sort_by=sort_by,
123
125
  sort_ascending=sort_ascending,
124
- required_assets=list(self.asset_bands.keys()),
126
+ required_assets=list(asset_bands.keys()),
125
127
  cache_dir=cache_upath,
126
128
  properties_to_record=[self.HARMONIZE_PROPERTY_NAME],
127
129
  )
@@ -129,6 +131,42 @@ class Sentinel2(StacDataSource, TileStore):
129
131
  self.harmonize = harmonize
130
132
  self.timeout = timeout
131
133
 
134
+ # --- DirectMaterializeDataSource implementation ---
135
+
136
+ def get_asset_url(self, item_name: str, asset_key: str) -> str:
137
+ """Get the URL to read the asset for the given item and asset key.
138
+
139
+ Args:
140
+ item_name: the name of the item.
141
+ asset_key: the key identifying which asset to get.
142
+
143
+ Returns:
144
+ the URL to read the asset from.
145
+ """
146
+ item = self.get_item_by_name(item_name)
147
+ return item.asset_urls[asset_key]
148
+
149
+ def get_read_callback(
150
+ self, item_name: str, asset_key: str
151
+ ) -> Callable[[npt.NDArray[Any]], npt.NDArray[Any]] | None:
152
+ """Return a callback to harmonize Sentinel-2 data if needed.
153
+
154
+ Args:
155
+ item_name: the name of the item being read.
156
+ asset_key: the key identifying which asset is being read.
157
+
158
+ Returns:
159
+ A callback function for harmonization, or None if not needed.
160
+ """
161
+ # Visual bands do not need harmonization.
162
+ if not self.harmonize or asset_key == "visual":
163
+ return None
164
+
165
+ item = self.get_item_by_name(item_name)
166
+ return self._get_harmonize_callback(item)
167
+
168
+ # --- Harmonization helpers ---
169
+
132
170
  def _get_harmonize_callback(
133
171
  self, item: SourceItem
134
172
  ) -> Callable[[npt.NDArray], npt.NDArray] | None:
@@ -223,152 +261,3 @@ class Sentinel2(StacDataSource, TileStore):
223
261
  item.name,
224
262
  asset_key,
225
263
  )
226
-
227
- def is_raster_ready(
228
- self, layer_name: str, item_name: str, bands: list[str]
229
- ) -> bool:
230
- """Checks if this raster has been written to the store.
231
-
232
- Args:
233
- layer_name: the layer name or alias.
234
- item_name: the item.
235
- bands: the list of bands identifying which specific raster to read.
236
-
237
- Returns:
238
- whether there is a raster in the store matching the source, item, and
239
- bands.
240
- """
241
- # Always ready since we wrap accesses to underlying API.
242
- return True
243
-
244
- def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]:
245
- """Get the sets of bands that have been stored for the specified item.
246
-
247
- Args:
248
- layer_name: the layer name or alias.
249
- item_name: the item.
250
-
251
- Returns:
252
- a list of lists of bands that are in the tile store (with one raster
253
- stored corresponding to each inner list). If no rasters are ready for
254
- this item, returns empty list.
255
- """
256
- return list(self.asset_bands.values())
257
-
258
- def _get_asset_by_band(self, bands: list[str]) -> str:
259
- """Get the name of the asset based on the band names."""
260
- for asset_key, asset_bands in self.asset_bands.items():
261
- if bands == asset_bands:
262
- return asset_key
263
-
264
- raise ValueError(f"no known asset with bands {bands}")
265
-
266
- def get_raster_bounds(
267
- self, layer_name: str, item_name: str, bands: list[str], projection: Projection
268
- ) -> PixelBounds:
269
- """Get the bounds of the raster in the specified projection.
270
-
271
- Args:
272
- layer_name: the layer name or alias.
273
- item_name: the item to check.
274
- bands: the list of bands identifying which specific raster to read. These
275
- bands must match the bands of a stored raster.
276
- projection: the projection to get the raster's bounds in.
277
-
278
- Returns:
279
- the bounds of the raster in the projection.
280
- """
281
- item = self.get_item_by_name(item_name)
282
- geom = item.geometry.to_projection(projection)
283
- return (
284
- int(geom.shp.bounds[0]),
285
- int(geom.shp.bounds[1]),
286
- int(geom.shp.bounds[2]),
287
- int(geom.shp.bounds[3]),
288
- )
289
-
290
- def read_raster(
291
- self,
292
- layer_name: str,
293
- item_name: str,
294
- bands: list[str],
295
- projection: Projection,
296
- bounds: PixelBounds,
297
- resampling: Resampling = Resampling.bilinear,
298
- ) -> npt.NDArray[Any]:
299
- """Read raster data from the store.
300
-
301
- Args:
302
- layer_name: the layer name or alias.
303
- item_name: the item to read.
304
- bands: the list of bands identifying which specific raster to read. These
305
- bands must match the bands of a stored raster.
306
- projection: the projection to read in.
307
- bounds: the bounds to read.
308
- resampling: the resampling method to use in case reprojection is needed.
309
-
310
- Returns:
311
- the raster data
312
- """
313
- asset_key = self._get_asset_by_band(bands)
314
- item = self.get_item_by_name(item_name)
315
- asset_url = item.asset_urls[asset_key]
316
-
317
- # Construct the transform to use for the warped dataset.
318
- wanted_transform = affine.Affine(
319
- projection.x_resolution,
320
- 0,
321
- bounds[0] * projection.x_resolution,
322
- 0,
323
- projection.y_resolution,
324
- bounds[1] * projection.y_resolution,
325
- )
326
-
327
- # Read from the raster under the specified projection/bounds.
328
- with rasterio.open(asset_url) as src:
329
- with rasterio.vrt.WarpedVRT(
330
- src,
331
- crs=projection.crs,
332
- transform=wanted_transform,
333
- width=bounds[2] - bounds[0],
334
- height=bounds[3] - bounds[1],
335
- resampling=resampling,
336
- ) as vrt:
337
- raw_data = vrt.read()
338
-
339
- # We can return the data now if harmonization is not needed.
340
- if not self.harmonize or bands == self.ASSET_BANDS["visual"]:
341
- return raw_data
342
-
343
- # Otherwise we apply the harmonize_callback.
344
- item = self.get_item_by_name(item_name)
345
- harmonize_callback = self._get_harmonize_callback(item)
346
-
347
- if harmonize_callback is None:
348
- return raw_data
349
-
350
- array = harmonize_callback(raw_data)
351
- return array
352
-
353
- def materialize(
354
- self,
355
- window: Window,
356
- item_groups: list[list[SourceItem]],
357
- layer_name: str,
358
- layer_cfg: LayerConfig,
359
- ) -> None:
360
- """Materialize data for the window.
361
-
362
- Args:
363
- window: the window to materialize
364
- item_groups: the items from get_items
365
- layer_name: the name of this layer
366
- layer_cfg: the config of this layer
367
- """
368
- RasterMaterializer().materialize(
369
- TileStoreWithLayer(self, layer_name),
370
- window,
371
- layer_name,
372
- layer_cfg,
373
- item_groups,
374
- )
@@ -3,7 +3,6 @@
3
3
  import os
4
4
  import tempfile
5
5
  from datetime import UTC, datetime
6
- from typing import Any
7
6
 
8
7
  import cdsapi
9
8
  import netCDF4
@@ -160,9 +159,8 @@ class ERA5Land(DataSource):
160
159
 
161
160
  return all_groups
162
161
 
163
- def deserialize_item(self, serialized_item: Any) -> Item:
162
+ def deserialize_item(self, serialized_item: dict) -> Item:
164
163
  """Deserializes an item from JSON-decoded data."""
165
- assert isinstance(serialized_item, dict)
166
164
  return Item.deserialize(serialized_item)
167
165
 
168
166
  def _convert_nc_to_tif(self, nc_path: UPath, tif_path: UPath) -> None:
@@ -353,9 +353,8 @@ class Copernicus(DataSource):
353
353
  self.username = os.environ["COPERNICUS_USERNAME"]
354
354
  self.password = os.environ["COPERNICUS_PASSWORD"]
355
355
 
356
- def deserialize_item(self, serialized_item: Any) -> CopernicusItem:
356
+ def deserialize_item(self, serialized_item: dict) -> CopernicusItem:
357
357
  """Deserializes an item from JSON-decoded data."""
358
- assert isinstance(serialized_item, dict)
359
358
  return CopernicusItem.deserialize(serialized_item)
360
359
 
361
360
  def _get(self, path: str) -> dict[str, Any]:
@@ -76,7 +76,7 @@ class DataSource(Generic[ItemType]):
76
76
  """
77
77
  raise NotImplementedError
78
78
 
79
- def deserialize_item(self, serialized_item: Any) -> ItemType:
79
+ def deserialize_item(self, serialized_item: dict) -> ItemType:
80
80
  """Deserializes an item from JSON-decoded data."""
81
81
  raise NotImplementedError
82
82