rslearn 0.0.20__py3-none-any.whl → 0.0.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/data_sources/aws_open_data.py +11 -15
- rslearn/data_sources/aws_sentinel2_element84.py +374 -0
- rslearn/data_sources/climate_data_store.py +216 -29
- rslearn/data_sources/gcp_public_data.py +16 -0
- rslearn/data_sources/planetary_computer.py +28 -257
- rslearn/data_sources/soilgrids.py +331 -0
- rslearn/data_sources/stac.py +255 -0
- rslearn/models/attention_pooling.py +5 -2
- rslearn/models/olmoearth_pretrain/model.py +7 -8
- rslearn/train/dataset.py +44 -35
- rslearn/train/lightning_module.py +3 -3
- rslearn/train/tasks/embedding.py +2 -2
- rslearn/train/tasks/task.py +4 -2
- rslearn/utils/geometry.py +2 -2
- rslearn/utils/stac.py +173 -0
- {rslearn-0.0.20.dist-info → rslearn-0.0.22.dist-info}/METADATA +4 -1
- {rslearn-0.0.20.dist-info → rslearn-0.0.22.dist-info}/RECORD +22 -18
- {rslearn-0.0.20.dist-info → rslearn-0.0.22.dist-info}/WHEEL +0 -0
- {rslearn-0.0.20.dist-info → rslearn-0.0.22.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.20.dist-info → rslearn-0.0.22.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.20.dist-info → rslearn-0.0.22.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.20.dist-info → rslearn-0.0.22.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
"""A partial data source implementation providing get_items using a STAC API."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import shapely
|
|
7
|
+
from upath import UPath
|
|
8
|
+
|
|
9
|
+
from rslearn.config import QueryConfig
|
|
10
|
+
from rslearn.const import WGS84_PROJECTION
|
|
11
|
+
from rslearn.data_sources.data_source import Item, ItemLookupDataSource
|
|
12
|
+
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
13
|
+
from rslearn.log_utils import get_logger
|
|
14
|
+
from rslearn.utils.geometry import STGeometry
|
|
15
|
+
from rslearn.utils.stac import StacClient, StacItem
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SourceItem(Item):
|
|
21
|
+
"""An item in the StacDataSource data source."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
name: str,
|
|
26
|
+
geometry: STGeometry,
|
|
27
|
+
asset_urls: dict[str, str],
|
|
28
|
+
properties: dict[str, str],
|
|
29
|
+
):
|
|
30
|
+
"""Creates a new SourceItem.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
name: unique name of the item
|
|
34
|
+
geometry: the spatial and temporal extent of the item
|
|
35
|
+
asset_urls: map from asset key to the unsigned asset URL.
|
|
36
|
+
properties: properties requested by the data source implementation.
|
|
37
|
+
"""
|
|
38
|
+
super().__init__(name, geometry)
|
|
39
|
+
self.asset_urls = asset_urls
|
|
40
|
+
self.properties = properties
|
|
41
|
+
|
|
42
|
+
def serialize(self) -> dict[str, Any]:
|
|
43
|
+
"""Serializes the item to a JSON-encodable dictionary."""
|
|
44
|
+
d = super().serialize()
|
|
45
|
+
d["asset_urls"] = self.asset_urls
|
|
46
|
+
d["properties"] = self.properties
|
|
47
|
+
return d
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def deserialize(d: dict[str, Any]) -> "SourceItem":
|
|
51
|
+
"""Deserializes an item from a JSON-decoded dictionary."""
|
|
52
|
+
item = super(SourceItem, SourceItem).deserialize(d)
|
|
53
|
+
return SourceItem(
|
|
54
|
+
name=item.name,
|
|
55
|
+
geometry=item.geometry,
|
|
56
|
+
asset_urls=d["asset_urls"],
|
|
57
|
+
properties=d["properties"],
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class StacDataSource(ItemLookupDataSource[SourceItem]):
|
|
62
|
+
"""A partial data source implementing get_items using a STAC API.
|
|
63
|
+
|
|
64
|
+
This is a helper class that full implementations can extend to not have to worry
|
|
65
|
+
about the get_items and get_item_by_name implementation.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
endpoint: str,
|
|
71
|
+
collection_name: str,
|
|
72
|
+
query: dict[str, Any] | None = None,
|
|
73
|
+
sort_by: str | None = None,
|
|
74
|
+
sort_ascending: bool = True,
|
|
75
|
+
required_assets: list[str] | None = None,
|
|
76
|
+
cache_dir: UPath | None = None,
|
|
77
|
+
limit: int = 100,
|
|
78
|
+
properties_to_record: list[str] = [],
|
|
79
|
+
):
|
|
80
|
+
"""Create a new StacDataSource.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
endpoint: the STAC endpoint to use.
|
|
84
|
+
collection_name: the STAC collection name.
|
|
85
|
+
query: optional STAC query dict to include in searches, e.g. {"eo:cloud_cover": {"lt": 50}}.
|
|
86
|
+
sort_by: sort results by this STAC property.
|
|
87
|
+
sort_ascending: if sort_by is set, sort in ascending order (default).
|
|
88
|
+
Otherwise sort in descending order.
|
|
89
|
+
required_assets: if set, we ignore items that do not have all of these
|
|
90
|
+
asset keys.
|
|
91
|
+
cache_dir: optional cache directory to cache items. This is recommended if
|
|
92
|
+
allowing direct materialization from the data source, since it will
|
|
93
|
+
likely be necessary to make lots of get_item_by_name calls during
|
|
94
|
+
materialization. TODO: give direct materialization access to the Item
|
|
95
|
+
object.
|
|
96
|
+
limit: limit to pass to search queries.
|
|
97
|
+
properties_to_record: if these properties on the STAC item exist, they are
|
|
98
|
+
are retained in the SourceItem when we initialize it.
|
|
99
|
+
"""
|
|
100
|
+
self.client = StacClient(endpoint)
|
|
101
|
+
self.collection_name = collection_name
|
|
102
|
+
self.query = query
|
|
103
|
+
self.sort_by = sort_by
|
|
104
|
+
self.sort_ascending = sort_ascending
|
|
105
|
+
self.required_assets = required_assets
|
|
106
|
+
self.cache_dir = cache_dir
|
|
107
|
+
self.limit = limit
|
|
108
|
+
self.properties_to_record = properties_to_record
|
|
109
|
+
|
|
110
|
+
def _stac_item_to_item(self, stac_item: StacItem) -> SourceItem:
|
|
111
|
+
# Make sure geometry, time range, and assets are set.
|
|
112
|
+
if stac_item.geometry is None:
|
|
113
|
+
raise ValueError("got unexpected item with no geometry")
|
|
114
|
+
if stac_item.time_range is None:
|
|
115
|
+
raise ValueError("got unexpected item with no time range")
|
|
116
|
+
if stac_item.assets is None:
|
|
117
|
+
raise ValueError("got unexpected item with no assets")
|
|
118
|
+
|
|
119
|
+
shp = shapely.geometry.shape(stac_item.geometry)
|
|
120
|
+
geom = STGeometry(WGS84_PROJECTION, shp, stac_item.time_range)
|
|
121
|
+
asset_urls = {
|
|
122
|
+
asset_key: asset_obj.href
|
|
123
|
+
for asset_key, asset_obj in stac_item.assets.items()
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
# Keep any properties requested by the data source implementation.
|
|
127
|
+
properties = {}
|
|
128
|
+
for prop_name in self.properties_to_record:
|
|
129
|
+
if prop_name not in stac_item.properties:
|
|
130
|
+
continue
|
|
131
|
+
properties[prop_name] = stac_item.properties[prop_name]
|
|
132
|
+
|
|
133
|
+
return SourceItem(stac_item.id, geom, asset_urls, properties)
|
|
134
|
+
|
|
135
|
+
def get_item_by_name(self, name: str) -> SourceItem:
|
|
136
|
+
"""Gets an item by name.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
name: the name of the item to get
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
the item object
|
|
143
|
+
"""
|
|
144
|
+
# If cache_dir is set, we cache the item. First here we check if it is already
|
|
145
|
+
# in the cache.
|
|
146
|
+
cache_fname: UPath | None = None
|
|
147
|
+
if self.cache_dir:
|
|
148
|
+
cache_fname = self.cache_dir / f"{name}.json"
|
|
149
|
+
if cache_fname is not None and cache_fname.exists():
|
|
150
|
+
with cache_fname.open() as f:
|
|
151
|
+
return SourceItem.deserialize(json.load(f))
|
|
152
|
+
|
|
153
|
+
# No cache or not in cache, so we need to make the STAC request.
|
|
154
|
+
logger.debug(f"Getting STAC item {name}")
|
|
155
|
+
stac_items = self.client.search(ids=[name], collections=[self.collection_name])
|
|
156
|
+
|
|
157
|
+
if len(stac_items) == 0:
|
|
158
|
+
raise ValueError(
|
|
159
|
+
f"Item {name} not found in collection {self.collection_name}"
|
|
160
|
+
)
|
|
161
|
+
if len(stac_items) > 1:
|
|
162
|
+
raise ValueError(
|
|
163
|
+
f"Multiple items found for ID {name} in collection {self.collection_name}"
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
stac_item = stac_items[0]
|
|
167
|
+
item = self._stac_item_to_item(stac_item)
|
|
168
|
+
|
|
169
|
+
# Finally we cache it if cache_dir is set.
|
|
170
|
+
if cache_fname is not None:
|
|
171
|
+
with cache_fname.open("w") as f:
|
|
172
|
+
json.dump(item.serialize(), f)
|
|
173
|
+
|
|
174
|
+
return item
|
|
175
|
+
|
|
176
|
+
def get_items(
|
|
177
|
+
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
178
|
+
) -> list[list[list[SourceItem]]]:
|
|
179
|
+
"""Get a list of items in the data source intersecting the given geometries.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
geometries: the spatiotemporal geometries
|
|
183
|
+
query_config: the query configuration
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
List of groups of items that should be retrieved for each geometry.
|
|
187
|
+
"""
|
|
188
|
+
groups = []
|
|
189
|
+
for geometry in geometries:
|
|
190
|
+
# Get potentially relevant items from the collection by performing one search
|
|
191
|
+
# for each requested geometry.
|
|
192
|
+
wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
|
|
193
|
+
logger.debug("performing STAC search for geometry %s", wgs84_geometry)
|
|
194
|
+
stac_items = self.client.search(
|
|
195
|
+
collections=[self.collection_name],
|
|
196
|
+
intersects=json.loads(shapely.to_geojson(wgs84_geometry.shp)),
|
|
197
|
+
date_time=wgs84_geometry.time_range,
|
|
198
|
+
query=self.query,
|
|
199
|
+
limit=self.limit,
|
|
200
|
+
)
|
|
201
|
+
logger.debug("STAC search yielded %d items", len(stac_items))
|
|
202
|
+
|
|
203
|
+
if self.required_assets is not None:
|
|
204
|
+
# Filter out items that are missing any of the assets in self.asset_bands.
|
|
205
|
+
good_stac_items = []
|
|
206
|
+
for stac_item in stac_items:
|
|
207
|
+
if stac_item.assets is None:
|
|
208
|
+
raise ValueError(f"got STAC item {stac_item.id} with no assets")
|
|
209
|
+
|
|
210
|
+
good = True
|
|
211
|
+
for asset_key in self.required_assets:
|
|
212
|
+
if asset_key in stac_item.assets:
|
|
213
|
+
continue
|
|
214
|
+
good = False
|
|
215
|
+
break
|
|
216
|
+
if good:
|
|
217
|
+
good_stac_items.append(stac_item)
|
|
218
|
+
logger.debug(
|
|
219
|
+
"required_assets filter from %d to %d items",
|
|
220
|
+
len(stac_items),
|
|
221
|
+
len(good_stac_items),
|
|
222
|
+
)
|
|
223
|
+
stac_items = good_stac_items
|
|
224
|
+
|
|
225
|
+
if self.sort_by is not None:
|
|
226
|
+
sort_by = self.sort_by
|
|
227
|
+
stac_items.sort(
|
|
228
|
+
key=lambda stac_item: stac_item.properties[sort_by],
|
|
229
|
+
reverse=not self.sort_ascending,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
candidate_items = [
|
|
233
|
+
self._stac_item_to_item(stac_item) for stac_item in stac_items
|
|
234
|
+
]
|
|
235
|
+
|
|
236
|
+
# Since we made the STAC request, might as well save these to the cache.
|
|
237
|
+
if self.cache_dir is not None:
|
|
238
|
+
for item in candidate_items:
|
|
239
|
+
cache_fname = self.cache_dir / f"{item.name}.json"
|
|
240
|
+
if cache_fname.exists():
|
|
241
|
+
continue
|
|
242
|
+
with cache_fname.open("w") as f:
|
|
243
|
+
json.dump(item.serialize(), f)
|
|
244
|
+
|
|
245
|
+
cur_groups = match_candidate_items_to_window(
|
|
246
|
+
geometry, candidate_items, query_config
|
|
247
|
+
)
|
|
248
|
+
groups.append(cur_groups)
|
|
249
|
+
|
|
250
|
+
return groups
|
|
251
|
+
|
|
252
|
+
def deserialize_item(self, serialized_item: Any) -> SourceItem:
|
|
253
|
+
"""Deserializes an item from JSON-decoded data."""
|
|
254
|
+
assert isinstance(serialized_item, dict)
|
|
255
|
+
return SourceItem.deserialize(serialized_item)
|
|
@@ -150,8 +150,11 @@ class AttentionPool(IntermediateComponent):
|
|
|
150
150
|
D // self.num_heads
|
|
151
151
|
)
|
|
152
152
|
attn_weights = F.softmax(attn_scores, dim=-1)
|
|
153
|
-
x = torch.matmul(attn_weights, v) # [B,
|
|
154
|
-
|
|
153
|
+
x = torch.matmul(attn_weights, v) # [B*H*W, num_heads, 1, D_head]
|
|
154
|
+
x = x.squeeze(-2) # [B*H*W, num_heads, D_head]
|
|
155
|
+
return rearrange(
|
|
156
|
+
x, "(b h w) nh dh -> b (nh dh) h w", b=B, h=H, w=W
|
|
157
|
+
) # [B, D, H, W]
|
|
155
158
|
|
|
156
159
|
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
157
160
|
"""Forward pass for attention pooling linear probe.
|
|
@@ -159,20 +159,19 @@ class OlmoEarth(FeatureExtractor):
|
|
|
159
159
|
that contains the distributed checkpoint. This is the format produced by
|
|
160
160
|
pre-training runs in olmoearth_pretrain.
|
|
161
161
|
"""
|
|
162
|
-
# We avoid loading the train module here because it depends on running within
|
|
163
|
-
# olmo_core.
|
|
164
|
-
# Only pull in olmo_core when trying to load a distributed checkpoint to avoid dependency.
|
|
165
|
-
require_olmo_core("_load_model_from_checkpoint")
|
|
166
|
-
from olmo_core.distributed.checkpoint import load_model_and_optim_state
|
|
167
|
-
|
|
168
162
|
with (checkpoint_upath / "config.json").open() as f:
|
|
169
163
|
config_dict = json.load(f)
|
|
170
164
|
model_config = Config.from_dict(config_dict["model"])
|
|
171
165
|
|
|
172
166
|
model = model_config.build()
|
|
173
167
|
|
|
174
|
-
# Load the checkpoint.
|
|
168
|
+
# Load the checkpoint (requires olmo_core for distributed checkpoint loading).
|
|
175
169
|
if not random_initialization:
|
|
170
|
+
require_olmo_core(
|
|
171
|
+
"_load_model_from_checkpoint with random_initialization=False"
|
|
172
|
+
)
|
|
173
|
+
from olmo_core.distributed.checkpoint import load_model_and_optim_state
|
|
174
|
+
|
|
176
175
|
train_module_dir = checkpoint_upath / "model_and_optim"
|
|
177
176
|
load_model_and_optim_state(str(train_module_dir), model)
|
|
178
177
|
logger.info(f"loaded OlmoEarth encoder from {train_module_dir}")
|
|
@@ -242,7 +241,7 @@ class OlmoEarth(FeatureExtractor):
|
|
|
242
241
|
present_modalities.append(modality)
|
|
243
242
|
tensors = []
|
|
244
243
|
for idx, inp in enumerate(context.inputs):
|
|
245
|
-
assert isinstance(inp, RasterImage)
|
|
244
|
+
assert isinstance(inp[modality], RasterImage)
|
|
246
245
|
tensors.append(inp[modality].image)
|
|
247
246
|
cur_timestamps = inp[modality].timestamps
|
|
248
247
|
if cur_timestamps is not None and len(cur_timestamps) > len(
|
rslearn/train/dataset.py
CHANGED
|
@@ -205,8 +205,7 @@ def read_raster_layer_for_data_input(
|
|
|
205
205
|
group_idx: int,
|
|
206
206
|
layer_config: LayerConfig,
|
|
207
207
|
data_input: DataInput,
|
|
208
|
-
|
|
209
|
-
) -> tuple[torch.Tensor, tuple[datetime, datetime] | None]:
|
|
208
|
+
) -> torch.Tensor:
|
|
210
209
|
"""Read a raster layer for a DataInput.
|
|
211
210
|
|
|
212
211
|
This scans the available rasters for the layer at the window to determine which
|
|
@@ -219,11 +218,9 @@ def read_raster_layer_for_data_input(
|
|
|
219
218
|
group_idx: the item group.
|
|
220
219
|
layer_config: the layer configuration.
|
|
221
220
|
data_input: the DataInput that specifies the bands and dtype.
|
|
222
|
-
layer_data: the WindowLayerData associated with this layer and window.
|
|
223
221
|
|
|
224
222
|
Returns:
|
|
225
|
-
|
|
226
|
-
with that data.
|
|
223
|
+
Raster data as a tensor.
|
|
227
224
|
"""
|
|
228
225
|
# See what different sets of bands we need to read to get all the
|
|
229
226
|
# configured bands.
|
|
@@ -294,34 +291,46 @@ def read_raster_layer_for_data_input(
|
|
|
294
291
|
src[src_indexes, :, :].astype(data_input.dtype.get_numpy_dtype())
|
|
295
292
|
)
|
|
296
293
|
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
294
|
+
return image
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def read_layer_time_range(
|
|
298
|
+
layer_data: WindowLayerData | None, group_idx: int
|
|
299
|
+
) -> tuple[datetime, datetime] | None:
|
|
300
|
+
"""Extract the combined time range from all items in a layer data group.
|
|
301
|
+
|
|
302
|
+
Returns the min start time and max end time across all items, or None if
|
|
303
|
+
no items have time ranges.
|
|
304
|
+
|
|
305
|
+
Raises:
|
|
306
|
+
ValueError: If some items have time_range and others don't.
|
|
307
|
+
"""
|
|
308
|
+
if layer_data is None:
|
|
309
|
+
return None
|
|
310
|
+
|
|
311
|
+
serialized_items = layer_data.serialized_item_groups[group_idx]
|
|
312
|
+
if not serialized_items:
|
|
313
|
+
return None
|
|
314
|
+
|
|
315
|
+
first_item = Item.deserialize(serialized_items[0])
|
|
316
|
+
if first_item.geometry.time_range is None:
|
|
317
|
+
return None
|
|
318
|
+
|
|
319
|
+
# If the first item has a time_range, all items must have one
|
|
320
|
+
time_ranges: list[tuple[datetime, datetime]] = []
|
|
321
|
+
for serialized_item in serialized_items:
|
|
322
|
+
item = Item.deserialize(serialized_item)
|
|
323
|
+
if item.geometry.time_range is None:
|
|
324
|
+
raise ValueError(
|
|
325
|
+
f"Item '{item.name}' has no time_range, but first item does. "
|
|
326
|
+
"All items in a group must consistently have or lack time_range."
|
|
322
327
|
)
|
|
328
|
+
time_ranges.append(item.geometry.time_range)
|
|
323
329
|
|
|
324
|
-
return
|
|
330
|
+
return (
|
|
331
|
+
min(tr[0] for tr in time_ranges),
|
|
332
|
+
max(tr[1] for tr in time_ranges),
|
|
333
|
+
)
|
|
325
334
|
|
|
326
335
|
|
|
327
336
|
def read_data_input(
|
|
@@ -378,17 +387,17 @@ def read_data_input(
|
|
|
378
387
|
time_ranges: list[tuple[datetime, datetime] | None] = []
|
|
379
388
|
for layer_name, group_idx in layers_to_read:
|
|
380
389
|
layer_config = dataset.layers[layer_name]
|
|
381
|
-
image
|
|
390
|
+
image = read_raster_layer_for_data_input(
|
|
382
391
|
window,
|
|
383
392
|
bounds,
|
|
384
393
|
layer_name,
|
|
385
394
|
group_idx,
|
|
386
395
|
layer_config,
|
|
387
396
|
data_input,
|
|
388
|
-
# some layers (e.g. "label_raster") won't have associated
|
|
389
|
-
# layer datas
|
|
390
|
-
layer_datas[layer_name] if layer_name in layer_datas else None,
|
|
391
397
|
)
|
|
398
|
+
# some layers (e.g. "label_raster") won't have associated layer datas
|
|
399
|
+
layer_data = layer_datas.get(layer_name)
|
|
400
|
+
time_range = read_layer_time_range(layer_data, group_idx)
|
|
392
401
|
if len(time_ranges) > 0:
|
|
393
402
|
if type(time_ranges[-1]) is not type(time_range):
|
|
394
403
|
raise ValueError(
|
|
@@ -345,14 +345,14 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
345
345
|
)
|
|
346
346
|
|
|
347
347
|
if self.visualize_dir:
|
|
348
|
-
for
|
|
349
|
-
|
|
348
|
+
for inp, target, output, metadata in zip(
|
|
349
|
+
inputs, targets, outputs, metadatas
|
|
350
350
|
):
|
|
351
351
|
images = self.task.visualize(inp, target, output)
|
|
352
352
|
for image_suffix, image in images.items():
|
|
353
353
|
out_fname = os.path.join(
|
|
354
354
|
self.visualize_dir,
|
|
355
|
-
f"{metadata
|
|
355
|
+
f"{metadata.window_name}_{metadata.patch_bounds[0]}_{metadata.patch_bounds[1]}_{image_suffix}.png",
|
|
356
356
|
)
|
|
357
357
|
Image.fromarray(image).save(out_fname)
|
|
358
358
|
|
rslearn/train/tasks/embedding.py
CHANGED
|
@@ -6,7 +6,7 @@ import numpy.typing as npt
|
|
|
6
6
|
import torch
|
|
7
7
|
from torchmetrics import MetricCollection
|
|
8
8
|
|
|
9
|
-
from rslearn.models.component import FeatureMaps
|
|
9
|
+
from rslearn.models.component import FeatureMaps, Predictor
|
|
10
10
|
from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
|
|
11
11
|
from rslearn.utils import Feature
|
|
12
12
|
|
|
@@ -83,7 +83,7 @@ class EmbeddingTask(Task):
|
|
|
83
83
|
return MetricCollection({})
|
|
84
84
|
|
|
85
85
|
|
|
86
|
-
class EmbeddingHead:
|
|
86
|
+
class EmbeddingHead(Predictor):
|
|
87
87
|
"""Head for embedding task.
|
|
88
88
|
|
|
89
89
|
It just adds a dummy loss to act as a Predictor.
|
rslearn/train/tasks/task.py
CHANGED
|
@@ -108,8 +108,10 @@ class BasicTask(Task):
|
|
|
108
108
|
Returns:
|
|
109
109
|
a dictionary mapping image name to visualization image
|
|
110
110
|
"""
|
|
111
|
-
|
|
112
|
-
|
|
111
|
+
raster_image = input_dict["image"]
|
|
112
|
+
assert isinstance(raster_image, RasterImage)
|
|
113
|
+
# We don't really handle time series here, just use the first timestep.
|
|
114
|
+
image = raster_image.image.cpu()[self.image_bands, 0, :, :]
|
|
113
115
|
if self.remap_values:
|
|
114
116
|
factor = (self.remap_values[1][1] - self.remap_values[1][0]) / (
|
|
115
117
|
self.remap_values[0][1] - self.remap_values[0][0]
|
rslearn/utils/geometry.py
CHANGED
|
@@ -153,8 +153,8 @@ class ResolutionFactor:
|
|
|
153
153
|
else:
|
|
154
154
|
return Projection(
|
|
155
155
|
projection.crs,
|
|
156
|
-
projection.x_resolution
|
|
157
|
-
projection.y_resolution
|
|
156
|
+
projection.x_resolution / self.numerator,
|
|
157
|
+
projection.y_resolution / self.numerator,
|
|
158
158
|
)
|
|
159
159
|
|
|
160
160
|
def multiply_bounds(self, bounds: PixelBounds) -> PixelBounds:
|
rslearn/utils/stac.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""STAC API client."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import requests
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
Bbox = tuple[float, float, float, float]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class StacAsset:
|
|
17
|
+
"""A STAC asset."""
|
|
18
|
+
|
|
19
|
+
href: str
|
|
20
|
+
title: str | None
|
|
21
|
+
type: str | None
|
|
22
|
+
roles: list[str] | None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True)
|
|
26
|
+
class StacItem:
|
|
27
|
+
"""A STAC item."""
|
|
28
|
+
|
|
29
|
+
id: str
|
|
30
|
+
properties: dict[str, Any]
|
|
31
|
+
collection: str | None
|
|
32
|
+
bbox: Bbox | None
|
|
33
|
+
geometry: dict[str, Any] | None
|
|
34
|
+
assets: dict[str, StacAsset] | None
|
|
35
|
+
time_range: tuple[datetime, datetime] | None
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def from_dict(cls, item: dict[str, Any]) -> "StacItem":
|
|
39
|
+
"""Create a STAC item from the item dict returned from API."""
|
|
40
|
+
properties = item.get("properties", {})
|
|
41
|
+
|
|
42
|
+
# Parse bbox.
|
|
43
|
+
bbox: Bbox | None = None
|
|
44
|
+
if "bbox" in item:
|
|
45
|
+
if len(item["bbox"]) != 4:
|
|
46
|
+
raise NotImplementedError(
|
|
47
|
+
f"got bbox with {len(item['bbox'])} coordinates but only 4 coordinates is implemented"
|
|
48
|
+
)
|
|
49
|
+
bbox = tuple(item["bbox"])
|
|
50
|
+
|
|
51
|
+
# Parse assets.
|
|
52
|
+
assets: dict[str, StacAsset] = {}
|
|
53
|
+
for name, asset in item.get("assets", {}).items():
|
|
54
|
+
assets[name] = StacAsset(
|
|
55
|
+
href=asset["href"],
|
|
56
|
+
title=asset.get("title"),
|
|
57
|
+
type=asset.get("type"),
|
|
58
|
+
roles=asset.get("roles"),
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# Parse time range.
|
|
62
|
+
time_range: tuple[datetime, datetime] | None = None
|
|
63
|
+
if "start_datetime" in properties and "end_datetime" in properties:
|
|
64
|
+
time_range = (
|
|
65
|
+
datetime.fromisoformat(properties["start_datetime"]),
|
|
66
|
+
datetime.fromisoformat(properties["end_datetime"]),
|
|
67
|
+
)
|
|
68
|
+
elif "datetime" in properties:
|
|
69
|
+
ts = datetime.fromisoformat(properties["datetime"])
|
|
70
|
+
time_range = (ts, ts)
|
|
71
|
+
|
|
72
|
+
return cls(
|
|
73
|
+
id=item["id"],
|
|
74
|
+
properties=properties,
|
|
75
|
+
collection=item.get("collection"),
|
|
76
|
+
bbox=bbox,
|
|
77
|
+
geometry=item.get("geometry"),
|
|
78
|
+
assets=assets,
|
|
79
|
+
time_range=time_range,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class StacClient:
|
|
84
|
+
"""Limited functionality client for STAC APIs."""
|
|
85
|
+
|
|
86
|
+
def __init__(self, endpoint: str):
|
|
87
|
+
"""Create a new StacClient.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
endpoint: the STAC endpoint (base URL)
|
|
91
|
+
"""
|
|
92
|
+
self.endpoint = endpoint
|
|
93
|
+
self.session = requests.Session()
|
|
94
|
+
|
|
95
|
+
def search(
|
|
96
|
+
self,
|
|
97
|
+
collections: list[str] | None = None,
|
|
98
|
+
bbox: Bbox | None = None,
|
|
99
|
+
intersects: dict[str, Any] | None = None,
|
|
100
|
+
date_time: datetime | tuple[datetime, datetime] | None = None,
|
|
101
|
+
ids: list[str] | None = None,
|
|
102
|
+
limit: int | None = None,
|
|
103
|
+
query: dict[str, Any] | None = None,
|
|
104
|
+
) -> list[StacItem]:
|
|
105
|
+
"""Execute a STAC item search.
|
|
106
|
+
|
|
107
|
+
We use the JSON POST API. Pagination is handled so the returned items are
|
|
108
|
+
concatenated across all available pages.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
collections: only search within the provided collection(s).
|
|
112
|
+
bbox: only return features intersecting the provided bounding box.
|
|
113
|
+
intersects: only return features intersecting this GeoJSON geometry.
|
|
114
|
+
date_time: only return features that have a temporal property intersecting
|
|
115
|
+
the provided time range or timestamp.
|
|
116
|
+
ids: only return the provided item IDs.
|
|
117
|
+
limit: number of items per page. We will read all the pages.
|
|
118
|
+
query: query dict, if STAC query extension is supported by this API. See
|
|
119
|
+
https://github.com/stac-api-extensions/query.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
list of matching STAC items.
|
|
123
|
+
"""
|
|
124
|
+
# Build JSON request data.
|
|
125
|
+
request_data: dict[str, Any] = {}
|
|
126
|
+
if collections is not None:
|
|
127
|
+
request_data["collections"] = collections
|
|
128
|
+
if bbox is not None:
|
|
129
|
+
request_data["bbox"] = bbox
|
|
130
|
+
if intersects is not None:
|
|
131
|
+
request_data["intersects"] = intersects
|
|
132
|
+
if date_time is not None:
|
|
133
|
+
if isinstance(date_time, tuple):
|
|
134
|
+
start_time = date_time[0].isoformat().replace("+00:00", "Z")
|
|
135
|
+
end_time = date_time[1].isoformat().replace("+00:00", "Z")
|
|
136
|
+
request_data["datetime"] = f"{start_time}/{end_time}"
|
|
137
|
+
else:
|
|
138
|
+
request_data["datetime"] = date_time.isoformat().replace("+00:00", "Z")
|
|
139
|
+
if ids is not None:
|
|
140
|
+
request_data["ids"] = ids
|
|
141
|
+
if limit is not None:
|
|
142
|
+
request_data["limit"] = limit
|
|
143
|
+
if query is not None:
|
|
144
|
+
request_data["query"] = query
|
|
145
|
+
|
|
146
|
+
# Handle pagination.
|
|
147
|
+
cur_url = self.endpoint + "/search"
|
|
148
|
+
items: list[StacItem] = []
|
|
149
|
+
while True:
|
|
150
|
+
logger.debug("Reading STAC items from %s", cur_url)
|
|
151
|
+
response = self.session.post(url=cur_url, json=request_data)
|
|
152
|
+
response.raise_for_status()
|
|
153
|
+
data = response.json()
|
|
154
|
+
for item_dict in data["features"]:
|
|
155
|
+
items.append(StacItem.from_dict(item_dict))
|
|
156
|
+
|
|
157
|
+
next_link = None
|
|
158
|
+
next_request_data: dict[str, Any] = {}
|
|
159
|
+
for link in data.get("links", []):
|
|
160
|
+
if "rel" not in link or link["rel"] != "next":
|
|
161
|
+
continue
|
|
162
|
+
assert link["method"] == "POST"
|
|
163
|
+
next_link = link["href"]
|
|
164
|
+
next_request_data = link["body"]
|
|
165
|
+
break
|
|
166
|
+
|
|
167
|
+
if next_link is None:
|
|
168
|
+
break
|
|
169
|
+
|
|
170
|
+
cur_url = next_link
|
|
171
|
+
request_data = next_request_data
|
|
172
|
+
|
|
173
|
+
return items
|