rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,195 @@
1
+ """Crop type data from the USDA Cropland Data Layer."""
2
+
3
+ import os
4
+ import tempfile
5
+ import zipfile
6
+ from datetime import UTC, datetime, timedelta
7
+ from typing import Any
8
+
9
+ import requests
10
+ import requests.auth
11
+ import shapely
12
+ from upath import UPath
13
+
14
+ from rslearn.config import QueryConfig
15
+ from rslearn.const import WGS84_PROJECTION
16
+ from rslearn.data_sources import DataSource, DataSourceContext, Item
17
+ from rslearn.data_sources.utils import match_candidate_items_to_window
18
+ from rslearn.log_utils import get_logger
19
+ from rslearn.tile_stores import TileStoreWithLayer
20
+ from rslearn.utils.geometry import STGeometry
21
+
22
+ logger = get_logger(__name__)
23
+
24
+
25
+ class CDL(DataSource):
26
+ """Data source for crop type data from the USDA Cropland Data Layer.
27
+
28
+ See https://www.nass.usda.gov/Research_and_Science/Cropland/SARS1a.php for details
29
+ about the data.
30
+
31
+ There is one GeoTIFF item per year from 2008. Each GeoTIFF spans the entire
32
+ continental US, and has a single band.
33
+ """
34
+
35
+ BASE_URL = (
36
+ "https://www.nass.usda.gov/Research_and_Science/Cropland/Release/datasets/"
37
+ )
38
+ ZIP_FILENAMES = {
39
+ 2024: "2024_30m_cdls.zip",
40
+ 2023: "2023_30m_cdls.zip",
41
+ 2022: "2022_30m_cdls.zip",
42
+ 2021: "2021_30m_cdls.zip",
43
+ 2020: "2020_30m_cdls.zip",
44
+ 2019: "2019_30m_cdls.zip",
45
+ 2018: "2018_30m_cdls.zip",
46
+ 2017: "2017_30m_cdls.zip",
47
+ 2016: "2016_30m_cdls.zip",
48
+ 2015: "2015_30m_cdls.zip",
49
+ 2014: "2014_30m_cdls.zip",
50
+ 2013: "2013_30m_cdls.zip",
51
+ 2012: "2012_30m_cdls.zip",
52
+ 2011: "2011_30m_cdls.zip",
53
+ 2010: "2010_30m_cdls.zip",
54
+ 2009: "2009_30m_cdls.zip",
55
+ 2008: "2008_30m_cdls.zip",
56
+ }
57
+
58
+ # The bounds of each GeoTIFF in WGS84 coordinates, based on the 2023 map.
59
+ BOUNDS = shapely.box(-127.9, 23.0, -65.3, 48.3)
60
+
61
+ def __init__(
62
+ self,
63
+ timeout: timedelta = timedelta(seconds=10),
64
+ context: DataSourceContext = DataSourceContext(),
65
+ ):
66
+ """Initialize a new CDL instance.
67
+
68
+ Args:
69
+ timeout: timeout for requests.
70
+ context: the data source context.
71
+ """
72
+ self.timeout = timeout
73
+
74
+ # Get the band name from the layer config, which should have a single band set
75
+ # with a single band. If the layer config is not available in the context, we
76
+ # default to "cdl".
77
+ if context.layer_config is not None:
78
+ if len(context.layer_config.band_sets) != 1:
79
+ raise ValueError("expected a single band set")
80
+ if len(context.layer_config.band_sets[0].bands) != 1:
81
+ raise ValueError("expected band set to have a single band")
82
+ self.band_name = context.layer_config.band_sets[0].bands[0]
83
+ else:
84
+ self.band_name = "cdl"
85
+
86
+ def get_item_by_name(self, name: str) -> Item:
87
+ """Gets an item by name.
88
+
89
+ Args:
90
+ name: the name of the item to get. For CDL, the item name is the filename
91
+ of the zip file containing the per-year GeoTIFF.
92
+
93
+ Returns:
94
+ the Item object
95
+ """
96
+ year = int(name[0:4])
97
+ geometry = STGeometry(
98
+ WGS84_PROJECTION,
99
+ self.BOUNDS,
100
+ (
101
+ datetime(year, 1, 1, tzinfo=UTC),
102
+ datetime(year + 1, 1, 1, tzinfo=UTC),
103
+ ),
104
+ )
105
+ return Item(name, geometry)
106
+
107
+ def get_items(
108
+ self, geometries: list[STGeometry], query_config: QueryConfig
109
+ ) -> list[list[list[Item]]]:
110
+ """Get a list of items in the data source intersecting the given geometries.
111
+
112
+ Args:
113
+ geometries: the spatiotemporal geometries
114
+ query_config: the query configuration
115
+
116
+ Returns:
117
+ List of groups of items that should be retrieved for each geometry.
118
+ """
119
+ # First enumerate all items.
120
+ # Then we simply pass this to match_candidate_items_to_window.
121
+ items: list[Item] = []
122
+ for year, fname in self.ZIP_FILENAMES.items():
123
+ geometry = STGeometry(
124
+ WGS84_PROJECTION,
125
+ self.BOUNDS,
126
+ (
127
+ datetime(year, 1, 1, tzinfo=UTC),
128
+ datetime(year + 1, 1, 1, tzinfo=UTC),
129
+ ),
130
+ )
131
+ items.append(Item(fname, geometry))
132
+
133
+ groups = []
134
+ for geometry in geometries:
135
+ cur_groups = match_candidate_items_to_window(geometry, items, query_config)
136
+ groups.append(cur_groups)
137
+
138
+ return groups
139
+
140
+ def deserialize_item(self, serialized_item: Any) -> Item:
141
+ """Deserializes an item from JSON-decoded data."""
142
+ assert isinstance(serialized_item, dict)
143
+ return Item.deserialize(serialized_item)
144
+
145
+ def ingest(
146
+ self,
147
+ tile_store: TileStoreWithLayer,
148
+ items: list[Item],
149
+ geometries: list[list[STGeometry]],
150
+ ) -> None:
151
+ """Ingest items into the given tile store.
152
+
153
+ Args:
154
+ tile_store: the tile store to ingest into
155
+ items: the items to ingest
156
+ geometries: a list of geometries needed for each item
157
+ """
158
+ for item in items:
159
+ if tile_store.is_raster_ready(item.name, [self.band_name]):
160
+ continue
161
+
162
+ # Download the zip file.
163
+ url = self.BASE_URL + item.name
164
+ logger.debug(f"Downloading CDL GeoTIFF from {url}")
165
+ response = requests.get(
166
+ url, stream=True, timeout=self.timeout.total_seconds()
167
+ )
168
+ response.raise_for_status()
169
+
170
+ with tempfile.TemporaryDirectory() as tmp_dir:
171
+ # Store it in temporary directory.
172
+ zip_fname = os.path.join(tmp_dir, "data.zip")
173
+ with open(zip_fname, "wb") as f:
174
+ for chunk in response.iter_content(chunk_size=8192):
175
+ f.write(chunk)
176
+
177
+ # Extract the .tif file.
178
+ logger.debug(f"Extracting GeoTIFF from {item.name}")
179
+ with zipfile.ZipFile(zip_fname) as zip_f:
180
+ candidate_member_names = [
181
+ member_name
182
+ for member_name in zip_f.namelist()
183
+ if member_name.endswith(".tif")
184
+ ]
185
+ if len(candidate_member_names) != 1:
186
+ raise ValueError(
187
+ f"expected CDL zip to have one .tif file but got {candidate_member_names}"
188
+ )
189
+ local_fname = zip_f.extract(candidate_member_names[0], path=tmp_dir)
190
+
191
+ # Now we can ingest it.
192
+ logger.debug(f"Ingesting data for {item.name}")
193
+ tile_store.write_raster_file(
194
+ item.name, [self.band_name], UPath(local_fname)
195
+ )
@@ -1,29 +1,30 @@
1
- """Data source for Landsat data from USGS M2M API."""
1
+ """Data source for Landsat data from USGS M2M API.
2
+
3
+ # TODO: Handle the requests in a helper function for none checking
4
+ """
2
5
 
