rslearn 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. rslearn/config/dataset.py +22 -13
  2. rslearn/data_sources/__init__.py +8 -0
  3. rslearn/data_sources/aws_landsat.py +27 -18
  4. rslearn/data_sources/aws_open_data.py +41 -42
  5. rslearn/data_sources/copernicus.py +148 -2
  6. rslearn/data_sources/data_source.py +17 -10
  7. rslearn/data_sources/gcp_public_data.py +177 -100
  8. rslearn/data_sources/geotiff.py +1 -0
  9. rslearn/data_sources/google_earth_engine.py +17 -15
  10. rslearn/data_sources/local_files.py +59 -32
  11. rslearn/data_sources/openstreetmap.py +27 -23
  12. rslearn/data_sources/planet.py +10 -9
  13. rslearn/data_sources/planet_basemap.py +303 -0
  14. rslearn/data_sources/raster_source.py +23 -13
  15. rslearn/data_sources/usgs_landsat.py +56 -27
  16. rslearn/data_sources/utils.py +13 -6
  17. rslearn/data_sources/vector_source.py +1 -0
  18. rslearn/data_sources/xyz_tiles.py +8 -9
  19. rslearn/dataset/add_windows.py +1 -1
  20. rslearn/dataset/dataset.py +16 -5
  21. rslearn/dataset/manage.py +9 -4
  22. rslearn/dataset/materialize.py +26 -5
  23. rslearn/dataset/window.py +5 -0
  24. rslearn/log_utils.py +24 -0
  25. rslearn/main.py +123 -59
  26. rslearn/models/clip.py +62 -0
  27. rslearn/models/conv.py +56 -0
  28. rslearn/models/faster_rcnn.py +2 -19
  29. rslearn/models/fpn.py +1 -1
  30. rslearn/models/module_wrapper.py +43 -0
  31. rslearn/models/molmo.py +65 -0
  32. rslearn/models/multitask.py +1 -1
  33. rslearn/models/pooling_decoder.py +4 -2
  34. rslearn/models/satlaspretrain.py +4 -7
  35. rslearn/models/simple_time_series.py +61 -55
  36. rslearn/models/ssl4eo_s12.py +9 -9
  37. rslearn/models/swin.py +22 -21
  38. rslearn/models/unet.py +4 -2
  39. rslearn/models/upsample.py +35 -0
  40. rslearn/tile_stores/file.py +6 -3
  41. rslearn/tile_stores/tile_store.py +19 -7
  42. rslearn/train/callbacks/freeze_unfreeze.py +3 -3
  43. rslearn/train/data_module.py +5 -4
  44. rslearn/train/dataset.py +79 -36
  45. rslearn/train/lightning_module.py +15 -11
  46. rslearn/train/prediction_writer.py +22 -11
  47. rslearn/train/tasks/classification.py +9 -8
  48. rslearn/train/tasks/detection.py +94 -37
  49. rslearn/train/tasks/multi_task.py +1 -1
  50. rslearn/train/tasks/regression.py +8 -4
  51. rslearn/train/tasks/segmentation.py +23 -19
  52. rslearn/train/transforms/__init__.py +1 -1
  53. rslearn/train/transforms/concatenate.py +6 -2
  54. rslearn/train/transforms/crop.py +6 -2
  55. rslearn/train/transforms/flip.py +5 -1
  56. rslearn/train/transforms/normalize.py +9 -5
  57. rslearn/train/transforms/pad.py +1 -1
  58. rslearn/train/transforms/transform.py +3 -3
  59. rslearn/utils/__init__.py +4 -5
  60. rslearn/utils/array.py +2 -2
  61. rslearn/utils/feature.py +1 -1
  62. rslearn/utils/fsspec.py +70 -1
  63. rslearn/utils/geometry.py +155 -3
  64. rslearn/utils/grid_index.py +5 -5
  65. rslearn/utils/mp.py +4 -3
  66. rslearn/utils/raster_format.py +81 -73
  67. rslearn/utils/rtree_index.py +64 -17
  68. rslearn/utils/sqlite_index.py +7 -1
  69. rslearn/utils/utils.py +11 -3
  70. rslearn/utils/vector_format.py +113 -17
  71. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/METADATA +32 -27
  72. rslearn-0.0.2.dist-info/RECORD +94 -0
  73. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/WHEEL +1 -1
  74. rslearn/utils/mgrs.py +0 -24
  75. rslearn-0.0.1.dist-info/RECORD +0 -88
  76. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/LICENSE +0 -0
  77. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/entry_points.txt +0 -0
  78. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,303 @@
