rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,52 @@
1
1
  """Data source for raster data in ESA Copernicus API."""
2
2
 
3
+ import functools
4
+ import io
5
+ import json
6
+ import os
7
+ import pathlib
8
+ import shutil
9
+ import tempfile
10
+ import urllib.request
3
11
  import xml.etree.ElementTree as ET
4
12
  from collections.abc import Callable
13
+ from datetime import datetime
14
+ from enum import Enum
15
+ from typing import Any
16
+ from urllib.parse import quote
17
+ from zipfile import ZipFile
5
18
 
6
19
  import numpy as np
7
20
  import numpy.typing as npt
21
+ import rasterio
22
+ import requests
23
+ import shapely
24
+ from upath import UPath
25
+
26
+ from rslearn.config import QueryConfig
27
+ from rslearn.const import WGS84_PROJECTION
28
+ from rslearn.data_sources.data_source import DataSource, DataSourceContext, Item
29
+ from rslearn.data_sources.utils import match_candidate_items_to_window
30
+ from rslearn.log_utils import get_logger
31
+ from rslearn.tile_stores import TileStoreWithLayer
32
+ from rslearn.utils.fsspec import open_atomic
33
+ from rslearn.utils.geometry import (
34
+ FloatBounds,
35
+ STGeometry,
36
+ flatten_shape,
37
+ split_shape_at_antimeridian,
38
+ )
39
+ from rslearn.utils.grid_index import GridIndex
40
+ from rslearn.utils.raster_format import get_raster_projection_and_bounds
41
+
42
+ SENTINEL2_TILE_URL = "https://sentiwiki.copernicus.eu/__attachments/1692737/S2A_OPER_GIP_TILPAR_MPC__20151209T095117_V20150622T000000_21000101T000000_B00.zip"
43
+ SENTINEL2_KML_NAMESPACE = "{http://www.opengis.net/kml/2.2}"
44
+
45
+ logger = get_logger(__name__)
8
46
 
9
47
 
10
48
  def get_harmonize_callback(
11
- tree: ET.ElementTree,
49
+ tree: "ET.ElementTree[ET.Element[str]] | ET.Element[str]",
12
50
  ) -> Callable[[npt.NDArray], npt.NDArray] | None:
13
51
  """Gets the harmonization callback based on the metadata XML.
14
52
 
@@ -23,20 +61,854 @@ def get_harmonize_callback(
23
61
  None if no callback is needed, or the callback to subtract the new offset
24
62
  """
25
63
  offset = None
26
- for el in tree.iter("RADIO_ADD_OFFSET"):
27
- value = int(el.text)
28
- if offset is None:
29
- offset = value
30
- assert offset <= 0
31
- # For now assert the offset is always -1000.
32
- assert offset == -1000
33
- else:
34
- assert offset == value
64
+
65
+ # The metadata will use different tag for L1C / L2A.
66
+ # L1C: RADIO_ADD_OFFSET
67
+ # L2A: BOA_ADD_OFFSET
68
+ for potential_tag in ["RADIO_ADD_OFFSET", "BOA_ADD_OFFSET"]:
69
+ for el in tree.iter(potential_tag):
70
+ if el.text is None:
71
+ raise ValueError(f"text is missing in {el}")
72
+ value = int(el.text)
73
+ if offset is None:
74
+ offset = value
75
+ assert offset <= 0
76
+ # For now assert the offset is always -1000.
77
+ assert offset == -1000
78
+ else:
79
+ assert offset == value
35
80
 
36
81
  if offset is None or offset == 0:
37
82
  return None
38
83
 
39
- def callback(array):
40
- return np.clip(array, -offset, None) + offset
84
+ def callback(array: npt.NDArray) -> npt.NDArray:
85
+ # Subtract positive number instead of add negative number since only the former
86
+ # works with uint16 array.
87
+ assert array.shape[0] == 1 and array.dtype == np.uint16
88
+ return np.clip(array, -offset, None) - (-offset) # type: ignore
41
89
 
42
90
  return callback