3
6
  import io
4
7
  import json
8
+ import os
5
9
  import shutil
10
+ import tempfile
6
11
  import time
7
12
  import uuid
8
13
  from collections.abc import Generator
9
- from datetime import datetime, timedelta, timezone
14
+ from datetime import UTC, datetime, timedelta
10
15
  from typing import Any, BinaryIO
11
16
 
12
- import pytimeparse
13
- import rasterio
14
17
  import requests
15
18
  import shapely
16
19
  from upath import UPath
17
20
 
18
- from rslearn.config import LayerConfig, QueryConfig, RasterLayerConfig
21
+ from rslearn.config import QueryConfig
19
22
  from rslearn.const import WGS84_PROJECTION
20
- from rslearn.data_sources import DataSource, Item
23
+ from rslearn.data_sources import DataSource, DataSourceContext, Item
21
24
  from rslearn.data_sources.utils import match_candidate_items_to_window
22
- from rslearn.tile_stores import PrefixedTileStore, TileStore
25
+ from rslearn.tile_stores import TileStoreWithLayer
23
26
  from rslearn.utils import STGeometry
24
27
 
25
- from .raster_source import get_needed_projections, ingest_raster
26
-
27
28
 
28
29
  class APIException(Exception):
29
30
  """Exception raised for M2M API errors."""
