rslearn 0.0.26__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. rslearn/data_sources/__init__.py +2 -0
  2. rslearn/data_sources/aws_landsat.py +44 -161
  3. rslearn/data_sources/aws_open_data.py +2 -4
  4. rslearn/data_sources/aws_sentinel1.py +1 -3
  5. rslearn/data_sources/aws_sentinel2_element84.py +54 -165
  6. rslearn/data_sources/climate_data_store.py +1 -3
  7. rslearn/data_sources/copernicus.py +1 -2
  8. rslearn/data_sources/data_source.py +1 -1
  9. rslearn/data_sources/direct_materialize_data_source.py +336 -0
  10. rslearn/data_sources/earthdaily.py +52 -155
  11. rslearn/data_sources/earthdatahub.py +425 -0
  12. rslearn/data_sources/eurocrops.py +1 -2
  13. rslearn/data_sources/gcp_public_data.py +1 -2
  14. rslearn/data_sources/google_earth_engine.py +1 -2
  15. rslearn/data_sources/hf_srtm.py +595 -0
  16. rslearn/data_sources/local_files.py +1 -1
  17. rslearn/data_sources/openstreetmap.py +1 -1
  18. rslearn/data_sources/planet.py +1 -2
  19. rslearn/data_sources/planet_basemap.py +1 -2
  20. rslearn/data_sources/planetary_computer.py +183 -186
  21. rslearn/data_sources/soilgrids.py +3 -3
  22. rslearn/data_sources/stac.py +1 -2
  23. rslearn/data_sources/usda_cdl.py +1 -3
  24. rslearn/data_sources/usgs_landsat.py +7 -254
  25. rslearn/data_sources/worldcereal.py +1 -1
  26. rslearn/data_sources/worldcover.py +1 -1
  27. rslearn/data_sources/worldpop.py +1 -1
  28. rslearn/data_sources/xyz_tiles.py +5 -9
  29. rslearn/models/concatenate_features.py +6 -1
  30. rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
  31. rslearn/train/data_module.py +27 -27
  32. rslearn/train/dataset.py +109 -62
  33. rslearn/train/lightning_module.py +1 -1
  34. rslearn/train/model_context.py +3 -3
  35. rslearn/train/prediction_writer.py +69 -41
  36. rslearn/train/tasks/classification.py +1 -1
  37. rslearn/train/tasks/detection.py +5 -5
  38. rslearn/train/tasks/regression.py +1 -1
  39. rslearn/utils/__init__.py +2 -0
  40. rslearn/utils/geometry.py +21 -0
  41. rslearn/utils/m2m_api.py +251 -0
  42. rslearn/utils/retry_session.py +43 -0
  43. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
  44. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/RECORD +49 -45
  45. rslearn/data_sources/earthdata_srtm.py +0 -282
  46. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
  47. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
  48. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
  49. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
  50. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
