rslearn 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/data_sources/aws_open_data.py +11 -15
- rslearn/data_sources/aws_sentinel2_element84.py +374 -0
- rslearn/data_sources/gcp_public_data.py +16 -0
- rslearn/data_sources/planetary_computer.py +78 -257
- rslearn/data_sources/soilgrids.py +331 -0
- rslearn/data_sources/stac.py +275 -0
- rslearn/main.py +4 -1
- rslearn/models/attention_pooling.py +5 -2
- rslearn/train/lightning_module.py +24 -11
- rslearn/train/tasks/embedding.py +2 -2
- rslearn/train/tasks/multi_task.py +8 -5
- rslearn/train/tasks/per_pixel_regression.py +1 -1
- rslearn/train/tasks/segmentation.py +143 -21
- rslearn/train/tasks/task.py +4 -2
- rslearn/utils/geometry.py +2 -2
- rslearn/utils/stac.py +173 -0
- {rslearn-0.0.21.dist-info → rslearn-0.0.23.dist-info}/METADATA +4 -1
- {rslearn-0.0.21.dist-info → rslearn-0.0.23.dist-info}/RECORD +23 -19
- {rslearn-0.0.21.dist-info → rslearn-0.0.23.dist-info}/WHEEL +1 -1
- {rslearn-0.0.21.dist-info → rslearn-0.0.23.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.21.dist-info → rslearn-0.0.23.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.21.dist-info → rslearn-0.0.23.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.21.dist-info → rslearn-0.0.23.dist-info}/top_level.txt +0 -0
|
@@ -210,11 +210,30 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
210
210
|
# Fail silently for single-dataset case, which is okay
|
|
211
211
|
pass
|
|
212
212
|
|
|
213
|
+
def on_validation_epoch_end(self) -> None:
|
|
214
|
+
"""Compute and log validation metrics at epoch end.
|
|
215
|
+
|
|
216
|
+
We manually compute and log metrics here (instead of passing the MetricCollection
|
|
217
|
+
to log_dict) because MetricCollection.compute() properly flattens dict-returning
|
|
218
|
+
metrics, while log_dict expects each metric to return a scalar tensor.
|
|
219
|
+
"""
|
|
220
|
+
metrics = self.val_metrics.compute()
|
|
221
|
+
self.log_dict(metrics)
|
|
222
|
+
self.val_metrics.reset()
|
|
223
|
+
|
|
213
224
|
def on_test_epoch_end(self) -> None:
|
|
214
|
-
"""
|
|
225
|
+
"""Compute and log test metrics at epoch end, optionally save to file.
|
|
226
|
+
|
|
227
|
+
We manually compute and log metrics here (instead of passing the MetricCollection
|
|
228
|
+
to log_dict) because MetricCollection.compute() properly flattens dict-returning
|
|
229
|
+
metrics, while log_dict expects each metric to return a scalar tensor.
|
|
230
|
+
"""
|
|
231
|
+
metrics = self.test_metrics.compute()
|
|
232
|
+
self.log_dict(metrics)
|
|
233
|
+
self.test_metrics.reset()
|
|
234
|
+
|
|
215
235
|
if self.metrics_file:
|
|
216
236
|
with open(self.metrics_file, "w") as f:
|
|
217
|
-
metrics = self.test_metrics.compute()
|
|
218
237
|
metrics_dict = {k: v.item() for k, v in metrics.items()}
|
|
219
238
|
json.dump(metrics_dict, f, indent=4)
|
|
220
239
|
logger.info(f"Saved metrics to {self.metrics_file}")
|
|
@@ -300,9 +319,6 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
300
319
|
sync_dist=True,
|
|
301
320
|
)
|
|
302
321
|
self.val_metrics.update(outputs, targets)
|
|
303
|
-
self.log_dict(
|
|
304
|
-
self.val_metrics, batch_size=batch_size, on_epoch=True, sync_dist=True
|
|
305
|
-
)
|
|
306
322
|
|
|
307
323
|
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
|
|
308
324
|
"""Compute the test loss and additional metrics.
|
|
@@ -340,19 +356,16 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
340
356
|
sync_dist=True,
|
|
341
357
|
)
|
|
342
358
|
self.test_metrics.update(outputs, targets)
|
|
343
|
-
self.log_dict(
|
|
344
|
-
self.test_metrics, batch_size=batch_size, on_epoch=True, sync_dist=True
|
|
345
|
-
)
|
|
346
359
|
|
|
347
360
|
if self.visualize_dir:
|
|
348
|
-
for
|
|
349
|
-
|
|
361
|
+
for inp, target, output, metadata in zip(
|
|
362
|
+
inputs, targets, outputs, metadatas
|
|
350
363
|
):
|
|
351
364
|
images = self.task.visualize(inp, target, output)
|
|
352
365
|
for image_suffix, image in images.items():
|
|
353
366
|
out_fname = os.path.join(
|
|
354
367
|
self.visualize_dir,
|
|
355
|
-
f"{metadata
|
|
368
|
+
f"{metadata.window_name}_{metadata.patch_bounds[0]}_{metadata.patch_bounds[1]}_{image_suffix}.png",
|
|
356
369
|
)
|
|
357
370
|
Image.fromarray(image).save(out_fname)
|
|
358
371
|
|
rslearn/train/tasks/embedding.py
CHANGED
|
@@ -6,7 +6,7 @@ import numpy.typing as npt
|
|
|
6
6
|
import torch
|
|
7
7
|
from torchmetrics import MetricCollection
|
|
8
8
|
|
|
9
|
-
from rslearn.models.component import FeatureMaps
|
|
9
|
+
from rslearn.models.component import FeatureMaps, Predictor
|
|
10
10
|
from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
|
|
11
11
|
from rslearn.utils import Feature
|
|
12
12
|
|
|
@@ -83,7 +83,7 @@ class EmbeddingTask(Task):
|
|
|
83
83
|
return MetricCollection({})
|
|
84
84
|
|
|
85
85
|
|
|
86
|
-
class EmbeddingHead:
|
|
86
|
+
class EmbeddingHead(Predictor):
|
|
87
87
|
"""Head for embedding task.
|
|
88
88
|
|
|
89
89
|
It just adds a dummy loss to act as a Predictor.
|
|
@@ -118,13 +118,16 @@ class MultiTask(Task):
|
|
|
118
118
|
|
|
119
119
|
def get_metrics(self) -> MetricCollection:
|
|
120
120
|
"""Get metrics for this task."""
|
|
121
|
-
metrics
|
|
121
|
+
# Flatten metrics into a single dict with task_name/ prefix to avoid nested
|
|
122
|
+
# MetricCollections. Nested collections cause issues because MetricCollection
|
|
123
|
+
# has postfix=None which breaks MetricCollection.compute().
|
|
124
|
+
all_metrics = {}
|
|
122
125
|
for task_name, task in self.tasks.items():
|
|
123
|
-
cur_metrics = {}
|
|
124
126
|
for metric_name, metric in task.get_metrics().items():
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
127
|
+
all_metrics[f"{task_name}/{metric_name}"] = MetricWrapper(
|
|
128
|
+
task_name, metric
|
|
129
|
+
)
|
|
130
|
+
return MetricCollection(all_metrics)
|
|
128
131
|
|
|
129
132
|
|
|
130
133
|
class MetricWrapper(Metric):
|
|
@@ -100,7 +100,7 @@ class PerPixelRegressionTask(BasicTask):
|
|
|
100
100
|
raise ValueError(
|
|
101
101
|
f"PerPixelRegressionTask output must be an HW tensor, but got shape {raw_output.shape}"
|
|
102
102
|
)
|
|
103
|
-
return (raw_output / self.scale_factor).cpu().numpy()
|
|
103
|
+
return (raw_output[None, :, :] / self.scale_factor).cpu().numpy()
|
|
104
104
|
|
|
105
105
|
def visualize(
|
|
106
106
|
self,
|
|
@@ -53,6 +53,7 @@ class SegmentationTask(BasicTask):
|
|
|
53
53
|
enable_accuracy_metric: bool = True,
|
|
54
54
|
enable_miou_metric: bool = False,
|
|
55
55
|
enable_f1_metric: bool = False,
|
|
56
|
+
report_metric_per_class: bool = False,
|
|
56
57
|
f1_metric_thresholds: list[list[float]] = [[0.5]],
|
|
57
58
|
metric_kwargs: dict[str, Any] = {},
|
|
58
59
|
miou_metric_kwargs: dict[str, Any] = {},
|
|
@@ -74,6 +75,8 @@ class SegmentationTask(BasicTask):
|
|
|
74
75
|
enable_accuracy_metric: whether to enable the accuracy metric (default
|
|
75
76
|
true).
|
|
76
77
|
enable_f1_metric: whether to enable the F1 metric (default false).
|
|
78
|
+
report_metric_per_class: whether to report chosen metrics for each class, in
|
|
79
|
+
addition to the average score across classes.
|
|
77
80
|
enable_miou_metric: whether to enable the mean IoU metric (default false).
|
|
78
81
|
f1_metric_thresholds: list of list of thresholds to apply for F1 metric.
|
|
79
82
|
Each inner list is used to initialize a separate F1 metric where the
|
|
@@ -107,6 +110,7 @@ class SegmentationTask(BasicTask):
|
|
|
107
110
|
self.enable_accuracy_metric = enable_accuracy_metric
|
|
108
111
|
self.enable_f1_metric = enable_f1_metric
|
|
109
112
|
self.enable_miou_metric = enable_miou_metric
|
|
113
|
+
self.report_metric_per_class = report_metric_per_class
|
|
110
114
|
self.f1_metric_thresholds = f1_metric_thresholds
|
|
111
115
|
self.metric_kwargs = metric_kwargs
|
|
112
116
|
self.miou_metric_kwargs = miou_metric_kwargs
|
|
@@ -237,29 +241,41 @@ class SegmentationTask(BasicTask):
|
|
|
237
241
|
# Metric name can't contain "." so change to ",".
|
|
238
242
|
suffix = "_" + str(thresholds[0]).replace(".", ",")
|
|
239
243
|
|
|
244
|
+
# Create one metric per type - it returns a dict with "avg" and optionally per-class keys
|
|
240
245
|
metrics["F1" + suffix] = SegmentationMetric(
|
|
241
|
-
F1Metric(
|
|
246
|
+
F1Metric(
|
|
247
|
+
num_classes=self.num_classes,
|
|
248
|
+
score_thresholds=thresholds,
|
|
249
|
+
report_per_class=self.report_metric_per_class,
|
|
250
|
+
),
|
|
242
251
|
)
|
|
243
252
|
metrics["precision" + suffix] = SegmentationMetric(
|
|
244
253
|
F1Metric(
|
|
245
254
|
num_classes=self.num_classes,
|
|
246
255
|
score_thresholds=thresholds,
|
|
247
256
|
metric_mode="precision",
|
|
248
|
-
|
|
257
|
+
report_per_class=self.report_metric_per_class,
|
|
258
|
+
),
|
|
249
259
|
)
|
|
250
260
|
metrics["recall" + suffix] = SegmentationMetric(
|
|
251
261
|
F1Metric(
|
|
252
262
|
num_classes=self.num_classes,
|
|
253
263
|
score_thresholds=thresholds,
|
|
254
264
|
metric_mode="recall",
|
|
255
|
-
|
|
265
|
+
report_per_class=self.report_metric_per_class,
|
|
266
|
+
),
|
|
256
267
|
)
|
|
257
268
|
|
|
258
269
|
if self.enable_miou_metric:
|
|
259
|
-
miou_metric_kwargs: dict[str, Any] = dict(
|
|
270
|
+
miou_metric_kwargs: dict[str, Any] = dict(
|
|
271
|
+
num_classes=self.num_classes,
|
|
272
|
+
report_per_class=self.report_metric_per_class,
|
|
273
|
+
)
|
|
260
274
|
if self.nodata_value is not None:
|
|
261
275
|
miou_metric_kwargs["nodata_value"] = self.nodata_value
|
|
262
276
|
miou_metric_kwargs.update(self.miou_metric_kwargs)
|
|
277
|
+
|
|
278
|
+
# Create one metric - it returns a dict with "avg" and optionally per-class keys
|
|
263
279
|
metrics["mean_iou"] = SegmentationMetric(
|
|
264
280
|
MeanIoUMetric(**miou_metric_kwargs),
|
|
265
281
|
pass_probabilities=False,
|
|
@@ -274,6 +290,20 @@ class SegmentationTask(BasicTask):
|
|
|
274
290
|
class SegmentationHead(Predictor):
|
|
275
291
|
"""Head for segmentation task."""
|
|
276
292
|
|
|
293
|
+
def __init__(self, weights: list[float] | None = None, dice_loss: bool = False):
|
|
294
|
+
"""Initialize a new SegmentationTask.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
weights: weights for cross entropy loss (Tensor of size C)
|
|
298
|
+
dice_loss: weather to add dice loss to cross entropy
|
|
299
|
+
"""
|
|
300
|
+
super().__init__()
|
|
301
|
+
if weights is not None:
|
|
302
|
+
self.register_buffer("weights", torch.Tensor(weights))
|
|
303
|
+
else:
|
|
304
|
+
self.weights = None
|
|
305
|
+
self.dice_loss = dice_loss
|
|
306
|
+
|
|
277
307
|
def forward(
|
|
278
308
|
self,
|
|
279
309
|
intermediates: Any,
|
|
@@ -308,7 +338,7 @@ class SegmentationHead(Predictor):
|
|
|
308
338
|
labels = torch.stack([target["classes"] for target in targets], dim=0)
|
|
309
339
|
mask = torch.stack([target["valid"] for target in targets], dim=0)
|
|
310
340
|
per_pixel_loss = torch.nn.functional.cross_entropy(
|
|
311
|
-
logits, labels, reduction="none"
|
|
341
|
+
logits, labels, weight=self.weights, reduction="none"
|
|
312
342
|
)
|
|
313
343
|
mask_sum = torch.sum(mask)
|
|
314
344
|
if mask_sum > 0:
|
|
@@ -318,6 +348,9 @@ class SegmentationHead(Predictor):
|
|
|
318
348
|
# If there are no valid pixels, we avoid dividing by zero and just let
|
|
319
349
|
# the summed mask loss be zero.
|
|
320
350
|
losses["cls"] = torch.sum(per_pixel_loss * mask)
|
|
351
|
+
if self.dice_loss:
|
|
352
|
+
dice_loss = DiceLoss()(outputs, labels, mask)
|
|
353
|
+
losses["dice"] = dice_loss
|
|
321
354
|
|
|
322
355
|
return ModelOutput(
|
|
323
356
|
outputs=outputs,
|
|
@@ -333,6 +366,7 @@ class SegmentationMetric(Metric):
|
|
|
333
366
|
metric: Metric,
|
|
334
367
|
pass_probabilities: bool = True,
|
|
335
368
|
class_idx: int | None = None,
|
|
369
|
+
output_key: str | None = None,
|
|
336
370
|
):
|
|
337
371
|
"""Initialize a new SegmentationMetric.
|
|
338
372
|
|
|
@@ -341,12 +375,19 @@ class SegmentationMetric(Metric):
|
|
|
341
375
|
classes from the targets and masking out invalid pixels.
|
|
342
376
|
pass_probabilities: whether to pass predicted probabilities to the metric.
|
|
343
377
|
If False, argmax is applied to pass the predicted classes instead.
|
|
344
|
-
class_idx: if
|
|
378
|
+
class_idx: if set, return only this class index's value. For backward
|
|
379
|
+
compatibility with configs using standard torchmetrics. Internally
|
|
380
|
+
converted to output_key="cls_{class_idx}".
|
|
381
|
+
output_key: if the wrapped metric returns a dict (or a tensor that gets
|
|
382
|
+
converted to a dict), return only this key's value. For standard
|
|
383
|
+
torchmetrics with average=None, tensors are converted to dicts with
|
|
384
|
+
keys "cls_0", "cls_1", etc. If None, the full dict is returned.
|
|
345
385
|
"""
|
|
346
386
|
super().__init__()
|
|
347
387
|
self.metric = metric
|
|
348
388
|
self.pass_probablities = pass_probabilities
|
|
349
389
|
self.class_idx = class_idx
|
|
390
|
+
self.output_key = output_key
|
|
350
391
|
|
|
351
392
|
def update(
|
|
352
393
|
self, preds: list[Any] | torch.Tensor, targets: list[dict[str, Any]]
|
|
@@ -376,10 +417,32 @@ class SegmentationMetric(Metric):
|
|
|
376
417
|
self.metric.update(preds, labels)
|
|
377
418
|
|
|
378
419
|
def compute(self) -> Any:
|
|
379
|
-
"""Returns the computed metric.
|
|
420
|
+
"""Returns the computed metric.
|
|
421
|
+
|
|
422
|
+
If the wrapped metric returns a multi-element tensor (e.g., standard torchmetrics
|
|
423
|
+
with average=None), it is converted to a dict with keys like "cls_0", "cls_1", etc.
|
|
424
|
+
This allows uniform handling via output_key for both standard torchmetrics and
|
|
425
|
+
custom dict-returning metrics.
|
|
426
|
+
"""
|
|
380
427
|
result = self.metric.compute()
|
|
428
|
+
|
|
429
|
+
# Convert multi-element tensors to dict for uniform handling.
|
|
430
|
+
# This supports standard torchmetrics with average=None which return per-class tensors.
|
|
431
|
+
if isinstance(result, torch.Tensor) and result.ndim >= 1:
|
|
432
|
+
result = {f"cls_{i}": result[i] for i in range(len(result))}
|
|
433
|
+
|
|
434
|
+
if self.output_key is not None:
|
|
435
|
+
if not isinstance(result, dict):
|
|
436
|
+
raise TypeError(
|
|
437
|
+
f"output_key is set to '{self.output_key}' but metric returned "
|
|
438
|
+
f"{type(result).__name__} instead of dict"
|
|
439
|
+
)
|
|
440
|
+
return result[self.output_key]
|
|
381
441
|
if self.class_idx is not None:
|
|
382
|
-
|
|
442
|
+
# For backward compatibility: class_idx can index into the converted dict
|
|
443
|
+
if isinstance(result, dict):
|
|
444
|
+
return result[f"cls_{self.class_idx}"]
|
|
445
|
+
return result[self.class_idx]
|
|
383
446
|
return result
|
|
384
447
|
|
|
385
448
|
def reset(self) -> None:
|
|
@@ -404,6 +467,7 @@ class F1Metric(Metric):
|
|
|
404
467
|
num_classes: int,
|
|
405
468
|
score_thresholds: list[float],
|
|
406
469
|
metric_mode: str = "f1",
|
|
470
|
+
report_per_class: bool = False,
|
|
407
471
|
):
|
|
408
472
|
"""Create a new F1Metric.
|
|
409
473
|
|
|
@@ -413,11 +477,14 @@ class F1Metric(Metric):
|
|
|
413
477
|
metric is the best F1 across score thresholds.
|
|
414
478
|
metric_mode: set to "precision" or "recall" to return that instead of F1
|
|
415
479
|
(default "f1")
|
|
480
|
+
report_per_class: whether to include per-class scores in the output dict.
|
|
481
|
+
If False, only returns the "avg" key.
|
|
416
482
|
"""
|
|
417
483
|
super().__init__()
|
|
418
484
|
self.num_classes = num_classes
|
|
419
485
|
self.score_thresholds = score_thresholds
|
|
420
486
|
self.metric_mode = metric_mode
|
|
487
|
+
self.report_per_class = report_per_class
|
|
421
488
|
|
|
422
489
|
assert self.metric_mode in ["f1", "precision", "recall"]
|
|
423
490
|
|
|
@@ -462,9 +529,10 @@ class F1Metric(Metric):
|
|
|
462
529
|
"""Compute metric.
|
|
463
530
|
|
|
464
531
|
Returns:
|
|
465
|
-
|
|
532
|
+
dict with "avg" key containing mean score across classes.
|
|
533
|
+
If report_per_class is True, also includes "cls_N" keys for each class N.
|
|
466
534
|
"""
|
|
467
|
-
|
|
535
|
+
cls_best_scores = {}
|
|
468
536
|
|
|
469
537
|
for cls_idx in range(self.num_classes):
|
|
470
538
|
best_score = None
|
|
@@ -501,9 +569,12 @@ class F1Metric(Metric):
|
|
|
501
569
|
if best_score is None or score > best_score:
|
|
502
570
|
best_score = score
|
|
503
571
|
|
|
504
|
-
|
|
572
|
+
cls_best_scores[f"cls_{cls_idx}"] = best_score
|
|
505
573
|
|
|
506
|
-
|
|
574
|
+
report_scores = {"avg": torch.mean(torch.stack(list(cls_best_scores.values())))}
|
|
575
|
+
if self.report_per_class:
|
|
576
|
+
report_scores.update(cls_best_scores)
|
|
577
|
+
return report_scores
|
|
507
578
|
|
|
508
579
|
|
|
509
580
|
class MeanIoUMetric(Metric):
|
|
@@ -523,7 +594,7 @@ class MeanIoUMetric(Metric):
|
|
|
523
594
|
num_classes: int,
|
|
524
595
|
nodata_value: int | None = None,
|
|
525
596
|
ignore_missing_classes: bool = False,
|
|
526
|
-
|
|
597
|
+
report_per_class: bool = False,
|
|
527
598
|
):
|
|
528
599
|
"""Create a new MeanIoUMetric.
|
|
529
600
|
|
|
@@ -535,15 +606,14 @@ class MeanIoUMetric(Metric):
|
|
|
535
606
|
ignore_missing_classes: whether to ignore classes that don't appear in
|
|
536
607
|
either the predictions or the ground truth. If false, the IoU for a
|
|
537
608
|
missing class will be 0.
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
only supports scalar return values from metrics.
|
|
609
|
+
report_per_class: whether to include per-class IoU scores in the output dict.
|
|
610
|
+
If False, only returns the "avg" key.
|
|
541
611
|
"""
|
|
542
612
|
super().__init__()
|
|
543
613
|
self.num_classes = num_classes
|
|
544
614
|
self.nodata_value = nodata_value
|
|
545
615
|
self.ignore_missing_classes = ignore_missing_classes
|
|
546
|
-
self.
|
|
616
|
+
self.report_per_class = report_per_class
|
|
547
617
|
|
|
548
618
|
self.add_state(
|
|
549
619
|
"intersections", default=torch.zeros(self.num_classes), dist_reduce_fx="sum"
|
|
@@ -584,9 +654,11 @@ class MeanIoUMetric(Metric):
|
|
|
584
654
|
"""Compute metric.
|
|
585
655
|
|
|
586
656
|
Returns:
|
|
587
|
-
the mean IoU across classes.
|
|
657
|
+
dict with "avg" containing the mean IoU across classes.
|
|
658
|
+
If report_per_class is True, also includes "cls_N" keys for each valid class N.
|
|
588
659
|
"""
|
|
589
|
-
|
|
660
|
+
cls_scores = {}
|
|
661
|
+
valid_scores = []
|
|
590
662
|
|
|
591
663
|
for cls_idx in range(self.num_classes):
|
|
592
664
|
# Check if nodata_value is set and is one of the classes
|
|
@@ -599,6 +671,56 @@ class MeanIoUMetric(Metric):
|
|
|
599
671
|
if union == 0 and self.ignore_missing_classes:
|
|
600
672
|
continue
|
|
601
673
|
|
|
602
|
-
|
|
674
|
+
score = intersection / union
|
|
675
|
+
cls_scores[f"cls_{cls_idx}"] = score
|
|
676
|
+
valid_scores.append(score)
|
|
677
|
+
|
|
678
|
+
report_scores = {"avg": torch.mean(torch.stack(valid_scores))}
|
|
679
|
+
if self.report_per_class:
|
|
680
|
+
report_scores.update(cls_scores)
|
|
681
|
+
return report_scores
|
|
682
|
+
|
|
683
|
+
|
|
684
|
+
class DiceLoss(torch.nn.Module):
|
|
685
|
+
"""Mean Dice Loss for segmentation.
|
|
686
|
+
|
|
687
|
+
This is the mean of the per-class dice loss (1 - 2*intersection / union scores).
|
|
688
|
+
The per-class intersection is the number of pixels across all examples where
|
|
689
|
+
the predicted label and ground truth label are both that class, and the per-class
|
|
690
|
+
union is defined similarly.
|
|
691
|
+
"""
|
|
692
|
+
|
|
693
|
+
def __init__(self, smooth: float = 1e-7):
|
|
694
|
+
"""Initialize a new DiceLoss."""
|
|
695
|
+
super().__init__()
|
|
696
|
+
self.smooth = smooth
|
|
697
|
+
|
|
698
|
+
def forward(
|
|
699
|
+
self, inputs: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
|
|
700
|
+
) -> torch.Tensor:
|
|
701
|
+
"""Compute Dice Loss.
|
|
702
|
+
|
|
703
|
+
Returns:
|
|
704
|
+
the mean Dicen Loss across classes
|
|
705
|
+
"""
|
|
706
|
+
num_classes = inputs.shape[1]
|
|
707
|
+
targets_one_hot = (
|
|
708
|
+
torch.nn.functional.one_hot(targets, num_classes)
|
|
709
|
+
.permute(0, 3, 1, 2)
|
|
710
|
+
.float()
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
# Expand mask to [B, C, H, W]
|
|
714
|
+
mask = mask.unsqueeze(1).expand_as(inputs)
|
|
715
|
+
|
|
716
|
+
dice_per_class = []
|
|
717
|
+
for c in range(num_classes):
|
|
718
|
+
pred_c = inputs[:, c] * mask[:, c]
|
|
719
|
+
target_c = targets_one_hot[:, c] * mask[:, c]
|
|
720
|
+
|
|
721
|
+
intersection = (pred_c * target_c).sum()
|
|
722
|
+
union = pred_c.sum() + target_c.sum()
|
|
723
|
+
dice_c = (2.0 * intersection + self.smooth) / (union + self.smooth)
|
|
724
|
+
dice_per_class.append(dice_c)
|
|
603
725
|
|
|
604
|
-
return torch.
|
|
726
|
+
return 1 - torch.stack(dice_per_class).mean()
|
rslearn/train/tasks/task.py
CHANGED
|
@@ -108,8 +108,10 @@ class BasicTask(Task):
|
|
|
108
108
|
Returns:
|
|
109
109
|
a dictionary mapping image name to visualization image
|
|
110
110
|
"""
|
|
111
|
-
|
|
112
|
-
|
|
111
|
+
raster_image = input_dict["image"]
|
|
112
|
+
assert isinstance(raster_image, RasterImage)
|
|
113
|
+
# We don't really handle time series here, just use the first timestep.
|
|
114
|
+
image = raster_image.image.cpu()[self.image_bands, 0, :, :]
|
|
113
115
|
if self.remap_values:
|
|
114
116
|
factor = (self.remap_values[1][1] - self.remap_values[1][0]) / (
|
|
115
117
|
self.remap_values[0][1] - self.remap_values[0][0]
|
rslearn/utils/geometry.py
CHANGED
|
@@ -153,8 +153,8 @@ class ResolutionFactor:
|
|
|
153
153
|
else:
|
|
154
154
|
return Projection(
|
|
155
155
|
projection.crs,
|
|
156
|
-
projection.x_resolution
|
|
157
|
-
projection.y_resolution
|
|
156
|
+
projection.x_resolution / self.numerator,
|
|
157
|
+
projection.y_resolution / self.numerator,
|
|
158
158
|
)
|
|
159
159
|
|
|
160
160
|
def multiply_bounds(self, bounds: PixelBounds) -> PixelBounds:
|
rslearn/utils/stac.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""STAC API client."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import requests
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
Bbox = tuple[float, float, float, float]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class StacAsset:
|
|
17
|
+
"""A STAC asset."""
|
|
18
|
+
|
|
19
|
+
href: str
|
|
20
|
+
title: str | None
|
|
21
|
+
type: str | None
|
|
22
|
+
roles: list[str] | None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True)
|
|
26
|
+
class StacItem:
|
|
27
|
+
"""A STAC item."""
|
|
28
|
+
|
|
29
|
+
id: str
|
|
30
|
+
properties: dict[str, Any]
|
|
31
|
+
collection: str | None
|
|
32
|
+
bbox: Bbox | None
|
|
33
|
+
geometry: dict[str, Any] | None
|
|
34
|
+
assets: dict[str, StacAsset] | None
|
|
35
|
+
time_range: tuple[datetime, datetime] | None
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def from_dict(cls, item: dict[str, Any]) -> "StacItem":
|
|
39
|
+
"""Create a STAC item from the item dict returned from API."""
|
|
40
|
+
properties = item.get("properties", {})
|
|
41
|
+
|
|
42
|
+
# Parse bbox.
|
|
43
|
+
bbox: Bbox | None = None
|
|
44
|
+
if "bbox" in item:
|
|
45
|
+
if len(item["bbox"]) != 4:
|
|
46
|
+
raise NotImplementedError(
|
|
47
|
+
f"got bbox with {len(item['bbox'])} coordinates but only 4 coordinates is implemented"
|
|
48
|
+
)
|
|
49
|
+
bbox = tuple(item["bbox"])
|
|
50
|
+
|
|
51
|
+
# Parse assets.
|
|
52
|
+
assets: dict[str, StacAsset] = {}
|
|
53
|
+
for name, asset in item.get("assets", {}).items():
|
|
54
|
+
assets[name] = StacAsset(
|
|
55
|
+
href=asset["href"],
|
|
56
|
+
title=asset.get("title"),
|
|
57
|
+
type=asset.get("type"),
|
|
58
|
+
roles=asset.get("roles"),
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# Parse time range.
|
|
62
|
+
time_range: tuple[datetime, datetime] | None = None
|
|
63
|
+
if "start_datetime" in properties and "end_datetime" in properties:
|
|
64
|
+
time_range = (
|
|
65
|
+
datetime.fromisoformat(properties["start_datetime"]),
|
|
66
|
+
datetime.fromisoformat(properties["end_datetime"]),
|
|
67
|
+
)
|
|
68
|
+
elif "datetime" in properties:
|
|
69
|
+
ts = datetime.fromisoformat(properties["datetime"])
|
|
70
|
+
time_range = (ts, ts)
|
|
71
|
+
|
|
72
|
+
return cls(
|
|
73
|
+
id=item["id"],
|
|
74
|
+
properties=properties,
|
|
75
|
+
collection=item.get("collection"),
|
|
76
|
+
bbox=bbox,
|
|
77
|
+
geometry=item.get("geometry"),
|
|
78
|
+
assets=assets,
|
|
79
|
+
time_range=time_range,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class StacClient:
|
|
84
|
+
"""Limited functionality client for STAC APIs."""
|
|
85
|
+
|
|
86
|
+
def __init__(self, endpoint: str):
|
|
87
|
+
"""Create a new StacClient.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
endpoint: the STAC endpoint (base URL)
|
|
91
|
+
"""
|
|
92
|
+
self.endpoint = endpoint
|
|
93
|
+
self.session = requests.Session()
|
|
94
|
+
|
|
95
|
+
def search(
|
|
96
|
+
self,
|
|
97
|
+
collections: list[str] | None = None,
|
|
98
|
+
bbox: Bbox | None = None,
|
|
99
|
+
intersects: dict[str, Any] | None = None,
|
|
100
|
+
date_time: datetime | tuple[datetime, datetime] | None = None,
|
|
101
|
+
ids: list[str] | None = None,
|
|
102
|
+
limit: int | None = None,
|
|
103
|
+
query: dict[str, Any] | None = None,
|
|
104
|
+
) -> list[StacItem]:
|
|
105
|
+
"""Execute a STAC item search.
|
|
106
|
+
|
|
107
|
+
We use the JSON POST API. Pagination is handled so the returned items are
|
|
108
|
+
concatenated across all available pages.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
collections: only search within the provided collection(s).
|
|
112
|
+
bbox: only return features intersecting the provided bounding box.
|
|
113
|
+
intersects: only return features intersecting this GeoJSON geometry.
|
|
114
|
+
date_time: only return features that have a temporal property intersecting
|
|
115
|
+
the provided time range or timestamp.
|
|
116
|
+
ids: only return the provided item IDs.
|
|
117
|
+
limit: number of items per page. We will read all the pages.
|
|
118
|
+
query: query dict, if STAC query extension is supported by this API. See
|
|
119
|
+
https://github.com/stac-api-extensions/query.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
list of matching STAC items.
|
|
123
|
+
"""
|
|
124
|
+
# Build JSON request data.
|
|
125
|
+
request_data: dict[str, Any] = {}
|
|
126
|
+
if collections is not None:
|
|
127
|
+
request_data["collections"] = collections
|
|
128
|
+
if bbox is not None:
|
|
129
|
+
request_data["bbox"] = bbox
|
|
130
|
+
if intersects is not None:
|
|
131
|
+
request_data["intersects"] = intersects
|
|
132
|
+
if date_time is not None:
|
|
133
|
+
if isinstance(date_time, tuple):
|
|
134
|
+
start_time = date_time[0].isoformat().replace("+00:00", "Z")
|
|
135
|
+
end_time = date_time[1].isoformat().replace("+00:00", "Z")
|
|
136
|
+
request_data["datetime"] = f"{start_time}/{end_time}"
|
|
137
|
+
else:
|
|
138
|
+
request_data["datetime"] = date_time.isoformat().replace("+00:00", "Z")
|
|
139
|
+
if ids is not None:
|
|
140
|
+
request_data["ids"] = ids
|
|
141
|
+
if limit is not None:
|
|
142
|
+
request_data["limit"] = limit
|
|
143
|
+
if query is not None:
|
|
144
|
+
request_data["query"] = query
|
|
145
|
+
|
|
146
|
+
# Handle pagination.
|
|
147
|
+
cur_url = self.endpoint + "/search"
|
|
148
|
+
items: list[StacItem] = []
|
|
149
|
+
while True:
|
|
150
|
+
logger.debug("Reading STAC items from %s", cur_url)
|
|
151
|
+
response = self.session.post(url=cur_url, json=request_data)
|
|
152
|
+
response.raise_for_status()
|
|
153
|
+
data = response.json()
|
|
154
|
+
for item_dict in data["features"]:
|
|
155
|
+
items.append(StacItem.from_dict(item_dict))
|
|
156
|
+
|
|
157
|
+
next_link = None
|
|
158
|
+
next_request_data: dict[str, Any] = {}
|
|
159
|
+
for link in data.get("links", []):
|
|
160
|
+
if "rel" not in link or link["rel"] != "next":
|
|
161
|
+
continue
|
|
162
|
+
assert link["method"] == "POST"
|
|
163
|
+
next_link = link["href"]
|
|
164
|
+
next_request_data = link["body"]
|
|
165
|
+
break
|
|
166
|
+
|
|
167
|
+
if next_link is None:
|
|
168
|
+
break
|
|
169
|
+
|
|
170
|
+
cur_url = next_link
|
|
171
|
+
request_data = next_request_data
|
|
172
|
+
|
|
173
|
+
return items
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rslearn
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.23
|
|
4
4
|
Summary: A library for developing remote sensing datasets and models
|
|
5
5
|
Author: OlmoEarth Team
|
|
6
6
|
License: Apache License
|
|
@@ -218,11 +218,13 @@ Requires-Dist: fsspec>=2025.10.0
|
|
|
218
218
|
Requires-Dist: jsonargparse>=4.35.0
|
|
219
219
|
Requires-Dist: lightning>=2.5.1.post0
|
|
220
220
|
Requires-Dist: Pillow>=11.3
|
|
221
|
+
Requires-Dist: pydantic>=2
|
|
221
222
|
Requires-Dist: pyproj>=3.7
|
|
222
223
|
Requires-Dist: python-dateutil>=2.9
|
|
223
224
|
Requires-Dist: pytimeparse>=1.1
|
|
224
225
|
Requires-Dist: rasterio>=1.4
|
|
225
226
|
Requires-Dist: shapely>=2.1
|
|
227
|
+
Requires-Dist: soilgrids>=0.1.4
|
|
226
228
|
Requires-Dist: torch>=2.7.0
|
|
227
229
|
Requires-Dist: torchvision>=0.22.0
|
|
228
230
|
Requires-Dist: tqdm>=4.67
|
|
@@ -317,6 +319,7 @@ for how to setup these data sources.
|
|
|
317
319
|
- Xyz (Slippy) Tiles (e.g., Mapbox tiles)
|
|
318
320
|
- Planet Labs (PlanetScope, SkySat)
|
|
319
321
|
- ESA WorldCover 2021
|
|
322
|
+
- ISRIC SoilGrids (WCS)
|
|
320
323
|
|
|
321
324
|
rslearn can also be used to easily mosaic, crop, and re-project any sets of local
|
|
322
325
|
raster and vector files you may have.
|