@@ -37,17 +38,46 @@ class M2MAPIClient:
37
38
  api_url = "https://m2m.cr.usgs.gov/api/api/json/stable/"
38
39
  pagination_size = 1000
39
40
 
40
- def __init__(self, username, password):
41
+ def __init__(
42
+ self,
43
+ username: str,
44
+ password: str | None = None,
45
+ token: str | None = None,
46
+ timeout: timedelta = timedelta(seconds=120),
47
+ ) -> None:
41
48
  """Initialize a new M2MAPIClient.
42
49
 
43
50
  Args:
44
51
  username: the EROS username
45
52
  password: the EROS password
53
+ token: the application token. One of password or token must be specified.
54
+ timeout: timeout for requests.
46
55
  """
47
56
  self.username = username
48
- self.password = password
49
- json_data = json.dumps({"username": self.username, "password": self.password})
50
- response = requests.post(self.api_url + "login", data=json_data)
57
+ self.timeout = timeout
58
+
59
+ if password is not None and token is not None:
60
+ raise ValueError("only one of password or token can be specified")
61
+
62
+ if password is not None:
63
+ json_data = json.dumps({"username": self.username, "password": password})
64
+ response = requests.post(
65
+ self.api_url + "login",
66
+ data=json_data,
67
+ timeout=self.timeout.total_seconds(),
68
+ )
69
+
70
+ elif token is not None:
71
+ json_data = json.dumps({"username": username, "token": token})
72
+ response = requests.post(
73
+ self.api_url + "login-token",
74
+ data=json_data,
75
+ timeout=self.timeout.total_seconds(),
76
+ )
77
+
78
+ else:
79
+ raise ValueError("one of password or token must be specified")
80
+
51
81
  response.raise_for_status()
52
82
  self.auth_token = response.json()["data"]
53
83
 
@@ -67,24 +97,26 @@ class M2MAPIClient:
67
97
  self.api_url + endpoint,
68
98
  headers={"X-Auth-Token": self.auth_token},
69
99
  data=json.dumps(data),
100
+ timeout=self.timeout.total_seconds(),
70
101
  )
71
102
  response.raise_for_status()
72
103
  if response.text:
73
- data = response.json()
74
- if data["errorMessage"]:
75
- raise APIException(data["errorMessage"])
76
- return data
104
+ response_dict = response.json()
105
+
106
+ if response_dict["errorMessage"]:
107
+ raise APIException(response_dict["errorMessage"])
108
+ return response_dict
77
109
  return None
78
110
 
79
- def close(self):
111
+ def close(self) -> None:
80
112
  """Logout from the API."""
81
113
  self.request("logout")
82
114
 
83
- def __enter__(self):
115
+ def __enter__(self) -> "M2MAPIClient":
84
116
  """Enter function to provide with semantics."""
85
117
  return self
86
118
 
87
- def __exit__(self):
119
+ def __exit__(self) -> None:
88
120
  """Exit function to provide with semantics.
89
121
 
90
122
  Logs out the API.
@@ -100,7 +132,10 @@ class M2MAPIClient:
100
132
  Returns:
101
133
  list of filter objects
102
134
  """
103
- return self.request("dataset-filters", {"datasetName": dataset_name})["data"]
135
+ response_dict = self.request("dataset-filters", {"datasetName": dataset_name})
136
+ if response_dict is None:
137
+ raise APIException("No response from API")
138
+ return response_dict["data"]
104
139
 
