rslearn 0.0.26__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. rslearn/data_sources/__init__.py +2 -0
  2. rslearn/data_sources/aws_landsat.py +44 -161
  3. rslearn/data_sources/aws_open_data.py +2 -4
  4. rslearn/data_sources/aws_sentinel1.py +1 -3
  5. rslearn/data_sources/aws_sentinel2_element84.py +54 -165
  6. rslearn/data_sources/climate_data_store.py +1 -3
  7. rslearn/data_sources/copernicus.py +1 -2
  8. rslearn/data_sources/data_source.py +1 -1
  9. rslearn/data_sources/direct_materialize_data_source.py +336 -0
  10. rslearn/data_sources/earthdaily.py +52 -155
  11. rslearn/data_sources/earthdatahub.py +425 -0
  12. rslearn/data_sources/eurocrops.py +1 -2
  13. rslearn/data_sources/gcp_public_data.py +1 -2
  14. rslearn/data_sources/google_earth_engine.py +1 -2
  15. rslearn/data_sources/hf_srtm.py +595 -0
  16. rslearn/data_sources/local_files.py +1 -1
  17. rslearn/data_sources/openstreetmap.py +1 -1
  18. rslearn/data_sources/planet.py +1 -2
  19. rslearn/data_sources/planet_basemap.py +1 -2
  20. rslearn/data_sources/planetary_computer.py +183 -186
  21. rslearn/data_sources/soilgrids.py +3 -3
  22. rslearn/data_sources/stac.py +1 -2
  23. rslearn/data_sources/usda_cdl.py +1 -3
  24. rslearn/data_sources/usgs_landsat.py +7 -254
  25. rslearn/data_sources/worldcereal.py +1 -1
  26. rslearn/data_sources/worldcover.py +1 -1
  27. rslearn/data_sources/worldpop.py +1 -1
  28. rslearn/data_sources/xyz_tiles.py +5 -9
  29. rslearn/models/concatenate_features.py +6 -1
  30. rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
  31. rslearn/train/data_module.py +27 -27
  32. rslearn/train/dataset.py +109 -62
  33. rslearn/train/lightning_module.py +1 -1
  34. rslearn/train/model_context.py +3 -3
  35. rslearn/train/prediction_writer.py +69 -41
  36. rslearn/train/tasks/classification.py +1 -1
  37. rslearn/train/tasks/detection.py +5 -5
  38. rslearn/train/tasks/regression.py +1 -1
  39. rslearn/utils/__init__.py +2 -0
  40. rslearn/utils/geometry.py +21 -0
  41. rslearn/utils/m2m_api.py +251 -0
  42. rslearn/utils/retry_session.py +43 -0
  43. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
  44. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/RECORD +49 -45
  45. rslearn/data_sources/earthdata_srtm.py +0 -282
  46. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
  47. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
  48. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
  49. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
  50. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
