rslearn 0.0.26__py3-none-any.whl → 0.0.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. rslearn/data_sources/__init__.py +2 -0
  2. rslearn/data_sources/aws_landsat.py +44 -161
  3. rslearn/data_sources/aws_open_data.py +2 -4
  4. rslearn/data_sources/aws_sentinel1.py +1 -3
  5. rslearn/data_sources/aws_sentinel2_element84.py +54 -165
  6. rslearn/data_sources/climate_data_store.py +1 -3
  7. rslearn/data_sources/copernicus.py +1 -2
  8. rslearn/data_sources/data_source.py +1 -1
  9. rslearn/data_sources/direct_materialize_data_source.py +336 -0
  10. rslearn/data_sources/earthdaily.py +52 -155
  11. rslearn/data_sources/earthdatahub.py +425 -0
  12. rslearn/data_sources/eurocrops.py +1 -2
  13. rslearn/data_sources/gcp_public_data.py +1 -2
  14. rslearn/data_sources/google_earth_engine.py +1 -2
  15. rslearn/data_sources/hf_srtm.py +595 -0
  16. rslearn/data_sources/local_files.py +1 -1
  17. rslearn/data_sources/openstreetmap.py +1 -1
  18. rslearn/data_sources/planet.py +1 -2
  19. rslearn/data_sources/planet_basemap.py +1 -2
  20. rslearn/data_sources/planetary_computer.py +183 -186
  21. rslearn/data_sources/soilgrids.py +3 -3
  22. rslearn/data_sources/stac.py +1 -2
  23. rslearn/data_sources/usda_cdl.py +1 -3
  24. rslearn/data_sources/usgs_landsat.py +7 -254
  25. rslearn/data_sources/worldcereal.py +1 -1
  26. rslearn/data_sources/worldcover.py +1 -1
  27. rslearn/data_sources/worldpop.py +1 -1
  28. rslearn/data_sources/xyz_tiles.py +5 -9
  29. rslearn/dataset/storage/file.py +16 -12
  30. rslearn/models/concatenate_features.py +6 -1
  31. rslearn/tile_stores/default.py +4 -2
  32. rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
  33. rslearn/train/data_module.py +36 -33
  34. rslearn/train/dataset.py +159 -68
  35. rslearn/train/lightning_module.py +60 -4
  36. rslearn/train/metrics.py +162 -0
  37. rslearn/train/model_context.py +3 -3
  38. rslearn/train/prediction_writer.py +69 -41
  39. rslearn/train/tasks/classification.py +14 -1
  40. rslearn/train/tasks/detection.py +5 -5
  41. rslearn/train/tasks/per_pixel_regression.py +19 -6
  42. rslearn/train/tasks/regression.py +19 -3
  43. rslearn/train/tasks/segmentation.py +17 -0
  44. rslearn/utils/__init__.py +2 -0
  45. rslearn/utils/fsspec.py +51 -1
  46. rslearn/utils/geometry.py +21 -0
  47. rslearn/utils/m2m_api.py +251 -0
  48. rslearn/utils/retry_session.py +43 -0
  49. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/METADATA +6 -3
  50. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/RECORD +55 -50
  51. rslearn/data_sources/earthdata_srtm.py +0 -282
  52. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/WHEEL +0 -0
  53. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/entry_points.txt +0 -0
  54. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/licenses/LICENSE +0 -0
  55. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/licenses/NOTICE +0 -0
  56. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/top_level.txt +0 -0