1
+ """Data source for Planet Labs Basemaps API."""
2
+
3
+ import os
4
+ import tempfile
5
+ from datetime import datetime
6
+ from typing import Any
7
+
8
+ import rasterio
9
+ import requests
10
+ import shapely
11
+ from upath import UPath
12
+
13
+ from rslearn.config import LayerConfig, QueryConfig, RasterLayerConfig
14
+ from rslearn.const import WGS84_PROJECTION
15
+ from rslearn.data_sources import DataSource, Item
16
+ from rslearn.data_sources.utils import match_candidate_items_to_window
17
+ from rslearn.log_utils import get_logger
18
+ from rslearn.tile_stores import PrefixedTileStore, TileStore
19
+ from rslearn.utils import STGeometry
20
+
21
+ from .raster_source import get_needed_projections, ingest_raster
22
+
23
+ logger = get_logger(__name__)
24
+
25
+
26
+ class PlanetItem(Item):
27
+ """An item referencing a particular mosaic and quad in Basemaps API."""
28
+
29
+ def __init__(self, name: str, geometry: STGeometry, mosaic_id: str, quad_id: str):
30
+ """Create a new PlanetItem.
31
+
32
+ Args:
33
+ name: the item name (combination of mosaic and quad ID).
34
+ geometry: the geometry associated with this quad.
35
+ mosaic_id: the mosaic ID in API
36
+ quad_id: the quad ID in API
37
+ """
38
+ super().__init__(name, geometry)
39
+ self.mosaic_id = mosaic_id
40
+ self.quad_id = quad_id
41
+
42
+ def serialize(self) -> dict:
43
+ """Serializes the item to a JSON-encodable dictionary."""
44
+ d = super().serialize()
45
+ d["mosaic_id"] = self.mosaic_id
46
+ d["quad_id"] = self.quad_id
47
+ return d
48
+
49
+ @staticmethod
50
+ def deserialize(d: dict) -> Item:
51
+ """Deserializes an item from a JSON-decoded dictionary."""
52
+ item = super(PlanetItem, PlanetItem).deserialize(d)
53
+ return PlanetItem(
54
+ name=item.name,
55
+ geometry=item.geometry,
56
+ mosaic_id=d["mosaic_id"],
57
+ quad_id=d["quad_id"],
58
+ )
59
+
60
+
61
+ class ApiError(Exception):
62
+ """An error from Planet Labs API."""
63
+
64
+ pass
65
+
66
+
67
+ class PlanetBasemap(DataSource):
68
+ """A data source for Planet Labs Basemaps API."""
69
+
70
+ api_url = "https://api.planet.com/basemaps/v1/"
71
+
72
+ def __init__(
73
+ self,
74
+ config: RasterLayerConfig,
75
+ series_id: str,
76
+ bands: list[str],
77
+ api_key: str | None = None,
78
+ ):
79
+ """Initialize a new Planet instance.
80
+
81
+ Args:
82
+ config: the LayerConfig of the layer containing this data source
83
+ series_id: the series of mosaics to use.
84
+ bands: list of band names to use.
85
+ api_key: optional Planet API key (it can also be provided via PL_API_KEY
86
+ environmnet variable).
87
+ """
88
+ self.config = config
89
+ self.bands = bands
90
+
91
+ self.session = requests.Session()
92
+ if api_key is None:
93
+ api_key = os.environ["PL_API_KEY"]
94
+ self.session.auth = (api_key, "")
95
+
96
+ # List mosaics.
97
+ self.mosaics = {}
98
+ for mosaic_dict in self._api_get_paginate(
99
+ path=f"series/{series_id}/mosaics", list_key="mosaics"
100
+ ):
101
+ shp = shapely.box(*mosaic_dict["bbox"])
102
+ time_range = (
103
+ datetime.fromisoformat(mosaic_dict["first_acquired"]),
104
+ datetime.fromisoformat(mosaic_dict["last_acquired"]),
105
+ )
106
+ geom = STGeometry(WGS84_PROJECTION, shp, time_range)
107
+ self.mosaics[mosaic_dict["id"]] = geom
108
+
109
+ @staticmethod
110
+ def from_config(config: LayerConfig, ds_path: UPath) -> "PlanetBasemap":
111
+ """Creates a new PlanetBasemap instance from a configuration dictionary."""
112
+ assert isinstance(config, RasterLayerConfig)
113
+ if config.data_source is None:
114
+ raise ValueError("data_source is required")
115
+ d = config.data_source.config_dict
116
+ kwargs = dict(
117
+ config=config,
118
+ series_id=d["series_id"],
119
+ bands=d["bands"],
120
+ )
121
+ optional_keys = [
122
+ "api_key",
123
+ ]
124
+ for optional_key in optional_keys:
125
+ if optional_key in d:
126
+ kwargs[optional_key] = d[optional_key]
127
+ return PlanetBasemap(**kwargs)
128
+
129
+ def _api_get(
130
+ self,
131
+ path: str | None = None,
132
+ url: str | None = None,
133
+ query_args: dict[str, str] | None = None,
134
+ ) -> list[Any] | dict[str, Any]:
135
+ """Perform a GET request on the API.
136
+
137
+ Args:
138
+ path: the path to GET, like "series".
139
+ url: the full URL to GET. Only one of path or url should be set.
140
+ query_args: optional params to include with the request.
141
+
142
+ Returns:
143
+ the JSON response data.
144
+
145
+ Raises:
146
+ ApiError: if the API returned an error response.
147
+ """
148
+ if path is None and url is None:
149
+ raise ValueError("Only one of path or url should be set")
150
+ if query_args:
151
+ kwargs = dict(params=query_args)
152
+ else:
153
+ kwargs = {}
154
+
155
+ if path:
156
+ url = self.api_url + path
157
+ if url is None:
158
+ raise ValueError("url is required")
159
+ response = self.session.get(url, **kwargs) # type: ignore
160
+
161
+ if response.status_code != 200:
162
+ raise ApiError(
163
+ f"{url}: got status code {response.status_code}: {response.text}"
164
+ )
165
+ return response.json()
166
+
167
+ def _api_get_paginate(
168
+ self, path: str, list_key: str, query_args: dict[str, str] | None = None
169
+ ) -> list:
170
+ """Get all items in a paginated response.
171
+
172
+ Args:
173
+ path: the path to GET.
174
+ list_key: the key in the response containing the list that should be
175
+ concatenated across all available pages.
176
+ query_args: optional params to include with the requests.
177
+
178
+ Returns:
179
+ the concatenated list of items.
180
+
181
+ Raises:
182
+ ApiError if the API returned an error response.
183
+ """
184
+ next_url = self.api_url + path
185
+ items = []
186
+ while True:
187
+ json_data = self._api_get(url=next_url, query_args=query_args)
188
+ if not isinstance(json_data, dict):
189
+ logger.warning(f"Expected dict, got {type(json_data)}")
190
+ continue
191
+ items += json_data[list_key]
192
+
193
+ if "_next" in json_data["_links"]:
194
+ next_url = json_data["_links"]["_next"]
195
+ else:
196
+ return items
197
+
198
+ def get_items(
199
+ self, geometries: list[STGeometry], query_config: QueryConfig
200
+ ) -> list[list[list[PlanetItem]]]:
201
+ """Get a list of items in the data source intersecting the given geometries.
202
+
203
+ Args:
204
+ geometries: the spatiotemporal geometries
205
+ query_config: the query configuration
206
+
207
+ Returns:
208
+ List of groups of items that should be retrieved for each geometry.
209
+ """
210
+ groups = []
211
+ for geometry in geometries:
212
+ geom_bbox = geometry.to_projection(WGS84_PROJECTION).shp.bounds
213
+ geom_bbox_str = ",".join([str(value) for value in geom_bbox])
214
+
215
+ # Find the relevant mosaics that the geometry intersects.
216
+ # For each relevant mosaic, identify the intersecting quads.
217
+ items = []
218
+ for mosaic_id, mosaic_geom in self.mosaics.items():
219
+ if not geometry.intersects(mosaic_geom):
220
+ continue
221
+ logger.info(f"found mosaic {mosaic_geom} for geom {geometry}")
222
+ # List all quads that intersect the current geometry's
223
+ # longitude/latitude bbox in this mosaic.
224
+ for quad_dict in self._api_get_paginate(
225
+ path=f"mosaics/{mosaic_id}/quads",
226
+ list_key="items",
227
+ query_args={"bbox": geom_bbox_str},
228
+ ):
229
+ logger.info(f"found quad {quad_dict}")
230
+ shp = shapely.box(*quad_dict["bbox"])
231
+ geom = STGeometry(WGS84_PROJECTION, shp, mosaic_geom.time_range)
232
+ quad_id = quad_dict["id"]
233
+ items.append(
234
+ PlanetItem(f"{mosaic_id}_{quad_id}", geom, mosaic_id, quad_id)
235
+ )
236
+ logger.info(f"found {len(items)} items for geom {geometry}")
237
+ cur_groups = match_candidate_items_to_window(geometry, items, query_config)
238
+ groups.append(cur_groups)
239
+
240
+ return groups
241
+
242
+ def deserialize_item(self, serialized_item: Any) -> Item:
243
+ """Deserializes an item from JSON-decoded data."""
244
+ assert isinstance(serialized_item, dict)
245
+ return PlanetItem.deserialize(serialized_item)
246
+
247
+ def ingest(
248
+ self,
249
+ tile_store: TileStore,
250
+ items: list[Item],
251
+ geometries: list[list[STGeometry]],
252
+ ) -> None:
253
+ """Ingest items into the given tile store.
254
+
255
+ Args:
256
+ tile_store: the tile store to ingest into
257
+ items: the items to ingest
258
+ geometries: a list of geometries needed for each item
259
+ """
260
+ for item, cur_geometries in zip(items, geometries):
261
+ with tempfile.TemporaryDirectory() as tmp_dir:
262
+ band_names = self.bands
263
+ cur_tile_store = PrefixedTileStore(
264
+ tile_store, (item.name, "_".join(band_names))
265
+ )
266
+ needed_projections = get_needed_projections(
267
+ cur_tile_store, band_names, self.config.band_sets, cur_geometries
268
+ )
269
+ if not needed_projections:
270
+ continue
271
+
272
+ assert isinstance(item, PlanetItem)
273
+ download_url = (
274
+ self.api_url + f"mosaics/{item.mosaic_id}/quads/{item.quad_id}/full"
275
+ )
276
+ response = self.session.get(
277
+ download_url, allow_redirects=True, stream=True
278
+ )
279
+ if response.status_code != 200:
280
+ # # temporary skip for now
281
+ # logger.error(
282
+ # f"{download_url}: got status code {response.status_code}: {response.text}"
283
+ # )
284
+ # continue
285
+ raise ApiError(
286
+ f"{download_url}: got status code {response.status_code}: {response.text}"
287
+ )
288
+
289
+ with tempfile.TemporaryDirectory() as tmp_dir:
290
+ local_fname = os.path.join(tmp_dir, "temp.tif")
291
+ with open(local_fname, "wb") as f:
292
+ for chunk in response.iter_content(chunk_size=8192):
293
+ f.write(chunk)
294
+
295
+ with rasterio.open(local_fname) as raster:
296
+ for projection in needed_projections:
297
+ ingest_raster(
298
+ tile_store=cur_tile_store,
299
+ raster=raster,
300
+ projection=projection,
301
+ time_range=item.geometry.time_range,
302
+ layer_config=self.config,
303
+ )
@@ -15,10 +15,13 @@ from rasterio.crs import CRS
15
15
  from rslearn.config import BandSetConfig, RasterFormatConfig, RasterLayerConfig