105
140
  def scene_search(
106
141
  self,
@@ -119,7 +154,7 @@ class M2MAPIClient:
119
154
  bbox: optional spatial filter
120
155
  metadata_filter: optional metadata filter dict
121
156
  """
122
- base_data = {"datasetName": dataset_name, "sceneFilter": {}}
157
+ base_data: dict[str, Any] = {"datasetName": dataset_name, "sceneFilter": {}}
123
158
  if acquisition_time_range:
124
159
  base_data["sceneFilter"]["acquisitionFilter"] = {
125
160
  "start": acquisition_time_range[0].isoformat(),
@@ -146,7 +181,10 @@ class M2MAPIClient:
146
181
  cur_data = base_data.copy()
147
182
  cur_data["startingNumber"] = starting_number
148
183
  cur_data["maxResults"] = self.pagination_size
149
- data = self.request("scene-search", cur_data)["data"]
184
+ response_dict = self.request("scene-search", cur_data)
185
+ if response_dict is None:
186
+ raise APIException("No response from API")
187
+ data = response_dict["data"]
150
188
  results.extend(data["results"])
151
189
  if data["recordsReturned"] < self.pagination_size:
152
190
  break
@@ -164,14 +202,17 @@ class M2MAPIClient:
164
202
  Returns:
165
203
  full scene metadata
166
204
  """
167
- return self.request(
205
+ response_dict = self.request(
168
206
  "scene-metadata",
169
207
  {
170
208
  "datasetName": dataset_name,
171
209
  "entityId": entity_id,
172
210
  "metadataType": "full",
173
211
  },
174
- )["data"]
212
+ )
213
+ if response_dict is None:
214
+ raise APIException("No response from API")
215
+ return response_dict["data"]
175
216
 
176
217
  def get_downloadable_products(
177
218
  self, dataset_name: str, entity_id: str
@@ -186,7 +227,10 @@ class M2MAPIClient:
186
227
  list of downloadable products
187
228
  """
188
229
  data = {"datasetName": dataset_name, "entityIds": [entity_id]}
189
- return self.request("download-options", data)["data"]
230
+ response_dict = self.request("download-options", data)
231
+ if response_dict is None:
232
+ raise APIException("No response from API")
233
+ return response_dict["data"]
190
234
 
191
235
  def get_download_url(self, entity_id: str, product_id: str) -> str:
192
236
  """Get the download URL for a given product.
@@ -204,9 +248,15 @@ class M2MAPIClient:
204
248
  {"label": label, "entityId": entity_id, "productId": product_id}
205
249
  ]
206
250
  }
207
- response = self.request("download-request", data)["data"]
251
+ response_dict = self.request("download-request", data)
252
+ if response_dict is None:
253
+ raise APIException("No response from API")
254
+ response = response_dict["data"]
208
255
  while True:
209
- response = self.request("download-retrieve", {"label": label})["data"]
256
+ response_dict = self.request("download-retrieve", {"label": label})
257
+ if response_dict is None:
258
+ raise APIException("No response from API")
259
+ response = response_dict["data"]
210
260
  if len(response["available"]) > 0:
211
261
  return response["available"][0]["url"]
212
262
  if len(response["requested"]) == 0:
@@ -264,45 +314,29 @@ class LandsatOliTirs(DataSource):
264
314
 
265
315
  def __init__(
266
316
  self,
267
- config: LayerConfig,
268
317
  username: str,
269
- password: str,
270
- max_time_delta: timedelta = timedelta(days=30),
271
318
  sort_by: str | None = None,
319
+ password: str | None = None,
320
+ token: str | None = None,
321
+ timeout: timedelta = timedelta(seconds=10),
322
+ context: DataSourceContext = DataSourceContext(),
272
323
  ):
273
324
  """Initialize a new LandsatOliTirs instance.
274
325
 
275
326
  Args:
276
- config: the LayerConfig of the layer containing this data source
277
327
  username: EROS username
278
- password: EROS password
279
- max_time_delta: maximum time before a query start time or after a
280
- query end time to look for products. This is required due to the large
281
- number of available products, and defaults to 30 days.
282
328
  sort_by: can be "cloud_cover", default arbitrary order; only has effect for
283
329
  SpaceMode.WITHIN.
330
+ password: EROS password (see M2MAPIClient).
331
+ token: EROS application token (see M2MAPIClient).
332
+ timeout: timeout for requests.
333
+ context: the data source context.
284
334
  """
285
- self.config = config
286
- self.max_time_delta = max_time_delta
287
335
  self.sort_by = sort_by
336
+ self.timeout = timeout
288
337
 