@@ -1,282 +0,0 @@
1
- """Elevation data from the Shuttle Radar Topography Mission via NASA Earthdata."""
2
-
3
- import math
4
- import os
5
- import tempfile
6
- import zipfile
7
- from datetime import timedelta
8
- from typing import Any
9
-
10
- import requests
11
- import requests.auth
12
- import shapely
13
- from upath import UPath
14
-
15
- from rslearn.config import QueryConfig, SpaceMode
16
- from rslearn.const import WGS84_PROJECTION
17
- from rslearn.data_sources import DataSource, DataSourceContext, Item
18
- from rslearn.log_utils import get_logger
19
- from rslearn.tile_stores import TileStoreWithLayer
20
- from rslearn.utils.geometry import STGeometry
21
-
22
- logger = get_logger(__name__)
23
-
24
-
25
- class SRTM(DataSource):
26
- """Data source for SRTM elevation data using NASA Earthdata credentials.
27
-
28
- See https://e4ftl01.cr.usgs.gov/MEASURES/SRTMGL1.003/ and
29
- https://dwtkns.com/srtm30m/ for details about the data.
30
-
31
- The data is split into 1x1-degree tiles, where the filename ends with e.g.
32
- S28W055.SRTMGL1.hgt.zip (so only the first seven characters change).
33
-
34
- These URLs can only be accessed with a NASA Earthdata username and password.
35
-
36
- The zip file contains a single hgt file which can be read by rasterio. It has a
37
- single 16-bit signed integer band indicating the elevation.
38
-
39
- Items from this data source do not come with a time range. The band name will match
40
- that specified in the band set, which should have a single band.
41
- """
42
-
43
- BASE_URL = "https://e4ftl01.cr.usgs.gov/MEASURES/SRTMGL1.003/2000.02.11/"
44
- FILENAME_SUFFIX = ".SRTMGL1.hgt.zip"
45
-
46
- def __init__(
47
- self,
48
- username: str | None = None,
49
- password: str | None = None,
50
- timeout: timedelta = timedelta(seconds=10),
51
- context: DataSourceContext = DataSourceContext(),
52
- ):
53
- """Initialize a new SRTM instance.
54
-
55
- Args:
56
- username: NASA Earthdata account username. If not set, it is read from the
57
- NASA_EARTHDATA_USERNAME environment variable.
58
- password: NASA Earthdata account password. If not set, it is read from the
59
- NASA_EARTHDATA_PASSWORD environment variable.
60
- timeout: timeout for requests.
61
- context: the data source context.
62
- """
63
- # Get band name from context if possible, falling back to "srtm".
64
- if context.layer_config is not None:
65
- if len(context.layer_config.band_sets) != 1:
66
- raise ValueError("expected a single band set")
67
- if len(context.layer_config.band_sets[0].bands) != 1:
68
- raise ValueError("expected band set to have a single band")
69
- self.band_name = context.layer_config.band_sets[0].bands[0]
70
- else:
71
- self.band_name = "srtm"
72
-
73
- self.timeout = timeout
74
-
75
- if username is None:
76
- username = os.environ["NASA_EARTHDATA_USERNAME"]
77
- self.username = username
78
-
79
- if password is None:
80
- password = os.environ["NASA_EARTHDATA_PASSWORD"]
81
- self.password = password
82
-
83
- self.session = requests.session()
84
-
85
- def get_item_by_name(self, name: str) -> Item:
86
- """Gets an item by name.
87
-
88
- Args:
89
- name: the name of the item to get. For SRTM, the item name is the filename
90
- of the zip file containing the hgt file.
91
-
92
- Returns:
93
- the Item object
94
- """
95
- if not name.endswith(self.FILENAME_SUFFIX):
96
- raise ValueError(
97
- f"expected item name to end with {self.FILENAME_SUFFIX}, but got {name}"
98
- )
99
- # Parse the first seven characters, e.g. S28W055.
100
- # We do this to reconstruct the geometry of the item.
101
- lat_sign = name[0]
102
- lat_degrees = int(name[1:3])
103
- lon_sign = name[4]
104
- lon_degrees = int(name[5:8])
105
-
106
- if lat_sign == "N":
107
- lat_min = lat_degrees
108
- elif lat_sign == "S":
109
- lat_min = -lat_degrees
110
- else:
111
- raise ValueError(f"invalid item name {name}")
112
-
113
- if lon_sign == "E":
114
- lon_min = lon_degrees
115
- elif lon_sign == "W":
116
- lon_min = -lon_degrees
117
- else:
118
- raise ValueError(f"invalid item name {name}")
119
-
120
- geometry = STGeometry(
121
- WGS84_PROJECTION,
122
- shapely.box(lon_min, lat_min, lon_min + 1, lat_min + 1),
123
- None,
124
- )
125
- return Item(name, geometry)
126
-
127
- def _lon_lat_to_item(self, lon_min: int, lat_min: int) -> Item:
128
- """Get an item based on the 1x1 longitude/latitude grid.
129
-
130
- Args:
131
- lon_min: the starting longitude integer of the grid cell.
132
- lat_min: the starting latitude integer of the grid cell.
133
-
134
- Returns:
135
- the Item object.
136
- """
137
- # Construct the filename for this grid cell.
138
- # The item name is just the filename.
139
- if lon_min < 0:
140
- lon_part = f"W{-lon_min:03d}"
141
- else:
142
- lon_part = f"E{lon_min:03d}"
143
- if lat_min < 0:
144
- lat_part = f"S{-lat_min:02d}"
145
- else:
146
- lat_part = f"N{lat_min:02d}"
147
- fname = lat_part + lon_part + self.FILENAME_SUFFIX
148
-
149
- # We also need the geometry for the item.
150
- geometry = STGeometry(
151
- WGS84_PROJECTION,
152
- shapely.box(lon_min, lat_min, lon_min + 1, lat_min + 1),
153
- None,
154
- )
155
-
156
- return Item(fname, geometry)
157
-
158
- def get_items(
159
- self, geometries: list[STGeometry], query_config: QueryConfig
160
- ) -> list[list[list[Item]]]:
161
- """Get a list of items in the data source intersecting the given geometries.
162
-
163
- Args:
164
- geometries: the spatiotemporal geometries
165
- query_config: the query configuration
166
-
167
- Returns:
168
- List of groups of items that should be retrieved for each geometry.
169
- """
170
- # We only support mosaic here, other query modes don't really make sense.
171
- if query_config.space_mode != SpaceMode.MOSAIC or query_config.max_matches != 1:
172
- raise ValueError(
173
- "expected mosaic with max_matches=1 for the query configuration"
174
- )
175
-
176
- groups = []
177
- for geometry in geometries:
178
- # We iterate over each 1x1 cell that this geometry intersects and include
179
- # the corresponing item in this item group.
180
- # Since it is a mosaic with one match, there will just be one item group
181
- # for each item.
182
- wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
183
- shp_bounds = wgs84_geometry.shp.bounds
184
- cell_bounds = (
185
- math.floor(shp_bounds[0]),
186
- math.floor(shp_bounds[1]),
187
- math.ceil(shp_bounds[2]),
188
- math.ceil(shp_bounds[3]),
189
- )
190
- # lon_min/lat_min are the lower range of each cell.
191
- items = []
192
- for lon_min in range(cell_bounds[0], cell_bounds[2]):
193
- for lat_min in range(cell_bounds[1], cell_bounds[3]):
194
- items.append(self._lon_lat_to_item(lon_min, lat_min))
195
-
196
- logger.debug(f"Got {len(items)} items (grid cells) for geometry")
197
- groups.append([items])
198
-
199
- return groups
200
-
201
- def deserialize_item(self, serialized_item: Any) -> Item:
202
- """Deserializes an item from JSON-decoded data."""
203
- assert isinstance(serialized_item, dict)
204
- return Item.deserialize(serialized_item)
205
-
206
- def ingest(
207
- self,
208
- tile_store: TileStoreWithLayer,
209
- items: list[Item],
210
- geometries: list[list[STGeometry]],
211
- ) -> None:
212
- """Ingest items into the given tile store.
213
-
214
- Args:
215
- tile_store: the tile store to ingest into
216
- items: the items to ingest
217
- geometries: a list of geometries needed for each item
218
- """
219
- for item in items:
220
- if tile_store.is_raster_ready(item.name, [self.band_name]):
221
- continue
222
-
223
- # Download the item.
224
- # We first attempt to access it directly, which works if we have already
225
- # authenticated. If not, we get redirected to a login endpoint where we
226
- # need to use basic authentication; the endpoint will redirect us back to
227
- # the original URL.
228
- url = self.BASE_URL + item.name
229
- logger.debug(f"Downloading SRTM data for {item.name} from {url}")
230
-
231
- # Try to access directly.
232
- response = self.session.get(
233
- url,
234
- stream=True,
235
- timeout=self.timeout.total_seconds(),
236
- allow_redirects=False,
237
- )
238
-
239
- if response.status_code == 302:
240
- # Encountered redirect, so set response to actually access the redirect
241
- # URL. This time we follow redirects since it will take us back to the
242
- # original URL.
243
- redirect_url = response.headers["Location"]
244
- logger.debug(f"Following redirect to {redirect_url}")
245
- auth = requests.auth.HTTPBasicAuth(self.username, self.password)
246
- response = self.session.get(
247
- redirect_url,
248
- stream=True,
249
- timeout=self.timeout.total_seconds(),
250
- auth=auth,
251
- )
252
-
253
- if response.status_code == 404:
254
- # Some grid cells don't exist so this isn't a big issue.
255
- logger.warning(
256
- f"Skipping item {item.name} because there is no data at that cell"
257
- )
258
- continue
259
- response.raise_for_status()
260
-
261
- with tempfile.TemporaryDirectory() as tmp_dir:
262
- # Store it in temporary directory.
263
- zip_fname = os.path.join(tmp_dir, "data.zip")
264
- with open(zip_fname, "wb") as f:
265
- for chunk in response.iter_content(chunk_size=8192):
266
- f.write(chunk)
267
-
268
- # Extract the .hgt file.
269
- logger.debug(f"Extracting data for {item.name}")
270
- with zipfile.ZipFile(zip_fname) as zip_f:
271
- member_names = zip_f.namelist()
272
- if len(member_names) != 1:
273
- raise ValueError(
274
- f"expected SRTM zip to have one member but got {member_names}"
275
- )
276
- local_fname = zip_f.extract(member_names[0], path=tmp_dir)
277
-
278
- # Now we can ingest it.
279
- logger.debug(f"Ingesting data for {item.name}")
280
- tile_store.write_raster_file(
281
- item.name, [self.band_name], UPath(local_fname)
282
- )