rslearn 0.0.26__py3-none-any.whl → 0.0.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/data_sources/__init__.py +2 -0
- rslearn/data_sources/aws_landsat.py +44 -161
- rslearn/data_sources/aws_open_data.py +2 -4
- rslearn/data_sources/aws_sentinel1.py +1 -3
- rslearn/data_sources/aws_sentinel2_element84.py +54 -165
- rslearn/data_sources/climate_data_store.py +1 -3
- rslearn/data_sources/copernicus.py +1 -2
- rslearn/data_sources/data_source.py +1 -1
- rslearn/data_sources/direct_materialize_data_source.py +336 -0
- rslearn/data_sources/earthdaily.py +52 -155
- rslearn/data_sources/earthdatahub.py +425 -0
- rslearn/data_sources/eurocrops.py +1 -2
- rslearn/data_sources/gcp_public_data.py +1 -2
- rslearn/data_sources/google_earth_engine.py +1 -2
- rslearn/data_sources/hf_srtm.py +595 -0
- rslearn/data_sources/local_files.py +1 -1
- rslearn/data_sources/openstreetmap.py +1 -1
- rslearn/data_sources/planet.py +1 -2
- rslearn/data_sources/planet_basemap.py +1 -2
- rslearn/data_sources/planetary_computer.py +183 -186
- rslearn/data_sources/soilgrids.py +3 -3
- rslearn/data_sources/stac.py +1 -2
- rslearn/data_sources/usda_cdl.py +1 -3
- rslearn/data_sources/usgs_landsat.py +7 -254
- rslearn/data_sources/worldcereal.py +1 -1
- rslearn/data_sources/worldcover.py +1 -1
- rslearn/data_sources/worldpop.py +1 -1
- rslearn/data_sources/xyz_tiles.py +5 -9
- rslearn/dataset/storage/file.py +16 -12
- rslearn/models/concatenate_features.py +6 -1
- rslearn/tile_stores/default.py +4 -2
- rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
- rslearn/train/data_module.py +36 -33
- rslearn/train/dataset.py +159 -68
- rslearn/train/lightning_module.py +60 -4
- rslearn/train/metrics.py +162 -0
- rslearn/train/model_context.py +3 -3
- rslearn/train/prediction_writer.py +69 -41
- rslearn/train/tasks/classification.py +14 -1
- rslearn/train/tasks/detection.py +5 -5
- rslearn/train/tasks/per_pixel_regression.py +19 -6
- rslearn/train/tasks/regression.py +19 -3
- rslearn/train/tasks/segmentation.py +17 -0
- rslearn/utils/__init__.py +2 -0
- rslearn/utils/fsspec.py +51 -1
- rslearn/utils/geometry.py +21 -0
- rslearn/utils/m2m_api.py +251 -0
- rslearn/utils/retry_session.py +43 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/METADATA +6 -3
- {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/RECORD +55 -50
- rslearn/data_sources/earthdata_srtm.py +0 -282
- {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/WHEEL +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/top_level.txt +0 -0
|
@@ -4,12 +4,9 @@
|
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
6
|
import io
|
|
7
|
-
import json
|
|
8
7
|
import os
|
|
9
8
|
import shutil
|
|
10
9
|
import tempfile
|
|
11
|
-
import time
|
|
12
|
-
import uuid
|
|
13
10
|
from collections.abc import Generator
|
|
14
11
|
from datetime import UTC, datetime, timedelta
|
|
15
12
|
from typing import Any, BinaryIO
|
|
@@ -24,246 +21,7 @@ from rslearn.data_sources import DataSource, DataSourceContext, Item
|
|
|
24
21
|
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
25
22
|
from rslearn.tile_stores import TileStoreWithLayer
|
|
26
23
|
from rslearn.utils import STGeometry
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class APIException(Exception):
|
|
30
|
-
"""Exception raised for M2M API errors."""
|
|
31
|
-
|
|
32
|
-
pass
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
class M2MAPIClient:
|
|
36
|
-
"""An API client for interacting with the USGS M2M API."""
|
|
37
|
-
|
|
38
|
-
api_url = "https://m2m.cr.usgs.gov/api/api/json/stable/"
|
|
39
|
-
pagination_size = 1000
|
|
40
|
-
|
|
41
|
-
def __init__(
|
|
42
|
-
self,
|
|
43
|
-
username: str,
|
|
44
|
-
password: str | None = None,
|
|
45
|
-
token: str | None = None,
|
|
46
|
-
timeout: timedelta = timedelta(seconds=120),
|
|
47
|
-
) -> None:
|
|
48
|
-
"""Initialize a new M2MAPIClient.
|
|
49
|
-
|
|
50
|
-
Args:
|
|
51
|
-
username: the EROS username
|
|
52
|
-
password: the EROS password
|
|
53
|
-
token: the application token. One of password or token must be specified.
|
|
54
|
-
timeout: timeout for requests.
|
|
55
|
-
"""
|
|
56
|
-
self.username = username
|
|
57
|
-
self.timeout = timeout
|
|
58
|
-
|
|
59
|
-
if password is not None and token is not None:
|
|
60
|
-
raise ValueError("only one of password or token can be specified")
|
|
61
|
-
|
|
62
|
-
if password is not None:
|
|
63
|
-
json_data = json.dumps({"username": self.username, "password": password})
|
|
64
|
-
response = requests.post(
|
|
65
|
-
self.api_url + "login",
|
|
66
|
-
data=json_data,
|
|
67
|
-
timeout=self.timeout.total_seconds(),
|
|
68
|
-
)
|
|
69
|
-
|
|
70
|
-
elif token is not None:
|
|
71
|
-
json_data = json.dumps({"username": username, "token": token})
|
|
72
|
-
response = requests.post(
|
|
73
|
-
self.api_url + "login-token",
|
|
74
|
-
data=json_data,
|
|
75
|
-
timeout=self.timeout.total_seconds(),
|
|
76
|
-
)
|
|
77
|
-
|
|
78
|
-
else:
|
|
79
|
-
raise ValueError("one of password or token must be specified")
|
|
80
|
-
|
|
81
|
-
response.raise_for_status()
|
|
82
|
-
self.auth_token = response.json()["data"]
|
|
83
|
-
|
|
84
|
-
def request(
|
|
85
|
-
self, endpoint: str, data: dict[str, Any] | None = None
|
|
86
|
-
) -> dict[str, Any] | None:
|
|
87
|
-
"""Make a request to the API.
|
|
88
|
-
|
|
89
|
-
Args:
|
|
90
|
-
endpoint: the endpoint to call
|
|
91
|
-
data: POST data to pass
|
|
92
|
-
|
|
93
|
-
Returns:
|
|
94
|
-
JSON response data if any
|
|
95
|
-
"""
|
|
96
|
-
response = requests.post(
|
|
97
|
-
self.api_url + endpoint,
|
|
98
|
-
headers={"X-Auth-Token": self.auth_token},
|
|
99
|
-
data=json.dumps(data),
|
|
100
|
-
timeout=self.timeout.total_seconds(),
|
|
101
|
-
)
|
|
102
|
-
response.raise_for_status()
|
|
103
|
-
if response.text:
|
|
104
|
-
response_dict = response.json()
|
|
105
|
-
|
|
106
|
-
if response_dict["errorMessage"]:
|
|
107
|
-
raise APIException(response_dict["errorMessage"])
|
|
108
|
-
return response_dict
|
|
109
|
-
return None
|
|
110
|
-
|
|
111
|
-
def close(self) -> None:
|
|
112
|
-
"""Logout from the API."""
|
|
113
|
-
self.request("logout")
|
|
114
|
-
|
|
115
|
-
def __enter__(self) -> "M2MAPIClient":
|
|
116
|
-
"""Enter function to provide with semantics."""
|
|
117
|
-
return self
|
|
118
|
-
|
|
119
|
-
def __exit__(self) -> None:
|
|
120
|
-
"""Exit function to provide with semantics.
|
|
121
|
-
|
|
122
|
-
Logs out the API.
|
|
123
|
-
"""
|
|
124
|
-
self.close()
|
|
125
|
-
|
|
126
|
-
def get_filters(self, dataset_name: str) -> list[dict[str, Any]]:
|
|
127
|
-
"""Returns filters available for the given dataset.
|
|
128
|
-
|
|
129
|
-
Args:
|
|
130
|
-
dataset_name: the dataset name e.g. landsat_ot_c2_l1
|
|
131
|
-
|
|
132
|
-
Returns:
|
|
133
|
-
list of filter objects
|
|
134
|
-
"""
|
|
135
|
-
response_dict = self.request("dataset-filters", {"datasetName": dataset_name})
|
|
136
|
-
if response_dict is None:
|
|
137
|
-
raise APIException("No response from API")
|
|
138
|
-
return response_dict["data"]
|
|
139
|
-
|
|
140
|
-
def scene_search(
|
|
141
|
-
self,
|
|
142
|
-
dataset_name: str,
|
|
143
|
-
acquisition_time_range: tuple[datetime, datetime] | None = None,
|
|
144
|
-
cloud_cover_range: tuple[int, int] | None = None,
|
|
145
|
-
bbox: tuple[float, float, float, float] | None = None,
|
|
146
|
-
metadata_filter: dict[str, Any] | None = None,
|
|
147
|
-
) -> list[dict[str, Any]]:
|
|
148
|
-
"""Search for scenes matching the arguments.
|
|
149
|
-
|
|
150
|
-
Args:
|
|
151
|
-
dataset_name: the dataset name e.g. landsat_ot_c2_l1
|
|
152
|
-
acquisition_time_range: optional filter on the acquisition time
|
|
153
|
-
cloud_cover_range: optional filter on the cloud cover
|
|
154
|
-
bbox: optional spatial filter
|
|
155
|
-
metadata_filter: optional metadata filter dict
|
|
156
|
-
"""
|
|
157
|
-
base_data: dict[str, Any] = {"datasetName": dataset_name, "sceneFilter": {}}
|
|
158
|
-
if acquisition_time_range:
|
|
159
|
-
base_data["sceneFilter"]["acquisitionFilter"] = {
|
|
160
|
-
"start": acquisition_time_range[0].isoformat(),
|
|
161
|
-
"end": acquisition_time_range[1].isoformat(),
|
|
162
|
-
}
|
|
163
|
-
if cloud_cover_range:
|
|
164
|
-
base_data["sceneFilter"]["cloudCoverFilter"] = {
|
|
165
|
-
"min": cloud_cover_range[0],
|
|
166
|
-
"max": cloud_cover_range[1],
|
|
167
|
-
"includeUnknown": False,
|
|
168
|
-
}
|
|
169
|
-
if bbox:
|
|
170
|
-
base_data["sceneFilter"]["spatialFilter"] = {
|
|
171
|
-
"filterType": "mbr",
|
|
172
|
-
"lowerLeft": {"longitude": bbox[0], "latitude": bbox[1]},
|
|
173
|
-
"upperRight": {"longitude": bbox[2], "latitude": bbox[3]},
|
|
174
|
-
}
|
|
175
|
-
if metadata_filter:
|
|
176
|
-
base_data["sceneFilter"]["metadataFilter"] = metadata_filter
|
|
177
|
-
|
|
178
|
-
starting_number = 1
|
|
179
|
-
results = []
|
|
180
|
-
while True:
|
|
181
|
-
cur_data = base_data.copy()
|
|
182
|
-
cur_data["startingNumber"] = starting_number
|
|
183
|
-
cur_data["maxResults"] = self.pagination_size
|
|
184
|
-
response_dict = self.request("scene-search", cur_data)
|
|
185
|
-
if response_dict is None:
|
|
186
|
-
raise APIException("No response from API")
|
|
187
|
-
data = response_dict["data"]
|
|
188
|
-
results.extend(data["results"])
|
|
189
|
-
if data["recordsReturned"] < self.pagination_size:
|
|
190
|
-
break
|
|
191
|
-
starting_number += self.pagination_size
|
|
192
|
-
|
|
193
|
-
return results
|
|
194
|
-
|
|
195
|
-
def get_scene_metadata(self, dataset_name: str, entity_id: str) -> dict[str, Any]:
|
|
196
|
-
"""Get detailed metadata for a scene.
|
|
197
|
-
|
|
198
|
-
Args:
|
|
199
|
-
dataset_name: the dataset name in which to search
|
|
200
|
-
entity_id: the entity ID of the scene
|
|
201
|
-
|
|
202
|
-
Returns:
|
|
203
|
-
full scene metadata
|
|
204
|
-
"""
|
|
205
|
-
response_dict = self.request(
|
|
206
|
-
"scene-metadata",
|
|
207
|
-
{
|
|
208
|
-
"datasetName": dataset_name,
|
|
209
|
-
"entityId": entity_id,
|
|
210
|
-
"metadataType": "full",
|
|
211
|
-
},
|
|
212
|
-
)
|
|
213
|
-
if response_dict is None:
|
|
214
|
-
raise APIException("No response from API")
|
|
215
|
-
return response_dict["data"]
|
|
216
|
-
|
|
217
|
-
def get_downloadable_products(
|
|
218
|
-
self, dataset_name: str, entity_id: str
|
|
219
|
-
) -> list[dict[str, Any]]:
|
|
220
|
-
"""Get the downloadable products for a given scene.
|
|
221
|
-
|
|
222
|
-
Args:
|
|
223
|
-
dataset_name: the dataset name
|
|
224
|
-
entity_id: the entity ID of the scene
|
|
225
|
-
|
|
226
|
-
Returns:
|
|
227
|
-
list of downloadable products
|
|
228
|
-
"""
|
|
229
|
-
data = {"datasetName": dataset_name, "entityIds": [entity_id]}
|
|
230
|
-
response_dict = self.request("download-options", data)
|
|
231
|
-
if response_dict is None:
|
|
232
|
-
raise APIException("No response from API")
|
|
233
|
-
return response_dict["data"]
|
|
234
|
-
|
|
235
|
-
def get_download_url(self, entity_id: str, product_id: str) -> str:
|
|
236
|
-
"""Get the download URL for a given product.
|
|
237
|
-
|
|
238
|
-
Args:
|
|
239
|
-
entity_id: the entity ID of the product
|
|
240
|
-
product_id: the product ID of the product
|
|
241
|
-
|
|
242
|
-
Returns:
|
|
243
|
-
the download URL
|
|
244
|
-
"""
|
|
245
|
-
label = str(uuid.uuid4())
|
|
246
|
-
data = {
|
|
247
|
-
"downloads": [
|
|
248
|
-
{"label": label, "entityId": entity_id, "productId": product_id}
|
|
249
|
-
]
|
|
250
|
-
}
|
|
251
|
-
response_dict = self.request("download-request", data)
|
|
252
|
-
if response_dict is None:
|
|
253
|
-
raise APIException("No response from API")
|
|
254
|
-
response = response_dict["data"]
|
|
255
|
-
while True:
|
|
256
|
-
response_dict = self.request("download-retrieve", {"label": label})
|
|
257
|
-
if response_dict is None:
|
|
258
|
-
raise APIException("No response from API")
|
|
259
|
-
response = response_dict["data"]
|
|
260
|
-
if len(response["available"]) > 0:
|
|
261
|
-
return response["available"][0]["url"]
|
|
262
|
-
if len(response["requested"]) == 0:
|
|
263
|
-
raise Exception("Did not get download URL")
|
|
264
|
-
if response["requested"][0].get("url"):
|
|
265
|
-
return response["requested"][0]["url"]
|
|
266
|
-
time.sleep(10)
|
|
24
|
+
from rslearn.utils.m2m_api import APIException, M2MAPIClient
|
|
267
25
|
|
|
268
26
|
|
|
269
27
|
class LandsatOliTirsItem(Item):
|
|
@@ -314,30 +72,26 @@ class LandsatOliTirs(DataSource):
|
|
|
314
72
|
|
|
315
73
|
def __init__(
|
|
316
74
|
self,
|
|
317
|
-
username: str,
|
|
318
|
-
sort_by: str | None = None,
|
|
319
|
-
password: str | None = None,
|
|
75
|
+
username: str | None = None,
|
|
320
76
|
token: str | None = None,
|
|
77
|
+
sort_by: str | None = None,
|
|
321
78
|
timeout: timedelta = timedelta(seconds=10),
|
|
322
79
|
context: DataSourceContext = DataSourceContext(),
|
|
323
80
|
):
|
|
324
81
|
"""Initialize a new LandsatOliTirs instance.
|
|
325
82
|
|
|
326
83
|
Args:
|
|
327
|
-
username: EROS username
|
|
84
|
+
username: EROS username (see M2MAPIClient).
|
|
85
|
+
token: EROS application token (see M2MAPIClient).
|
|
328
86
|
sort_by: can be "cloud_cover", default arbitrary order; only has effect for
|
|
329
87
|
SpaceMode.WITHIN.
|
|
330
|
-
password: EROS password (see M2MAPIClient).
|
|
331
|
-
token: EROS application token (see M2MAPIClient).
|
|
332
88
|
timeout: timeout for requests.
|
|
333
89
|
context: the data source context.
|
|
334
90
|
"""
|
|
335
91
|
self.sort_by = sort_by
|
|
336
92
|
self.timeout = timeout
|
|
337
93
|
|
|
338
|
-
self.client = M2MAPIClient(
|
|
339
|
-
username, password=password, token=token, timeout=timeout
|
|
340
|
-
)
|
|
94
|
+
self.client = M2MAPIClient(username=username, token=token, timeout=timeout)
|
|
341
95
|
|
|
342
96
|
def _scene_metadata_to_item(self, result: dict[str, Any]) -> LandsatOliTirsItem:
|
|
343
97
|
"""Convert scene metadata from the API to a LandsatOliTirsItem."""
|
|
@@ -429,9 +183,8 @@ class LandsatOliTirs(DataSource):
|
|
|
429
183
|
)
|
|
430
184
|
return self._scene_metadata_to_item(scene_metadata)
|
|
431
185
|
|
|
432
|
-
def deserialize_item(self, serialized_item:
|
|
186
|
+
def deserialize_item(self, serialized_item: dict) -> Item:
|
|
433
187
|
"""Deserializes an item from JSON-decoded data."""
|
|
434
|
-
assert isinstance(serialized_item, dict)
|
|
435
188
|
return LandsatOliTirsItem.deserialize(serialized_item)
|
|
436
189
|
|
|
437
190
|
def _get_download_urls(self, item: Item) -> dict[str, tuple[str, str]]:
|
|
@@ -291,7 +291,7 @@ class WorldCereal(LocalFiles):
|
|
|
291
291
|
raise ValueError(f"No AEZ files found for {self.band}")
|
|
292
292
|
|
|
293
293
|
super().__init__(
|
|
294
|
-
src_dir=tif_dir,
|
|
294
|
+
src_dir=tif_dir.absolute().as_uri(),
|
|
295
295
|
raster_item_specs=item_specs,
|
|
296
296
|
layer_type=LayerType.RASTER,
|
|
297
297
|
context=context,
|
rslearn/data_sources/worldpop.py
CHANGED
|
@@ -80,7 +80,7 @@ class WorldPop(LocalFiles):
|
|
|
80
80
|
worldpop_upath.mkdir(parents=True, exist_ok=True)
|
|
81
81
|
self.download_worldpop_data(worldpop_upath, timeout)
|
|
82
82
|
super().__init__(
|
|
83
|
-
src_dir=worldpop_upath,
|
|
83
|
+
src_dir=worldpop_upath.absolute().as_uri(),
|
|
84
84
|
layer_type=LayerType.RASTER,
|
|
85
85
|
context=context,
|
|
86
86
|
)
|
|
@@ -19,7 +19,7 @@ from rslearn.config import LayerConfig, QueryConfig
|
|
|
19
19
|
from rslearn.dataset import Window
|
|
20
20
|
from rslearn.dataset.materialize import RasterMaterializer
|
|
21
21
|
from rslearn.tile_stores import TileStore, TileStoreWithLayer
|
|
22
|
-
from rslearn.utils import PixelBounds, Projection, STGeometry
|
|
22
|
+
from rslearn.utils import PixelBounds, Projection, STGeometry, get_global_raster_bounds
|
|
23
23
|
from rslearn.utils.array import copy_spatial_array
|
|
24
24
|
from rslearn.utils.raster_format import get_transform_from_projection_and_bounds
|
|
25
25
|
|
|
@@ -184,7 +184,7 @@ class XyzTiles(DataSource, TileStore):
|
|
|
184
184
|
groups.append(cur_groups)
|
|
185
185
|
return groups
|
|
186
186
|
|
|
187
|
-
def deserialize_item(self, serialized_item:
|
|
187
|
+
def deserialize_item(self, serialized_item: dict) -> Item:
|
|
188
188
|
"""Deserializes an item from JSON-decoded data."""
|
|
189
189
|
return Item.deserialize(serialized_item)
|
|
190
190
|
|
|
@@ -278,13 +278,9 @@ class XyzTiles(DataSource, TileStore):
|
|
|
278
278
|
Returns:
|
|
279
279
|
the bounds of the raster in the projection.
|
|
280
280
|
"""
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
int(geom.shp.bounds[1]),
|
|
285
|
-
int(geom.shp.bounds[2]),
|
|
286
|
-
int(geom.shp.bounds[3]),
|
|
287
|
-
)
|
|
281
|
+
# XyzTiles is a global data source, so we return global raster bounds based on
|
|
282
|
+
# the projection.
|
|
283
|
+
return get_global_raster_bounds(projection)
|
|
288
284
|
|
|
289
285
|
def read_raster(
|
|
290
286
|
self,
|
rslearn/dataset/storage/file.py
CHANGED
|
@@ -15,7 +15,7 @@ from rslearn.dataset.window import (
|
|
|
15
15
|
get_window_layer_dir,
|
|
16
16
|
)
|
|
17
17
|
from rslearn.log_utils import get_logger
|
|
18
|
-
from rslearn.utils.fsspec import open_atomic
|
|
18
|
+
from rslearn.utils.fsspec import iter_nonhidden_subdirs, open_atomic
|
|
19
19
|
from rslearn.utils.mp import star_imap_unordered
|
|
20
20
|
|
|
21
21
|
from .storage import WindowStorage, WindowStorageFactory
|
|
@@ -77,8 +77,8 @@ class FileWindowStorage(WindowStorage):
|
|
|
77
77
|
window_dirs = []
|
|
78
78
|
if not groups:
|
|
79
79
|
groups = []
|
|
80
|
-
for
|
|
81
|
-
groups.append(
|
|
80
|
+
for group_dir in iter_nonhidden_subdirs(self.path / "windows"):
|
|
81
|
+
groups.append(group_dir.name)
|
|
82
82
|
for group in groups:
|
|
83
83
|
group_dir = self.path / "windows" / group
|
|
84
84
|
if not group_dir.exists():
|
|
@@ -86,16 +86,20 @@ class FileWindowStorage(WindowStorage):
|
|
|
86
86
|
f"Skipping group directory {group_dir} since it does not exist"
|
|
87
87
|
)
|
|
88
88
|
continue
|
|
89
|
+
if not group_dir.is_dir():
|
|
90
|
+
logger.warning(
|
|
91
|
+
f"Skipping group path {group_dir} since it is not a directory"
|
|
92
|
+
)
|
|
93
|
+
continue
|
|
89
94
|
if names:
|
|
90
|
-
|
|
95
|
+
for window_name in names:
|
|
96
|
+
window_dir = group_dir / window_name
|
|
97
|
+
if not window_dir.is_dir():
|
|
98
|
+
continue
|
|
99
|
+
window_dirs.append(window_dir)
|
|
91
100
|
else:
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
cur_names.append(p.name)
|
|
95
|
-
|
|
96
|
-
for window_name in cur_names:
|
|
97
|
-
window_dir = group_dir / window_name
|
|
98
|
-
window_dirs.append(window_dir)
|
|
101
|
+
for window_dir in iter_nonhidden_subdirs(group_dir):
|
|
102
|
+
window_dirs.append(window_dir)
|
|
99
103
|
|
|
100
104
|
if workers == 0:
|
|
101
105
|
windows = [load_window(self, window_dir) for window_dir in window_dirs]
|
|
@@ -162,7 +166,7 @@ class FileWindowStorage(WindowStorage):
|
|
|
162
166
|
return []
|
|
163
167
|
|
|
164
168
|
completed_layers = []
|
|
165
|
-
for layer_dir in layers_directory
|
|
169
|
+
for layer_dir in iter_nonhidden_subdirs(layers_directory):
|
|
166
170
|
layer_name, group_idx = get_layer_and_group_from_dir_name(layer_dir.name)
|
|
167
171
|
if not self.is_layer_completed(group, name, layer_name, group_idx):
|
|
168
172
|
continue
|
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
+
from einops import rearrange
|
|
6
7
|
|
|
7
8
|
from rslearn.train.model_context import ModelContext
|
|
8
9
|
|
|
@@ -79,7 +80,11 @@ class ConcatenateFeatures(IntermediateComponent):
|
|
|
79
80
|
)
|
|
80
81
|
|
|
81
82
|
add_data = torch.stack(
|
|
82
|
-
[
|
|
83
|
+
[
|
|
84
|
+
rearrange(input_data[self.key].image, "c t h w -> (c t) h w")
|
|
85
|
+
for input_data in context.inputs
|
|
86
|
+
],
|
|
87
|
+
dim=0,
|
|
83
88
|
)
|
|
84
89
|
add_features = self.conv_layers(add_data)
|
|
85
90
|
|
rslearn/tile_stores/default.py
CHANGED
|
@@ -15,6 +15,8 @@ from upath import UPath
|
|
|
15
15
|
from rslearn.const import WGS84_PROJECTION
|
|
16
16
|
from rslearn.utils.feature import Feature
|
|
17
17
|
from rslearn.utils.fsspec import (
|
|
18
|
+
iter_nonhidden_files,
|
|
19
|
+
iter_nonhidden_subdirs,
|
|
18
20
|
join_upath,
|
|
19
21
|
open_atomic,
|
|
20
22
|
open_rasterio_upath_reader,
|
|
@@ -129,7 +131,7 @@ class DefaultTileStore(TileStore):
|
|
|
129
131
|
ValueError: if no file is found.
|
|
130
132
|
"""
|
|
131
133
|
raster_dir = self._get_raster_dir(layer_name, item_name, bands)
|
|
132
|
-
for fname in raster_dir
|
|
134
|
+
for fname in iter_nonhidden_files(raster_dir):
|
|
133
135
|
# Ignore completed sentinel files, bands files, as well as temporary files created by
|
|
134
136
|
# open_atomic (in case this tile store is on local filesystem).
|
|
135
137
|
if fname.name == COMPLETED_FNAME:
|
|
@@ -175,7 +177,7 @@ class DefaultTileStore(TileStore):
|
|
|
175
177
|
return []
|
|
176
178
|
|
|
177
179
|
bands: list[list[str]] = []
|
|
178
|
-
for raster_dir in item_dir
|
|
180
|
+
for raster_dir in iter_nonhidden_subdirs(item_dir):
|
|
179
181
|
if not (raster_dir / BANDS_FNAME).exists():
|
|
180
182
|
# This is likely a legacy directory where the bands are only encoded in
|
|
181
183
|
# the directory name, so we have to rely on that.
|