rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,13 @@
1
1
  """Base classes for rslearn data sources."""
2
2
 
3
3
  from collections.abc import Generator
4
- from typing import Any, BinaryIO
4
+ from typing import Any, BinaryIO, Generic, TypeVar
5
+
6
+ from upath import UPath
5
7
 
6
8
  from rslearn.config import LayerConfig, QueryConfig
7
9
  from rslearn.dataset import Window
8
- from rslearn.tile_stores import TileStore
10
+ from rslearn.tile_stores import TileStoreWithLayer
9
11
  from rslearn.utils import STGeometry
10
12
 
11
13
 
@@ -51,7 +53,10 @@ class Item:
51
53
  return hash(self.name)
52
54
 
53
55
 
54
- class DataSource:
56
+ ItemType = TypeVar("ItemType", bound="Item")
57
+
58
+
59
+ class DataSource(Generic[ItemType]):
55
60
  """A set of raster or vector files that can be retrieved.
56
61
 
57
62
  Data sources should support at least one of ingest and materialize.
@@ -59,7 +64,7 @@ class DataSource:
59
64
 
60
65
  def get_items(
61
66
  self, geometries: list[STGeometry], query_config: QueryConfig
62
- ) -> list[list[list[Item]]]:
67
+ ) -> list[list[list[ItemType]]]:
63
68
  """Get a list of items in the data source intersecting the given geometries.
64
69
 
65
70
  Args:
@@ -71,14 +76,14 @@ class DataSource:
71
76
  """
72
77
  raise NotImplementedError
73
78
 
74
- def deserialize_item(self, serialized_item: Any) -> Item:
79
+ def deserialize_item(self, serialized_item: Any) -> ItemType:
75
80
  """Deserializes an item from JSON-decoded data."""
76
81
  raise NotImplementedError
77
82
 
78
83
  def ingest(
79
84
  self,
80
- tile_store: TileStore,
81
- items: list[Item],
85
+ tile_store: TileStoreWithLayer,
86
+ items: list[ItemType],
82
87
  geometries: list[list[STGeometry]],
83
88
  ) -> None:
84
89
  """Ingest items into the given tile store.
@@ -93,7 +98,7 @@ class DataSource:
93
98
  def materialize(
94
99
  self,
95
100
  window: Window,
96
- item_groups: list[list[Item]],
101
+ item_groups: list[list[ItemType]],
97
102
  layer_name: str,
98
103
  layer_cfg: LayerConfig,
99
104
  ) -> None:
@@ -108,17 +113,43 @@ class DataSource:
108
113
  raise NotImplementedError
109
114
 
110
115
 
111
- class ItemLookupDataSource(DataSource):
116
+ class ItemLookupDataSource(DataSource[ItemType]):
112
117
  """A data source that can look up items by name."""
113
118
 
114
- def get_item_by_name(self, name: str) -> Item:
119
+ def get_item_by_name(self, name: str) -> ItemType:
115
120
  """Gets an item by name."""
116
121
  raise NotImplementedError
117
122
 
118
123
 
119
- class RetrieveItemDataSource(DataSource):
124
+ class RetrieveItemDataSource(DataSource[ItemType]):
120
125
  """A data source that can retrieve items in their raw format."""
121
126
 
122
- def retrieve_item(self, item: Item) -> Generator[tuple[str, BinaryIO], None, None]:
127
+ def retrieve_item(
128
+ self, item: ItemType
129
+ ) -> Generator[tuple[str, BinaryIO], None, None]:
123
130
  """Retrieves the rasters corresponding to an item as file streams."""
124
131
  raise NotImplementedError
