rslearn 0.0.26__py3-none-any.whl → 0.0.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/data_sources/__init__.py +2 -0
- rslearn/data_sources/aws_landsat.py +44 -161
- rslearn/data_sources/aws_open_data.py +2 -4
- rslearn/data_sources/aws_sentinel1.py +1 -3
- rslearn/data_sources/aws_sentinel2_element84.py +54 -165
- rslearn/data_sources/climate_data_store.py +1 -3
- rslearn/data_sources/copernicus.py +1 -2
- rslearn/data_sources/data_source.py +1 -1
- rslearn/data_sources/direct_materialize_data_source.py +336 -0
- rslearn/data_sources/earthdaily.py +52 -155
- rslearn/data_sources/earthdatahub.py +425 -0
- rslearn/data_sources/eurocrops.py +1 -2
- rslearn/data_sources/gcp_public_data.py +1 -2
- rslearn/data_sources/google_earth_engine.py +1 -2
- rslearn/data_sources/hf_srtm.py +595 -0
- rslearn/data_sources/local_files.py +1 -1
- rslearn/data_sources/openstreetmap.py +1 -1
- rslearn/data_sources/planet.py +1 -2
- rslearn/data_sources/planet_basemap.py +1 -2
- rslearn/data_sources/planetary_computer.py +183 -186
- rslearn/data_sources/soilgrids.py +3 -3
- rslearn/data_sources/stac.py +1 -2
- rslearn/data_sources/usda_cdl.py +1 -3
- rslearn/data_sources/usgs_landsat.py +7 -254
- rslearn/data_sources/worldcereal.py +1 -1
- rslearn/data_sources/worldcover.py +1 -1
- rslearn/data_sources/worldpop.py +1 -1
- rslearn/data_sources/xyz_tiles.py +5 -9
- rslearn/dataset/storage/file.py +16 -12
- rslearn/models/concatenate_features.py +6 -1
- rslearn/tile_stores/default.py +4 -2
- rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
- rslearn/train/data_module.py +36 -33
- rslearn/train/dataset.py +159 -68
- rslearn/train/lightning_module.py +60 -4
- rslearn/train/metrics.py +162 -0
- rslearn/train/model_context.py +3 -3
- rslearn/train/prediction_writer.py +69 -41
- rslearn/train/tasks/classification.py +14 -1
- rslearn/train/tasks/detection.py +5 -5
- rslearn/train/tasks/per_pixel_regression.py +19 -6
- rslearn/train/tasks/regression.py +19 -3
- rslearn/train/tasks/segmentation.py +17 -0
- rslearn/utils/__init__.py +2 -0
- rslearn/utils/fsspec.py +51 -1
- rslearn/utils/geometry.py +21 -0
- rslearn/utils/m2m_api.py +251 -0
- rslearn/utils/retry_session.py +43 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/METADATA +6 -3
- {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/RECORD +55 -50
- rslearn/data_sources/earthdata_srtm.py +0 -282
- {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/WHEEL +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""rslearn PredictionWriter implementation."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import warnings
|
|
4
5
|
from collections.abc import Iterable, Sequence
|
|
5
6
|
from dataclasses import dataclass
|
|
6
7
|
from pathlib import Path
|
|
@@ -39,20 +40,20 @@ logger = get_logger(__name__)
|
|
|
39
40
|
|
|
40
41
|
|
|
41
42
|
@dataclass
|
|
42
|
-
class
|
|
43
|
-
"""A
|
|
43
|
+
class PendingCropOutput:
|
|
44
|
+
"""A crop output that hasn't been merged yet."""
|
|
44
45
|
|
|
45
46
|
bounds: PixelBounds
|
|
46
47
|
output: Any
|
|
47
48
|
|
|
48
49
|
|
|
49
|
-
class
|
|
50
|
-
"""Base class for merging predictions from multiple
|
|
50
|
+
class CropPredictionMerger:
|
|
51
|
+
"""Base class for merging predictions from multiple crops."""
|
|
51
52
|
|
|
52
53
|
def merge(
|
|
53
54
|
self,
|
|
54
55
|
window: Window,
|
|
55
|
-
outputs: Sequence[
|
|
56
|
+
outputs: Sequence[PendingCropOutput],
|
|
56
57
|
layer_config: LayerConfig,
|
|
57
58
|
) -> Any:
|
|
58
59
|
"""Merge the outputs.
|
|
@@ -68,39 +69,60 @@ class PatchPredictionMerger:
|
|
|
68
69
|
raise NotImplementedError
|
|
69
70
|
|
|
70
71
|
|
|
71
|
-
class VectorMerger(
|
|
72
|
+
class VectorMerger(CropPredictionMerger):
|
|
72
73
|
"""Merger for vector data that simply concatenates the features."""
|
|
73
74
|
|
|
74
75
|
def merge(
|
|
75
76
|
self,
|
|
76
77
|
window: Window,
|
|
77
|
-
outputs: Sequence[
|
|
78
|
+
outputs: Sequence[PendingCropOutput],
|
|
78
79
|
layer_config: LayerConfig,
|
|
79
80
|
) -> list[Feature]:
|
|
80
81
|
"""Concatenate the vector features."""
|
|
81
82
|
return [feat for output in outputs for feat in output.output]
|
|
82
83
|
|
|
83
84
|
|
|
84
|
-
class RasterMerger(
|
|
85
|
+
class RasterMerger(CropPredictionMerger):
|
|
85
86
|
"""Merger for raster data that copies the rasters to the output."""
|
|
86
87
|
|
|
87
|
-
def __init__(
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
overlap_pixels: int | None = None,
|
|
91
|
+
downsample_factor: int = 1,
|
|
92
|
+
# Deprecated parameter (for backwards compatibility)
|
|
93
|
+
padding: int | None = None,
|
|
94
|
+
):
|
|
88
95
|
"""Create a new RasterMerger.
|
|
89
96
|
|
|
90
97
|
Args:
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
98
|
+
overlap_pixels: the number of pixels shared between adjacent crops during
|
|
99
|
+
sliding window inference. Half of this overlap is removed from each
|
|
100
|
+
crop during merging (except at window boundaries where the full crop
|
|
101
|
+
is retained).
|
|
94
102
|
downsample_factor: the factor by which the rasters output by the task are
|
|
95
103
|
lower in resolution relative to the window resolution.
|
|
104
|
+
padding: deprecated, use overlap_pixels instead. The old padding value
|
|
105
|
+
equals overlap_pixels // 2.
|
|
96
106
|
"""
|
|
97
|
-
|
|
107
|
+
# Handle deprecated padding parameter
|
|
108
|
+
if padding is not None:
|
|
109
|
+
warnings.warn(
|
|
110
|
+
"padding is deprecated, use overlap_pixels instead. "
|
|
111
|
+
"Note: overlap_pixels = padding * 2",
|
|
112
|
+
FutureWarning,
|
|
113
|
+
stacklevel=2,
|
|
114
|
+
)
|
|
115
|
+
if overlap_pixels is not None:
|
|
116
|
+
raise ValueError("Cannot specify both padding and overlap_pixels")
|
|
117
|
+
overlap_pixels = padding * 2
|
|
118
|
+
|
|
119
|
+
self.overlap_pixels = overlap_pixels
|
|
98
120
|
self.downsample_factor = downsample_factor
|
|
99
121
|
|
|
100
122
|
def merge(
|
|
101
123
|
self,
|
|
102
124
|
window: Window,
|
|
103
|
-
outputs: Sequence[
|
|
125
|
+
outputs: Sequence[PendingCropOutput],
|
|
104
126
|
layer_config: LayerConfig,
|
|
105
127
|
) -> npt.NDArray:
|
|
106
128
|
"""Merge the raster outputs."""
|
|
@@ -114,6 +136,12 @@ class RasterMerger(PatchPredictionMerger):
|
|
|
114
136
|
dtype=layer_config.band_sets[0].dtype.get_numpy_dtype(),
|
|
115
137
|
)
|
|
116
138
|
|
|
139
|
+
# Compute how many pixels to trim from each side.
|
|
140
|
+
# We remove half of the overlap from each side (not at window boundaries).
|
|
141
|
+
trim_pixels = (
|
|
142
|
+
self.overlap_pixels // 2 if self.overlap_pixels is not None else None
|
|
143
|
+
)
|
|
144
|
+
|
|
117
145
|
# Ensure the outputs are sorted by height then width.
|
|
118
146
|
# This way when we merge we can be sure that outputs that are lower or further
|
|
119
147
|
# to the right will overwrite earlier outputs.
|
|
@@ -123,18 +151,18 @@ class RasterMerger(PatchPredictionMerger):
|
|
|
123
151
|
for output in sorted_outputs:
|
|
124
152
|
# So now we just need to compute the src_offset to copy.
|
|
125
153
|
# If the output is not on the left or top boundary, then we should apply
|
|
126
|
-
# the
|
|
154
|
+
# the trim (if set).
|
|
127
155
|
src = output.output
|
|
128
156
|
src_offset = (
|
|
129
157
|
output.bounds[0] // self.downsample_factor,
|
|
130
158
|
output.bounds[1] // self.downsample_factor,
|
|
131
159
|
)
|
|
132
|
-
if
|
|
133
|
-
src = src[:, :,
|
|
134
|
-
src_offset = (src_offset[0] +
|
|
135
|
-
if
|
|
136
|
-
src = src[:,
|
|
137
|
-
src_offset = (src_offset[0], src_offset[1] +
|
|
160
|
+
if trim_pixels is not None and output.bounds[0] != window.bounds[0]:
|
|
161
|
+
src = src[:, :, trim_pixels:]
|
|
162
|
+
src_offset = (src_offset[0] + trim_pixels, src_offset[1])
|
|
163
|
+
if trim_pixels is not None and output.bounds[1] != window.bounds[1]:
|
|
164
|
+
src = src[:, trim_pixels:, :]
|
|
165
|
+
src_offset = (src_offset[0], src_offset[1] + trim_pixels)
|
|
138
166
|
|
|
139
167
|
copy_spatial_array(
|
|
140
168
|
src=src,
|
|
@@ -162,7 +190,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
162
190
|
output_layer: str,
|
|
163
191
|
path_options: dict[str, Any] | None = None,
|
|
164
192
|
selector: list[str] | None = None,
|
|
165
|
-
merger:
|
|
193
|
+
merger: CropPredictionMerger | None = None,
|
|
166
194
|
output_path: str | Path | None = None,
|
|
167
195
|
layer_config: LayerConfig | None = None,
|
|
168
196
|
storage_config: StorageConfig | None = None,
|
|
@@ -175,7 +203,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
175
203
|
path_options: additional options for path to pass to fsspec
|
|
176
204
|
selector: keys to access the desired output in the output dict if needed.
|
|
177
205
|
e.g ["key1", "key2"] gets output["key1"]["key2"]
|
|
178
|
-
merger: merger to use to merge outputs from overlapped
|
|
206
|
+
merger: merger to use to merge outputs from overlapped crops.
|
|
179
207
|
output_path: optional custom path for writing predictions. If provided,
|
|
180
208
|
predictions will be written to this path instead of deriving from dataset path.
|
|
181
209
|
layer_config: optional layer configuration. If provided, this config will be
|
|
@@ -217,9 +245,9 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
217
245
|
self.merger = VectorMerger()
|
|
218
246
|
|
|
219
247
|
# Map from window name to pending data to write.
|
|
220
|
-
# This is used when windows are split up into
|
|
221
|
-
#
|
|
222
|
-
self.pending_outputs: dict[str, list[
|
|
248
|
+
# This is used when windows are split up into crops, so the data from all the
|
|
249
|
+
# crops of each window need to be reconstituted.
|
|
250
|
+
self.pending_outputs: dict[str, list[PendingCropOutput]] = {}
|
|
223
251
|
|
|
224
252
|
def _get_layer_config_and_dataset_storage(
|
|
225
253
|
self,
|
|
@@ -327,7 +355,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
327
355
|
will be processed by the task to obtain a vector (list[Feature]) or
|
|
328
356
|
raster (npt.NDArray) output.
|
|
329
357
|
metadatas: corresponding list of metadatas from the batch describing the
|
|
330
|
-
|
|
358
|
+
crops that were processed.
|
|
331
359
|
"""
|
|
332
360
|
# Process the predictions into outputs that can be written.
|
|
333
361
|
outputs: list = [
|
|
@@ -349,17 +377,17 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
349
377
|
)
|
|
350
378
|
self.process_output(
|
|
351
379
|
window,
|
|
352
|
-
metadata.
|
|
353
|
-
metadata.
|
|
354
|
-
metadata.
|
|
380
|
+
metadata.crop_idx,
|
|
381
|
+
metadata.num_crops_in_window,
|
|
382
|
+
metadata.crop_bounds,
|
|
355
383
|
output,
|
|
356
384
|
)
|
|
357
385
|
|
|
358
386
|
def process_output(
|
|
359
387
|
self,
|
|
360
388
|
window: Window,
|
|
361
|
-
|
|
362
|
-
|
|
389
|
+
crop_idx: int,
|
|
390
|
+
num_crops: int,
|
|
363
391
|
cur_bounds: PixelBounds,
|
|
364
392
|
output: npt.NDArray | list[Feature],
|
|
365
393
|
) -> None:
|
|
@@ -367,28 +395,28 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
367
395
|
|
|
368
396
|
Args:
|
|
369
397
|
window: the window that the output pertains to.
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
cur_bounds: the bounds of the current
|
|
398
|
+
crop_idx: the index of this crop for the window.
|
|
399
|
+
num_crops: the total number of crops to be processed for the window.
|
|
400
|
+
cur_bounds: the bounds of the current crop.
|
|
373
401
|
output: the output data.
|
|
374
402
|
"""
|
|
375
|
-
# Incorporate the output into our list of pending
|
|
403
|
+
# Incorporate the output into our list of pending crop outputs.
|
|
376
404
|
if window.name not in self.pending_outputs:
|
|
377
405
|
self.pending_outputs[window.name] = []
|
|
378
|
-
self.pending_outputs[window.name].append(
|
|
406
|
+
self.pending_outputs[window.name].append(PendingCropOutput(cur_bounds, output))
|
|
379
407
|
logger.debug(
|
|
380
|
-
f"Stored
|
|
408
|
+
f"Stored PendingCropOutput for crop #{crop_idx}/{num_crops} at window {window.name}"
|
|
381
409
|
)
|
|
382
410
|
|
|
383
|
-
if
|
|
411
|
+
if crop_idx < num_crops - 1:
|
|
384
412
|
return
|
|
385
413
|
|
|
386
|
-
# This is the last
|
|
414
|
+
# This is the last crop so it's time to write it.
|
|
387
415
|
# First get the pending output and clear it.
|
|
388
416
|
pending_output = self.pending_outputs[window.name]
|
|
389
417
|
del self.pending_outputs[window.name]
|
|
390
418
|
|
|
391
|
-
# Merge outputs from overlapped
|
|
419
|
+
# Merge outputs from overlapped crops if merger is set.
|
|
392
420
|
logger.debug(f"Merging and writing for window {window.name}")
|
|
393
421
|
merged_output = self.merger.merge(window, pending_output, self.layer_config)
|
|
394
422
|
|
|
@@ -16,6 +16,7 @@ from torchmetrics.classification import (
|
|
|
16
16
|
)
|
|
17
17
|
|
|
18
18
|
from rslearn.models.component import FeatureVector, Predictor
|
|
19
|
+
from rslearn.train.metrics import ConfusionMatrixMetric
|
|
19
20
|
from rslearn.train.model_context import (
|
|
20
21
|
ModelContext,
|
|
21
22
|
ModelOutput,
|
|
@@ -44,6 +45,7 @@ class ClassificationTask(BasicTask):
|
|
|
44
45
|
f1_metric_kwargs: dict[str, Any] = {},
|
|
45
46
|
positive_class: str | None = None,
|
|
46
47
|
positive_class_threshold: float = 0.5,
|
|
48
|
+
enable_confusion_matrix: bool = False,
|
|
47
49
|
**kwargs: Any,
|
|
48
50
|
):
|
|
49
51
|
"""Initialize a new ClassificationTask.
|
|
@@ -69,6 +71,8 @@ class ClassificationTask(BasicTask):
|
|
|
69
71
|
positive_class: positive class name.
|
|
70
72
|
positive_class_threshold: threshold for classifying the positive class in
|
|
71
73
|
binary classification (default 0.5).
|
|
74
|
+
enable_confusion_matrix: whether to compute confusion matrix (default false).
|
|
75
|
+
If true, it requires wandb to be initialized for logging.
|
|
72
76
|
kwargs: other arguments to pass to BasicTask
|
|
73
77
|
"""
|
|
74
78
|
super().__init__(**kwargs)
|
|
@@ -84,6 +88,7 @@ class ClassificationTask(BasicTask):
|
|
|
84
88
|
self.f1_metric_kwargs = f1_metric_kwargs
|
|
85
89
|
self.positive_class = positive_class
|
|
86
90
|
self.positive_class_threshold = positive_class_threshold
|
|
91
|
+
self.enable_confusion_matrix = enable_confusion_matrix
|
|
87
92
|
|
|
88
93
|
if self.positive_class_threshold != 0.5:
|
|
89
94
|
# Must be binary classification
|
|
@@ -201,7 +206,7 @@ class ClassificationTask(BasicTask):
|
|
|
201
206
|
feature = Feature(
|
|
202
207
|
STGeometry(
|
|
203
208
|
metadata.projection,
|
|
204
|
-
shapely.Point(metadata.
|
|
209
|
+
shapely.Point(metadata.crop_bounds[0], metadata.crop_bounds[1]),
|
|
205
210
|
None,
|
|
206
211
|
),
|
|
207
212
|
{
|
|
@@ -278,6 +283,14 @@ class ClassificationTask(BasicTask):
|
|
|
278
283
|
)
|
|
279
284
|
metrics["f1"] = ClassificationMetric(MulticlassF1Score(**kwargs))
|
|
280
285
|
|
|
286
|
+
if self.enable_confusion_matrix:
|
|
287
|
+
metrics["confusion_matrix"] = ClassificationMetric(
|
|
288
|
+
ConfusionMatrixMetric(
|
|
289
|
+
num_classes=len(self.classes),
|
|
290
|
+
class_names=self.classes,
|
|
291
|
+
),
|
|
292
|
+
)
|
|
293
|
+
|
|
281
294
|
return MetricCollection(metrics)
|
|
282
295
|
|
|
283
296
|
|
rslearn/train/tasks/detection.py
CHANGED
|
@@ -128,7 +128,7 @@ class DetectionTask(BasicTask):
|
|
|
128
128
|
if not load_targets:
|
|
129
129
|
return {}, {}
|
|
130
130
|
|
|
131
|
-
bounds = metadata.
|
|
131
|
+
bounds = metadata.crop_bounds
|
|
132
132
|
|
|
133
133
|
boxes = []
|
|
134
134
|
class_labels = []
|
|
@@ -244,10 +244,10 @@ class DetectionTask(BasicTask):
|
|
|
244
244
|
features = []
|
|
245
245
|
for box, class_id, score in zip(boxes, class_ids, scores):
|
|
246
246
|
shp = shapely.box(
|
|
247
|
-
metadata.
|
|
248
|
-
metadata.
|
|
249
|
-
metadata.
|
|
250
|
-
metadata.
|
|
247
|
+
metadata.crop_bounds[0] + float(box[0]),
|
|
248
|
+
metadata.crop_bounds[1] + float(box[1]),
|
|
249
|
+
metadata.crop_bounds[0] + float(box[2]),
|
|
250
|
+
metadata.crop_bounds[1] + float(box[3]),
|
|
251
251
|
)
|
|
252
252
|
geom = STGeometry(metadata.projection, shp, None)
|
|
253
253
|
properties: dict[str, Any] = {
|
|
@@ -149,22 +149,28 @@ class PerPixelRegressionHead(Predictor):
|
|
|
149
149
|
"""Head for per-pixel regression task."""
|
|
150
150
|
|
|
151
151
|
def __init__(
|
|
152
|
-
self,
|
|
152
|
+
self,
|
|
153
|
+
loss_mode: Literal["mse", "l1", "huber"] = "mse",
|
|
154
|
+
use_sigmoid: bool = False,
|
|
155
|
+
huber_delta: float = 1.0,
|
|
153
156
|
):
|
|
154
|
-
"""Initialize a new
|
|
157
|
+
"""Initialize a new PerPixelRegressionHead.
|
|
155
158
|
|
|
156
159
|
Args:
|
|
157
|
-
loss_mode: the loss function to use
|
|
160
|
+
loss_mode: the loss function to use: "mse" (default), "l1", or "huber".
|
|
158
161
|
use_sigmoid: whether to apply a sigmoid activation on the output. This
|
|
159
162
|
requires targets to be between 0-1.
|
|
163
|
+
huber_delta: delta parameter for Huber loss (only used when
|
|
164
|
+
loss_mode="huber").
|
|
160
165
|
"""
|
|
161
166
|
super().__init__()
|
|
162
167
|
|
|
163
|
-
if loss_mode not in ["mse", "l1"]:
|
|
164
|
-
raise ValueError("invalid loss mode")
|
|
168
|
+
if loss_mode not in ["mse", "l1", "huber"]:
|
|
169
|
+
raise ValueError(f"invalid loss mode {loss_mode}")
|
|
165
170
|
|
|
166
171
|
self.loss_mode = loss_mode
|
|
167
172
|
self.use_sigmoid = use_sigmoid
|
|
173
|
+
self.huber_delta = huber_delta
|
|
168
174
|
|
|
169
175
|
def forward(
|
|
170
176
|
self,
|
|
@@ -217,8 +223,15 @@ class PerPixelRegressionHead(Predictor):
|
|
|
217
223
|
scores = torch.square(outputs - labels)
|
|
218
224
|
elif self.loss_mode == "l1":
|
|
219
225
|
scores = torch.abs(outputs - labels)
|
|
226
|
+
elif self.loss_mode == "huber":
|
|
227
|
+
scores = torch.nn.functional.huber_loss(
|
|
228
|
+
outputs,
|
|
229
|
+
labels,
|
|
230
|
+
reduction="none",
|
|
231
|
+
delta=self.huber_delta,
|
|
232
|
+
)
|
|
220
233
|
else:
|
|
221
|
-
|
|
234
|
+
raise ValueError(f"unknown loss mode {self.loss_mode}")
|
|
222
235
|
|
|
223
236
|
# Compute average but only over valid pixels.
|
|
224
237
|
mask_total = mask.sum()
|
|
@@ -130,7 +130,7 @@ class RegressionTask(BasicTask):
|
|
|
130
130
|
feature = Feature(
|
|
131
131
|
STGeometry(
|
|
132
132
|
metadata.projection,
|
|
133
|
-
shapely.Point(metadata.
|
|
133
|
+
shapely.Point(metadata.crop_bounds[0], metadata.crop_bounds[1]),
|
|
134
134
|
None,
|
|
135
135
|
),
|
|
136
136
|
{
|
|
@@ -196,18 +196,24 @@ class RegressionHead(Predictor):
|
|
|
196
196
|
"""Head for regression task."""
|
|
197
197
|
|
|
198
198
|
def __init__(
|
|
199
|
-
self,
|
|
199
|
+
self,
|
|
200
|
+
loss_mode: Literal["mse", "l1", "huber"] = "mse",
|
|
201
|
+
use_sigmoid: bool = False,
|
|
202
|
+
huber_delta: float = 1.0,
|
|
200
203
|
):
|
|
201
204
|
"""Initialize a new RegressionHead.
|
|
202
205
|
|
|
203
206
|
Args:
|
|
204
|
-
loss_mode: the loss function to use
|
|
207
|
+
loss_mode: the loss function to use: "mse" (default), "l1", or "huber".
|
|
205
208
|
use_sigmoid: whether to apply a sigmoid activation on the output. This
|
|
206
209
|
requires targets to be between 0-1.
|
|
210
|
+
huber_delta: delta parameter for Huber loss (only used when
|
|
211
|
+
loss_mode="huber").
|
|
207
212
|
"""
|
|
208
213
|
super().__init__()
|
|
209
214
|
self.loss_mode = loss_mode
|
|
210
215
|
self.use_sigmoid = use_sigmoid
|
|
216
|
+
self.huber_delta = huber_delta
|
|
211
217
|
|
|
212
218
|
def forward(
|
|
213
219
|
self,
|
|
@@ -251,6 +257,16 @@ class RegressionHead(Predictor):
|
|
|
251
257
|
losses["regress"] = torch.mean(torch.square(outputs - labels) * mask)
|
|
252
258
|
elif self.loss_mode == "l1":
|
|
253
259
|
losses["regress"] = torch.mean(torch.abs(outputs - labels) * mask)
|
|
260
|
+
elif self.loss_mode == "huber":
|
|
261
|
+
losses["regress"] = torch.mean(
|
|
262
|
+
torch.nn.functional.huber_loss(
|
|
263
|
+
outputs,
|
|
264
|
+
labels,
|
|
265
|
+
reduction="none",
|
|
266
|
+
delta=self.huber_delta,
|
|
267
|
+
)
|
|
268
|
+
* mask
|
|
269
|
+
)
|
|
254
270
|
else:
|
|
255
271
|
raise ValueError(f"unknown loss mode {self.loss_mode}")
|
|
256
272
|
|
|
@@ -10,6 +10,7 @@ import torchmetrics.classification
|
|
|
10
10
|
from torchmetrics import Metric, MetricCollection
|
|
11
11
|
|
|
12
12
|
from rslearn.models.component import FeatureMaps, Predictor
|
|
13
|
+
from rslearn.train.metrics import ConfusionMatrixMetric
|
|
13
14
|
from rslearn.train.model_context import (
|
|
14
15
|
ModelContext,
|
|
15
16
|
ModelOutput,
|
|
@@ -43,6 +44,8 @@ class SegmentationTask(BasicTask):
|
|
|
43
44
|
other_metrics: dict[str, Metric] = {},
|
|
44
45
|
output_probs: bool = False,
|
|
45
46
|
output_class_idx: int | None = None,
|
|
47
|
+
enable_confusion_matrix: bool = False,
|
|
48
|
+
class_names: list[str] | None = None,
|
|
46
49
|
**kwargs: Any,
|
|
47
50
|
) -> None:
|
|
48
51
|
"""Initialize a new SegmentationTask.
|
|
@@ -80,6 +83,10 @@ class SegmentationTask(BasicTask):
|
|
|
80
83
|
during prediction.
|
|
81
84
|
output_class_idx: if set along with output_probs, only output the probability
|
|
82
85
|
for this specific class index (single-channel output).
|
|
86
|
+
enable_confusion_matrix: whether to compute confusion matrix (default false).
|
|
87
|
+
If true, it requires wandb to be initialized for logging.
|
|
88
|
+
class_names: optional list of class names for labeling confusion matrix axes.
|
|
89
|
+
If not provided, classes will be labeled as "class_0", "class_1", etc.
|
|
83
90
|
kwargs: additional arguments to pass to BasicTask
|
|
84
91
|
"""
|
|
85
92
|
super().__init__(**kwargs)
|
|
@@ -106,6 +113,8 @@ class SegmentationTask(BasicTask):
|
|
|
106
113
|
self.other_metrics = other_metrics
|
|
107
114
|
self.output_probs = output_probs
|
|
108
115
|
self.output_class_idx = output_class_idx
|
|
116
|
+
self.enable_confusion_matrix = enable_confusion_matrix
|
|
117
|
+
self.class_names = class_names
|
|
109
118
|
|
|
110
119
|
def process_inputs(
|
|
111
120
|
self,
|
|
@@ -285,6 +294,14 @@ class SegmentationTask(BasicTask):
|
|
|
285
294
|
if self.other_metrics:
|
|
286
295
|
metrics.update(self.other_metrics)
|
|
287
296
|
|
|
297
|
+
if self.enable_confusion_matrix:
|
|
298
|
+
metrics["confusion_matrix"] = SegmentationMetric(
|
|
299
|
+
ConfusionMatrixMetric(
|
|
300
|
+
num_classes=self.num_classes,
|
|
301
|
+
class_names=self.class_names,
|
|
302
|
+
),
|
|
303
|
+
)
|
|
304
|
+
|
|
288
305
|
return MetricCollection(metrics)
|
|
289
306
|
|
|
290
307
|
|
rslearn/utils/__init__.py
CHANGED
|
@@ -7,6 +7,7 @@ from .geometry import (
|
|
|
7
7
|
PixelBounds,
|
|
8
8
|
Projection,
|
|
9
9
|
STGeometry,
|
|
10
|
+
get_global_raster_bounds,
|
|
10
11
|
is_same_resolution,
|
|
11
12
|
shp_intersects,
|
|
12
13
|
)
|
|
@@ -23,6 +24,7 @@ __all__ = (
|
|
|
23
24
|
"Projection",
|
|
24
25
|
"STGeometry",
|
|
25
26
|
"daterange",
|
|
27
|
+
"get_global_raster_bounds",
|
|
26
28
|
"get_utm_ups_crs",
|
|
27
29
|
"is_same_resolution",
|
|
28
30
|
"logger",
|
rslearn/utils/fsspec.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import tempfile
|
|
5
|
-
from collections.abc import Generator
|
|
5
|
+
from collections.abc import Generator, Iterator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
@@ -16,6 +16,56 @@ from rslearn.log_utils import get_logger
|
|
|
16
16
|
logger = get_logger(__name__)
|
|
17
17
|
|
|
18
18
|
|
|
19
|
+
def iter_nonhidden(path: UPath) -> Iterator[UPath]:
|
|
20
|
+
"""Iterate over non-hidden entries in a directory.
|
|
21
|
+
|
|
22
|
+
Hidden entries are those whose basename starts with "." (e.g. ".DS_Store").
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
path: the directory to iterate.
|
|
26
|
+
|
|
27
|
+
Yields:
|
|
28
|
+
non-hidden UPath entries in the directory.
|
|
29
|
+
"""
|
|
30
|
+
try:
|
|
31
|
+
it = path.iterdir()
|
|
32
|
+
except (FileNotFoundError, NotADirectoryError):
|
|
33
|
+
return
|
|
34
|
+
|
|
35
|
+
for p in it:
|
|
36
|
+
if p.name.startswith("."):
|
|
37
|
+
continue
|
|
38
|
+
yield p
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def iter_nonhidden_subdirs(path: UPath) -> Iterator[UPath]:
|
|
42
|
+
"""Iterate over non-hidden subdirectories in a directory.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
path: the directory to iterate.
|
|
46
|
+
|
|
47
|
+
Yields:
|
|
48
|
+
non-hidden subdirectories in the directory.
|
|
49
|
+
"""
|
|
50
|
+
for p in iter_nonhidden(path):
|
|
51
|
+
if p.is_dir():
|
|
52
|
+
yield p
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def iter_nonhidden_files(path: UPath) -> Iterator[UPath]:
|
|
56
|
+
"""Iterate over non-hidden files in a directory.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
path: the directory to iterate.
|
|
60
|
+
|
|
61
|
+
Yields:
|
|
62
|
+
non-hidden files in the directory.
|
|
63
|
+
"""
|
|
64
|
+
for p in iter_nonhidden(path):
|
|
65
|
+
if p.is_file():
|
|
66
|
+
yield p
|
|
67
|
+
|
|
68
|
+
|
|
19
69
|
@contextmanager
|
|
20
70
|
def get_upath_local(
|
|
21
71
|
path: UPath, extra_paths: list[UPath] = []
|
rslearn/utils/geometry.py
CHANGED
|
@@ -116,6 +116,27 @@ class Projection:
|
|
|
116
116
|
WGS84_PROJECTION = Projection(CRS.from_epsg(WGS84_EPSG), 1, 1)
|
|
117
117
|
|
|
118
118
|
|
|
119
|
+
def get_global_raster_bounds(projection: Projection) -> PixelBounds:
|
|
120
|
+
"""Get very large pixel bounds for a global raster in the given projection.
|
|
121
|
+
|
|
122
|
+
This is useful for data sources that cover the entire world and don't want to
|
|
123
|
+
compute exact bounds in arbitrary projections (which can fail for projections
|
|
124
|
+
like UTM that only cover part of the world).
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
projection: the projection to get bounds in.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Pixel bounds that will intersect with any reasonable window. We assume that the
|
|
131
|
+
absolute value of CRS coordinates is at most 2^32, and adjust it based on the
|
|
132
|
+
resolution in the Projection in case very fine-grained resolutions are used.
|
|
133
|
+
"""
|
|
134
|
+
crs_bound = 2**32
|
|
135
|
+
pixel_bound_x = int(crs_bound / abs(projection.x_resolution))
|
|
136
|
+
pixel_bound_y = int(crs_bound / abs(projection.y_resolution))
|
|
137
|
+
return (-pixel_bound_x, -pixel_bound_y, pixel_bound_x, pixel_bound_y)
|
|
138
|
+
|
|
139
|
+
|
|
119
140
|
class ResolutionFactor:
|
|
120
141
|
"""Multiplier for the resolution in a Projection.
|
|
121
142
|
|