@@ -4,12 +4,9 @@
4
4
  """
5
5
 
6
6
  import io
7
- import json
8
7
  import os
9
8
  import shutil
10
9
  import tempfile
11
- import time
12
- import uuid
13
10
  from collections.abc import Generator
14
11
  from datetime import UTC, datetime, timedelta
15
12
  from typing import Any, BinaryIO
@@ -24,246 +21,7 @@ from rslearn.data_sources import DataSource, DataSourceContext, Item
24
21
  from rslearn.data_sources.utils import match_candidate_items_to_window
25
22
  from rslearn.tile_stores import TileStoreWithLayer
26
23
  from rslearn.utils import STGeometry
27
-
28
-
29
- class APIException(Exception):
30
- """Exception raised for M2M API errors."""
31
-
32
- pass
33
-
34
-
35
- class M2MAPIClient:
36
- """An API client for interacting with the USGS M2M API."""
37
-
38
- api_url = "https://m2m.cr.usgs.gov/api/api/json/stable/"
39
- pagination_size = 1000
40
-
41
- def __init__(
42
- self,
43
- username: str,
44
- password: str | None = None,
45
- token: str | None = None,
46
- timeout: timedelta = timedelta(seconds=120),
47
- ) -> None:
48
- """Initialize a new M2MAPIClient.
49
-
50
- Args:
51
- username: the EROS username
52
- password: the EROS password
53
- token: the application token. One of password or token must be specified.
54
- timeout: timeout for requests.
55
- """
56
- self.username = username
57
- self.timeout = timeout
58
-
59
- if password is not None and token is not None:
60
- raise ValueError("only one of password or token can be specified")
61
-
62
- if password is not None:
63
- json_data = json.dumps({"username": self.username, "password": password})
64
- response = requests.post(
65
- self.api_url + "login",
66
- data=json_data,
67
- timeout=self.timeout.total_seconds(),
68
- )
69
-
70
- elif token is not None:
71
- json_data = json.dumps({"username": username, "token": token})
72
- response = requests.post(
73
- self.api_url + "login-token",
74
- data=json_data,
75
- timeout=self.timeout.total_seconds(),
76
- )
77
-
78
- else:
79
- raise ValueError("one of password or token must be specified")
80
-
81
- response.raise_for_status()
82
- self.auth_token = response.json()["data"]
83
-
84
- def request(
85
- self, endpoint: str, data: dict[str, Any] | None = None
86
- ) -> dict[str, Any] | None:
87
- """Make a request to the API.
88
-
89
- Args:
90
- endpoint: the endpoint to call
91
- data: POST data to pass
92
-
93
- Returns:
94
- JSON response data if any
95
- """
96
- response = requests.post(
97
- self.api_url + endpoint,
98
- headers={"X-Auth-Token": self.auth_token},
99
- data=json.dumps(data),
100
- timeout=self.timeout.total_seconds(),
101
- )
102
- response.raise_for_status()
103
- if response.text:
104
- response_dict = response.json()
105
-
106
- if response_dict["errorMessage"]:
107
- raise APIException(response_dict["errorMessage"])
108
- return response_dict
109
- return None
110
-
111
- def close(self) -> None:
112
- """Logout from the API."""
113
- self.request("logout")
114
-
115
- def __enter__(self) -> "M2MAPIClient":
116
- """Enter function to provide with semantics."""
117
- return self
118
-
119
- def __exit__(self) -> None:
120
- """Exit function to provide with semantics.
121
-
122
- Logs out the API.
123
- """
124
- self.close()
125
-
126
- def get_filters(self, dataset_name: str) -> list[dict[str, Any]]:
127
- """Returns filters available for the given dataset.
128
-
129
- Args:
130
- dataset_name: the dataset name e.g. landsat_ot_c2_l1
131
-
132
- Returns:
133
- list of filter objects
134
- """
135
- response_dict = self.request("dataset-filters", {"datasetName": dataset_name})
136
- if response_dict is None:
137
- raise APIException("No response from API")
138
- return response_dict["data"]
139
-
140
- def scene_search(
141
- self,
142
- dataset_name: str,
143
- acquisition_time_range: tuple[datetime, datetime] | None = None,
144
- cloud_cover_range: tuple[int, int] | None = None,
145
- bbox: tuple[float, float, float, float] | None = None,
146
- metadata_filter: dict[str, Any] | None = None,
147
- ) -> list[dict[str, Any]]:
148
- """Search for scenes matching the arguments.
149
-
150
- Args:
151
- dataset_name: the dataset name e.g. landsat_ot_c2_l1
152
- acquisition_time_range: optional filter on the acquisition time
153
- cloud_cover_range: optional filter on the cloud cover
154
- bbox: optional spatial filter
155
- metadata_filter: optional metadata filter dict
156
- """
157
- base_data: dict[str, Any] = {"datasetName": dataset_name, "sceneFilter": {}}
158
- if acquisition_time_range:
159
- base_data["sceneFilter"]["acquisitionFilter"] = {
160
- "start": acquisition_time_range[0].isoformat(),
161
- "end": acquisition_time_range[1].isoformat(),
162
- }
163
- if cloud_cover_range:
164
- base_data["sceneFilter"]["cloudCoverFilter"] = {
165
- "min": cloud_cover_range[0],
166
- "max": cloud_cover_range[1],
167
- "includeUnknown": False,
168
- }
169
- if bbox:
170
- base_data["sceneFilter"]["spatialFilter"] = {
171
- "filterType": "mbr",
172
- "lowerLeft": {"longitude": bbox[0], "latitude": bbox[1]},
173
- "upperRight": {"longitude": bbox[2], "latitude": bbox[3]},
174
- }
175
- if metadata_filter:
176
- base_data["sceneFilter"]["metadataFilter"] = metadata_filter
177
-
178
- starting_number = 1
179
- results = []
180
- while True:
181
- cur_data = base_data.copy()
182
- cur_data["startingNumber"] = starting_number
183
- cur_data["maxResults"] = self.pagination_size
184
- response_dict = self.request("scene-search", cur_data)
185
- if response_dict is None:
186
- raise APIException("No response from API")
187
- data = response_dict["data"]
188
- results.extend(data["results"])
189
- if data["recordsReturned"] < self.pagination_size:
190
- break
191
- starting_number += self.pagination_size
192
-
193
- return results
194
-
195
- def get_scene_metadata(self, dataset_name: str, entity_id: str) -> dict[str, Any]:
196
- """Get detailed metadata for a scene.
197
-
198
- Args:
199
- dataset_name: the dataset name in which to search
200
- entity_id: the entity ID of the scene
201
-
202
- Returns:
203
- full scene metadata
204
- """
205
- response_dict = self.request(
206
- "scene-metadata",
207
- {
208
- "datasetName": dataset_name,
209
- "entityId": entity_id,
210
- "metadataType": "full",
211
- },
212
- )
213
- if response_dict is None:
214
- raise APIException("No response from API")
215
- return response_dict["data"]
216
-
217
- def get_downloadable_products(
218
- self, dataset_name: str, entity_id: str
219
- ) -> list[dict[str, Any]]:
220
- """Get the downloadable products for a given scene.
221
-
222
- Args:
223
- dataset_name: the dataset name
224
- entity_id: the entity ID of the scene
225
-
226
- Returns:
227
- list of downloadable products
228
- """
229
- data = {"datasetName": dataset_name, "entityIds": [entity_id]}
230
- response_dict = self.request("download-options", data)
231
- if response_dict is None:
232
- raise APIException("No response from API")
233
- return response_dict["data"]
234
-
235
- def get_download_url(self, entity_id: str, product_id: str) -> str:
236
- """Get the download URL for a given product.
237
-
238
- Args:
239
- entity_id: the entity ID of the product
240
- product_id: the product ID of the product
241
-
242
- Returns:
243
- the download URL
244
- """
245
- label = str(uuid.uuid4())
246
- data = {
247
- "downloads": [
248
- {"label": label, "entityId": entity_id, "productId": product_id}
249
- ]
250
- }
251
- response_dict = self.request("download-request", data)
252
- if response_dict is None:
253
- raise APIException("No response from API")
254
- response = response_dict["data"]
255
- while True:
256
- response_dict = self.request("download-retrieve", {"label": label})
257
- if response_dict is None:
258
- raise APIException("No response from API")
259
- response = response_dict["data"]
260
- if len(response["available"]) > 0:
261
- return response["available"][0]["url"]
262
- if len(response["requested"]) == 0:
263
- raise Exception("Did not get download URL")
264
- if response["requested"][0].get("url"):
265
- return response["requested"][0]["url"]
266
- time.sleep(10)
24
+ from rslearn.utils.m2m_api import APIException, M2MAPIClient
267
25
 
268
26
 
269
27
  class LandsatOliTirsItem(Item):
@@ -314,30 +72,26 @@ class LandsatOliTirs(DataSource):
314
72
 
315
73
  def __init__(
316
74
  self,
317
- username: str,
318
- sort_by: str | None = None,
319
- password: str | None = None,
75
+ username: str | None = None,
320
76
  token: str | None = None,
77
+ sort_by: str | None = None,
321
78
  timeout: timedelta = timedelta(seconds=10),
322
79
  context: DataSourceContext = DataSourceContext(),
323
80
  ):
324
81
  """Initialize a new LandsatOliTirs instance.