@@ -4,12 +4,9 @@
4
4
  """
5
5
 
6
6
  import io
7
- import json
8
7
  import os
9
8
  import shutil
10
9
  import tempfile
11
- import time
12
- import uuid
13
10
  from collections.abc import Generator
14
11
  from datetime import UTC, datetime, timedelta
15
12
  from typing import Any, BinaryIO
@@ -24,246 +21,7 @@ from rslearn.data_sources import DataSource, DataSourceContext, Item
24
21
  from rslearn.data_sources.utils import match_candidate_items_to_window
25
22
  from rslearn.tile_stores import TileStoreWithLayer
26
23
  from rslearn.utils import STGeometry
27
-
28
-
29
- class APIException(Exception):
30
- """Exception raised for M2M API errors."""
31
-
32
- pass
33
-
34
-
35
- class M2MAPIClient:
36
- """An API client for interacting with the USGS M2M API."""
37
-
38
- api_url = "https://m2m.cr.usgs.gov/api/api/json/stable/"
39
- pagination_size = 1000
40
-
41
- def __init__(
42
- self,
43
- username: str,
44
- password: str | None = None,
45
- token: str | None = None,
46
- timeout: timedelta = timedelta(seconds=120),
47
- ) -> None:
48
- """Initialize a new M2MAPIClient.
49
-
50
- Args:
51
- username: the EROS username
52
- password: the EROS password
53
- token: the application token. One of password or token must be specified.
54
- timeout: timeout for requests.
55
- """
56
- self.username = username
57
- self.timeout = timeout
58
-
59
- if password is not None and token is not None:
60
- raise ValueError("only one of password or token can be specified")
61
-
62
- if password is not None:
63
- json_data = json.dumps({"username": self.username, "password": password})
64
- response = requests.post(
65
- self.api_url + "login",
66
- data=json_data,
67
- timeout=self.timeout.total_seconds(),
68
- )
69
-
70
- elif token is not None:
71
- json_data = json.dumps({"username": username, "token": token})
72
- response = requests.post(
73
- self.api_url + "login-token",
74
- data=json_data,
75
- timeout=self.timeout.total_seconds(),
76
- )
77
-
78
- else:
79
- raise ValueError("one of password or token must be specified")
80
-
81
- response.raise_for_status()
82
- self.auth_token = response.json()["data"]
83
-
84
- def request(
85
- self, endpoint: str, data: dict[str, Any] | None = None
86
- ) -> dict[str, Any] | None:
87
- """Make a request to the API.
88
-
89
- Args:
90
- endpoint: the endpoint to call
91
- data: POST data to pass
92
-
93
- Returns:
94
- JSON response data if any
95
- """
96
- response = requests.post(
97
- self.api_url + endpoint,
98
- headers={"X-Auth-Token": self.auth_token},
99
- data=json.dumps(data),
100
- timeout=self.timeout.total_seconds(),
101
- )
102
- response.raise_for_status()
103
- if response.text:
104
- response_dict = response.json()
105
-
106
- if response_dict["errorMessage"]:
107
- raise APIException(response_dict["errorMessage"])
108
- return response_dict
109
- return None
110
-
111
- def close(self) -> None:
112
- """Logout from the API."""
113
- self.request("logout")
114
-
115
- def __enter__(self) -> "M2MAPIClient":
116
- """Enter function to provide with semantics."""
117
- return self
118
-
119
- def __exit__(self) -> None:
120
- """Exit function to provide with semantics.
121
-
122
- Logs out the API.
123
- """
124
- self.close()
125
-
126
- def get_filters(self, dataset_name: str) -> list[dict[str, Any]]:
127
- """Returns filters available for the given dataset.
128
-
129
- Args:
130
- dataset_name: the dataset name e.g. landsat_ot_c2_l1
131
-
132
- Returns:
133
- list of filter objects
134
- """
135
- response_dict = self.request("dataset-filters", {"datasetName": dataset_name})
136
- if response_dict is None:
137
- raise APIException("No response from API")
138
- return response_dict["data"]
139
-
140
- def scene_search(
141
- self,
142
- dataset_name: str,
143
- acquisition_time_range: tuple[datetime, datetime] | None = None,
144
- cloud_cover_range: tuple[int, int] | None = None,
145
- bbox: tuple[float, float, float, float] | None = None,
146
- metadata_filter: dict[str, Any] | None = None,
147
- ) -> list[dict[str, Any]]:
148
- """Search for scenes matching the arguments.
149
-
150
- Args:
151
- dataset_name: the dataset name e.g. landsat_ot_c2_l1
152
- acquisition_time_range: optional filter on the acquisition time
153
- cloud_cover_range: optional filter on the cloud cover
154
- bbox: optional spatial filter
155
- metadata_filter: optional metadata filter dict
156
- """
157
- base_data: dict[str, Any] = {"datasetName": dataset_name, "sceneFilter": {}}
158
- if acquisition_time_range:
159
- base_data["sceneFilter"]["acquisitionFilter"] = {
160
- "start": acquisition_time_range[0].isoformat(),
161
- "end": acquisition_time_range[1].isoformat(),
162
- }
163
- if cloud_cover_range:
164
- base_data["sceneFilter"]["cloudCoverFilter"] = {
165
- "min": cloud_cover_range[0],
166
- "max": cloud_cover_range[1],
167
- "includeUnknown": False,
168
- }
169
- if bbox:
170
- base_data["sceneFilter"]["spatialFilter"] = {
171
- "filterType": "mbr",
172
- "lowerLeft": {"longitude": bbox[0], "latitude": bbox[1]},
173
- "upperRight": {"longitude": bbox[2], "latitude": bbox[3]},
174
- }
175
- if metadata_filter:
176
- base_data["sceneFilter"]["metadataFilter"] = metadata_filter
177
-
178
- starting_number = 1
179
- results = []
180
- while True:
181
- cur_data = base_data.copy()
182
- cur_data["startingNumber"] = starting_number
183
- cur_data["maxResults"] = self.pagination_size
184
- response_dict = self.request("scene-search", cur_data)
185
- if response_dict is None:
186
- raise APIException("No response from API")
187
- data = response_dict["data"]
188
- results.extend(data["results"])
189
- if data["recordsReturned"] < self.pagination_size:
190
- break
191
- starting_number += self.pagination_size
192
-
193
- return results
194
-
195
- def get_scene_metadata(self, dataset_name: str, entity_id: str) -> dict[str, Any]:
196
- """Get detailed metadata for a scene.
197
-
198
- Args:
199
- dataset_name: the dataset name in which to search
200
- entity_id: the entity ID of the scene
201
-
202
- Returns:
203
- full scene metadata
204
- """
205
- response_dict = self.request(
206
- "scene-metadata",
207
- {
208
- "datasetName": dataset_name,
209
- "entityId": entity_id,
210
- "metadataType": "full",
211
- },
212
- )
213
- if response_dict is None:
214
- raise APIException("No response from API")
215
- return response_dict["data"]
216
-
217
- def get_downloadable_products(
218
- self, dataset_name: str, entity_id: str
219
- ) -> list[dict[str, Any]]:
220
- """Get the downloadable products for a given scene.
221
-
222
- Args:
223
- dataset_name: the dataset name
224
- entity_id: the entity ID of the scene
225
-
226
- Returns:
227
- list of downloadable products
228
- """
229
- data = {"datasetName": dataset_name, "entityIds": [entity_id]}
230
- response_dict = self.request("download-options", data)
231
- if response_dict is None:
232
- raise APIException("No response from API")
233
- return response_dict["data"]
234
-
235
- def get_download_url(self, entity_id: str, product_id: str) -> str:
236
- """Get the download URL for a given product.
237
-
238
- Args:
239
- entity_id: the entity ID of the product
240
- product_id: the product ID of the product
241
-
242
- Returns:
243
- the download URL
244
- """
245
- label = str(uuid.uuid4())
246
- data = {
247
- "downloads": [
248
- {"label": label, "entityId": entity_id, "productId": product_id}
249
- ]
250
- }
251
- response_dict = self.request("download-request", data)
252
- if response_dict is None:
253
- raise APIException("No response from API")
254
- response = response_dict["data"]
255
- while True:
256
- response_dict = self.request("download-retrieve", {"label": label})
257
- if response_dict is None:
258
- raise APIException("No response from API")
259
- response = response_dict["data"]
260
- if len(response["available"]) > 0:
261
- return response["available"][0]["url"]
262
- if len(response["requested"]) == 0:
263
- raise Exception("Did not get download URL")
264
- if response["requested"][0].get("url"):
265
- return response["requested"][0]["url"]
266
- time.sleep(10)
24
+ from rslearn.utils.m2m_api import APIException, M2MAPIClient
267
25
 
268
26
 
269
27
  class LandsatOliTirsItem(Item):
@@ -314,30 +72,26 @@ class LandsatOliTirs(DataSource):
314
72
 
315
73
  def __init__(
316
74
  self,
317
- username: str,
318
- sort_by: str | None = None,
319
- password: str | None = None,
75
+ username: str | None = None,
320
76
  token: str | None = None,
77
+ sort_by: str | None = None,
321
78
  timeout: timedelta = timedelta(seconds=10),
322
79
  context: DataSourceContext = DataSourceContext(),
323
80
  ):
324
81
  """Initialize a new LandsatOliTirs instance.