132
+
133
+
134
+ class DataSourceContext:
135
+ """This context is passed to every data source.
136
+
137
+ When initializing data sources within rslearn, we always set the ds_path and
138
+ layer_config. However, for convenience (for users directly initializing the data
139
+ sources externally), each data source should allow for initialization when one or
140
+ both are missing.
141
+ """
142
+
143
+ def __init__(
144
+ self, ds_path: UPath | None = None, layer_config: LayerConfig | None = None
145
+ ):
146
+ """Create a new DataSourceContext.
147
+
148
+ Args:
149
+ ds_path: the path of the underlying dataset.
150
+ layer_config: the LayerConfig for the layer that the data source is for.
151
+ """
152
+ # We don't use dataclass here because otherwise jsonargparse will ignore our
153
+ # custom serializer/deserializer defined in rslearn.utils.jsonargparse.
154
+ self.ds_path = ds_path
155
+ self.layer_config = layer_config
@@ -0,0 +1,484 @@
1
+ """Data on EarthDaily."""
2
+
3
+ import json
4
+ import os
5
+ import tempfile
6
+ from datetime import timedelta
7
+ from typing import Any, Literal
8
+
9
+ import affine
10
+ import numpy.typing as npt
11
+ import pystac
12
+ import pystac_client
13
+ import rasterio
14
+ import requests
15
+ import shapely
16
+ from earthdaily import EDSClient, EDSConfig
17
+ from rasterio.enums import Resampling
18
+ from upath import UPath
19
+
20
+ from rslearn.config import LayerConfig, QueryConfig
21
+ from rslearn.const import WGS84_PROJECTION
22
+ from rslearn.data_sources import DataSource, DataSourceContext, Item
23
+ from rslearn.data_sources.utils import match_candidate_items_to_window
24
+ from rslearn.dataset import Window
25
+ from rslearn.dataset.materialize import RasterMaterializer
26
+ from rslearn.log_utils import get_logger
27
+ from rslearn.tile_stores import TileStore, TileStoreWithLayer
28
+ from rslearn.utils.fsspec import join_upath
29
+ from rslearn.utils.geometry import PixelBounds, Projection, STGeometry
30
+
31
+ logger = get_logger(__name__)
32
+
33
+
34
+ class EarthDailyItem(Item):
35
+ """An item in the EarthDaily data source."""
36
+
37
+ def __init__(self, name: str, geometry: STGeometry, asset_urls: dict[str, str]):
38
+ """Creates a new EarthDailyItem.
39
+
40
+ Args:
41
+ name: unique name of the item
42
+ geometry: the spatial and temporal extent of the item
43
+ asset_urls: map from asset key to the asset URL.
44
+ """
45
+ super().__init__(name, geometry)
46
+ self.asset_urls = asset_urls
47
+
48
+ def serialize(self) -> dict[str, Any]:
49
+ """Serializes the item to a JSON-encodable dictionary."""
50
+ d = super().serialize()
51
+ d["asset_urls"] = self.asset_urls
52
+ return d
53
+
54
+ @staticmethod
55
+ def deserialize(d: dict[str, Any]) -> "EarthDailyItem":
56
+ """Deserializes an item from a JSON-decoded dictionary."""
57
+ item = super(EarthDailyItem, EarthDailyItem).deserialize(d)
58
+ return EarthDailyItem(
59
+ name=item.name,
60
+ geometry=item.geometry,
61
+ asset_urls=d["asset_urls"],
62
+ )
63
+
64
+
65
+ class EarthDaily(DataSource, TileStore):
66
+ """A data source for EarthDaily data.
67
+
68
+ This requires the following environment variables to be set:
69
+ - EDS_CLIENT_ID
70
+ - EDS_SECRET
71
+ - EDS_AUTH_URL
72
+ - EDS_API_URL
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ collection_name: str,
78
+ asset_bands: dict[str, list[str]],
79
+ query: dict[str, Any] | None = None,
80
+ sort_by: str | None = None,
81
+ sort_ascending: bool = True,
82
+ timeout: timedelta = timedelta(seconds=10),
83
+ skip_items_missing_assets: bool = False,
84
+ cache_dir: str | None = None,
85
+ max_retries: int = 3,
86
+ retry_backoff_factor: float = 5.0,
87
+ service_name: Literal["platform"] = "platform",
88
+ context: DataSourceContext = DataSourceContext(),
89
+ ):
90
+ """Initialize a new EarthDaily instance.
91
+
92
+ Args:
93
+ collection_name: the STAC collection name on EarthDaily.
94
+ asset_bands: assets to ingest, mapping from asset name to the list of bands
95
+ in that asset.
96
+ query: optional query argument to STAC searches.
97
+ sort_by: sort by this property in the STAC items.
98
+ sort_ascending: whether to sort ascending (or descending).
99
+ timeout: timeout for API requests.
100
+ skip_items_missing_assets: skip STAC items that are missing any of the
101
+ assets in asset_bands during get_items.
102
+ cache_dir: optional directory to cache items by name, including asset URLs.
103
+ If not set, there will be no cache and instead STAC requests will be
104
+ needed each time.
105
+ max_retries: the maximum number of retry attempts for HTTP requests that fail
106
+ due to transient errors (e.g., 429, 500, 502, 503, 504 status codes).
107
+ retry_backoff_factor: backoff factor for exponential retry delays between HTTP
108
+ request attempts. The delay between retries is calculated using the formula:
109
+ `(retry_backoff_factor * (2 ** (retry_count - 1)))` seconds.
110
+ service_name: the service name, only "platform" is supported, the other
111
+ services "legacy" and "internal" are not supported.
112
+ context: the data source context.
113
+ """
114
+ self.collection_name = collection_name
115
+ self.asset_bands = asset_bands
116
+ self.query = query
117
+ self.sort_by = sort_by
118
+ self.sort_ascending = sort_ascending
119
+ self.timeout = timeout
120
+ self.skip_items_missing_assets = skip_items_missing_assets
121
+ self.max_retries = max_retries
122
+ self.retry_backoff_factor = retry_backoff_factor
123
+ self.service_name = service_name
124
+
125
+ if cache_dir is not None:
126
+ # Use dataset path as root if provided.
127
+ if context.ds_path is not None:
128
+ self.cache_dir = join_upath(context.ds_path, cache_dir)
129
+ else:
130
+ self.cache_dir = UPath(cache_dir)
131
+
132
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
133
+ else:
134
+ self.cache_dir = None
135
+
136
+ self.eds_client: EDSClient | None = None
137
+ self.client: pystac_client.Client | None = None
138
+ self.collection: pystac_client.CollectionClient | None = None
139
+
140
+ def _load_client(
141
+ self,
142
+ ) -> tuple[EDSClient, pystac_client.Client, pystac_client.CollectionClient]:
143
+ """Lazily load EDS client.
144
+
145
+ We don't load it when creating the data source because it takes time and caller
146
+ may not be calling get_items. Additionally, loading it during the get_items
147
+ call enables leveraging the retry loop functionality in
148
+ prepare_dataset_windows.
149
+ """
150
+ if self.eds_client is not None:
151
+ return self.eds_client, self.client, self.collection
152
+
153
+ self.eds_client = EDSClient(
154
+ EDSConfig(
155
+ max_retries=self.max_retries,
156
+ retry_backoff_factor=self.retry_backoff_factor,
157
+ )
158
+ )
159
+
160
+ if self.service_name == "platform":
161
+ self.client = self.eds_client.platform.pystac_client
162
+ self.collection = self.client.get_collection(self.collection_name)
163
+ else:
164
+ raise ValueError(f"Invalid service name: {self.service_name}")
165
+
166
+ return self.eds_client, self.client, self.collection
167
+
168
+ def _stac_item_to_item(self, stac_item: pystac.Item) -> EarthDailyItem:
169
+ shp = shapely.geometry.shape(stac_item.geometry)
170
+
171
+ metadata = stac_item.common_metadata
172
+ if metadata.start_datetime is not None and metadata.end_datetime is not None:
173
+ time_range = (
174
+ metadata.start_datetime,
175
+ metadata.end_datetime,
176
+ )
177
+ elif stac_item.datetime is not None:
178
+ time_range = (stac_item.datetime, stac_item.datetime)
179
+ else:
180
+ raise ValueError(
181
+ f"item {stac_item.id} unexpectedly missing start_datetime, end_datetime, and datetime"
182
+ )
183
+
184
+ geom = STGeometry(WGS84_PROJECTION, shp, time_range)
185
+ asset_urls = {
186
+ asset_key: asset_obj.extra_fields["alternate"]["download"]["href"]
187
+ for asset_key, asset_obj in stac_item.assets.items()
188
+ if "alternate" in asset_obj.extra_fields
189
+ and "download" in asset_obj.extra_fields["alternate"]
190
+ and "href" in asset_obj.extra_fields["alternate"]["download"]
191
+ }
192
+ return EarthDailyItem(stac_item.id, geom, asset_urls)
193
+
194
+ def get_item_by_name(self, name: str) -> EarthDailyItem:
195
+ """Gets an item by name.
196
+
197
+ Args:
198
+ name: the name of the item to get
199
+
200
+ Returns:
201
+ the item object
202
+ """
203
+ # If cache_dir is set, we cache the item. First here we check if it is already
204
+ # in the cache.
205
+ cache_fname: UPath | None = None
206
+ if self.cache_dir:
207
+ cache_fname = self.cache_dir / f"{name}.json"
208
+ if cache_fname is not None and cache_fname.exists():
209
+ with cache_fname.open() as f:
210
+ return EarthDailyItem.deserialize(json.load(f))
211
+
212
+ # No cache or not in cache, so we need to make the STAC request.
213
+ _, _, collection = self._load_client()
214
+ stac_item = collection.get_item(name)
215
+ item = self._stac_item_to_item(stac_item)
216
+
217
+ # Finally we cache it if cache_dir is set.
218
+ if cache_fname is not None:
219
+ with cache_fname.open("w") as f:
220
+ json.dump(item.serialize(), f)
221
+
222
+ return item
223
+
224
+ def get_items(
225
+ self, geometries: list[STGeometry], query_config: QueryConfig
226
+ ) -> list[list[list[EarthDailyItem]]]:
227
+ """Get a list of items in the data source intersecting the given geometries.
228
+
229
+ Args:
230
+ geometries: the spatiotemporal geometries
231
+ query_config: the query configuration
232
+ """
233
+ _, client, _ = self._load_client()
234
+
235
+ groups = []
236
+ for geometry in geometries:
237
+ # Get potentially relevant items from the collection by performing one search
238
+ # for each requested geometry.
239
+ wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
240
+ logger.debug("performing STAC search for geometry %s", wgs84_geometry)
241
+ result = client.search(
242
+ collections=[self.collection_name],
243
+ intersects=shapely.to_geojson(wgs84_geometry.shp),
244
+ datetime=wgs84_geometry.time_range,
245
+ query=self.query,
246
+ )
247
+ stac_items = [item for item in result.item_collection()]
248
+ logger.debug("STAC search yielded %d items", len(stac_items))
249
+
250
+ if self.skip_items_missing_assets:
251
+ # Filter out items that are missing any of the assets in self.asset_bands.
252
+ good_stac_items = []
253
+ for stac_item in stac_items:
254
+ good = True
255
+ for asset_key in self.asset_bands.keys():
256
+ if asset_key in stac_item.assets:
257
+ continue
258
+ good = False
259
+ break
260
+ if good:
261
+ good_stac_items.append(stac_item)
262
+ logger.debug(
263
+ "skip_items_missing_assets filter from %d to %d items",
264
+ len(stac_items),
265
+ len(good_stac_items),
266
+ )
267
+ stac_items = good_stac_items
268
+
269
+ if self.sort_by is not None:
270
+ stac_items.sort(
271
+ key=lambda stac_item: stac_item.properties[self.sort_by],
272
+ reverse=not self.sort_ascending,
273
+ )
274
+
275
+ candidate_items = [
276
+ # The only way to get the asset URLs is to get the item by name.
277
+ self.get_item_by_name(stac_item.id)
278
+ for stac_item in stac_items
279
+ ]
280
+
281
+ cur_groups = match_candidate_items_to_window(
282
+ geometry, candidate_items, query_config
283
+ )
284
+ groups.append(cur_groups)
285
+
286
+ return groups
287
+
288
+ def deserialize_item(self, serialized_item: Any) -> EarthDailyItem:
289
+ """Deserializes an item from JSON-decoded data."""
290
+ assert isinstance(serialized_item, dict)
291
+ return EarthDailyItem.deserialize(serialized_item)
292
+
293
+ def ingest(
294
+ self,
295
+ tile_store: TileStoreWithLayer,
296
+ items: list[EarthDailyItem],
297
+ geometries: list[list[STGeometry]],
298
+ ) -> None:
299
+ """Ingest items into the given tile store.
300
+
301
+ Args:
302
+ tile_store: the tile store to ingest into
303
+ items: the items to ingest
304
+ geometries: a list of geometries needed for each item
305
+ """
306
+ for item in items:
307
+ for asset_key, band_names in self.asset_bands.items():
308
+ if asset_key not in item.asset_urls:
309
+ continue
310
+ if tile_store.is_raster_ready(item.name, band_names):
311
+ continue
312
+
313
+ asset_url = item.asset_urls[asset_key]
314
+ with tempfile.TemporaryDirectory() as tmp_dir:
315
+ local_fname = os.path.join(tmp_dir, f"{asset_key}.tif")
316
+ logger.debug(
317
+ "EarthDaily download item %s asset %s to %s",
318
+ item.name,
319
+ asset_key,
320
+ local_fname,
321
+ )
322
+ with requests.get(
323
+ asset_url, stream=True, timeout=self.timeout.total_seconds()
324
+ ) as r:
325
+ r.raise_for_status()
326
+ with open(local_fname, "wb") as f:
327
+ for chunk in r.iter_content(chunk_size=8192):
328
+ f.write(chunk)
329
+
330
+ logger.debug(
331
+ "EarthDaily ingest item %s asset %s",
332
+ item.name,
333
+ asset_key,
334
+ )
335
+ tile_store.write_raster_file(
336
+ item.name, band_names, UPath(local_fname)
337
+ )
338
+
339
+ logger.debug(
340
+ "EarthDaily done ingesting item %s asset %s",
341
+ item.name,
342
+ asset_key,
343
+ )
344
+
345
+ def is_raster_ready(
346
+ self, layer_name: str, item_name: str, bands: list[str]
347
+ ) -> bool:
348
+ """Checks if this raster has been written to the store.
349
+
350
+ Args:
351
+ layer_name: the layer name or alias.
352
+ item_name: the item.
353
+ bands: the list of bands identifying which specific raster to read.
354
+
355
+ Returns:
356
+ whether there is a raster in the store matching the source, item, and
357
+ bands.
358
+ """
359
+ # Always ready since we wrap accesses to EarthDaily.
360
+ return True
361
+
362
+ def get_raster_bands(self, layer_name: str, item_name: str) -> list[list[str]]:
363
+ """Get the sets of bands that have been stored for the specified item.
364
+
365
+ Args:
366
+ layer_name: the layer name or alias.
367
+ item_name: the item.
368
+ """
369
+ if self.skip_items_missing_assets:
370
+ # In this case we can assume that the item has all of the assets.
371
+ return list(self.asset_bands.values())
372
+
373
+ # Otherwise we have to lookup the STAC item to see which assets it has.
374
+ # Here we use get_item_by_name since it handles caching.
375
+ item = self.get_item_by_name(item_name)
376
+ all_bands = []
377
+ for asset_key, band_names in self.asset_bands.items():
378
+ if asset_key not in item.asset_urls:
379
+ continue
380
+ all_bands.append(band_names)
381
+ return all_bands
382
+
383
+ def _get_asset_by_band(self, bands: list[str]) -> str:
384
+ """Get the name of the asset based on the band names."""
385
+ for asset_key, asset_bands in self.asset_bands.items():
386
+ if bands == asset_bands:
387
+ return asset_key
388
+
389
+ raise ValueError(f"no raster with bands {bands}")
390
+
391
+ def get_raster_bounds(
392
+ self, layer_name: str, item_name: str, bands: list[str], projection: Projection
393
+ ) -> PixelBounds:
394
+ """Get the bounds of the raster in the specified projection.
395
+
396
+ Args:
397
+ layer_name: the layer name or alias.
398
+ item_name: the item to check.
399
+ bands: the list of bands identifying which specific raster to read. These
400
+ bands must match the bands of a stored raster.
401
+ projection: the projection to get the raster's bounds in.
402
+
403
+ Returns:
404
+ the bounds of the raster in the projection.
405
+ """
406
+ item = self.get_item_by_name(item_name)
407
+ geom = item.geometry.to_projection(projection)
408
+ return (
409
+ int(geom.shp.bounds[0]),
410
+ int(geom.shp.bounds[1]),
411
+ int(geom.shp.bounds[2]),
412
+ int(geom.shp.bounds[3]),
413
+ )
414
+
415
+ def read_raster(
416
+ self,
417
+ layer_name: str,
418
+ item_name: str,
419
+ bands: list[str],
420
+ projection: Projection,
421
+ bounds: PixelBounds,
422
+ resampling: Resampling = Resampling.bilinear,
423
+ ) -> npt.NDArray[Any]:
424
+ """Read raster data from the store.
425
+
426
+ Args:
427
+ layer_name: the layer name or alias.
428
+ item_name: the item to read.
429
+ bands: the list of bands identifying which specific raster to read. These
430
+ bands must match the bands of a stored raster.
431
+ projection: the projection to read in.
432
+ bounds: the bounds to read.
433
+ resampling: the resampling method to use in case reprojection is needed.
434
+
435
+ Returns:
436
+ the raster data
437
+ """
438
+ asset_key = self._get_asset_by_band(bands)
439
+ item = self.get_item_by_name(item_name)
440
+ asset_url = item.asset_urls[asset_key]
441
+
442
+ # Construct the transform to use for the warped dataset.
443
+ wanted_transform = affine.Affine(
444
+ projection.x_resolution,
445
+ 0,
446
+ bounds[0] * projection.x_resolution,
447
+ 0,
448
+ projection.y_resolution,
449
+ bounds[1] * projection.y_resolution,
450
+ )
451
+
452
+ with rasterio.open(asset_url) as src:
453
+ with rasterio.vrt.WarpedVRT(
454
+ src,
455
+ crs=projection.crs,
456
+ transform=wanted_transform,
457
+ width=bounds[2] - bounds[0],
458
+ height=bounds[3] - bounds[1],
459
+ resampling=resampling,
460
+ ) as vrt:
461
+ return vrt.read()
462
+
463
+ def materialize(
464
+ self,
465
+ window: Window,
466
+ item_groups: list[list[Item]],
467
+ layer_name: str,
468
+ layer_cfg: LayerConfig,
469
+ ) -> None:
470
+ """Materialize data for the window.
471
+
472
+ Args:
473
+ window: the window to materialize
474
+ item_groups: the items from get_items
475
+ layer_name: the name of this layer
476
+ layer_cfg: the config of this layer
477
+ """
478
+ RasterMaterializer().materialize(
479
+ TileStoreWithLayer(self, layer_name),
480
+ window,
481
+ layer_name,
482
+ layer_cfg,
483
+ item_groups,
484
+ )