325
82
 
326
83
  Args:
327
- username: EROS username
84
+ username: EROS username (see M2MAPIClient).
85
+ token: EROS application token (see M2MAPIClient).
328
86
  sort_by: can be "cloud_cover", default arbitrary order; only has effect for
329
87
  SpaceMode.WITHIN.
330
- password: EROS password (see M2MAPIClient).
331
- token: EROS application token (see M2MAPIClient).
332
88
  timeout: timeout for requests.
333
89
  context: the data source context.
334
90
  """
335
91
  self.sort_by = sort_by
336
92
  self.timeout = timeout
337
93
 
338
- self.client = M2MAPIClient(
339
- username, password=password, token=token, timeout=timeout
340
- )
94
+ self.client = M2MAPIClient(username=username, token=token, timeout=timeout)
341
95
 
342
96
  def _scene_metadata_to_item(self, result: dict[str, Any]) -> LandsatOliTirsItem:
343
97
  """Convert scene metadata from the API to a LandsatOliTirsItem."""
@@ -429,9 +183,8 @@ class LandsatOliTirs(DataSource):
429
183
  )
430
184
  return self._scene_metadata_to_item(scene_metadata)
431
185
 
432
- def deserialize_item(self, serialized_item: Any) -> Item:
186
+ def deserialize_item(self, serialized_item: dict) -> Item:
433
187
  """Deserializes an item from JSON-decoded data."""
434
- assert isinstance(serialized_item, dict)
435
188
  return LandsatOliTirsItem.deserialize(serialized_item)
436
189
 
437
190
  def _get_download_urls(self, item: Item) -> dict[str, tuple[str, str]]:
@@ -291,7 +291,7 @@ class WorldCereal(LocalFiles):
291
291
  raise ValueError(f"No AEZ files found for {self.band}")
292
292
 
293
293
  super().__init__(
294
- src_dir=tif_dir,
294
+ src_dir=tif_dir.absolute().as_uri(),
295
295
  raster_item_specs=item_specs,
296
296
  layer_type=LayerType.RASTER,
297
297
  context=context,
@@ -75,7 +75,7 @@ class WorldCover(LocalFiles):
75
75
  tif_dir = self.download_worldcover_data(worldcover_upath)
76
76
 
77
77
  super().__init__(
78
- src_dir=tif_dir,
78
+ src_dir=tif_dir.absolute().as_uri(),
79
79
  layer_type=LayerType.RASTER,
80
80
  context=context,
81
81
  )
@@ -80,7 +80,7 @@ class WorldPop(LocalFiles):
80
80
  worldpop_upath.mkdir(parents=True, exist_ok=True)
81
81
  self.download_worldpop_data(worldpop_upath, timeout)
82
82
  super().__init__(
83
- src_dir=worldpop_upath,
83
+ src_dir=worldpop_upath.absolute().as_uri(),
84
84
  layer_type=LayerType.RASTER,
85
85
  context=context,
86
86
  )
@@ -19,7 +19,7 @@ from rslearn.config import LayerConfig, QueryConfig
19
19
  from rslearn.dataset import Window
20
20
  from rslearn.dataset.materialize import RasterMaterializer
21
21
  from rslearn.tile_stores import TileStore, TileStoreWithLayer
22
- from rslearn.utils import PixelBounds, Projection, STGeometry
22
+ from rslearn.utils import PixelBounds, Projection, STGeometry, get_global_raster_bounds
23
23
  from rslearn.utils.array import copy_spatial_array
24
24
  from rslearn.utils.raster_format import get_transform_from_projection_and_bounds
25
25
 
@@ -184,7 +184,7 @@ class XyzTiles(DataSource, TileStore):
184
184
  groups.append(cur_groups)
185
185
  return groups
186
186
 
187
- def deserialize_item(self, serialized_item: Any) -> Item:
187
+ def deserialize_item(self, serialized_item: dict) -> Item:
188
188
  """Deserializes an item from JSON-decoded data."""
189
189
  return Item.deserialize(serialized_item)
190
190
 
@@ -278,13 +278,9 @@ class XyzTiles(DataSource, TileStore):
278
278
  Returns:
279
279
  the bounds of the raster in the projection.
280
280
  """
281
- geom = STGeometry(self.projection, self.shp, None).to_projection(projection)
282
- return (
283
- int(geom.shp.bounds[0]),
284
- int(geom.shp.bounds[1]),
285
- int(geom.shp.bounds[2]),
286
- int(geom.shp.bounds[3]),
287
- )
281
+ # XyzTiles is a global data source, so we return global raster bounds based on
282
+ # the projection.
283
+ return get_global_raster_bounds(projection)
288
284
 
289
285
  def read_raster(
290
286
  self,
@@ -3,6 +3,7 @@
3
3
  from typing import Any
4
4
 
5
5
  import torch
6
+ from einops import rearrange
6
7
 
7
8
  from rslearn.train.model_context import ModelContext
8
9
 
@@ -79,7 +80,11 @@ class ConcatenateFeatures(IntermediateComponent):
79
80
  )
80
81
 
81
82
  add_data = torch.stack(
82
- [input_data[self.key] for input_data in context.inputs], dim=0
83
+ [
84
+ rearrange(input_data[self.key].image, "c t h w -> (c t) h w")
85
+ for input_data in context.inputs
86
+ ],
87
+ dim=0,
83
88
  )
84
89
  add_features = self.conv_layers(add_data)
85
90