rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
"""Data source for Planet Labs Basemaps API."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import tempfile
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import requests
|
|
9
|
+
import shapely
|
|
10
|
+
from upath import UPath
|
|
11
|
+
|
|
12
|
+
from rslearn.config import QueryConfig
|
|
13
|
+
from rslearn.const import WGS84_PROJECTION
|
|
14
|
+
from rslearn.data_sources import DataSource, DataSourceContext, Item
|
|
15
|
+
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
16
|
+
from rslearn.log_utils import get_logger
|
|
17
|
+
from rslearn.tile_stores import TileStoreWithLayer
|
|
18
|
+
from rslearn.utils import STGeometry
|
|
19
|
+
|
|
20
|
+
logger = get_logger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class PlanetItem(Item):
|
|
24
|
+
"""An item referencing a particular mosaic and quad in Basemaps API."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, name: str, geometry: STGeometry, mosaic_id: str, quad_id: str):
|
|
27
|
+
"""Create a new PlanetItem.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
name: the item name (combination of mosaic and quad ID).
|
|
31
|
+
geometry: the geometry associated with this quad.
|
|
32
|
+
mosaic_id: the mosaic ID in API
|
|
33
|
+
quad_id: the quad ID in API
|
|
34
|
+
"""
|
|
35
|
+
super().__init__(name, geometry)
|
|
36
|
+
self.mosaic_id = mosaic_id
|
|
37
|
+
self.quad_id = quad_id
|
|
38
|
+
|
|
39
|
+
def serialize(self) -> dict:
|
|
40
|
+
"""Serializes the item to a JSON-encodable dictionary."""
|
|
41
|
+
d = super().serialize()
|
|
42
|
+
d["mosaic_id"] = self.mosaic_id
|
|
43
|
+
d["quad_id"] = self.quad_id
|
|
44
|
+
return d
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def deserialize(d: dict) -> Item:
|
|
48
|
+
"""Deserializes an item from a JSON-decoded dictionary."""
|
|
49
|
+
item = super(PlanetItem, PlanetItem).deserialize(d)
|
|
50
|
+
return PlanetItem(
|
|
51
|
+
name=item.name,
|
|
52
|
+
geometry=item.geometry,
|
|
53
|
+
mosaic_id=d["mosaic_id"],
|
|
54
|
+
quad_id=d["quad_id"],
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class ApiError(Exception):
|
|
59
|
+
"""An error from Planet Labs API."""
|
|
60
|
+
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class PlanetBasemap(DataSource):
|
|
65
|
+
"""A data source for Planet Labs Basemaps API."""
|
|
66
|
+
|
|
67
|
+
api_url = "https://api.planet.com/basemaps/v1/"
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
series_id: str,
|
|
72
|
+
bands: list[str],
|
|
73
|
+
api_key: str | None = None,
|
|
74
|
+
context: DataSourceContext = DataSourceContext(),
|
|
75
|
+
):
|
|
76
|
+
"""Initialize a new Planet instance.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
series_id: the series of mosaics to use.
|
|
80
|
+
bands: list of band names to use.
|
|
81
|
+
api_key: optional Planet API key (it can also be provided via PL_API_KEY
|
|
82
|
+
environmnet variable).
|
|
83
|
+
context: the data source context
|
|
84
|
+
"""
|
|
85
|
+
self.series_id = series_id
|
|
86
|
+
self.bands = bands
|
|
87
|
+
|
|
88
|
+
self.session = requests.Session()
|
|
89
|
+
if api_key is None:
|
|
90
|
+
api_key = os.environ["PL_API_KEY"]
|
|
91
|
+
self.session.auth = (api_key, "")
|
|
92
|
+
|
|
93
|
+
# Lazily load mosaics.
|
|
94
|
+
self.mosaics: dict | None = None
|
|
95
|
+
|
|
96
|
+
def _load_mosaics(self) -> dict[str, STGeometry]:
|
|
97
|
+
"""Lazily load mosaics in the configured series_id from Planet API.
|
|
98
|
+
|
|
99
|
+
We don't load it when creating the data source because it takes time and caller
|
|
100
|
+
may not be calling get_items. Additionally, loading it during the get_items
|
|
101
|
+
call enables leveraging the retry loop functionality in
|
|
102
|
+
prepare_dataset_windows.
|
|
103
|
+
"""
|
|
104
|
+
if self.mosaics is not None:
|
|
105
|
+
return self.mosaics
|
|
106
|
+
|
|
107
|
+
self.mosaics = {}
|
|
108
|
+
for mosaic_dict in self._api_get_paginate(
|
|
109
|
+
path=f"series/{self.series_id}/mosaics", list_key="mosaics"
|
|
110
|
+
):
|
|
111
|
+
shp = shapely.box(*mosaic_dict["bbox"])
|
|
112
|
+
time_range = (
|
|
113
|
+
datetime.fromisoformat(mosaic_dict["first_acquired"]),
|
|
114
|
+
datetime.fromisoformat(mosaic_dict["last_acquired"]),
|
|
115
|
+
)
|
|
116
|
+
geom = STGeometry(WGS84_PROJECTION, shp, time_range)
|
|
117
|
+
self.mosaics[mosaic_dict["id"]] = geom
|
|
118
|
+
|
|
119
|
+
return self.mosaics
|
|
120
|
+
|
|
121
|
+
def _api_get(
|
|
122
|
+
self,
|
|
123
|
+
path: str | None = None,
|
|
124
|
+
url: str | None = None,
|
|
125
|
+
query_args: dict[str, str] | None = None,
|
|
126
|
+
) -> list[Any] | dict[str, Any]:
|
|
127
|
+
"""Perform a GET request on the API.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
path: the path to GET, like "series".
|
|
131
|
+
url: the full URL to GET. Only one of path or url should be set.
|
|
132
|
+
query_args: optional params to include with the request.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
the JSON response data.
|
|
136
|
+
|
|
137
|
+
Raises:
|
|
138
|
+
ApiError: if the API returned an error response.
|
|
139
|
+
"""
|
|
140
|
+
if path is None and url is None:
|
|
141
|
+
raise ValueError("Only one of path or url should be set")
|
|
142
|
+
if query_args:
|
|
143
|
+
kwargs = dict(params=query_args)
|
|
144
|
+
else:
|
|
145
|
+
kwargs = {}
|
|
146
|
+
|
|
147
|
+
if path:
|
|
148
|
+
url = self.api_url + path
|
|
149
|
+
if url is None:
|
|
150
|
+
raise ValueError("url is required")
|
|
151
|
+
response = self.session.get(url, **kwargs) # type: ignore
|
|
152
|
+
|
|
153
|
+
if response.status_code != 200:
|
|
154
|
+
raise ApiError(
|
|
155
|
+
f"{url}: got status code {response.status_code}: {response.text}"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
return response.json()
|
|
159
|
+
|
|
160
|
+
def _api_get_paginate(
|
|
161
|
+
self, path: str, list_key: str, query_args: dict[str, str] | None = None
|
|
162
|
+
) -> list:
|
|
163
|
+
"""Get all items in a paginated response.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
path: the path to GET.
|
|
167
|
+
list_key: the key in the response containing the list that should be
|
|
168
|
+
concatenated across all available pages.
|
|
169
|
+
query_args: optional params to include with the requests.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
the concatenated list of items.
|
|
173
|
+
|
|
174
|
+
Raises:
|
|
175
|
+
ApiError if the API returned an error response.
|
|
176
|
+
"""
|
|
177
|
+
next_url = self.api_url + path
|
|
178
|
+
items = []
|
|
179
|
+
while True:
|
|
180
|
+
json_data = self._api_get(url=next_url, query_args=query_args)
|
|
181
|
+
if not isinstance(json_data, dict):
|
|
182
|
+
logger.warning(f"Expected dict, got {type(json_data)}")
|
|
183
|
+
continue
|
|
184
|
+
items += json_data[list_key]
|
|
185
|
+
|
|
186
|
+
if "_next" in json_data["_links"]:
|
|
187
|
+
next_url = json_data["_links"]["_next"]
|
|
188
|
+
else:
|
|
189
|
+
return items
|
|
190
|
+
|
|
191
|
+
def get_items(
|
|
192
|
+
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
193
|
+
) -> list[list[list[PlanetItem]]]:
|
|
194
|
+
"""Get a list of items in the data source intersecting the given geometries.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
geometries: the spatiotemporal geometries
|
|
198
|
+
query_config: the query configuration
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
List of groups of items that should be retrieved for each geometry.
|
|
202
|
+
"""
|
|
203
|
+
mosaics = self._load_mosaics()
|
|
204
|
+
|
|
205
|
+
groups = []
|
|
206
|
+
for geometry in geometries:
|
|
207
|
+
geom_bbox = geometry.to_projection(WGS84_PROJECTION).shp.bounds
|
|
208
|
+
geom_bbox_str = ",".join([str(value) for value in geom_bbox])
|
|
209
|
+
|
|
210
|
+
# Find the relevant mosaics that the geometry intersects.
|
|
211
|
+
# For each relevant mosaic, identify the intersecting quads.
|
|
212
|
+
items = []
|
|
213
|
+
for mosaic_id, mosaic_geom in mosaics.items():
|
|
214
|
+
if not geometry.intersects(mosaic_geom):
|
|
215
|
+
continue
|
|
216
|
+
logger.info(f"found mosaic {mosaic_geom} for geom {geometry}")
|
|
217
|
+
# List all quads that intersect the current geometry's
|
|
218
|
+
# longitude/latitude bbox in this mosaic.
|
|
219
|
+
for quad_dict in self._api_get_paginate(
|
|
220
|
+
path=f"mosaics/{mosaic_id}/quads",
|
|
221
|
+
list_key="items",
|
|
222
|
+
query_args={"bbox": geom_bbox_str},
|
|
223
|
+
):
|
|
224
|
+
logger.info(f"found quad {quad_dict}")
|
|
225
|
+
shp = shapely.box(*quad_dict["bbox"])
|
|
226
|
+
geom = STGeometry(WGS84_PROJECTION, shp, mosaic_geom.time_range)
|
|
227
|
+
quad_id = quad_dict["id"]
|
|
228
|
+
items.append(
|
|
229
|
+
PlanetItem(f"{mosaic_id}_{quad_id}", geom, mosaic_id, quad_id)
|
|
230
|
+
)
|
|
231
|
+
logger.info(f"found {len(items)} items for geom {geometry}")
|
|
232
|
+
cur_groups = match_candidate_items_to_window(geometry, items, query_config)
|
|
233
|
+
groups.append(cur_groups)
|
|
234
|
+
|
|
235
|
+
return groups
|
|
236
|
+
|
|
237
|
+
def deserialize_item(self, serialized_item: Any) -> Item:
|
|
238
|
+
"""Deserializes an item from JSON-decoded data."""
|
|
239
|
+
assert isinstance(serialized_item, dict)
|
|
240
|
+
return PlanetItem.deserialize(serialized_item)
|
|
241
|
+
|
|
242
|
+
def ingest(
|
|
243
|
+
self,
|
|
244
|
+
tile_store: TileStoreWithLayer,
|
|
245
|
+
items: list[Item],
|
|
246
|
+
geometries: list[list[STGeometry]],
|
|
247
|
+
) -> None:
|
|
248
|
+
"""Ingest items into the given tile store.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
tile_store: the tile store to ingest into
|
|
252
|
+
items: the items to ingest
|
|
253
|
+
geometries: a list of geometries needed for each item
|
|
254
|
+
"""
|
|
255
|
+
for item in items:
|
|
256
|
+
if tile_store.is_raster_ready(item.name, self.bands):
|
|
257
|
+
continue
|
|
258
|
+
|
|
259
|
+
assert isinstance(item, PlanetItem)
|
|
260
|
+
download_url = (
|
|
261
|
+
self.api_url + f"mosaics/{item.mosaic_id}/quads/{item.quad_id}/full"
|
|
262
|
+
)
|
|
263
|
+
response = self.session.get(download_url, allow_redirects=True, stream=True)
|
|
264
|
+
if response.status_code != 200:
|
|
265
|
+
raise ApiError(
|
|
266
|
+
f"{download_url}: got status code {response.status_code}: {response.text}"
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
270
|
+
local_fname = os.path.join(tmp_dir, "temp.tif")
|
|
271
|
+
with open(local_fname, "wb") as f:
|
|
272
|
+
for chunk in response.iter_content(chunk_size=8192):
|
|
273
|
+
f.write(chunk)
|
|
274
|
+
|
|
275
|
+
tile_store.write_raster_file(item.name, self.bands, UPath(local_fname))
|