289
- self.client = M2MAPIClient(username, password)
290
-
291
- @staticmethod
292
- def from_config(config: LayerConfig, ds_path: UPath) -> "LandsatOliTirs":
293
- """Creates a new LandsatOliTirs instance from a configuration dictionary."""
294
- assert isinstance(config, RasterLayerConfig)
295
- d = config.data_source.config_dict
296
- if "max_time_delta" in d:
297
- max_time_delta = timedelta(seconds=pytimeparse.parse(d["max_time_delta"]))
298
- else:
299
- max_time_delta = timedelta(days=30)
300
- return LandsatOliTirs(
301
- config=config,
302
- username=d["username"],
303
- password=d["password"],
304
- max_time_delta=max_time_delta,
305
- sort_by=d.get("sort_by"),
338
+ self.client = M2MAPIClient(
339
+ username, password=password, token=token, timeout=timeout
306
340
  )
307
341
 
308
342
  def _scene_metadata_to_item(self, result: dict[str, Any]) -> LandsatOliTirsItem:
@@ -317,7 +351,7 @@ class LandsatOliTirs(DataSource):
317
351
  ts = datetime.strptime(metadata_dict["Start Time"], "%Y-%m-%d %H:%M:%S.%f")
318
352
  else:
319
353
  ts = datetime.strptime(metadata_dict["Start Time"], "%Y-%m-%d %H:%M:%S")
320
- ts = ts.replace(tzinfo=timezone.utc)
354
+ ts = ts.replace(tzinfo=UTC)
321
355
 
322
356
  return LandsatOliTirsItem(
323
357
  name=result["displayId"],
@@ -328,7 +362,7 @@ class LandsatOliTirs(DataSource):
328
362
 
329
363
  def get_items(
330
364
  self, geometries: list[STGeometry], query_config: QueryConfig
331
- ) -> list[list[list[Item]]]:
365
+ ) -> list[list[list[LandsatOliTirsItem]]]:
332
366
  """Get a list of items in the data source intersecting the given geometries.
333
367
 
334
368
  Args:
@@ -400,7 +434,7 @@ class LandsatOliTirs(DataSource):
400
434
  assert isinstance(serialized_item, dict)
401
435
  return LandsatOliTirsItem.deserialize(serialized_item)
402
436
 
403
- def _get_download_urls(self, item: Item) -> dict[str, str]:
437
+ def _get_download_urls(self, item: Item) -> dict[str, tuple[str, str]]:
404
438
  """Gets the download URLs for each band.
405
439
 
406
440
  Args:
@@ -438,7 +472,9 @@ class LandsatOliTirs(DataSource):
438
472
  download_urls = self._get_download_urls(item)
439
473
  for _, (display_id, download_url) in download_urls.items():
440
474
  buf = io.BytesIO()
441
- with requests.get(download_url, stream=True) as r:
475
+ with requests.get(
476
+ download_url, stream=True, timeout=self.timeout.total_seconds()
477
+ ) as r:
442
478
  r.raise_for_status()
443
479
  shutil.copyfileobj(r.raw, buf)
444
480
  buf.seek(0)
@@ -446,8 +482,8 @@ class LandsatOliTirs(DataSource):
446
482
 
447
483
  def ingest(
448
484
  self,
449
- tile_store: TileStore,
450
- items: list[Item],
485
+ tile_store: TileStoreWithLayer,
486
+ items: list[LandsatOliTirsItem],
451
487
  geometries: list[list[STGeometry]],
452
488
  ) -> None:
453
489
  """Ingest items into the given tile store.
@@ -457,30 +493,26 @@ class LandsatOliTirs(DataSource):
457
493
  items: the items to ingest
458
494
  geometries: a list of geometries needed for each item
459
495
  """
460
- for item, cur_geometries in zip(items, geometries):
496
+ for item in items:
461
497
  download_urls = self._get_download_urls(item)
498
+
462
499
  for band in self.bands:
463
500
  band_names = [band]
464
- cur_tile_store = PrefixedTileStore(
465
- tile_store, (item.name, "_".join(band_names))
466
- )
467
- needed_projections = get_needed_projections(
468
- cur_tile_store, band_names, self.config.band_sets, cur_geometries
469
- )
470
- if not needed_projections:
501
+ if tile_store.is_raster_ready(item.name, band_names):
471
502
  continue
472
503
 
473
- buf = io.BytesIO()
474
- with requests.get(download_urls[band][1], stream=True) as r:
475
- r.raise_for_status()
476
- shutil.copyfileobj(r.raw, buf)
477
- buf.seek(0)
478
- with rasterio.open(buf) as raster:
479
- for projection in needed_projections:
480
- ingest_raster(
481
- tile_store=cur_tile_store,
482
- raster=raster,
483
- projection=projection,
484
- time_range=item.geometry.time_range,
485
- layer_config=self.config,
486
- )
504
+ with tempfile.TemporaryDirectory() as tmp_dir:
505
+ local_filename = os.path.join(tmp_dir, "data.tif")
506
+
507
+ with requests.get(
508
+ download_urls[band][1],
509
+ stream=True,
510
+ timeout=self.timeout.total_seconds(),
511
+ ) as r:
512
+ r.raise_for_status()
513
+ with open(local_filename, "wb") as f:
514
+ shutil.copyfileobj(r.raw, f)
515
+
516
+ tile_store.write_raster_file(
517
+ item.name, band_names, UPath(local_filename)
518
+ )