rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,275 @@
1
+ """Data source for Planet Labs Basemaps API."""
2
+
3
+ import os
4
+ import tempfile
5
+ from datetime import datetime
6
+ from typing import Any
7
+
8
+ import requests
9
+ import shapely
10
+ from upath import UPath
11
+
12
+ from rslearn.config import QueryConfig
13
+ from rslearn.const import WGS84_PROJECTION
14
+ from rslearn.data_sources import DataSource, DataSourceContext, Item
15
+ from rslearn.data_sources.utils import match_candidate_items_to_window
16
+ from rslearn.log_utils import get_logger
17
+ from rslearn.tile_stores import TileStoreWithLayer
18
+ from rslearn.utils import STGeometry
19
+
20
+ logger = get_logger(__name__)
21
+
22
+
23
+ class PlanetItem(Item):
24
+ """An item referencing a particular mosaic and quad in Basemaps API."""
25
+
26
+ def __init__(self, name: str, geometry: STGeometry, mosaic_id: str, quad_id: str):
27
+ """Create a new PlanetItem.
28
+
29
+ Args:
30
+ name: the item name (combination of mosaic and quad ID).
31
+ geometry: the geometry associated with this quad.
32
+ mosaic_id: the mosaic ID in API
33
+ quad_id: the quad ID in API
34
+ """
35
+ super().__init__(name, geometry)
36
+ self.mosaic_id = mosaic_id
37
+ self.quad_id = quad_id
38
+
39
+ def serialize(self) -> dict:
40
+ """Serializes the item to a JSON-encodable dictionary."""
41
+ d = super().serialize()
42
+ d["mosaic_id"] = self.mosaic_id
43
+ d["quad_id"] = self.quad_id
44
+ return d
45
+
46
+ @staticmethod
47
+ def deserialize(d: dict) -> Item:
48
+ """Deserializes an item from a JSON-decoded dictionary."""
49
+ item = super(PlanetItem, PlanetItem).deserialize(d)
50
+ return PlanetItem(
51
+ name=item.name,
52
+ geometry=item.geometry,
53
+ mosaic_id=d["mosaic_id"],
54
+ quad_id=d["quad_id"],
55
+ )
56
+
57
+
58
+ class ApiError(Exception):
59
+ """An error from Planet Labs API."""
60
+
61
+ pass
62
+
63
+
64
+ class PlanetBasemap(DataSource):
65
+ """A data source for Planet Labs Basemaps API."""
66
+
67
+ api_url = "https://api.planet.com/basemaps/v1/"
68
+
69
+ def __init__(
70
+ self,
71
+ series_id: str,
72
+ bands: list[str],
73
+ api_key: str | None = None,
74
+ context: DataSourceContext = DataSourceContext(),
75
+ ):
76
+ """Initialize a new Planet instance.
77
+
78
+ Args:
79
+ series_id: the series of mosaics to use.
80
+ bands: list of band names to use.
81
+ api_key: optional Planet API key (it can also be provided via PL_API_KEY
82
+ environmnet variable).
83
+ context: the data source context
84
+ """
85
+ self.series_id = series_id
86
+ self.bands = bands
87
+
88
+ self.session = requests.Session()
89
+ if api_key is None:
90
+ api_key = os.environ["PL_API_KEY"]
91
+ self.session.auth = (api_key, "")
92
+
93
+ # Lazily load mosaics.
94
+ self.mosaics: dict | None = None
95
+
96
+ def _load_mosaics(self) -> dict[str, STGeometry]:
97
+ """Lazily load mosaics in the configured series_id from Planet API.
98
+
99
+ We don't load it when creating the data source because it takes time and caller
100
+ may not be calling get_items. Additionally, loading it during the get_items
101
+ call enables leveraging the retry loop functionality in
102
+ prepare_dataset_windows.
103
+ """
104
+ if self.mosaics is not None:
105
+ return self.mosaics
106
+
107
+ self.mosaics = {}
108
+ for mosaic_dict in self._api_get_paginate(
109
+ path=f"series/{self.series_id}/mosaics", list_key="mosaics"
110
+ ):
111
+ shp = shapely.box(*mosaic_dict["bbox"])
112
+ time_range = (
113
+ datetime.fromisoformat(mosaic_dict["first_acquired"]),
114
+ datetime.fromisoformat(mosaic_dict["last_acquired"]),
115
+ )
116
+ geom = STGeometry(WGS84_PROJECTION, shp, time_range)
117
+ self.mosaics[mosaic_dict["id"]] = geom
118
+
119
+ return self.mosaics
120
+
121
+ def _api_get(
122
+ self,
123
+ path: str | None = None,
124
+ url: str | None = None,
125
+ query_args: dict[str, str] | None = None,
126
+ ) -> list[Any] | dict[str, Any]:
127
+ """Perform a GET request on the API.
128
+
129
+ Args:
130
+ path: the path to GET, like "series".
131
+ url: the full URL to GET. Only one of path or url should be set.
132
+ query_args: optional params to include with the request.
133
+
134
+ Returns:
135
+ the JSON response data.
136
+
137
+ Raises:
138
+ ApiError: if the API returned an error response.
139
+ """
140
+ if path is None and url is None:
141
+ raise ValueError("Only one of path or url should be set")
142
+ if query_args:
143
+ kwargs = dict(params=query_args)
144
+ else:
145
+ kwargs = {}
146
+
147
+ if path:
148
+ url = self.api_url + path
149
+ if url is None:
150
+ raise ValueError("url is required")
151
+ response = self.session.get(url, **kwargs) # type: ignore
152
+
153
+ if response.status_code != 200:
154
+ raise ApiError(
155
+ f"{url}: got status code {response.status_code}: {response.text}"
156
+ )
157
+
158
+ return response.json()
159
+
160
+ def _api_get_paginate(
161
+ self, path: str, list_key: str, query_args: dict[str, str] | None = None
162
+ ) -> list:
163
+ """Get all items in a paginated response.
164
+
165
+ Args:
166
+ path: the path to GET.
167
+ list_key: the key in the response containing the list that should be
168
+ concatenated across all available pages.
169
+ query_args: optional params to include with the requests.
170
+
171
+ Returns:
172
+ the concatenated list of items.
173
+
174
+ Raises:
175
+ ApiError if the API returned an error response.
176
+ """
177
+ next_url = self.api_url + path
178
+ items = []
179
+ while True:
180
+ json_data = self._api_get(url=next_url, query_args=query_args)
181
+ if not isinstance(json_data, dict):
182
+ logger.warning(f"Expected dict, got {type(json_data)}")
183
+ continue
184
+ items += json_data[list_key]
185
+
186
+ if "_next" in json_data["_links"]:
187
+ next_url = json_data["_links"]["_next"]
188
+ else:
189
+ return items
190
+
191
+ def get_items(
192
+ self, geometries: list[STGeometry], query_config: QueryConfig
193
+ ) -> list[list[list[PlanetItem]]]:
194
+ """Get a list of items in the data source intersecting the given geometries.
195
+
196
+ Args:
197
+ geometries: the spatiotemporal geometries
198
+ query_config: the query configuration
199
+
200
+ Returns:
201
+ List of groups of items that should be retrieved for each geometry.
202
+ """
203
+ mosaics = self._load_mosaics()
204
+
205
+ groups = []
206
+ for geometry in geometries:
207
+ geom_bbox = geometry.to_projection(WGS84_PROJECTION).shp.bounds
208
+ geom_bbox_str = ",".join([str(value) for value in geom_bbox])
209
+
210
+ # Find the relevant mosaics that the geometry intersects.
211
+ # For each relevant mosaic, identify the intersecting quads.
212
+ items = []
213
+ for mosaic_id, mosaic_geom in mosaics.items():
214
+ if not geometry.intersects(mosaic_geom):
215
+ continue
216
+ logger.info(f"found mosaic {mosaic_geom} for geom {geometry}")
217
+ # List all quads that intersect the current geometry's
218
+ # longitude/latitude bbox in this mosaic.
219
+ for quad_dict in self._api_get_paginate(
220
+ path=f"mosaics/{mosaic_id}/quads",
221
+ list_key="items",
222
+ query_args={"bbox": geom_bbox_str},
223
+ ):
224
+ logger.info(f"found quad {quad_dict}")
225
+ shp = shapely.box(*quad_dict["bbox"])
226
+ geom = STGeometry(WGS84_PROJECTION, shp, mosaic_geom.time_range)
227
+ quad_id = quad_dict["id"]
228
+ items.append(
229
+ PlanetItem(f"{mosaic_id}_{quad_id}", geom, mosaic_id, quad_id)
230
+ )
231
+ logger.info(f"found {len(items)} items for geom {geometry}")
232
+ cur_groups = match_candidate_items_to_window(geometry, items, query_config)
233
+ groups.append(cur_groups)
234
+
235
+ return groups
236
+
237
+ def deserialize_item(self, serialized_item: Any) -> Item:
238
+ """Deserializes an item from JSON-decoded data."""
239
+ assert isinstance(serialized_item, dict)
240
+ return PlanetItem.deserialize(serialized_item)
241
+
242
+ def ingest(
243
+ self,
244
+ tile_store: TileStoreWithLayer,
245
+ items: list[Item],
246
+ geometries: list[list[STGeometry]],
247
+ ) -> None:
248
+ """Ingest items into the given tile store.
249
+
250
+ Args:
251
+ tile_store: the tile store to ingest into
252
+ items: the items to ingest
253
+ geometries: a list of geometries needed for each item
254
+ """
255
+ for item in items:
256
+ if tile_store.is_raster_ready(item.name, self.bands):
257
+ continue
258
+
259
+ assert isinstance(item, PlanetItem)
260
+ download_url = (
261
+ self.api_url + f"mosaics/{item.mosaic_id}/quads/{item.quad_id}/full"
262
+ )
263
+ response = self.session.get(download_url, allow_redirects=True, stream=True)
264
+ if response.status_code != 200:
265
+ raise ApiError(
266
+ f"{download_url}: got status code {response.status_code}: {response.text}"
267
+ )
268
+
269
+ with tempfile.TemporaryDirectory() as tmp_dir:
270
+ local_fname = os.path.join(tmp_dir, "temp.tif")
271
+ with open(local_fname, "wb") as f:
272
+ for chunk in response.iter_content(chunk_size=8192):
273
+ f.write(chunk)
274
+
275
+ tile_store.write_raster_file(item.name, self.bands, UPath(local_fname))