325
82
 
326
83
  Args:
327
- username: EROS username
84
+ username: EROS username (see M2MAPIClient).
85
+ token: EROS application token (see M2MAPIClient).
328
86
  sort_by: can be "cloud_cover", default arbitrary order; only has effect for
329
87
  SpaceMode.WITHIN.
330
- password: EROS password (see M2MAPIClient).
331
- token: EROS application token (see M2MAPIClient).
332
88
  timeout: timeout for requests.
333
89
  context: the data source context.
334
90
  """
335
91
  self.sort_by = sort_by
336
92
  self.timeout = timeout
337
93
 
338
- self.client = M2MAPIClient(
339
- username, password=password, token=token, timeout=timeout
340
- )
94
+ self.client = M2MAPIClient(username=username, token=token, timeout=timeout)
341
95
 
342
96
  def _scene_metadata_to_item(self, result: dict[str, Any]) -> LandsatOliTirsItem:
343
97
  """Convert scene metadata from the API to a LandsatOliTirsItem."""
@@ -429,9 +183,8 @@ class LandsatOliTirs(DataSource):
429
183
  )
430
184
  return self._scene_metadata_to_item(scene_metadata)
431
185
 
432
- def deserialize_item(self, serialized_item: Any) -> Item:
186
+ def deserialize_item(self, serialized_item: dict) -> Item:
433
187
  """Deserializes an item from JSON-decoded data."""
434
- assert isinstance(serialized_item, dict)
435
188
  return LandsatOliTirsItem.deserialize(serialized_item)
436
189
 
437
190
  def _get_download_urls(self, item: Item) -> dict[str, tuple[str, str]]:
@@ -291,7 +291,7 @@ class WorldCereal(LocalFiles):
291
291
  raise ValueError(f"No AEZ files found for {self.band}")
292
292
 
293
293
  super().__init__(
294
- src_dir=tif_dir,
294
+ src_dir=tif_dir.absolute().as_uri(),
295
295
  raster_item_specs=item_specs,
296
296
  layer_type=LayerType.RASTER,
297
297
  context=context,
@@ -75,7 +75,7 @@ class WorldCover(LocalFiles):
75
75
  tif_dir = self.download_worldcover_data(worldcover_upath)
76
76
 
77
77
  super().__init__(
78
- src_dir=tif_dir,
78
+ src_dir=tif_dir.absolute().as_uri(),
79
79
  layer_type=LayerType.RASTER,
80
80
  context=context,
81
81
  )
@@ -80,7 +80,7 @@ class WorldPop(LocalFiles):
80
80
  worldpop_upath.mkdir(parents=True, exist_ok=True)
81
81
  self.download_worldpop_data(worldpop_upath, timeout)
82
82
  super().__init__(
83
- src_dir=worldpop_upath,
83
+ src_dir=worldpop_upath.absolute().as_uri(),
84
84
  layer_type=LayerType.RASTER,
85
85
  context=context,
86
86
  )
@@ -19,7 +19,7 @@ from rslearn.config import LayerConfig, QueryConfig
19
19
  from rslearn.dataset import Window
20
20
  from rslearn.dataset.materialize import RasterMaterializer
21
21
  from rslearn.tile_stores import TileStore, TileStoreWithLayer
22
- from rslearn.utils import PixelBounds, Projection, STGeometry
22
+ from rslearn.utils import PixelBounds, Projection, STGeometry, get_global_raster_bounds
23
23
  from rslearn.utils.array import copy_spatial_array
24
24
  from rslearn.utils.raster_format import get_transform_from_projection_and_bounds
25
25
 
@@ -184,7 +184,7 @@ class XyzTiles(DataSource, TileStore):
184
184
  groups.append(cur_groups)
185
185
  return groups
186
186
 
187
- def deserialize_item(self, serialized_item: Any) -> Item:
187
+ def deserialize_item(self, serialized_item: dict) -> Item:
188
188
  """Deserializes an item from JSON-decoded data."""
189
189
  return Item.deserialize(serialized_item)
190
190
 
@@ -278,13 +278,9 @@ class XyzTiles(DataSource, TileStore):
278
278
  Returns:
279
279
  the bounds of the raster in the projection.
280
280
  """
