rslearn 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/data_sources/planetary_computer.py +149 -1
- rslearn/data_sources/stac.py +24 -3
- rslearn/main.py +4 -1
- rslearn/models/simple_time_series.py +1 -1
- rslearn/train/lightning_module.py +21 -8
- rslearn/train/tasks/multi_task.py +8 -5
- rslearn/train/tasks/per_pixel_regression.py +1 -1
- rslearn/train/tasks/segmentation.py +163 -22
- rslearn/utils/raster_format.py +17 -0
- rslearn/utils/stac.py +4 -0
- {rslearn-0.0.22.dist-info → rslearn-0.0.24.dist-info}/METADATA +1 -1
- {rslearn-0.0.22.dist-info → rslearn-0.0.24.dist-info}/RECORD +17 -17
- {rslearn-0.0.22.dist-info → rslearn-0.0.24.dist-info}/WHEEL +1 -1
- {rslearn-0.0.22.dist-info → rslearn-0.0.24.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.22.dist-info → rslearn-0.0.24.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.22.dist-info → rslearn-0.0.24.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.22.dist-info → rslearn-0.0.24.dist-info}/top_level.txt +0 -0
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
import os
|
|
4
4
|
import tempfile
|
|
5
5
|
import xml.etree.ElementTree as ET
|
|
6
|
-
from datetime import timedelta
|
|
6
|
+
from datetime import datetime, timedelta
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
9
9
|
import affine
|
|
@@ -12,6 +12,7 @@ import planetary_computer
|
|
|
12
12
|
import rasterio
|
|
13
13
|
import requests
|
|
14
14
|
from rasterio.enums import Resampling
|
|
15
|
+
from typing_extensions import override
|
|
15
16
|
from upath import UPath
|
|
16
17
|
|
|
17
18
|
from rslearn.config import LayerConfig
|
|
@@ -24,11 +25,104 @@ from rslearn.tile_stores import TileStore, TileStoreWithLayer
|
|
|
24
25
|
from rslearn.utils.fsspec import join_upath
|
|
25
26
|
from rslearn.utils.geometry import PixelBounds, Projection, STGeometry
|
|
26
27
|
from rslearn.utils.raster_format import get_raster_projection_and_bounds
|
|
28
|
+
from rslearn.utils.stac import StacClient, StacItem
|
|
27
29
|
|
|
28
30
|
from .copernicus import get_harmonize_callback
|
|
29
31
|
|
|
30
32
|
logger = get_logger(__name__)
|
|
31
33
|
|
|
34
|
+
# Max limit accepted by Planetary Computer API.
|
|
35
|
+
PLANETARY_COMPUTER_LIMIT = 1000
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class PlanetaryComputerStacClient(StacClient):
|
|
39
|
+
"""A StacClient subclass that handles Planetary Computer's pagination limits.
|
|
40
|
+
|
|
41
|
+
Planetary Computer STAC API does not support standard pagination and has a max
|
|
42
|
+
limit of 1000. If the initial query returns 1000 items, this client paginates
|
|
43
|
+
by sorting by ID and using gt (greater than) queries to fetch subsequent pages.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
@override
|
|
47
|
+
def search(
|
|
48
|
+
self,
|
|
49
|
+
collections: list[str] | None = None,
|
|
50
|
+
bbox: tuple[float, float, float, float] | None = None,
|
|
51
|
+
intersects: dict[str, Any] | None = None,
|
|
52
|
+
date_time: datetime | tuple[datetime, datetime] | None = None,
|
|
53
|
+
ids: list[str] | None = None,
|
|
54
|
+
limit: int | None = None,
|
|
55
|
+
query: dict[str, Any] | None = None,
|
|
56
|
+
sortby: list[dict[str, str]] | None = None,
|
|
57
|
+
) -> list[StacItem]:
|
|
58
|
+
# We will use sortby for pagination, so the caller must not set it.
|
|
59
|
+
if sortby is not None:
|
|
60
|
+
raise ValueError("sortby must not be set for PlanetaryComputerStacClient")
|
|
61
|
+
|
|
62
|
+
# First, try a simple query with the PC limit to detect if pagination is needed.
|
|
63
|
+
# We always use PLANETARY_COMPUTER_LIMIT for the request because PC doesn't
|
|
64
|
+
# support standard pagination, and we need to detect when we hit the limit
|
|
65
|
+
# to switch to ID-based pagination.
|
|
66
|
+
# We could just start sorting by ID here and do pagination, but we treate it as
|
|
67
|
+
# a special case to avoid sorting since that seems to speed up the query.
|
|
68
|
+
stac_items = super().search(
|
|
69
|
+
collections=collections,
|
|
70
|
+
bbox=bbox,
|
|
71
|
+
intersects=intersects,
|
|
72
|
+
date_time=date_time,
|
|
73
|
+
ids=ids,
|
|
74
|
+
limit=PLANETARY_COMPUTER_LIMIT,
|
|
75
|
+
query=query,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# If we got fewer than the PC limit, we have all the results.
|
|
79
|
+
if len(stac_items) < PLANETARY_COMPUTER_LIMIT:
|
|
80
|
+
return stac_items
|
|
81
|
+
|
|
82
|
+
# We hit the limit, so we need to paginate by ID.
|
|
83
|
+
# Re-fetch with sorting by ID to ensure consistent ordering for pagination.
|
|
84
|
+
logger.debug(
|
|
85
|
+
"Initial request returned %d items (at limit), switching to ID pagination",
|
|
86
|
+
len(stac_items),
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
all_items: list[StacItem] = []
|
|
90
|
+
last_id: str | None = None
|
|
91
|
+
|
|
92
|
+
while True:
|
|
93
|
+
# Build query with id > last_id if we're paginating.
|
|
94
|
+
combined_query: dict[str, Any] = dict(query) if query else {}
|
|
95
|
+
if last_id is not None:
|
|
96
|
+
combined_query["id"] = {"gt": last_id}
|
|
97
|
+
|
|
98
|
+
stac_items = super().search(
|
|
99
|
+
collections=collections,
|
|
100
|
+
bbox=bbox,
|
|
101
|
+
intersects=intersects,
|
|
102
|
+
date_time=date_time,
|
|
103
|
+
ids=ids,
|
|
104
|
+
limit=PLANETARY_COMPUTER_LIMIT,
|
|
105
|
+
query=combined_query if combined_query else None,
|
|
106
|
+
sortby=[{"field": "id", "direction": "asc"}],
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
all_items.extend(stac_items)
|
|
110
|
+
|
|
111
|
+
# If we got fewer than the limit, we've fetched everything.
|
|
112
|
+
if len(stac_items) < PLANETARY_COMPUTER_LIMIT:
|
|
113
|
+
break
|
|
114
|
+
|
|
115
|
+
# Otherwise, paginate using the last item's ID.
|
|
116
|
+
last_id = stac_items[-1].id
|
|
117
|
+
logger.debug(
|
|
118
|
+
"Got %d items, paginating with id > %s",
|
|
119
|
+
len(stac_items),
|
|
120
|
+
last_id,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
logger.debug("Total items fetched: %d", len(all_items))
|
|
124
|
+
return all_items
|
|
125
|
+
|
|
32
126
|
|
|
33
127
|
class PlanetaryComputer(StacDataSource, TileStore):
|
|
34
128
|
"""Modality-agnostic data source for data on Microsoft Planetary Computer.
|
|
@@ -100,6 +194,10 @@ class PlanetaryComputer(StacDataSource, TileStore):
|
|
|
100
194
|
required_assets=required_assets,
|
|
101
195
|
cache_dir=cache_upath,
|
|
102
196
|
)
|
|
197
|
+
|
|
198
|
+
# Replace the client with PlanetaryComputerStacClient to handle PC's pagination limits.
|
|
199
|
+
self.client = PlanetaryComputerStacClient(self.STAC_ENDPOINT)
|
|
200
|
+
|
|
103
201
|
self.asset_bands = asset_bands
|
|
104
202
|
self.timeout = timeout
|
|
105
203
|
self.skip_items_missing_assets = skip_items_missing_assets
|
|
@@ -567,3 +665,53 @@ class Naip(PlanetaryComputer):
|
|
|
567
665
|
context=context,
|
|
568
666
|
**kwargs,
|
|
569
667
|
)
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
class CopDemGlo30(PlanetaryComputer):
|
|
671
|
+
"""A data source for Copernicus DEM GLO-30 (30m) on Microsoft Planetary Computer.
|
|
672
|
+
|
|
673
|
+
See https://planetarycomputer.microsoft.com/dataset/cop-dem-glo-30.
|
|
674
|
+
"""
|
|
675
|
+
|
|
676
|
+
COLLECTION_NAME = "cop-dem-glo-30"
|
|
677
|
+
DATA_ASSET = "data"
|
|
678
|
+
|
|
679
|
+
def __init__(
|
|
680
|
+
self,
|
|
681
|
+
band_name: str = "DEM",
|
|
682
|
+
context: DataSourceContext = DataSourceContext(),
|
|
683
|
+
**kwargs: Any,
|
|
684
|
+
):
|
|
685
|
+
"""Initialize a new CopDemGlo30 instance.
|
|
686
|
+
|
|
687
|
+
Args:
|
|
688
|
+
band_name: band name to use if the layer config is missing from the
|
|
689
|
+
context.
|
|
690
|
+
context: the data source context.
|
|
691
|
+
kwargs: additional arguments to pass to PlanetaryComputer.
|
|
692
|
+
"""
|
|
693
|
+
if context.layer_config is not None:
|
|
694
|
+
if len(context.layer_config.band_sets) != 1:
|
|
695
|
+
raise ValueError("expected a single band set")
|
|
696
|
+
if len(context.layer_config.band_sets[0].bands) != 1:
|
|
697
|
+
raise ValueError("expected band set to have a single band")
|
|
698
|
+
band_name = context.layer_config.band_sets[0].bands[0]
|
|
699
|
+
|
|
700
|
+
super().__init__(
|
|
701
|
+
collection_name=self.COLLECTION_NAME,
|
|
702
|
+
asset_bands={self.DATA_ASSET: [band_name]},
|
|
703
|
+
# Skip since all items should have the same asset(s).
|
|
704
|
+
skip_items_missing_assets=True,
|
|
705
|
+
context=context,
|
|
706
|
+
**kwargs,
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
def _stac_item_to_item(self, stac_item: Any) -> SourceItem:
|
|
710
|
+
# Copernicus DEM is static; ignore item timestamps so it matches any window.
|
|
711
|
+
item = super()._stac_item_to_item(stac_item)
|
|
712
|
+
item.geometry = STGeometry(item.geometry.projection, item.geometry.shp, None)
|
|
713
|
+
return item
|
|
714
|
+
|
|
715
|
+
def _get_search_time_range(self, geometry: STGeometry) -> None:
|
|
716
|
+
# Copernicus DEM is static; do not filter STAC searches by time.
|
|
717
|
+
return None
|
rslearn/data_sources/stac.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""A partial data source implementation providing get_items using a STAC API."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
from datetime import datetime
|
|
4
5
|
from typing import Any
|
|
5
6
|
|
|
6
7
|
import shapely
|
|
@@ -11,6 +12,7 @@ from rslearn.const import WGS84_PROJECTION
|
|
|
11
12
|
from rslearn.data_sources.data_source import Item, ItemLookupDataSource
|
|
12
13
|
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
13
14
|
from rslearn.log_utils import get_logger
|
|
15
|
+
from rslearn.utils.fsspec import open_atomic
|
|
14
16
|
from rslearn.utils.geometry import STGeometry
|
|
15
17
|
from rslearn.utils.stac import StacClient, StacItem
|
|
16
18
|
|
|
@@ -132,6 +134,24 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
|
|
|
132
134
|
|
|
133
135
|
return SourceItem(stac_item.id, geom, asset_urls, properties)
|
|
134
136
|
|
|
137
|
+
def _get_search_time_range(
|
|
138
|
+
self, geometry: STGeometry
|
|
139
|
+
) -> datetime | tuple[datetime, datetime] | None:
|
|
140
|
+
"""Get time range to include in STAC API search.
|
|
141
|
+
|
|
142
|
+
By default, we filter STAC searches to the window's time range. Subclasses can
|
|
143
|
+
override this to disable time filtering for "static" datasets.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
geometry: the geometry we are searching for.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
the time range (or timestamp) to pass to the STAC search, or None to avoid
|
|
150
|
+
temporal filtering in the search request.
|
|
151
|
+
"""
|
|
152
|
+
# Note: StacClient.search accepts either a datetime or a (start, end) tuple.
|
|
153
|
+
return geometry.time_range
|
|
154
|
+
|
|
135
155
|
def get_item_by_name(self, name: str) -> SourceItem:
|
|
136
156
|
"""Gets an item by name.
|
|
137
157
|
|
|
@@ -168,7 +188,7 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
|
|
|
168
188
|
|
|
169
189
|
# Finally we cache it if cache_dir is set.
|
|
170
190
|
if cache_fname is not None:
|
|
171
|
-
with cache_fname
|
|
191
|
+
with open_atomic(cache_fname, "w") as f:
|
|
172
192
|
json.dump(item.serialize(), f)
|
|
173
193
|
|
|
174
194
|
return item
|
|
@@ -191,10 +211,11 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
|
|
|
191
211
|
# for each requested geometry.
|
|
192
212
|
wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
|
|
193
213
|
logger.debug("performing STAC search for geometry %s", wgs84_geometry)
|
|
214
|
+
search_time_range = self._get_search_time_range(wgs84_geometry)
|
|
194
215
|
stac_items = self.client.search(
|
|
195
216
|
collections=[self.collection_name],
|
|
196
217
|
intersects=json.loads(shapely.to_geojson(wgs84_geometry.shp)),
|
|
197
|
-
date_time=
|
|
218
|
+
date_time=search_time_range,
|
|
198
219
|
query=self.query,
|
|
199
220
|
limit=self.limit,
|
|
200
221
|
)
|
|
@@ -239,7 +260,7 @@ class StacDataSource(ItemLookupDataSource[SourceItem]):
|
|
|
239
260
|
cache_fname = self.cache_dir / f"{item.name}.json"
|
|
240
261
|
if cache_fname.exists():
|
|
241
262
|
continue
|
|
242
|
-
with cache_fname
|
|
263
|
+
with open_atomic(cache_fname, "w") as f:
|
|
243
264
|
json.dump(item.serialize(), f)
|
|
244
265
|
|
|
245
266
|
cur_groups = match_candidate_items_to_window(
|
rslearn/main.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import argparse
|
|
4
4
|
import multiprocessing
|
|
5
|
+
import os
|
|
5
6
|
import random
|
|
6
7
|
import sys
|
|
7
8
|
import time
|
|
@@ -45,6 +46,7 @@ handler_registry = {}
|
|
|
45
46
|
ItemType = TypeVar("ItemType", bound="Item")
|
|
46
47
|
|
|
47
48
|
MULTIPROCESSING_CONTEXT = "forkserver"
|
|
49
|
+
MP_CONTEXT_ENV_VAR = "RSLEARN_MULTIPROCESSING_CONTEXT"
|
|
48
50
|
|
|
49
51
|
|
|
50
52
|
def register_handler(category: Any, command: str) -> Callable:
|
|
@@ -837,7 +839,8 @@ def model_predict() -> None:
|
|
|
837
839
|
def main() -> None:
|
|
838
840
|
"""CLI entrypoint."""
|
|
839
841
|
try:
|
|
840
|
-
|
|
842
|
+
mp_context = os.environ.get(MP_CONTEXT_ENV_VAR, MULTIPROCESSING_CONTEXT)
|
|
843
|
+
multiprocessing.set_start_method(mp_context)
|
|
841
844
|
except RuntimeError as e:
|
|
842
845
|
logger.error(
|
|
843
846
|
f"Multiprocessing context already set to {multiprocessing.get_context()}: "
|
|
@@ -180,7 +180,7 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
180
180
|
# want to pass 2 timesteps to the model.
|
|
181
181
|
# TODO is probably to make this behaviour clearer but lets leave it like
|
|
182
182
|
# this for now to not break things.
|
|
183
|
-
num_timesteps = images.shape[1]
|
|
183
|
+
num_timesteps = image_channels // images.shape[1]
|
|
184
184
|
batched_timesteps = images.shape[2] // num_timesteps
|
|
185
185
|
images = rearrange(
|
|
186
186
|
images,
|
|
@@ -210,11 +210,30 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
210
210
|
# Fail silently for single-dataset case, which is okay
|
|
211
211
|
pass
|
|
212
212
|
|
|
213
|
+
def on_validation_epoch_end(self) -> None:
|
|
214
|
+
"""Compute and log validation metrics at epoch end.
|
|
215
|
+
|
|
216
|
+
We manually compute and log metrics here (instead of passing the MetricCollection
|
|
217
|
+
to log_dict) because MetricCollection.compute() properly flattens dict-returning
|
|
218
|
+
metrics, while log_dict expects each metric to return a scalar tensor.
|
|
219
|
+
"""
|
|
220
|
+
metrics = self.val_metrics.compute()
|
|
221
|
+
self.log_dict(metrics)
|
|
222
|
+
self.val_metrics.reset()
|
|
223
|
+
|
|
213
224
|
def on_test_epoch_end(self) -> None:
|
|
214
|
-
"""
|
|
225
|
+
"""Compute and log test metrics at epoch end, optionally save to file.
|
|
226
|
+
|
|
227
|
+
We manually compute and log metrics here (instead of passing the MetricCollection
|
|
228
|
+
to log_dict) because MetricCollection.compute() properly flattens dict-returning
|
|
229
|
+
metrics, while log_dict expects each metric to return a scalar tensor.
|
|
230
|
+
"""
|
|
231
|
+
metrics = self.test_metrics.compute()
|
|
232
|
+
self.log_dict(metrics)
|
|
233
|
+
self.test_metrics.reset()
|
|
234
|
+
|
|
215
235
|
if self.metrics_file:
|
|
216
236
|
with open(self.metrics_file, "w") as f:
|
|
217
|
-
metrics = self.test_metrics.compute()
|
|
218
237
|
metrics_dict = {k: v.item() for k, v in metrics.items()}
|
|
219
238
|
json.dump(metrics_dict, f, indent=4)
|
|
220
239
|
logger.info(f"Saved metrics to {self.metrics_file}")
|
|
@@ -300,9 +319,6 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
300
319
|
sync_dist=True,
|
|
301
320
|
)
|
|
302
321
|
self.val_metrics.update(outputs, targets)
|
|
303
|
-
self.log_dict(
|
|
304
|
-
self.val_metrics, batch_size=batch_size, on_epoch=True, sync_dist=True
|
|
305
|
-
)
|
|
306
322
|
|
|
307
323
|
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
|
|
308
324
|
"""Compute the test loss and additional metrics.
|
|
@@ -340,9 +356,6 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
340
356
|
sync_dist=True,
|
|
341
357
|
)
|
|
342
358
|
self.test_metrics.update(outputs, targets)
|
|
343
|
-
self.log_dict(
|
|
344
|
-
self.test_metrics, batch_size=batch_size, on_epoch=True, sync_dist=True
|
|
345
|
-
)
|
|
346
359
|
|
|
347
360
|
if self.visualize_dir:
|
|
348
361
|
for inp, target, output, metadata in zip(
|
|
@@ -118,13 +118,16 @@ class MultiTask(Task):
|
|
|
118
118
|
|
|
119
119
|
def get_metrics(self) -> MetricCollection:
|
|
120
120
|
"""Get metrics for this task."""
|
|
121
|
-
metrics
|
|
121
|
+
# Flatten metrics into a single dict with task_name/ prefix to avoid nested
|
|
122
|
+
# MetricCollections. Nested collections cause issues because MetricCollection
|
|
123
|
+
# has postfix=None which breaks MetricCollection.compute().
|
|
124
|
+
all_metrics = {}
|
|
122
125
|
for task_name, task in self.tasks.items():
|
|
123
|
-
cur_metrics = {}
|
|
124
126
|
for metric_name, metric in task.get_metrics().items():
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
127
|
+
all_metrics[f"{task_name}/{metric_name}"] = MetricWrapper(
|
|
128
|
+
task_name, metric
|
|
129
|
+
)
|
|
130
|
+
return MetricCollection(all_metrics)
|
|
128
131
|
|
|
129
132
|
|
|
130
133
|
class MetricWrapper(Metric):
|
|
@@ -100,7 +100,7 @@ class PerPixelRegressionTask(BasicTask):
|
|
|
100
100
|
raise ValueError(
|
|
101
101
|
f"PerPixelRegressionTask output must be an HW tensor, but got shape {raw_output.shape}"
|
|
102
102
|
)
|
|
103
|
-
return (raw_output / self.scale_factor).cpu().numpy()
|
|
103
|
+
return (raw_output[None, :, :] / self.scale_factor).cpu().numpy()
|
|
104
104
|
|
|
105
105
|
def visualize(
|
|
106
106
|
self,
|
|
@@ -53,11 +53,14 @@ class SegmentationTask(BasicTask):
|
|
|
53
53
|
enable_accuracy_metric: bool = True,
|
|
54
54
|
enable_miou_metric: bool = False,
|
|
55
55
|
enable_f1_metric: bool = False,
|
|
56
|
+
report_metric_per_class: bool = False,
|
|
56
57
|
f1_metric_thresholds: list[list[float]] = [[0.5]],
|
|
57
58
|
metric_kwargs: dict[str, Any] = {},
|
|
58
59
|
miou_metric_kwargs: dict[str, Any] = {},
|
|
59
60
|
prob_scales: list[float] | None = None,
|
|
60
61
|
other_metrics: dict[str, Metric] = {},
|
|
62
|
+
output_probs: bool = False,
|
|
63
|
+
output_class_idx: int | None = None,
|
|
61
64
|
**kwargs: Any,
|
|
62
65
|
) -> None:
|
|
63
66
|
"""Initialize a new SegmentationTask.
|
|
@@ -74,6 +77,8 @@ class SegmentationTask(BasicTask):
|
|
|
74
77
|
enable_accuracy_metric: whether to enable the accuracy metric (default
|
|
75
78
|
true).
|
|
76
79
|
enable_f1_metric: whether to enable the F1 metric (default false).
|
|
80
|
+
report_metric_per_class: whether to report chosen metrics for each class, in
|
|
81
|
+
addition to the average score across classes.
|
|
77
82
|
enable_miou_metric: whether to enable the mean IoU metric (default false).
|
|
78
83
|
f1_metric_thresholds: list of list of thresholds to apply for F1 metric.
|
|
79
84
|
Each inner list is used to initialize a separate F1 metric where the
|
|
@@ -89,6 +94,10 @@ class SegmentationTask(BasicTask):
|
|
|
89
94
|
this is only applied during prediction, not when computing val or test
|
|
90
95
|
metrics.
|
|
91
96
|
other_metrics: additional metrics to configure on this task.
|
|
97
|
+
output_probs: if True, output raw softmax probabilities instead of class IDs
|
|
98
|
+
during prediction.
|
|
99
|
+
output_class_idx: if set along with output_probs, only output the probability
|
|
100
|
+
for this specific class index (single-channel output).
|
|
92
101
|
kwargs: additional arguments to pass to BasicTask
|
|
93
102
|
"""
|
|
94
103
|
super().__init__(**kwargs)
|
|
@@ -107,11 +116,14 @@ class SegmentationTask(BasicTask):
|
|
|
107
116
|
self.enable_accuracy_metric = enable_accuracy_metric
|
|
108
117
|
self.enable_f1_metric = enable_f1_metric
|
|
109
118
|
self.enable_miou_metric = enable_miou_metric
|
|
119
|
+
self.report_metric_per_class = report_metric_per_class
|
|
110
120
|
self.f1_metric_thresholds = f1_metric_thresholds
|
|
111
121
|
self.metric_kwargs = metric_kwargs
|
|
112
122
|
self.miou_metric_kwargs = miou_metric_kwargs
|
|
113
123
|
self.prob_scales = prob_scales
|
|
114
124
|
self.other_metrics = other_metrics
|
|
125
|
+
self.output_probs = output_probs
|
|
126
|
+
self.output_class_idx = output_class_idx
|
|
115
127
|
|
|
116
128
|
def process_inputs(
|
|
117
129
|
self,
|
|
@@ -167,7 +179,9 @@ class SegmentationTask(BasicTask):
|
|
|
167
179
|
metadata: metadata about the patch being read
|
|
168
180
|
|
|
169
181
|
Returns:
|
|
170
|
-
CHW numpy array
|
|
182
|
+
CHW numpy array. If output_probs is False, returns one channel with
|
|
183
|
+
predicted class IDs. If output_probs is True, returns softmax probabilities
|
|
184
|
+
(num_classes channels, or 1 channel if output_class_idx is set).
|
|
171
185
|
"""
|
|
172
186
|
if not isinstance(raw_output, torch.Tensor) or len(raw_output.shape) != 3:
|
|
173
187
|
raise ValueError("the output for SegmentationTask must be a CHW tensor")
|
|
@@ -179,6 +193,15 @@ class SegmentationTask(BasicTask):
|
|
|
179
193
|
self.prob_scales, device=raw_output.device, dtype=raw_output.dtype
|
|
180
194
|
)[:, None, None]
|
|
181
195
|
)
|
|
196
|
+
|
|
197
|
+
if self.output_probs:
|
|
198
|
+
# Return raw softmax probabilities
|
|
199
|
+
probs = raw_output.cpu().numpy()
|
|
200
|
+
if self.output_class_idx is not None:
|
|
201
|
+
# Return only the specified class probability
|
|
202
|
+
return probs[self.output_class_idx : self.output_class_idx + 1, :, :]
|
|
203
|
+
return probs
|
|
204
|
+
|
|
182
205
|
classes = raw_output.argmax(dim=0).cpu().numpy()
|
|
183
206
|
return classes[None, :, :]
|
|
184
207
|
|
|
@@ -237,29 +260,41 @@ class SegmentationTask(BasicTask):
|
|
|
237
260
|
# Metric name can't contain "." so change to ",".
|
|
238
261
|
suffix = "_" + str(thresholds[0]).replace(".", ",")
|
|
239
262
|
|
|
263
|
+
# Create one metric per type - it returns a dict with "avg" and optionally per-class keys
|
|
240
264
|
metrics["F1" + suffix] = SegmentationMetric(
|
|
241
|
-
F1Metric(
|
|
265
|
+
F1Metric(
|
|
266
|
+
num_classes=self.num_classes,
|
|
267
|
+
score_thresholds=thresholds,
|
|
268
|
+
report_per_class=self.report_metric_per_class,
|
|
269
|
+
),
|
|
242
270
|
)
|
|
243
271
|
metrics["precision" + suffix] = SegmentationMetric(
|
|
244
272
|
F1Metric(
|
|
245
273
|
num_classes=self.num_classes,
|
|
246
274
|
score_thresholds=thresholds,
|
|
247
275
|
metric_mode="precision",
|
|
248
|
-
|
|
276
|
+
report_per_class=self.report_metric_per_class,
|
|
277
|
+
),
|
|
249
278
|
)
|
|
250
279
|
metrics["recall" + suffix] = SegmentationMetric(
|
|
251
280
|
F1Metric(
|
|
252
281
|
num_classes=self.num_classes,
|
|
253
282
|
score_thresholds=thresholds,
|
|
254
283
|
metric_mode="recall",
|
|
255
|
-
|
|
284
|
+
report_per_class=self.report_metric_per_class,
|
|
285
|
+
),
|
|
256
286
|
)
|
|
257
287
|
|
|
258
288
|
if self.enable_miou_metric:
|
|
259
|
-
miou_metric_kwargs: dict[str, Any] = dict(
|
|
289
|
+
miou_metric_kwargs: dict[str, Any] = dict(
|
|
290
|
+
num_classes=self.num_classes,
|
|
291
|
+
report_per_class=self.report_metric_per_class,
|
|
292
|
+
)
|
|
260
293
|
if self.nodata_value is not None:
|
|
261
294
|
miou_metric_kwargs["nodata_value"] = self.nodata_value
|
|
262
295
|
miou_metric_kwargs.update(self.miou_metric_kwargs)
|
|
296
|
+
|
|
297
|
+
# Create one metric - it returns a dict with "avg" and optionally per-class keys
|
|
263
298
|
metrics["mean_iou"] = SegmentationMetric(
|
|
264
299
|
MeanIoUMetric(**miou_metric_kwargs),
|
|
265
300
|
pass_probabilities=False,
|
|
@@ -274,6 +309,20 @@ class SegmentationTask(BasicTask):
|
|
|
274
309
|
class SegmentationHead(Predictor):
|
|
275
310
|
"""Head for segmentation task."""
|
|
276
311
|
|
|
312
|
+
def __init__(self, weights: list[float] | None = None, dice_loss: bool = False):
|
|
313
|
+
"""Initialize a new SegmentationTask.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
weights: weights for cross entropy loss (Tensor of size C)
|
|
317
|
+
dice_loss: weather to add dice loss to cross entropy
|
|
318
|
+
"""
|
|
319
|
+
super().__init__()
|
|
320
|
+
if weights is not None:
|
|
321
|
+
self.register_buffer("weights", torch.Tensor(weights))
|
|
322
|
+
else:
|
|
323
|
+
self.weights = None
|
|
324
|
+
self.dice_loss = dice_loss
|
|
325
|
+
|
|
277
326
|
def forward(
|
|
278
327
|
self,
|
|
279
328
|
intermediates: Any,
|
|
@@ -308,7 +357,7 @@ class SegmentationHead(Predictor):
|
|
|
308
357
|
labels = torch.stack([target["classes"] for target in targets], dim=0)
|
|
309
358
|
mask = torch.stack([target["valid"] for target in targets], dim=0)
|
|
310
359
|
per_pixel_loss = torch.nn.functional.cross_entropy(
|
|
311
|
-
logits, labels, reduction="none"
|
|
360
|
+
logits, labels, weight=self.weights, reduction="none"
|
|
312
361
|
)
|
|
313
362
|
mask_sum = torch.sum(mask)
|
|
314
363
|
if mask_sum > 0:
|
|
@@ -318,6 +367,9 @@ class SegmentationHead(Predictor):
|
|
|
318
367
|
# If there are no valid pixels, we avoid dividing by zero and just let
|
|
319
368
|
# the summed mask loss be zero.
|
|
320
369
|
losses["cls"] = torch.sum(per_pixel_loss * mask)
|
|
370
|
+
if self.dice_loss:
|
|
371
|
+
dice_loss = DiceLoss()(outputs, labels, mask)
|
|
372
|
+
losses["dice"] = dice_loss
|
|
321
373
|
|
|
322
374
|
return ModelOutput(
|
|
323
375
|
outputs=outputs,
|
|
@@ -333,6 +385,7 @@ class SegmentationMetric(Metric):
|
|
|
333
385
|
metric: Metric,
|
|
334
386
|
pass_probabilities: bool = True,
|
|
335
387
|
class_idx: int | None = None,
|
|
388
|
+
output_key: str | None = None,
|
|
336
389
|
):
|
|
337
390
|
"""Initialize a new SegmentationMetric.
|
|
338
391
|
|
|
@@ -341,12 +394,19 @@ class SegmentationMetric(Metric):
|
|
|
341
394
|
classes from the targets and masking out invalid pixels.
|
|
342
395
|
pass_probabilities: whether to pass predicted probabilities to the metric.
|
|
343
396
|
If False, argmax is applied to pass the predicted classes instead.
|
|
344
|
-
class_idx: if
|
|
397
|
+
class_idx: if set, return only this class index's value. For backward
|
|
398
|
+
compatibility with configs using standard torchmetrics. Internally
|
|
399
|
+
converted to output_key="cls_{class_idx}".
|
|
400
|
+
output_key: if the wrapped metric returns a dict (or a tensor that gets
|
|
401
|
+
converted to a dict), return only this key's value. For standard
|
|
402
|
+
torchmetrics with average=None, tensors are converted to dicts with
|
|
403
|
+
keys "cls_0", "cls_1", etc. If None, the full dict is returned.
|
|
345
404
|
"""
|
|
346
405
|
super().__init__()
|
|
347
406
|
self.metric = metric
|
|
348
407
|
self.pass_probablities = pass_probabilities
|
|
349
408
|
self.class_idx = class_idx
|
|
409
|
+
self.output_key = output_key
|
|
350
410
|
|
|
351
411
|
def update(
|
|
352
412
|
self, preds: list[Any] | torch.Tensor, targets: list[dict[str, Any]]
|
|
@@ -376,10 +436,32 @@ class SegmentationMetric(Metric):
|
|
|
376
436
|
self.metric.update(preds, labels)
|
|
377
437
|
|
|
378
438
|
def compute(self) -> Any:
|
|
379
|
-
"""Returns the computed metric.
|
|
439
|
+
"""Returns the computed metric.
|
|
440
|
+
|
|
441
|
+
If the wrapped metric returns a multi-element tensor (e.g., standard torchmetrics
|
|
442
|
+
with average=None), it is converted to a dict with keys like "cls_0", "cls_1", etc.
|
|
443
|
+
This allows uniform handling via output_key for both standard torchmetrics and
|
|
444
|
+
custom dict-returning metrics.
|
|
445
|
+
"""
|
|
380
446
|
result = self.metric.compute()
|
|
447
|
+
|
|
448
|
+
# Convert multi-element tensors to dict for uniform handling.
|
|
449
|
+
# This supports standard torchmetrics with average=None which return per-class tensors.
|
|
450
|
+
if isinstance(result, torch.Tensor) and result.ndim >= 1:
|
|
451
|
+
result = {f"cls_{i}": result[i] for i in range(len(result))}
|
|
452
|
+
|
|
453
|
+
if self.output_key is not None:
|
|
454
|
+
if not isinstance(result, dict):
|
|
455
|
+
raise TypeError(
|
|
456
|
+
f"output_key is set to '{self.output_key}' but metric returned "
|
|
457
|
+
f"{type(result).__name__} instead of dict"
|
|
458
|
+
)
|
|
459
|
+
return result[self.output_key]
|
|
381
460
|
if self.class_idx is not None:
|
|
382
|
-
|
|
461
|
+
# For backward compatibility: class_idx can index into the converted dict
|
|
462
|
+
if isinstance(result, dict):
|
|
463
|
+
return result[f"cls_{self.class_idx}"]
|
|
464
|
+
return result[self.class_idx]
|
|
383
465
|
return result
|
|
384
466
|
|
|
385
467
|
def reset(self) -> None:
|
|
@@ -404,6 +486,7 @@ class F1Metric(Metric):
|
|
|
404
486
|
num_classes: int,
|
|
405
487
|
score_thresholds: list[float],
|
|
406
488
|
metric_mode: str = "f1",
|
|
489
|
+
report_per_class: bool = False,
|
|
407
490
|
):
|
|
408
491
|
"""Create a new F1Metric.
|
|
409
492
|
|
|
@@ -413,11 +496,14 @@ class F1Metric(Metric):
|
|
|
413
496
|
metric is the best F1 across score thresholds.
|
|
414
497
|
metric_mode: set to "precision" or "recall" to return that instead of F1
|
|
415
498
|
(default "f1")
|
|
499
|
+
report_per_class: whether to include per-class scores in the output dict.
|
|
500
|
+
If False, only returns the "avg" key.
|
|
416
501
|
"""
|
|
417
502
|
super().__init__()
|
|
418
503
|
self.num_classes = num_classes
|
|
419
504
|
self.score_thresholds = score_thresholds
|
|
420
505
|
self.metric_mode = metric_mode
|
|
506
|
+
self.report_per_class = report_per_class
|
|
421
507
|
|
|
422
508
|
assert self.metric_mode in ["f1", "precision", "recall"]
|
|
423
509
|
|
|
@@ -462,9 +548,10 @@ class F1Metric(Metric):
|
|
|
462
548
|
"""Compute metric.
|
|
463
549
|
|
|
464
550
|
Returns:
|
|
465
|
-
|
|
551
|
+
dict with "avg" key containing mean score across classes.
|
|
552
|
+
If report_per_class is True, also includes "cls_N" keys for each class N.
|
|
466
553
|
"""
|
|
467
|
-
|
|
554
|
+
cls_best_scores = {}
|
|
468
555
|
|
|
469
556
|
for cls_idx in range(self.num_classes):
|
|
470
557
|
best_score = None
|
|
@@ -501,9 +588,12 @@ class F1Metric(Metric):
|
|
|
501
588
|
if best_score is None or score > best_score:
|
|
502
589
|
best_score = score
|
|
503
590
|
|
|
504
|
-
|
|
591
|
+
cls_best_scores[f"cls_{cls_idx}"] = best_score
|
|
505
592
|
|
|
506
|
-
|
|
593
|
+
report_scores = {"avg": torch.mean(torch.stack(list(cls_best_scores.values())))}
|
|
594
|
+
if self.report_per_class:
|
|
595
|
+
report_scores.update(cls_best_scores)
|
|
596
|
+
return report_scores
|
|
507
597
|
|
|
508
598
|
|
|
509
599
|
class MeanIoUMetric(Metric):
|
|
@@ -523,7 +613,7 @@ class MeanIoUMetric(Metric):
|
|
|
523
613
|
num_classes: int,
|
|
524
614
|
nodata_value: int | None = None,
|
|
525
615
|
ignore_missing_classes: bool = False,
|
|
526
|
-
|
|
616
|
+
report_per_class: bool = False,
|
|
527
617
|
):
|
|
528
618
|
"""Create a new MeanIoUMetric.
|
|
529
619
|
|
|
@@ -535,15 +625,14 @@ class MeanIoUMetric(Metric):
|
|
|
535
625
|
ignore_missing_classes: whether to ignore classes that don't appear in
|
|
536
626
|
either the predictions or the ground truth. If false, the IoU for a
|
|
537
627
|
missing class will be 0.
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
only supports scalar return values from metrics.
|
|
628
|
+
report_per_class: whether to include per-class IoU scores in the output dict.
|
|
629
|
+
If False, only returns the "avg" key.
|
|
541
630
|
"""
|
|
542
631
|
super().__init__()
|
|
543
632
|
self.num_classes = num_classes
|
|
544
633
|
self.nodata_value = nodata_value
|
|
545
634
|
self.ignore_missing_classes = ignore_missing_classes
|
|
546
|
-
self.
|
|
635
|
+
self.report_per_class = report_per_class
|
|
547
636
|
|
|
548
637
|
self.add_state(
|
|
549
638
|
"intersections", default=torch.zeros(self.num_classes), dist_reduce_fx="sum"
|
|
@@ -584,9 +673,11 @@ class MeanIoUMetric(Metric):
|
|
|
584
673
|
"""Compute metric.
|
|
585
674
|
|
|
586
675
|
Returns:
|
|
587
|
-
the mean IoU across classes.
|
|
676
|
+
dict with "avg" containing the mean IoU across classes.
|
|
677
|
+
If report_per_class is True, also includes "cls_N" keys for each valid class N.
|
|
588
678
|
"""
|
|
589
|
-
|
|
679
|
+
cls_scores = {}
|
|
680
|
+
valid_scores = []
|
|
590
681
|
|
|
591
682
|
for cls_idx in range(self.num_classes):
|
|
592
683
|
# Check if nodata_value is set and is one of the classes
|
|
@@ -599,6 +690,56 @@ class MeanIoUMetric(Metric):
|
|
|
599
690
|
if union == 0 and self.ignore_missing_classes:
|
|
600
691
|
continue
|
|
601
692
|
|
|
602
|
-
|
|
693
|
+
score = intersection / union
|
|
694
|
+
cls_scores[f"cls_{cls_idx}"] = score
|
|
695
|
+
valid_scores.append(score)
|
|
696
|
+
|
|
697
|
+
report_scores = {"avg": torch.mean(torch.stack(valid_scores))}
|
|
698
|
+
if self.report_per_class:
|
|
699
|
+
report_scores.update(cls_scores)
|
|
700
|
+
return report_scores
|
|
701
|
+
|
|
702
|
+
|
|
703
|
+
class DiceLoss(torch.nn.Module):
|
|
704
|
+
"""Mean Dice Loss for segmentation.
|
|
705
|
+
|
|
706
|
+
This is the mean of the per-class dice loss (1 - 2*intersection / union scores).
|
|
707
|
+
The per-class intersection is the number of pixels across all examples where
|
|
708
|
+
the predicted label and ground truth label are both that class, and the per-class
|
|
709
|
+
union is defined similarly.
|
|
710
|
+
"""
|
|
711
|
+
|
|
712
|
+
def __init__(self, smooth: float = 1e-7):
|
|
713
|
+
"""Initialize a new DiceLoss."""
|
|
714
|
+
super().__init__()
|
|
715
|
+
self.smooth = smooth
|
|
716
|
+
|
|
717
|
+
def forward(
|
|
718
|
+
self, inputs: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
|
|
719
|
+
) -> torch.Tensor:
|
|
720
|
+
"""Compute Dice Loss.
|
|
721
|
+
|
|
722
|
+
Returns:
|
|
723
|
+
the mean Dicen Loss across classes
|
|
724
|
+
"""
|
|
725
|
+
num_classes = inputs.shape[1]
|
|
726
|
+
targets_one_hot = (
|
|
727
|
+
torch.nn.functional.one_hot(targets, num_classes)
|
|
728
|
+
.permute(0, 3, 1, 2)
|
|
729
|
+
.float()
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
# Expand mask to [B, C, H, W]
|
|
733
|
+
mask = mask.unsqueeze(1).expand_as(inputs)
|
|
734
|
+
|
|
735
|
+
dice_per_class = []
|
|
736
|
+
for c in range(num_classes):
|
|
737
|
+
pred_c = inputs[:, c] * mask[:, c]
|
|
738
|
+
target_c = targets_one_hot[:, c] * mask[:, c]
|
|
739
|
+
|
|
740
|
+
intersection = (pred_c * target_c).sum()
|
|
741
|
+
union = pred_c.sum() + target_c.sum()
|
|
742
|
+
dice_c = (2.0 * intersection + self.smooth) / (union + self.smooth)
|
|
743
|
+
dice_per_class.append(dice_c)
|
|
603
744
|
|
|
604
|
-
return torch.
|
|
745
|
+
return 1 - torch.stack(dice_per_class).mean()
|
rslearn/utils/raster_format.py
CHANGED
|
@@ -476,6 +476,7 @@ class GeotiffRasterFormat(RasterFormat):
|
|
|
476
476
|
bounds: PixelBounds,
|
|
477
477
|
array: npt.NDArray[Any],
|
|
478
478
|
fname: str | None = None,
|
|
479
|
+
nodata_val: int | float | None = None,
|
|
479
480
|
) -> None:
|
|
480
481
|
"""Encodes raster data.
|
|
481
482
|
|
|
@@ -485,6 +486,7 @@ class GeotiffRasterFormat(RasterFormat):
|
|
|
485
486
|
bounds: the bounds of the raster data in the projection
|
|
486
487
|
array: the raster data
|
|
487
488
|
fname: override the filename to save as
|
|
489
|
+
nodata_val: set the nodata value when writing the raster.
|
|
488
490
|
"""
|
|
489
491
|
if fname is None:
|
|
490
492
|
fname = self.fname
|
|
@@ -520,6 +522,9 @@ class GeotiffRasterFormat(RasterFormat):
|
|
|
520
522
|
profile["tiled"] = True
|
|
521
523
|
profile["blockxsize"] = self.block_size
|
|
522
524
|
profile["blockysize"] = self.block_size
|
|
525
|
+
# Set nodata_val if provided.
|
|
526
|
+
if nodata_val is not None:
|
|
527
|
+
profile["nodata"] = nodata_val
|
|
523
528
|
|
|
524
529
|
profile.update(self.geotiff_options)
|
|
525
530
|
|
|
@@ -535,6 +540,7 @@ class GeotiffRasterFormat(RasterFormat):
|
|
|
535
540
|
bounds: PixelBounds,
|
|
536
541
|
resampling: Resampling = Resampling.bilinear,
|
|
537
542
|
fname: str | None = None,
|
|
543
|
+
nodata_val: int | float | None = None,
|
|
538
544
|
) -> npt.NDArray[Any]:
|
|
539
545
|
"""Decodes raster data.
|
|
540
546
|
|
|
@@ -544,6 +550,16 @@ class GeotiffRasterFormat(RasterFormat):
|
|
|
544
550
|
bounds: the bounds to read in the given projection.
|
|
545
551
|
resampling: resampling method to use in case resampling is needed.
|
|
546
552
|
fname: override the filename to read from
|
|
553
|
+
nodata_val: override the nodata value in the raster when reading. Pixels in
|
|
554
|
+
bounds that are not present in the source raster will be initialized to
|
|
555
|
+
this value. Note that, if the raster specifies a nodata value, and
|
|
556
|
+
some source pixels have that value, they will still be read under their
|
|
557
|
+
original value; overriding the nodata value is primarily useful if the
|
|
558
|
+
user wants out of bounds pixels to have a different value from the
|
|
559
|
+
source pixels, e.g. if the source data has background and foreground
|
|
560
|
+
classes (with background being nodata) but we want to read it in a
|
|
561
|
+
different projection and have out of bounds pixels be a third "invalid"
|
|
562
|
+
value.
|
|
547
563
|
|
|
548
564
|
Returns:
|
|
549
565
|
the raster data
|
|
@@ -561,6 +577,7 @@ class GeotiffRasterFormat(RasterFormat):
|
|
|
561
577
|
width=bounds[2] - bounds[0],
|
|
562
578
|
height=bounds[3] - bounds[1],
|
|
563
579
|
resampling=resampling,
|
|
580
|
+
src_nodata=nodata_val,
|
|
564
581
|
) as vrt:
|
|
565
582
|
return vrt.read()
|
|
566
583
|
|
rslearn/utils/stac.py
CHANGED
|
@@ -101,6 +101,7 @@ class StacClient:
|
|
|
101
101
|
ids: list[str] | None = None,
|
|
102
102
|
limit: int | None = None,
|
|
103
103
|
query: dict[str, Any] | None = None,
|
|
104
|
+
sortby: list[dict[str, str]] | None = None,
|
|
104
105
|
) -> list[StacItem]:
|
|
105
106
|
"""Execute a STAC item search.
|
|
106
107
|
|
|
@@ -117,6 +118,7 @@ class StacClient:
|
|
|
117
118
|
limit: number of items per page. We will read all the pages.
|
|
118
119
|
query: query dict, if STAC query extension is supported by this API. See
|
|
119
120
|
https://github.com/stac-api-extensions/query.
|
|
121
|
+
sortby: list of sort specifications, e.g. [{"field": "id", "direction": "asc"}].
|
|
120
122
|
|
|
121
123
|
Returns:
|
|
122
124
|
list of matching STAC items.
|
|
@@ -142,6 +144,8 @@ class StacClient:
|
|
|
142
144
|
request_data["limit"] = limit
|
|
143
145
|
if query is not None:
|
|
144
146
|
request_data["query"] = query
|
|
147
|
+
if sortby is not None:
|
|
148
|
+
request_data["sortby"] = sortby
|
|
145
149
|
|
|
146
150
|
# Handle pagination.
|
|
147
151
|
cur_url = self.endpoint + "/search"
|
|
@@ -3,7 +3,7 @@ rslearn/arg_parser.py,sha256=Go1MyEflcau_cziirmNd7Yhxa0WtXTAljIVE4f5H1GE,1194
|
|
|
3
3
|
rslearn/const.py,sha256=FUCfsvFAs-QarEDJ0grdy0C1HjUjLpNFYGo5I2Vpc5Y,449
|
|
4
4
|
rslearn/lightning_cli.py,sha256=1eTeffUlFqBe2KnyuYyXJdNKYQClCA-PV1xr0vJyJao,17972
|
|
5
5
|
rslearn/log_utils.py,sha256=unD9gShiuO7cx5Nnq8qqVQ4qrbOOwFVgcHxN5bXuiAo,941
|
|
6
|
-
rslearn/main.py,sha256=
|
|
6
|
+
rslearn/main.py,sha256=rrDEoa0xCkDflH-HN2SaHt0hb-rLfXWP-kJKISZAe9s,28714
|
|
7
7
|
rslearn/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
8
|
rslearn/template_params.py,sha256=Vop0Ha-S44ctCa9lvSZRjrMETznJZlR5y_gJrVIwrPg,791
|
|
9
9
|
rslearn/config/__init__.py,sha256=n1qpZ0ImshTtLYl5mC73BORYyUcjPJyHiyZkqUY1hiY,474
|
|
@@ -25,9 +25,9 @@ rslearn/data_sources/local_files.py,sha256=mo5W_BxBl89EPTIHNDEpXM6qBjrP225KK0Pcm
|
|
|
25
25
|
rslearn/data_sources/openstreetmap.py,sha256=TzZfouc2Z4_xjx2v_uv7aPn4tVW3flRVQN4qBfl507E,18161
|
|
26
26
|
rslearn/data_sources/planet.py,sha256=6FWQ0bl1k3jwvwp4EVGi2qs3OD1QhnKOKP36mN4HELI,9446
|
|
27
27
|
rslearn/data_sources/planet_basemap.py,sha256=e9R6FlagJjg8Z6Rc1dC6zK3xMkCohz8eohXqXmd29xg,9670
|
|
28
|
-
rslearn/data_sources/planetary_computer.py,sha256=
|
|
28
|
+
rslearn/data_sources/planetary_computer.py,sha256=k6CO5Yim2I-frlD8r2_uBo0CQFw89mN3_5mrv0Xk2WU,26449
|
|
29
29
|
rslearn/data_sources/soilgrids.py,sha256=rwO4goFPQ7lx420FvYBHYFXdihnZqn_-IjdqtxQ9j2g,12455
|
|
30
|
-
rslearn/data_sources/stac.py,sha256=
|
|
30
|
+
rslearn/data_sources/stac.py,sha256=Xty1JDueAAonNVLRo8vfNBhlHrVLjhmZ-uRBYbrGvtA,10753
|
|
31
31
|
rslearn/data_sources/usda_cdl.py,sha256=_WvxZkm0fbXfniRs6NT8iVCbTTmVPflDhsFT2ci6_Dk,6879
|
|
32
32
|
rslearn/data_sources/usgs_landsat.py,sha256=kPOb3hsZe5-guUcFZZkwzcRpYZ3Zo7Bk4E829q_xiyU,18516
|
|
33
33
|
rslearn/data_sources/utils.py,sha256=v_90ALOuts7RHNcx-j8o-aQ_aFjh8ZhXrmsaa9uEGDA,11651
|
|
@@ -69,7 +69,7 @@ rslearn/models/prithvi.py,sha256=J45eC1pd4l5AGlr19Qjrjrw5PPwvYE9bNM5qCFoznmg,403
|
|
|
69
69
|
rslearn/models/resize_features.py,sha256=U7ZIVwwToJJnwchFG59wLWWP9eikHDB_1c4OtpubxHU,1693
|
|
70
70
|
rslearn/models/sam2_enc.py,sha256=WZOtlp0FjaVztW4gEVIcsFQdKArS9iblRODP0b6Oc8M,3641
|
|
71
71
|
rslearn/models/satlaspretrain.py,sha256=2R48ulbtd44Qy2FYJCkllE2Wk35eZxkc79ruSgkmgcQ,3384
|
|
72
|
-
rslearn/models/simple_time_series.py,sha256=
|
|
72
|
+
rslearn/models/simple_time_series.py,sha256=farQwt_nJVyAbgaM2UzdyqpDuIO0SLmHr9e9EVPSWCE,14678
|
|
73
73
|
rslearn/models/singletask.py,sha256=9DM9a9-Mv3vVQqRhPOIXG2HHuVqVa_zuvgafeeYh4r0,1903
|
|
74
74
|
rslearn/models/ssl4eo_s12.py,sha256=DOlpIj6NfjIlWyJ27m9Xo8TMlovBDstFq0ARnmAJ6qY,3919
|
|
75
75
|
rslearn/models/swin.py,sha256=Xqr3SswbHP6IhwT2atZMAPF2TUzQqfxvihksb8WSeRo,6065
|
|
@@ -116,7 +116,7 @@ rslearn/train/__init__.py,sha256=fnJyY4aHs5zQqbDKSfXsJZXY_M9fbTsf7dRYaPwZr2M,30
|
|
|
116
116
|
rslearn/train/all_patches_dataset.py,sha256=EVoYCmS3g4OfWPt5CZzwHVx9isbnWh5HIGA0RBqPFeA,21145
|
|
117
117
|
rslearn/train/data_module.py,sha256=pgut8rEWHIieZ7RR8dUvhtlNqk0egEdznYF3tCvqdHg,23552
|
|
118
118
|
rslearn/train/dataset.py,sha256=Jy1jU3GigfHaFeX9rbveX9bqy2Pd5Wh_vquD6_aFnS8,36522
|
|
119
|
-
rslearn/train/lightning_module.py,sha256=
|
|
119
|
+
rslearn/train/lightning_module.py,sha256=V4YoEg9PrwrgG4q9Dmv_9OBrSIK-SRPzjWtZRIfmPFg,15366
|
|
120
120
|
rslearn/train/model_context.py,sha256=6o66BY6okBK-D5e0JUwPd7fxD_XehVaqxdQkJJKmQ3E,2580
|
|
121
121
|
rslearn/train/optimizer.py,sha256=EKSqkmERalDA0bF32Gey7n6z69KLyaUWKlRsGJfKBmE,927
|
|
122
122
|
rslearn/train/prediction_writer.py,sha256=rW0BUaYT_F1QqmpnQlbrLiLya1iBfC5Pb78G_NlF-vA,15956
|
|
@@ -130,10 +130,10 @@ rslearn/train/tasks/__init__.py,sha256=dag1u72x1-me6y0YcOubUo5MYZ0Tjf6-dOir9UeFN
|
|
|
130
130
|
rslearn/train/tasks/classification.py,sha256=72ZBcbunMsdPYQN53S-4GfiLIDrr1X3Hni07dBJ0pu0,14261
|
|
131
131
|
rslearn/train/tasks/detection.py,sha256=B0tfB7UGIbRtjnye3PhzLmfeQ4X7ImO3A-_LeNhBA54,21988
|
|
132
132
|
rslearn/train/tasks/embedding.py,sha256=NdJEAaDWlWYzvOBVf7eIHfFOzqTgavfFH1J1gMbAMVo,3891
|
|
133
|
-
rslearn/train/tasks/multi_task.py,sha256=
|
|
134
|
-
rslearn/train/tasks/per_pixel_regression.py,sha256=
|
|
133
|
+
rslearn/train/tasks/multi_task.py,sha256=32hvwyVsHqt7N_M3zXsTErK1K7-0-BPHzt7iGNehyaI,6314
|
|
134
|
+
rslearn/train/tasks/per_pixel_regression.py,sha256=Clrod6LQGjgNC0IAR4HLY7eCGWMHj2mk4d4moZCl4Qc,10209
|
|
135
135
|
rslearn/train/tasks/regression.py,sha256=bVS_ApZSpbL0NaaM8Mu5Bsu4SBUyLpVtrPslulvvZHs,12695
|
|
136
|
-
rslearn/train/tasks/segmentation.py,sha256=
|
|
136
|
+
rslearn/train/tasks/segmentation.py,sha256=LZeuveHhMQsjNOQfMcwqSI4Ux3k9zfa58A2eZHSif8Y,29391
|
|
137
137
|
rslearn/train/tasks/task.py,sha256=nMPunl9OlnOimr48saeTnwKMQ7Du4syGrwNKVQq4FL4,4110
|
|
138
138
|
rslearn/train/transforms/__init__.py,sha256=BkCAzm4f-8TEhPIuyvCj7eJGh36aMkZFYlq-H_jkSvY,778
|
|
139
139
|
rslearn/train/transforms/concatenate.py,sha256=hVVBaxIdk1Cx8JHPirj54TGpbWAJx5y_xD7k1rmGmT0,3166
|
|
@@ -155,17 +155,17 @@ rslearn/utils/get_utm_ups_crs.py,sha256=kUrcyjCK7KWvuP1XR-nURPeRqYeRO-3L8QUJ1QTF
|
|
|
155
155
|
rslearn/utils/grid_index.py,sha256=hRmrtgpqN1pLa-djnZtgSXqKJlbgGyttGnCEmPLD0zo,2347
|
|
156
156
|
rslearn/utils/jsonargparse.py,sha256=TRyZA151KzhjJlZczIHYguEA-YxCDYaZ2IwCRgx76nM,4791
|
|
157
157
|
rslearn/utils/mp.py,sha256=XYmVckI5TOQuCKc49NJyirDJyFgvb4AI-gGypG2j680,1399
|
|
158
|
-
rslearn/utils/raster_format.py,sha256=
|
|
158
|
+
rslearn/utils/raster_format.py,sha256=fwotJBadwqYSdK8UokiKOk1pOF8JMim3kP_VwLWivPI,27382
|
|
159
159
|
rslearn/utils/rtree_index.py,sha256=j0Zwrq3pXuAJ-hKpiRFQ7VNtvO3fZYk-Em2uBPAqfx4,6460
|
|
160
160
|
rslearn/utils/spatial_index.py,sha256=eomJAUgzmjir8j9HZnSgQoJHwN9H0wGTjmJkMkLLfsU,762
|
|
161
161
|
rslearn/utils/sqlite_index.py,sha256=YGOJi66544e6JNtfSft6YIlHklFdSJO2duxQ4TJ2iu4,2920
|
|
162
|
-
rslearn/utils/stac.py,sha256=
|
|
162
|
+
rslearn/utils/stac.py,sha256=c8NwOCKWvUwA-FSKlxZn-t7RZYweuye53OufT0bAK4A,5996
|
|
163
163
|
rslearn/utils/time.py,sha256=2ilSLG94_sxLP3y5RSV5L5CG8CoND_dbdzYEHVtN-I8,387
|
|
164
164
|
rslearn/utils/vector_format.py,sha256=4ZDYpfBLLxguJkiIaavTagiQK2Sv4Rz9NumbHlq-3Lw,15041
|
|
165
|
-
rslearn-0.0.
|
|
166
|
-
rslearn-0.0.
|
|
167
|
-
rslearn-0.0.
|
|
168
|
-
rslearn-0.0.
|
|
169
|
-
rslearn-0.0.
|
|
170
|
-
rslearn-0.0.
|
|
171
|
-
rslearn-0.0.
|
|
165
|
+
rslearn-0.0.24.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
|
|
166
|
+
rslearn-0.0.24.dist-info/licenses/NOTICE,sha256=wLPr6rwV_jCg-xEknNGwhnkfRfuoOE9MZ-lru2yZyLI,5070
|
|
167
|
+
rslearn-0.0.24.dist-info/METADATA,sha256=gV5mgeYPYiKWrEu7D8acOubWvg76Nn_4ICvlD7iTpcs,37936
|
|
168
|
+
rslearn-0.0.24.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
|
|
169
|
+
rslearn-0.0.24.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
|
|
170
|
+
rslearn-0.0.24.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
|
|
171
|
+
rslearn-0.0.24.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|