91
+
92
+
93
+ def get_sentinel2_tile_index() -> dict[str, list[FloatBounds]]:
94
+ """Get the Sentinel-2 tile index.
95
+
96
+ This is a map from tile name to a list of WGS84 bounds of the tile. A tile may have
97
+ multiple bounds if it crosses the antimeridian.
98
+ """
99
+ # Identify the Sentinel-2 tile names and bounds using the KML file.
100
+ # First, download the zip file and extract and parse the KML.
101
+ buf = io.BytesIO()
102
+ with urllib.request.urlopen(SENTINEL2_TILE_URL) as response:
103
+ shutil.copyfileobj(response, buf)
104
+ buf.seek(0)
105
+ with ZipFile(buf) as zipf:
106
+ member_names = zipf.namelist()
107
+ if len(member_names) != 1:
108
+ raise ValueError(
109
+ "Sentinel-2 tile zip file unexpectedly contains more than one file"
110
+ )
111
+
112
+ with zipf.open(member_names[0]) as memberf:
113
+ tree = ET.parse(memberf)
114
+
115
+ # Map from the tile name to a list of the longitude/latitude bounds.
116
+ tile_index: dict[str, list[FloatBounds]] = {}
117
+
118
+ # The KML is list of Placemark so iterate over those.
119
+ for placemark_node in tree.iter(SENTINEL2_KML_NAMESPACE + "Placemark"):
120
+ # The <name> node specifies the Sentinel-2 tile name.
121
+ name_node = placemark_node.find(SENTINEL2_KML_NAMESPACE + "name")
122
+ if name_node is None or name_node.text is None:
123
+ raise ValueError("Sentinel-2 KML has Placemark without valid name node")
124
+
125
+ tile_name = name_node.text
126
+
127
+ # There may be one or more <coordinates> nodes depending on whether it is a
128
+ # MultiGeometry. Some are polygons and some are points, but generally the
129
+ # points just seem to be the center of the tile. So we create one polygon for
130
+ # each coordinate list that is not a point, union them, and then split the
131
+ # union geometry over the antimeridian.
132
+ shapes = []
133
+ for coord_node in placemark_node.iter(SENTINEL2_KML_NAMESPACE + "coordinates"):
134
+ points = []
135
+ # It is list of space-separated coordinates like:
136
+ # 180,-73.0597374076,0 176.8646237862,-72.9914734628,0 ...
137
+ if coord_node.text is None:
138
+ raise ValueError("Sentinel-2 KML has coordinates node missing text")
139
+
140
+ point_strs = coord_node.text.strip().split()
141
+ for point_str in point_strs:
142
+ parts = point_str.split(",")
143
+ if len(parts) != 2 and len(parts) != 3:
144
+ continue
145
+
146
+ lon = float(parts[0])
147
+ lat = float(parts[1])
148
+ points.append((lon, lat))
149
+
150
+ # At least three points to get a polygon.
151
+ if len(points) < 3:
152
+ continue
153
+
154
+ shapes.append(shapely.Polygon(points))
155
+
156
+ if len(shapes) == 0:
157
+ raise ValueError("Sentinel-2 KML has Placemark with no coordinates")
158
+
159
+ # Now we union the shapes and split them at the antimeridian. This avoids
160
+ # issues where the tile bounds go from -180 to 180 longitude and thus match
161
+ # with anything at the same latitude.
162
+ union_shp = shapely.unary_union(shapes)
163
+ split_shapes = flatten_shape(split_shape_at_antimeridian(union_shp))
164
+ bounds_list: list[FloatBounds] = []
165
+ for shp in split_shapes:
166
+ bounds_list.append(shp.bounds)
167
+ tile_index[tile_name] = bounds_list
168
+
169
+ return tile_index
170
+
171
+
172
+ def _cache_sentinel2_tile_index(cache_dir: UPath) -> None:
173
+ """Cache the tiles from SENTINEL2_TILE_URL.
174
+
175
+ This way we just need to download it once.
176
+ """
177
+ json_fname = cache_dir / "tile_index.json"
178
+
179
+ if json_fname.exists():
180
+ return
181
+
182
+ logger.info(f"caching list of Sentinel-2 tiles to {json_fname}")
183
+ with open_atomic(json_fname, "w") as f:
184
+ json.dump(get_sentinel2_tile_index(), f)
185
+
186
+
187
+ @functools.cache
188
+ def load_sentinel2_tile_index(cache_dir: UPath) -> GridIndex:
189
+ """Load a GridIndex over Sentinel-2 tiles.
190
+
191
+ This function is cached so the GridIndex only needs to be constructed once (per
192
+ process).
193
+
194
+ Args:
195
+ cache_dir: the directory to cache the list of Sentinel-2 tiles.
196
+
197
+ Returns:
198
+ GridIndex over the tile names
199
+ """
200
+ _cache_sentinel2_tile_index(cache_dir)
201
+ json_fname = cache_dir / "tile_index.json"
202
+ with json_fname.open() as f:
203
+ json_data = json.load(f)
204
+
205
+ grid_index = GridIndex(0.5)
206
+ for tile_name, bounds_list in json_data.items():
207
+ for bounds in bounds_list:
208
+ grid_index.insert(bounds, tile_name)
209
+
210
+ return grid_index
211
+
212
+
213
+ def get_sentinel2_tiles(geometry: STGeometry, cache_dir: UPath) -> list[str]:
214
+ """Get all Sentinel-2 tiles (like 01CCV) intersecting the given geometry.
215
+
216
+ Args:
217
+ geometry: the geometry to check.
218
+ cache_dir: directory to cache the tiles.
219
+
220
+ Returns:
221
+ list of Sentinel-2 tile names that intersect the geometry.
222
+ """
223
+ tile_index = load_sentinel2_tile_index(cache_dir)
224
+ wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
225
+ # If the shape is a collection, it could be cutting across antimeridian.
226
+ # So we query each component shape separately and collect the results to avoid
227
+ # issues.
228
+ # We assume the caller has already applied split_at_antimeridian.
229
+ results = set()
230
+ for shp in flatten_shape(wgs84_geometry.shp):
231
+ for result in tile_index.query(shp.bounds):
232
+ assert isinstance(result, str)
233
+ results.add(result)
234
+ return list(results)
235
+
236
+
237
+ class ApiError(Exception):
238
+ """An error from Copernicus API."""
239
+
240
+ pass
241
+
242
+
243
+ class CopernicusItem(Item):
244
+ """An item in the Copernicus data source."""
245
+
246
+ def __init__(self, name: str, geometry: STGeometry, product_uuid: str) -> None:
247
+ """Create a new CopernicusItem.
248
+
249
+ Args:
250
+ name: the item name
251
+ geometry: the spatiotemporal item extent.
252
+ product_uuid: the product UUID from Copernicus API.
253
+ """
254
+ super().__init__(name, geometry)
255
+ self.product_uuid = product_uuid
256
+
257
+ def serialize(self) -> dict[str, Any]:
258
+ """Serializes the item to a JSON-encodable dictionary."""
259
+ d = super().serialize()
260
+ d["product_uuid"] = self.product_uuid
261
+ return d
262
+
263
+ @staticmethod
264
+ def deserialize(d: dict[str, Any]) -> "CopernicusItem":
265
+ """Deserializes an item from a JSON-decoded dictionary."""
266
+ item = super(CopernicusItem, CopernicusItem).deserialize(d)
267
+ return CopernicusItem(
268
+ name=item.name,
269
+ geometry=item.geometry,
270
+ product_uuid=d["product_uuid"],
271
+ )
272
+
273
+
274
+ class Copernicus(DataSource):
275
+ """Scenes from the ESA Copernicus OData API.
276
+
277
+ See https://documentation.dataspace.copernicus.eu/APIs/OData.html for details about
278
+ the API and how to get an access token.
279
+ """
280
+
281
+ BASE_URL = "https://catalogue.dataspace.copernicus.eu/odata/v1"
282
+
283
+ # The key in response dictionary for the next URL in paginated response.
284
+ NEXT_LINK_KEY = "@odata.nextLink"
285
+
286
+ # Chunk size to use when streaming a download.
287
+ CHUNK_SIZE = 8192
288
+
289
+ # Expected date format in filter strings.
290
+ DATE_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
291
+
292
+ # URL to get access tokens.
293
+ TOKEN_URL = "https://identity.dataspace.copernicus.eu/auth/realms/CDSE/protocol/openid-connect/token" # nosec
294
+
295
+ # We use a different URL for downloads. The BASE_URL would redirect to this URL but
296
+ # it makes it difficult since requests library drops the Authorization header after
297
+ # the redirect. So it is easier to access the DOWNLOAD_URL directly.
298
+ DOWNLOAD_URL = "https://download.dataspace.copernicus.eu/odata/v1"
299
+
300
+ def __init__(
301
+ self,
302
+ glob_to_bands: dict[str, list[str]],
303
+ access_token: str | None = None,
304
+ query_filter: str | None = None,
305
+ order_by: str | None = None,
306
+ sort_by: str | None = None,
307
+ sort_desc: bool = False,
308
+ timeout: float = 10,
309
+ context: DataSourceContext = DataSourceContext(),
310
+ ):
311
+ """Create a new Copernicus.
312
+
313
+ Args:
314
+ glob_to_bands: dictionary from a filename or glob string of an asset inside
315
+ the product zip file, to the list of bands that the asset contains.
316
+ access_token: API access token. See
317
+ https://documentation.dataspace.copernicus.eu/APIs/OData.html for how
318
+ to get a token. If not set, it is read from the environment variable
319
+ COPERNICUS_ACCESS_TOKEN. If that environment variable doesn't exist,
320
+ then we attempt to read the username/password from COPERNICUS_USERNAME
321
+ and COPERNICUS_PASSWORD (this is useful since access tokens are only
322
+ valid for an hour).
323
+ query_filter: filter string to include when searching for items. This will
324
+ be appended to other name, geographic, and sensing time filters where
325
+ applicable. For example, "Collection/Name eq 'SENTINEL-2'". See the API
326
+ documentation for more examples.
327
+ order_by: order by string to include when searching for items. For example,
328
+ "ContentDate/Start asc". See the API documentation for more examples.
329
+ sort_by: sort by the product attribute with this name. If set, attributes
330
+ will be expanded when listing products. Note that while order_by uses
331
+ the API to order products, the API provides limited options, and
332
+ sort_by instead is done after the API call.
333
+ sort_desc: for sort_by, sort in descending order instead of ascending
334
+ order.
335
+ timeout: timeout for requests.
336
+ context: the data source context.
337
+ """
338
+ self.glob_to_bands = glob_to_bands
339
+ self.query_filter = query_filter
340
+ self.order_by = order_by
341
+ self.sort_by = sort_by
342
+ self.sort_desc = sort_desc
343
+ self.timeout = timeout
344
+
345
+ self.username = None
346
+ self.password = None
347
+ self.access_token = access_token
348
+
349
+ if self.access_token is None:
350
+ if "COPERNICUS_ACCESS_TOKEN" in os.environ:
351
+ self.access_token = os.environ["COPERNICUS_ACCESS_TOKEN"]
352
+ else:
353
+ self.username = os.environ["COPERNICUS_USERNAME"]
354
+ self.password = os.environ["COPERNICUS_PASSWORD"]
355
+
356
+ def deserialize_item(self, serialized_item: Any) -> CopernicusItem:
357
+ """Deserializes an item from JSON-decoded data."""
358
+ assert isinstance(serialized_item, dict)
359
+ return CopernicusItem.deserialize(serialized_item)
360
+
361
+ def _get(self, path: str) -> dict[str, Any]:
362
+ """Get the API path and return JSON content."""
363
+ url = self.BASE_URL + path
364
+ logger.debug(f"GET {url}")
365
+ response = requests.get(url, timeout=self.timeout)
366
+ if response.status_code != 200:
367
+ content = str(response.content)
368
+ raise ApiError(
369
+ f"expected status code 200 but got {response.status_code} ({content})"
370
+ )
371
+ return response.json()
372
+
373
+ def _build_filter_string(self, base_filter: str) -> str:
374
+ """Build a filter string combining base_filter with user-provided filter.
375
+
376
+ Args:
377
+ base_filter: the base filter string that the caller wants to include.
378
+
379
+ Returns:
380
+ a filter string that combines base_filter with the optional user-provided
381
+ filter.
382
+ """
383
+ if self.query_filter is None:
384
+ return base_filter
385
+ else:
386
+ return f"{base_filter} and {self.query_filter}"
387
+
388
+ def _product_to_item(self, product: dict[str, Any]) -> CopernicusItem:
389
+ """Convert a product dictionary from API response to an Item.
390
+
391
+ Args:
392
+ product: the product dictionary that comes from an API response to the
393
+ /Products endpoint.
394
+
395
+ Returns:
396
+ corresponding Item.
397
+ """
398
+ name = product["Name"]
399
+ uuid = product["Id"]
400
+ shp = shapely.geometry.shape(product["GeoFootprint"])
401
+ time_range = (
402
+ datetime.fromisoformat(product["ContentDate"]["Start"]),
403
+ datetime.fromisoformat(product["ContentDate"]["End"]),
404
+ )
405
+ geom = STGeometry(WGS84_PROJECTION, shp, time_range)
406
+
407
+ return CopernicusItem(name, geom, uuid)
408
+
409
+ def _paginate(self, path: str) -> list[Any]:
410
+ """Iterate over pages of responses for the given path.
411
+
412
+ If the response includes "@odata.nextLink", then we continue to request until
413
+ it no longer has any nextLink. The values in response["value"] must be a list
414
+ and are concatenated.
415
+
416
+ Args:
417
+ path: the initial path to request. Additional requests will be made if a
418
+ nextLink appears in the response.
419
+
420
+ Returns:
421
+ the concatenated values across responses.
422
+ """
423
+ all_values = []
424
+
425
+ while True:
426
+ response = self._get(path)
427
+ all_values.extend(response["value"])
428
+ if self.NEXT_LINK_KEY not in response:
429
+ break
430
+
431
+ # Use the next link, but we only want the path not the base URL.
432
+ next_link = response[self.NEXT_LINK_KEY]
433
+ if not next_link.startswith(self.BASE_URL):
434
+ raise ValueError(
435
+ f"got next link {next_link} but it does not start with the base URL {self.BASE_URL}"
436
+ )
437
+ path = next_link.split(self.BASE_URL)[1]
438
+
439
+ return all_values
440
+
441
+ def _get_product(
442
+ self, name: str, expand_attributes: bool = False
443
+ ) -> dict[str, Any]:
444
+ """Get the product dict from Copernicus API given scene name.
445
+
446
+ Args:
447
+ name: the scene name to get.
448
+ expand_attributes: whether to request API to provide the attributes of the
449
+ returned product.
450
+
451
+ Returns:
452
+ the decoded JSON product dict.
453
+ """
454
+ filter_string = self._build_filter_string(f"Name eq '{quote(name)}'")
455
+ path = f"/Products?$filter={filter_string}"
456
+ if expand_attributes:
457
+ path += "&$expand=Attributes"
458
+ response = self._get(path)
459
+ products = response["value"]
460
+ if len(products) != 1:
461
+ raise ValueError(
462
+ f"expected one product from {path} but got {len(products)}"
463
+ )
464
+ return products[0]
465
+
466
+ def get_item_by_name(self, name: str) -> CopernicusItem:
467
+ """Gets an item by name.
468
+
469
+ Args:
470
+ name: the name of the item to get
471
+
472
+ Returns:
473
+ the item object
474
+ """
475
+ product = self._get_product(name)
476
+ return self._product_to_item(product)
477
+
478
+ def get_items(
479
+ self, geometries: list[STGeometry], query_config: QueryConfig
480
+ ) -> list[list[list[CopernicusItem]]]:
481
+ """Get a list of items in the data source intersecting the given geometries.
482
+
483
+ Args:
484
+ geometries: the spatiotemporal geometries
485
+ query_config: the query configuration
486
+
487
+ Returns:
488
+ List of groups of items that should be retrieved for each geometry.
489
+ """
490
+ groups = []
491
+ for geometry in geometries:
492
+ # Perform a spatial + temporal search.
493
+ # We use EPSG:4326 (WGS84) for the spatial search; the API expects WKT in
494
+ # addition to the EPSG identifier.
495
+ wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
496
+ wgs84_wkt = wgs84_geometry.shp.wkt
497
+ filter_string = (
498
+ f"OData.CSC.Intersects(area=geography'SRID=4326;{wgs84_wkt}')"
499
+ )
500
+
501
+ if wgs84_geometry.time_range is not None:
502
+ start = wgs84_geometry.time_range[0].strftime(self.DATE_FORMAT)
503
+ end = wgs84_geometry.time_range[1].strftime(self.DATE_FORMAT)
504
+ filter_string += f" and ContentDate/Start gt {start}"
505
+ filter_string += f" and ContentDate/End lt {end}"
506
+
507
+ filter_string = self._build_filter_string(filter_string)
508
+ path = f"/Products?$filter={filter_string}&$top=1000"
509
+
510
+ if self.order_by is not None:
511
+ path += f"&$orderby={self.order_by}"
512
+ if self.sort_by is not None:
513
+ path += "&$expand=Attributes"
514
+
515
+ products = self._paginate(path)
516
+
517
+ if self.sort_by is not None:
518
+ # Define helper function that computes the sort value.
519
+ def get_attribute_value(product: dict[str, Any]) -> Any:
520
+ attribute_by_name = {
521
+ attribute["Name"]: attribute["Value"]
522
+ for attribute in product["Attributes"]
523
+ }
524
+ return attribute_by_name[self.sort_by]
525
+
526
+ products.sort(
527
+ key=get_attribute_value,
528
+ reverse=self.sort_desc,
529
+ )
530
+
531
+ candidate_items = [self._product_to_item(product) for product in products]
532
+ cur_groups = match_candidate_items_to_window(
533
+ geometry, candidate_items, query_config
534
+ )
535
+ groups.append(cur_groups)
536
+
537
+ return groups
538
+
539
+ def _get_access_token(self) -> str:
540
+ """Get the access token to use for downloads.
541
+
542
+ If the username/password are set, we need to get the token from API.
543
+ """
544
+ if self.access_token is not None:
545
+ return self.access_token
546
+
547
+ response = requests.post(
548
+ self.TOKEN_URL,
549
+ data={
550
+ "grant_type": "password",
551
+ "username": self.username,
552
+ "password": self.password,
553
+ "client_id": "cdse-public",
554
+ },
555
+ timeout=self.timeout,
556
+ )
557
+ return response.json()["access_token"]
558
+
559
+ def _zip_member_glob(self, member_names: list[str], pattern: str) -> str:
560
+ """Pick the zip member name that matches the given pattern.
561
+
562
+ Args:
563
+ member_names: the list of names in the zip file.
564
+ pattern: the glob pattern to match.
565
+
566
+ Returns:
567
+ the member name matching the pattern.
568
+
569
+ Raises:
570
+ ValueError: if there is no matching member.
571
+ """
572
+ for name in member_names:
573
+ if pathlib.PurePosixPath(name).match(pattern):
574
+ return name
575
+ raise ValueError(f"no zip member matching {pattern}")
576
+
577
+ def _process_product_zip(
578
+ self, tile_store: TileStoreWithLayer, item: CopernicusItem, local_zip_fname: str
579
+ ) -> None:
580
+ """Ingest rasters in the specified product zip file.
581
+
582
+ Args:
583
+ tile_store: the tile store to ingest the rasters into.
584
+ item: the item to download and ingest.
585
+ local_zip_fname: the local filename where the product zip file has been
586
+ downloaded.
587
+ """
588
+ with ZipFile(local_zip_fname) as zipf:
589
+ member_names = zipf.namelist()
590
+
591
+ # Get each raster that is needed.
592
+ for glob_pattern, band_names in self.glob_to_bands.items():
593
+ if tile_store.is_raster_ready(item.name, band_names):
594
+ continue
595
+
596
+ member_name = self._zip_member_glob(member_names, glob_pattern)
597
+
598
+ # Extract it to a temporary directory.
599
+ with tempfile.TemporaryDirectory() as tmp_dir:
600
+ logger.debug(f"Extracting {member_name} for bands {band_names}")
601
+ local_raster_fname = zipf.extract(member_name, path=tmp_dir)
602
+
603
+ # Now we can ingest it.
604
+ logger.debug(f"Ingesting the raster for bands {band_names}")
605
+ tile_store.write_raster_file(
606
+ item.name, band_names, UPath(local_raster_fname)
607
+ )
608
+
609
+ def ingest(
610
+ self,
611
+ tile_store: TileStoreWithLayer,
612
+ items: list[CopernicusItem],
613
+ geometries: list[list[STGeometry]],
614
+ ) -> None:
615
+ """Ingest items into the given tile store.
616
+
617
+ Args:
618
+ tile_store: the tile store to ingest into
619
+ items: the items to ingest
620
+ geometries: a list of geometries needed for each item
621
+ """
622
+ for item in items:
623
+ # The product zip file is one big download, so we download it if any raster
624
+ # hasn't been ingested yet.
625
+ any_rasters_needed = False
626
+ for band_names in self.glob_to_bands.values():
627
+ if tile_store.is_raster_ready(item.name, band_names):
628
+ continue
629
+ any_rasters_needed = True
630
+ break
631
+ if not any_rasters_needed:
632
+ continue
633
+
634
+ # Download the product zip file to temporary directory.
635
+ with tempfile.TemporaryDirectory() as tmp_dir:
636
+ path = f"/Products({item.product_uuid})/$value"
637
+ logger.debug(
638
+ f"Downloading product zip file from {self.DOWNLOAD_URL + path}"
639
+ )
640
+
641
+ access_token = self._get_access_token()
642
+ headers = {
643
+ "Authorization": f"Bearer {access_token}",
644
+ }
645
+ response = requests.get(
646
+ self.DOWNLOAD_URL + path,
647
+ stream=True,
648
+ headers=headers,
649
+ timeout=self.timeout,
650
+ )
651
+ if response.status_code != 200:
652
+ content = str(response.content)
653
+ raise ApiError(
654
+ f"expected status code 200 but got {response.status_code} ({content})"
655
+ )
656
+
657
+ local_zip_fname = os.path.join(tmp_dir, "product.zip")
658
+ with open(local_zip_fname, "wb") as f:
659
+ for chunk in response.iter_content(chunk_size=self.CHUNK_SIZE):
660
+ f.write(chunk)
661
+
662
+ # Process each raster we need from the zip file.
663
+ self._process_product_zip(tile_store, item, local_zip_fname)
664
+
665
+
666
+ class Sentinel2ProductType(str, Enum):
667
+ """The Sentinel-2 product type."""
668
+
669
+ L1C = "S2MSI1C"
670
+ L2A = "S2MSI2A"
671
+
672
+
673
+ class Sentinel2(Copernicus):
674
+ """A data source for Sentinel-2 data from the Copernicus API."""
675
+
676
+ BANDS = {
677
+ "B01": ["B01"],
678
+ "B02": ["B02"],
679
+ "B03": ["B03"],
680
+ "B04": ["B04"],
681
+ "B05": ["B05"],
682
+ "B06": ["B06"],
683
+ "B07": ["B07"],
684
+ "B08": ["B08"],
685
+ "B09": ["B09"],
686
+ "B11": ["B11"],
687
+ "B12": ["B12"],
688
+ "B8A": ["B8A"],
689
+ "TCI": ["R", "G", "B"],
690
+ # L1C-only products.
691
+ "B10": ["B10"],
692
+ # L2A-only products.
693
+ "AOT": ["AOT"],
694
+ "WVP": ["WVP"],
695
+ "SCL": ["SCL"],
696
+ }
697
+
698
+ # Glob pattern for image files within the product zip file.
699
+ GLOB_PATTERNS = {
700
+ Sentinel2ProductType.L1C: {
701
+ "B01": "*/GRANULE/*/IMG_DATA/*_B01.jp2",
702
+ "B02": "*/GRANULE/*/IMG_DATA/*_B02.jp2",
703
+ "B03": "*/GRANULE/*/IMG_DATA/*_B03.jp2",
704
+ "B04": "*/GRANULE/*/IMG_DATA/*_B04.jp2",
705
+ "B05": "*/GRANULE/*/IMG_DATA/*_B05.jp2",
706
+ "B06": "*/GRANULE/*/IMG_DATA/*_B06.jp2",
707
+ "B07": "*/GRANULE/*/IMG_DATA/*_B07.jp2",
708
+ "B08": "*/GRANULE/*/IMG_DATA/*_B08.jp2",
709
+ "B8A": "*/GRANULE/*/IMG_DATA/*_B8A.jp2",
710
+ "B09": "*/GRANULE/*/IMG_DATA/*_B09.jp2",
711
+ "B10": "*/GRANULE/*/IMG_DATA/*_B10.jp2",
712
+ "B11": "*/GRANULE/*/IMG_DATA/*_B11.jp2",
713
+ "B12": "*/GRANULE/*/IMG_DATA/*_B12.jp2",
714
+ "TCI": "*/GRANULE/*/IMG_DATA/*_TCI.jp2",
715
+ },
716
+ Sentinel2ProductType.L2A: {
717
+ # In L2A, products are grouped by resolution.
718
+ # They are downsampled at lower resolutions too, so here we specify to just
719
+ # use the highest resolution one.
720
+ "B01": "*/GRANULE/*/IMG_DATA/R20m/*_B01_20m.jp2",
721
+ "B02": "*/GRANULE/*/IMG_DATA/R10m/*_B02_10m.jp2",
722
+ "B03": "*/GRANULE/*/IMG_DATA/R10m/*_B03_10m.jp2",
723
+ "B04": "*/GRANULE/*/IMG_DATA/R10m/*_B04_10m.jp2",
724
+ "B05": "*/GRANULE/*/IMG_DATA/R20m/*_B05_20m.jp2",
725
+ "B06": "*/GRANULE/*/IMG_DATA/R20m/*_B06_20m.jp2",
726
+ "B07": "*/GRANULE/*/IMG_DATA/R20m/*_B07_20m.jp2",
727
+ "B08": "*/GRANULE/*/IMG_DATA/R10m/*_B08_10m.jp2",
728
+ "B8A": "*/GRANULE/*/IMG_DATA/R20m/*_B8A_20m.jp2",
729
+ "B09": "*/GRANULE/*/IMG_DATA/R60m/*_B09_60m.jp2",
730
+ "B11": "*/GRANULE/*/IMG_DATA/R20m/*_B11_20m.jp2",
731
+ "B12": "*/GRANULE/*/IMG_DATA/R20m/*_B12_20m.jp2",
732
+ "TCI": "*/GRANULE/*/IMG_DATA/R10m/*_TCI_10m.jp2",
733
+ "AOT": "*/GRANULE/*/IMG_DATA/R10m/*_AOT_10m.jp2",
734
+ "WVP": "*/GRANULE/*/IMG_DATA/R10m/*_WVP_10m.jp2",
735
+ "SCL": "*/GRANULE/*/IMG_DATA/R20m/*_SCL_20m.jp2",
736
+ },
737
+ }
738
+
739
+ # Pattern of XML file within the product zip file.
740
+ METADATA_PATTERN = "*/MTD_MSIL*.xml"
741
+
742
+ def __init__(
743
+ self,
744
+ product_type: Sentinel2ProductType,
745
+ harmonize: bool = False,
746
+ assets: list[str] | None = None,
747
+ context: DataSourceContext = DataSourceContext(),
748
+ **kwargs: Any,
749
+ ):
750
+ """Create a new Sentinel2.
751
+
752
+ Args:
753
+ product_type: desired product type, L1C or L2A.
754
+ harmonize: harmonize pixel values across different processing baselines,
755
+ see https://developers.google.com/earth-engine/datasets/catalog/COPERNICUS_S2_SR_HARMONIZED
756
+ assets: the assets to download, or None to download all assets. This is
757
+ only used if the layer config is not in the context.
758
+ context: the data source context.
759
+ kwargs: additional arguments to pass to Copernicus.
760
+ """
761
+ # Create glob to bands map.
762
+ # If the context is provided, we limit to needed assets based on the configured
763
+ # band sets.
764
+ if context.layer_config is not None:
765
+ needed_assets = []
766
+ for asset_key, asset_bands in Sentinel2.BANDS.items():
767
+ # See if the bands provided by this asset intersect with the bands in
768
+ # at least one configured band set.
769
+ for band_set in context.layer_config.band_sets:
770
+ if not set(band_set.bands).intersection(set(asset_bands)):
771
+ continue
772
+ needed_assets.append(asset_key)
773
+ break
774
+ elif assets is not None:
775
+ needed_assets = assets
776
+ else:
777
+ needed_assets = list(Sentinel2.BANDS.keys())
778
+
779
+ glob_to_bands = {}
780
+ for asset_key in needed_assets:
781
+ band_names = self.BANDS[asset_key]
782
+ glob_pattern = self.GLOB_PATTERNS[product_type][asset_key]
783
+ glob_to_bands[glob_pattern] = band_names
784
+
785
+ # Create query filter based on the product type.
786
+ query_filter = f"Attributes/OData.CSC.StringAttribute/any(att:att/Name eq 'productType' and att/OData.CSC.StringAttribute/Value eq '{quote(product_type.value)}')"
787
+
788
+ super().__init__(
789
+ context=context,
790
+ glob_to_bands=glob_to_bands,
791
+ query_filter=query_filter,
792
+ **kwargs,
793
+ )
794
+ self.harmonize = harmonize
795
+
796
+ # Override to support harmonization step.
797
+ def _process_product_zip(
798
+ self, tile_store: TileStoreWithLayer, item: CopernicusItem, local_zip_fname: str
799
+ ) -> None:
800
+ """Ingest rasters in the specified product zip file.
801
+
802
+ Args:
803
+ tile_store: the tile store to ingest the rasters into.
804
+ item: the item to download and ingest.
805
+ local_zip_fname: the local filename where the product zip file has been
806
+ downloaded.
807
+ """
808
+ with ZipFile(local_zip_fname) as zipf:
809
+ member_names = zipf.namelist()
810
+
811
+ harmonize_callback = None
812
+ if self.harmonize:
813
+ # Need to check the product XML to see what the callback should be.
814
+ # It's in the zip file.
815
+ member_name = self._zip_member_glob(member_names, self.METADATA_PATTERN)
816
+ with zipf.open(member_name) as f:
817
+ xml_data = ET.parse(f)
818
+ harmonize_callback = get_harmonize_callback(xml_data)
819
+
820
+ # Get each raster that is needed.
821
+ for glob_pattern, band_names in self.glob_to_bands.items():
822
+ if tile_store.is_raster_ready(item.name, band_names):
823
+ continue
824
+
825
+ member_name = self._zip_member_glob(member_names, glob_pattern)
826
+
827
+ # Extract it to a temporary directory.
828
+ with tempfile.TemporaryDirectory() as tmp_dir:
829
+ logger.debug(f"Extracting {member_name} for bands {band_names}")
830
+ local_raster_fname = zipf.extract(member_name, path=tmp_dir)
831
+
832
+ logger.debug(f"Ingesting the raster for bands {band_names}")
833
+
834
+ if harmonize_callback is None or band_names == ["R", "G", "B"]:
835
+ # No callback -- we can just ingest the file directly.
836
+ # Or it is TCI product which is not impacted by the harmonization issue.
837
+ tile_store.write_raster_file(
838
+ item.name, band_names, UPath(local_raster_fname)
839
+ )
840
+
841
+ else:
842
+ # In this case we need to read the array, convert the pixel
843
+ # values, and pass modified array directly to the TileStore.
844
+ with rasterio.open(local_raster_fname) as src:
845
+ array = src.read()
846
+ projection, bounds = get_raster_projection_and_bounds(src)
847
+ array = harmonize_callback(array)
848
+ tile_store.write_raster(
849
+ item.name, band_names, projection, bounds, array
850
+ )
851
+
852
+
853
+ class Sentinel1ProductType(str, Enum):
854
+ """The Sentinel-1 product type."""
855
+
856
+ IW_GRDH = "IW_GRDH_1S"
857
+
858
+
859
+ class Sentinel1Polarisation(str, Enum):
860
+ """The Sentinel-1 polarisation."""
861
+
862
+ VV_VH = "VV&VH"
863
+
864
+
865
+ class Sentinel1OrbitDirection(str, Enum):
866
+ """The Sentinel-1 orbit direction."""
867
+
868
+ ASCENDING = "ASCENDING"
869
+ DESCENDING = "DESCENDING"
870
+
871
+
872
+ class Sentinel1(Copernicus):
873
+ """A data source for Sentinel-1 data from the Copernicus API."""
874
+
875
+ GLOB_TO_BANDS = {
876
+ Sentinel1Polarisation.VV_VH: {
877
+ "*/measurement/*-vh-*.tiff": ["vh"],
878
+ "*/measurement/*-vv-*.tiff": ["vv"],
879
+ }
880
+ }
881
+
882
+ # Pattern of XML file within the product zip file.
883
+ METADATA_PATTERN = "*/MTD_MSIL*.xml"
884
+
885
+ def __init__(
886
+ self,
887
+ product_type: Sentinel1ProductType,
888
+ polarisation: Sentinel1Polarisation,
889
+ orbit_direction: Sentinel1OrbitDirection | None = None,
890
+ context: DataSourceContext = DataSourceContext(),
891
+ **kwargs: Any,
892
+ ):
893
+ """Create a new Sentinel1.
894
+
895
+ Args:
896
+ product_type: desired product type.
897
+ polarisation: desired polarisation(s).
898
+ orbit_direction: optional orbit direction to filter by.
899
+ context: the data source context.
900
+ kwargs: additional arguments to pass to Copernicus.
901
+ """
902
+ # Create query filter based on the product type.
903
+ query_filter = (
904
+ f"Attributes/OData.CSC.StringAttribute/any(att:att/Name eq 'productType' and att/OData.CSC.StringAttribute/Value eq '{quote(product_type.value)}')"
905
+ + f" and Attributes/OData.CSC.StringAttribute/any(att:att/Name eq 'polarisationChannels' and att/OData.CSC.StringAttribute/Value eq '{quote(polarisation.value)}')"
906
+ )
907
+ if orbit_direction:
908
+ query_filter += f" and Attributes/OData.CSC.StringAttribute/any(att:att/Name eq 'orbitDirection' and att/OData.CSC.StringAttribute/Value eq '{quote(orbit_direction.value)}')"
909
+
910
+ super().__init__(
911
+ glob_to_bands=self.GLOB_TO_BANDS[polarisation],
912
+ query_filter=query_filter,
913
+ **kwargs,
914
+ )