281
- geom = STGeometry(self.projection, self.shp, None).to_projection(projection)
282
- return (
283
- int(geom.shp.bounds[0]),
284
- int(geom.shp.bounds[1]),
285
- int(geom.shp.bounds[2]),
286
- int(geom.shp.bounds[3]),
287
- )
281
+ # XyzTiles is a global data source, so we return global raster bounds based on
282
+ # the projection.
283
+ return get_global_raster_bounds(projection)
288
284
 
289
285
  def read_raster(
290
286
  self,
@@ -15,7 +15,7 @@ from rslearn.dataset.window import (
15
15
  get_window_layer_dir,
16
16
  )
17
17
  from rslearn.log_utils import get_logger
18
- from rslearn.utils.fsspec import open_atomic
18
+ from rslearn.utils.fsspec import iter_nonhidden_subdirs, open_atomic
19
19
  from rslearn.utils.mp import star_imap_unordered
20
20
 
21
21
  from .storage import WindowStorage, WindowStorageFactory
@@ -77,8 +77,8 @@ class FileWindowStorage(WindowStorage):
77
77
  window_dirs = []
78
78
  if not groups:
79
79
  groups = []
80
- for p in (self.path / "windows").iterdir():
81
- groups.append(p.name)
80
+ for group_dir in iter_nonhidden_subdirs(self.path / "windows"):
81
+ groups.append(group_dir.name)
82
82
  for group in groups:
83
83
  group_dir = self.path / "windows" / group
84
84
  if not group_dir.exists():
@@ -86,16 +86,20 @@ class FileWindowStorage(WindowStorage):
86
86
  f"Skipping group directory {group_dir} since it does not exist"
87
87
  )