16
16
  from rslearn.const import TILE_SIZE
17
17
  from rslearn.dataset import Window
18
+ from rslearn.log_utils import get_logger
18
19
  from rslearn.tile_stores import LayerMetadata, TileStore
19
20
  from rslearn.utils import Projection, STGeometry
20
21
  from rslearn.utils.raster_format import load_raster_format
21
22
 
23
+ logger = get_logger(__name__)
24
+
22
25
 
23
26
  class ArrayWithTransform:
24
27
  """Stores an array along with the transform associated with the array."""
@@ -70,7 +73,7 @@ class ArrayWithTransform:
70
73
  """
71
74
  return self.array
72
75
 
73
- def close(self):
76
+ def close(self) -> None:
74
77
  """This is to mimic the rasterio.DatasetReader API.
75
78
 
76
79
  The close function is a no-op.
@@ -144,23 +147,25 @@ def get_needed_projections(
144
147
  list of Projection objects for which the item has not been ingested yet
145
148
  """
146
149
  # Identify which band set configs are relevant to this raster.
147
- raster_bands = set(raster_bands)
148
- relevant_band_sets = []
150
+ raster_bands_set = set(raster_bands)
151
+ relevant_band_set_list = []
149
152
  for band_set in band_sets:
150
153
  is_match = False
154
+ if band_set.bands is None:
155
+ continue
151
156
  for band in band_set.bands:
152
- if band not in raster_bands:
157
+ if band not in raster_bands_set:
153
158
  continue
154
159
  is_match = True
155
160
  break
156
161
  if not is_match:
157
162
  continue
158
- relevant_band_sets.append(band_set)
163
+ relevant_band_set_list.append(band_set)
159
164
 
160
165
  all_projections = {geometry.projection for geometry in geometries}
161
166
  needed_projections = []
162
167
  for projection in all_projections:
163
- for band_set in relevant_band_sets:
168
+ for band_set in relevant_band_set_list:
164
169
  final_projection, _ = band_set.get_final_projection_and_bounds(
165
170
  projection, None
166
171
  )
@@ -216,16 +221,17 @@ def ingest_raster(
216
221
  else:
217
222
  # Compute the suggested target transform.
218
223
  # rasterio negates the y resolution itself so here we have to negate it.
224
+ raster_bounds: rasterio.coords.BoundingBox = raster.bounds
219
225
  (dst_transform, dst_width, dst_height) = (
220
226
  rasterio.warp.calculate_default_transform(
221
227
  # Source info.
222
228
  src_crs=raster.crs,
223
229
  width=raster.width,
224
230
  height=raster.height,
225
- left=raster.bounds[0],
226
- bottom=raster.bounds[1],
227
- right=raster.bounds[2],
228
- top=raster.bounds[3],
231
+ left=raster_bounds.left,
232
+ bottom=raster_bounds.bottom,
233
+ right=raster_bounds.right,
234
+ top=raster_bounds.top,
229
235
  # Destination info.
230
236
  dst_crs=projection.crs,
231
237
  resolution=(projection.x_resolution, -projection.y_resolution),
@@ -258,7 +264,7 @@ def materialize_raster(
258
264
  window: Window,
259
265
  layer_name: str,
260
266
  band_cfg: BandSetConfig,
261
- ):
267
+ ) -> None:
262
268
  """Materialize a given raster for a window.
263
269
 
264
270
  Currently it is only supported for materializing one band set.
@@ -272,7 +278,8 @@ def materialize_raster(
272
278
  window_projection, window_bounds = band_cfg.get_final_projection_and_bounds(
273
279
  window.projection, window.bounds
274
280
  )
275
-
281
+ if window_bounds is None:
282
+ raise ValueError(f"No windowbounds specified for {layer_name}")
276
283
  # Re-project to just extract the window.
277
284
  array = raster.read()
278
285
  window_width = window_bounds[2] - window_bounds[0]
@@ -297,7 +304,10 @@ def materialize_raster(
297
304
  dst_transform=dst_transform,
298
305
  resampling=rasterio.enums.Resampling.bilinear,
299
306
  )
300
-
307
+ if band_cfg.bands is None or band_cfg.format is None:
308
+ raise ValueError(
309
+ f"No bands or format specified for {layer_name} materialization"
310
+ )
301
311
  # Write the array to layer directory.
302
312
  layer_dir = window.path / "layers" / layer_name
303
313
  out_dir = layer_dir / "_".join(band_cfg.bands)
@@ -1,4 +1,7 @@
1
- """Data source for Landsat data from USGS M2M API."""
1
+ """Data source for Landsat data from USGS M2M API.
2
+
3
+ # TODO: Handle the requests in a helper function for none checking
4
+ """
2
5
 
3
6
  import io
4
7
  import json
@@ -15,7 +18,7 @@ import requests
15
18
  import shapely
16
19
  from upath import UPath
17
20
 
18
- from rslearn.config import LayerConfig, QueryConfig, RasterLayerConfig
21
+ from rslearn.config import QueryConfig, RasterLayerConfig
19
22
  from rslearn.const import WGS84_PROJECTION
20
23
  from rslearn.data_sources import DataSource, Item
21
24
  from rslearn.data_sources.utils import match_candidate_items_to_window
@@ -36,8 +39,9 @@ class M2MAPIClient:
36
39
 
37
40
  api_url = "https://m2m.cr.usgs.gov/api/api/json/stable/"
38
41
  pagination_size = 1000
42
+ TIMEOUT = 1000000 # Set very high to start
39
43
 
40
- def __init__(self, username, password):
44
+ def __init__(self, username: str, password: str) -> None:
41
45
  """Initialize a new M2MAPIClient.
42
46
 
43
47
  Args:
@@ -47,7 +51,9 @@ class M2MAPIClient:
47
51
  self.username = username
48
52
  self.password = password
49
53
  json_data = json.dumps({"username": self.username, "password": self.password})
50
- response = requests.post(self.api_url + "login", data=json_data)
54
+ response = requests.post(
55
+ self.api_url + "login", data=json_data, timeout=self.TIMEOUT
56
+ )
51
57
  response.raise_for_status()
52
58
  self.auth_token = response.json()["data"]
53
59
 
@@ -67,24 +73,26 @@ class M2MAPIClient:
67
73
  self.api_url + endpoint,
68
74
  headers={"X-Auth-Token": self.auth_token},
69
75
  data=json.dumps(data),
76
+ timeout=self.TIMEOUT,
70
77
  )
71
78
  response.raise_for_status()
72
79
  if response.text:
73
- data = response.json()
74
- if data["errorMessage"]:
75
- raise APIException(data["errorMessage"])
76
- return data
80
+ response_dict = response.json()
81
+
82
+ if response_dict["errorMessage"]:
83
+ raise APIException(response_dict["errorMessage"])
84
+ return response_dict
77
85
  return None
78
86
 
79
- def close(self):
87
+ def close(self) -> None:
80
88
  """Logout from the API."""
81
89
  self.request("logout")
82
90
 
83
- def __enter__(self):
91
+ def __enter__(self) -> "M2MAPIClient":
84
92
  """Enter function to provide with semantics."""
85
93
  return self
86
94
 
87
- def __exit__(self):
95
+ def __exit__(self) -> None:
88
96
  """Exit function to provide with semantics.
89
97
 
90
98
  Logs out the API.
@@ -100,7 +108,10 @@ class M2MAPIClient:
100
108
  Returns:
101
109
  list of filter objects
102
110
  """
103
- return self.request("dataset-filters", {"datasetName": dataset_name})["data"]
111
+ response_dict = self.request("dataset-filters", {"datasetName": dataset_name})
112
+ if response_dict is None:
113
+ raise APIException("No response from API")
114
+ return response_dict["data"]
104
115
 
105
116
  def scene_search(
106
117
  self,
@@ -119,7 +130,7 @@ class M2MAPIClient:
119
130
  bbox: optional spatial filter
120
131
  metadata_filter: optional metadata filter dict
121
132
  """
122
- base_data = {"datasetName": dataset_name, "sceneFilter": {}}
133
+ base_data: dict[str, Any] = {"datasetName": dataset_name, "sceneFilter": {}}
123
134
  if acquisition_time_range:
124
135
  base_data["sceneFilter"]["acquisitionFilter"] = {
125
136
  "start": acquisition_time_range[0].isoformat(),
@@ -146,7 +157,10 @@ class M2MAPIClient:
146
157
  cur_data = base_data.copy()
147
158
  cur_data["startingNumber"] = starting_number
148
159
  cur_data["maxResults"] = self.pagination_size
149
- data = self.request("scene-search", cur_data)["data"]
160
+ response_dict = self.request("scene-search", cur_data)
161
+ if response_dict is None:
162
+ raise APIException("No response from API")
163
+ data = response_dict["data"]
150
164
  results.extend(data["results"])
151
165
  if data["recordsReturned"] < self.pagination_size:
152
166
  break
@@ -164,14 +178,17 @@ class M2MAPIClient:
164
178
  Returns:
165
179
  full scene metadata
166
180
  """
167
- return self.request(
181
+ response_dict = self.request(
168
182
  "scene-metadata",
169
183
  {
170
184
  "datasetName": dataset_name,
171
185
  "entityId": entity_id,
172
186
  "metadataType": "full",
173
187
  },
174
- )["data"]
188
+ )
189
+ if response_dict is None:
190
+ raise APIException("No response from API")
191
+ return response_dict["data"]
175
192
 
176
193
  def get_downloadable_products(
177
194
  self, dataset_name: str, entity_id: str
@@ -186,7 +203,10 @@ class M2MAPIClient:
186
203
  list of downloadable products
187
204
  """
188
205
  data = {"datasetName": dataset_name, "entityIds": [entity_id]}
189
- return self.request("download-options", data)["data"]
206
+ response_dict = self.request("download-options", data)
207
+ if response_dict is None:
208
+ raise APIException("No response from API")
209
+ return response_dict["data"]
190
210
 
191
211
  def get_download_url(self, entity_id: str, product_id: str) -> str:
192
212
  """Get the download URL for a given product.
@@ -204,9 +224,15 @@ class M2MAPIClient:
204
224
  {"label": label, "entityId": entity_id, "productId": product_id}
205
225
  ]
206
226
  }
207
- response = self.request("download-request", data)["data"]
227
+ response_dict = self.request("download-request", data)
228
+ if response_dict is None:
229
+ raise APIException("No response from API")
230
+ response = response_dict["data"]
208
231
  while True:
209
- response = self.request("download-retrieve", {"label": label})["data"]
232
+ response_dict = self.request("download-retrieve", {"label": label})
233
+ if response_dict is None:
234
+ raise APIException("No response from API")
235
+ response = response_dict["data"]
210
236
  if len(response["available"]) > 0:
211
237
  return response["available"][0]["url"]
212
238
  if len(response["requested"]) == 0:
@@ -264,7 +290,7 @@ class LandsatOliTirs(DataSource):
264
290
 
265
291
  def __init__(
266
292
  self,
267
- config: LayerConfig,
293
+ config: RasterLayerConfig,
268
294
  username: str,
269
295
  password: str,
270
296
  max_time_delta: timedelta = timedelta(days=30),
@@ -289,9 +315,10 @@ class LandsatOliTirs(DataSource):
289
315
  self.client = M2MAPIClient(username, password)
290
316
 
291
317
  @staticmethod
292
- def from_config(config: LayerConfig, ds_path: UPath) -> "LandsatOliTirs":
318
+ def from_config(config: RasterLayerConfig, ds_path: UPath) -> "LandsatOliTirs":
293
319
  """Creates a new LandsatOliTirs instance from a configuration dictionary."""
294
- assert isinstance(config, RasterLayerConfig)
320
+ if config.data_source is None:
321
+ raise ValueError("data_source is required")
295
322
  d = config.data_source.config_dict
296
323
  if "max_time_delta" in d:
297
324
  max_time_delta = timedelta(seconds=pytimeparse.parse(d["max_time_delta"]))
@@ -328,7 +355,7 @@ class LandsatOliTirs(DataSource):
328
355
 
329
356
  def get_items(
330
357
  self, geometries: list[STGeometry], query_config: QueryConfig
331
- ) -> list[list[list[Item]]]:
358
+ ) -> list[list[list[LandsatOliTirsItem]]]:
332
359
  """Get a list of items in the data source intersecting the given geometries.
333
360
 
334
361
  Args:
@@ -400,7 +427,7 @@ class LandsatOliTirs(DataSource):
400
427
  assert isinstance(serialized_item, dict)
401
428
  return LandsatOliTirsItem.deserialize(serialized_item)
402
429
 
403
- def _get_download_urls(self, item: Item) -> dict[str, str]:
430
+ def _get_download_urls(self, item: Item) -> dict[str, tuple[str, str]]:
404
431
  """Gets the download URLs for each band.
405
432
 
406
433
  Args:
@@ -438,7 +465,7 @@ class LandsatOliTirs(DataSource):
438
465
  download_urls = self._get_download_urls(item)
439
466
  for _, (display_id, download_url) in download_urls.items():
440
467
  buf = io.BytesIO()
441
- with requests.get(download_url, stream=True) as r:
468
+ with requests.get(download_url, stream=True, timeout=self.TIMEOUT) as r:
442
469
  r.raise_for_status()
443
470
  shutil.copyfileobj(r.raw, buf)
444
471
  buf.seek(0)
@@ -447,7 +474,7 @@ class LandsatOliTirs(DataSource):
447
474
  def ingest(
448
475
  self,
449
476
  tile_store: TileStore,
450
- items: list[Item],
477
+ items: list[LandsatOliTirsItem],
451
478
  geometries: list[list[STGeometry]],
452
479
  ) -> None:
453
480
  """Ingest items into the given tile store.
@@ -471,7 +498,9 @@ class LandsatOliTirs(DataSource):
471
498
  continue
472
499
 
473
500
  buf = io.BytesIO()
474
- with requests.get(download_urls[band][1], stream=True) as r:
501
+ with requests.get(
502
+ download_urls[band][1], stream=True, timeout=self.TIMEOUT
503
+ ) as r:
475
504
  r.raise_for_status()
476
505
  shutil.copyfileobj(r.raw, buf)
477
506
  buf.seek(0)
@@ -1,5 +1,7 @@
1
1
  """Utilities shared by data sources."""
2
2
 
3
+ from typing import TypeVar
4
+
3
5
  from rslearn.config import QueryConfig, SpaceMode, TimeMode
4
6
  from rslearn.data_sources import Item
5
7
  from rslearn.utils import STGeometry, shp_intersects
@@ -11,10 +13,12 @@ MOSAIC_REMAINDER_EPSILON = 0.01
11
13
  """Fraction of original geometry area below which mosaic is considered to contain the
12
14
  entire geometry."""
13
15
 
16
+ ItemType = TypeVar("ItemType", bound=Item)
17
+
14
18
 
15
19
  def match_candidate_items_to_window(
16
- geometry: STGeometry, items: list[Item], query_config: QueryConfig
17
- ) -> list[list[Item]]:
20
+ geometry: STGeometry, items: list[ItemType], query_config: QueryConfig
21
+ ) -> list[list[ItemType]]:
18
22
  """Match candidate items to a window based on the query configuration.
19
23
 
20
24
  Candidate items should be collected that intersect with the window's spatial
@@ -45,17 +49,20 @@ def match_candidate_items_to_window(
45
49
  items = [
46
50
  item
47
51
  for item in items
48
- if not item.time_range or item.range[1] <= geometry.time_range[0]
52
+ if not item.geometry.time_range
53
+ or item.geometry.time_range[1] <= geometry.time_range[0]
49
54
  ]
50
55
  elif query_config.time_mode == TimeMode.AFTER:
51
56
  items = [
52
57
  item
53
58
  for item in items
54
- if not item.time_range
55
- or item.time_range[0] >= geometry.time_range[1]
59
+ if not item.geometry.time_range
60
+ or item.geometry.time_range[0] >= geometry.time_range[1]
56
61
  ]
57
62
  items.sort(
58
- key=lambda item: geometry.distance_to_time_range(item.time_range)
63
+ key=lambda item: geometry.distance_to_time_range(
64
+ item.geometry.time_range
65
+ )
59
66
  )
60
67
 
61
68
  # Now apply space mode.
@@ -0,0 +1 @@
1
+ """Placeholder for a vector data source."""