88
88
  continue
89
+ if not group_dir.is_dir():
90
+ logger.warning(
91
+ f"Skipping group path {group_dir} since it is not a directory"
92
+ )
93
+ continue
89
94
  if names:
90
- cur_names = names
95
+ for window_name in names:
96
+ window_dir = group_dir / window_name
97
+ if not window_dir.is_dir():
98
+ continue
99
+ window_dirs.append(window_dir)
91
100
  else:
92
- cur_names = []
93
- for p in group_dir.iterdir():
94
- cur_names.append(p.name)
95
-
96
- for window_name in cur_names:
97
- window_dir = group_dir / window_name
98
- window_dirs.append(window_dir)
101
+ for window_dir in iter_nonhidden_subdirs(group_dir):
102
+ window_dirs.append(window_dir)
99
103
 
100
104
  if workers == 0:
101
105
  windows = [load_window(self, window_dir) for window_dir in window_dirs]
@@ -162,7 +166,7 @@ class FileWindowStorage(WindowStorage):
162
166
  return []
163
167
 
164
168
  completed_layers = []
165
- for layer_dir in layers_directory.iterdir():
169
+ for layer_dir in iter_nonhidden_subdirs(layers_directory):
166
170
  layer_name, group_idx = get_layer_and_group_from_dir_name(layer_dir.name)
167
171
  if not self.is_layer_completed(group, name, layer_name, group_idx):
168
172
  continue
@@ -3,6 +3,7 @@
3
3
  from typing import Any
4
4
 
5
5
  import torch
6
+ from einops import rearrange
6
7
 
7
8
  from rslearn.train.model_context import ModelContext
8
9
 
@@ -79,7 +80,11 @@ class ConcatenateFeatures(IntermediateComponent):
79
80
  )
80
81
 
81
82
  add_data = torch.stack(
82
- [input_data[self.key] for input_data in context.inputs], dim=0
83
+ [
84
+ rearrange(input_data[self.key].image, "c t h w -> (c t) h w")
85
+ for input_data in context.inputs
86
+ ],
87
+ dim=0,
83
88
  )
84
89
  add_features = self.conv_layers(add_data)
85
90
 
@@ -15,6 +15,8 @@ from upath import UPath
15
15
  from rslearn.const import WGS84_PROJECTION
16
16
  from rslearn.utils.feature import Feature
17
17
  from rslearn.utils.fsspec import (
18
+ iter_nonhidden_files,
19
+ iter_nonhidden_subdirs,
18
20
  join_upath,
19
21
  open_atomic,
20
22
  open_rasterio_upath_reader,
@@ -129,7 +131,7 @@ class DefaultTileStore(TileStore):
129
131
  ValueError: if no file is found.
130
132
  """
131
133
  raster_dir = self._get_raster_dir(layer_name, item_name, bands)
132
- for fname in raster_dir.iterdir():
134
+ for fname in iter_nonhidden_files(raster_dir):
133
135
  # Ignore completed sentinel files, bands files, as well as temporary files created by
134
136
  # open_atomic (in case this tile store is on local filesystem).
135
137
  if fname.name == COMPLETED_FNAME:
@@ -175,7 +177,7 @@ class DefaultTileStore(TileStore):
175
177
  return []
176
178
 
177
179
  bands: list[list[str]] = []
178
- for raster_dir in item_dir.iterdir():
180
+ for raster_dir in iter_nonhidden_subdirs(item_dir):
179
181
  if not (raster_dir / BANDS_FNAME).exists():
180
182
  # This is likely a legacy directory where the bands are only encoded in
181
183
  # the directory name, so we have to rely on that.