rslearn 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -210,11 +210,30 @@ class RslearnLightningModule(L.LightningModule):
210
210
  # Fail silently for single-dataset case, which is okay
211
211
  pass
212
212
 
213
+ def on_validation_epoch_end(self) -> None:
214
+ """Compute and log validation metrics at epoch end.
215
+
216
+ We manually compute and log metrics here (instead of passing the MetricCollection
217
+ to log_dict) because MetricCollection.compute() properly flattens dict-returning
218
+ metrics, while log_dict expects each metric to return a scalar tensor.
219
+ """
220
+ metrics = self.val_metrics.compute()
221
+ self.log_dict(metrics)
222
+ self.val_metrics.reset()
223
+
213
224
  def on_test_epoch_end(self) -> None:
214
- """Optionally save the test metrics to a file."""
225
+ """Compute and log test metrics at epoch end, optionally save to file.
226
+
227
+ We manually compute and log metrics here (instead of passing the MetricCollection
228
+ to log_dict) because MetricCollection.compute() properly flattens dict-returning
229
+ metrics, while log_dict expects each metric to return a scalar tensor.
230
+ """
231
+ metrics = self.test_metrics.compute()
232
+ self.log_dict(metrics)
233
+ self.test_metrics.reset()
234
+
215
235
  if self.metrics_file:
216
236
  with open(self.metrics_file, "w") as f:
217
- metrics = self.test_metrics.compute()
218
237
  metrics_dict = {k: v.item() for k, v in metrics.items()}
219
238
  json.dump(metrics_dict, f, indent=4)
220
239
  logger.info(f"Saved metrics to {self.metrics_file}")
@@ -300,9 +319,6 @@ class RslearnLightningModule(L.LightningModule):
300
319
  sync_dist=True,
301
320
  )
302
321
  self.val_metrics.update(outputs, targets)
303
- self.log_dict(
304
- self.val_metrics, batch_size=batch_size, on_epoch=True, sync_dist=True
305
- )
306
322
 
307
323
  def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
308
324
  """Compute the test loss and additional metrics.
@@ -340,19 +356,16 @@ class RslearnLightningModule(L.LightningModule):
340
356
  sync_dist=True,
341
357
  )
342
358
  self.test_metrics.update(outputs, targets)
343
- self.log_dict(
344
- self.test_metrics, batch_size=batch_size, on_epoch=True, sync_dist=True
345
- )
346
359
 
347
360
  if self.visualize_dir:
348
- for idx, (inp, target, output, metadata) in enumerate(
349
- zip(inputs, targets, outputs, metadatas)
361
+ for inp, target, output, metadata in zip(
362
+ inputs, targets, outputs, metadatas
350
363
  ):
351
364
  images = self.task.visualize(inp, target, output)
352
365
  for image_suffix, image in images.items():
353
366
  out_fname = os.path.join(
354
367
  self.visualize_dir,
355
- f"{metadata['window_name']}_{metadata['bounds'][0]}_{metadata['bounds'][1]}_{image_suffix}.png",
368
+ f"{metadata.window_name}_{metadata.patch_bounds[0]}_{metadata.patch_bounds[1]}_{image_suffix}.png",
356
369
  )
357
370
  Image.fromarray(image).save(out_fname)
358
371
 
@@ -6,7 +6,7 @@ import numpy.typing as npt
6
6
  import torch
7
7
  from torchmetrics import MetricCollection
8
8
 
9
- from rslearn.models.component import FeatureMaps
9
+ from rslearn.models.component import FeatureMaps, Predictor
10
10
  from rslearn.train.model_context import ModelContext, ModelOutput, SampleMetadata
11
11
  from rslearn.utils import Feature
12
12
 
@@ -83,7 +83,7 @@ class EmbeddingTask(Task):
83
83
  return MetricCollection({})
84
84
 
85
85
 
86
- class EmbeddingHead:
86
+ class EmbeddingHead(Predictor):
87
87
  """Head for embedding task.
88
88
 
89
89
  It just adds a dummy loss to act as a Predictor.
@@ -118,13 +118,16 @@ class MultiTask(Task):
118
118
 
119
119
  def get_metrics(self) -> MetricCollection:
120
120
  """Get metrics for this task."""
121
- metrics = []
121
+ # Flatten metrics into a single dict with task_name/ prefix to avoid nested
122
+ # MetricCollections. Nested collections cause issues because MetricCollection
123
+ # has postfix=None which breaks MetricCollection.compute().
124
+ all_metrics = {}
122
125
  for task_name, task in self.tasks.items():
123
- cur_metrics = {}
124
126
  for metric_name, metric in task.get_metrics().items():
125
- cur_metrics[metric_name] = MetricWrapper(task_name, metric)
126
- metrics.append(MetricCollection(cur_metrics, prefix=f"{task_name}/"))
127
- return MetricCollection(metrics)
127
+ all_metrics[f"{task_name}/{metric_name}"] = MetricWrapper(
128
+ task_name, metric
129
+ )
130
+ return MetricCollection(all_metrics)
128
131
 
129
132
 
130
133
  class MetricWrapper(Metric):
@@ -100,7 +100,7 @@ class PerPixelRegressionTask(BasicTask):
100
100
  raise ValueError(
101
101
  f"PerPixelRegressionTask output must be an HW tensor, but got shape {raw_output.shape}"
102
102
  )
103
- return (raw_output / self.scale_factor).cpu().numpy()
103
+ return (raw_output[None, :, :] / self.scale_factor).cpu().numpy()
104
104
 
105
105
  def visualize(
106
106
  self,
@@ -53,6 +53,7 @@ class SegmentationTask(BasicTask):
53
53
  enable_accuracy_metric: bool = True,
54
54
  enable_miou_metric: bool = False,
55
55
  enable_f1_metric: bool = False,
56
+ report_metric_per_class: bool = False,
56
57
  f1_metric_thresholds: list[list[float]] = [[0.5]],
57
58
  metric_kwargs: dict[str, Any] = {},
58
59
  miou_metric_kwargs: dict[str, Any] = {},
@@ -74,6 +75,8 @@ class SegmentationTask(BasicTask):
74
75
  enable_accuracy_metric: whether to enable the accuracy metric (default
75
76
  true).
76
77
  enable_f1_metric: whether to enable the F1 metric (default false).
78
+ report_metric_per_class: whether to report chosen metrics for each class, in
79
+ addition to the average score across classes.
77
80
  enable_miou_metric: whether to enable the mean IoU metric (default false).
78
81
  f1_metric_thresholds: list of list of thresholds to apply for F1 metric.
79
82
  Each inner list is used to initialize a separate F1 metric where the
@@ -107,6 +110,7 @@ class SegmentationTask(BasicTask):
107
110
  self.enable_accuracy_metric = enable_accuracy_metric
108
111
  self.enable_f1_metric = enable_f1_metric
109
112
  self.enable_miou_metric = enable_miou_metric
113
+ self.report_metric_per_class = report_metric_per_class
110
114
  self.f1_metric_thresholds = f1_metric_thresholds
111
115
  self.metric_kwargs = metric_kwargs
112
116
  self.miou_metric_kwargs = miou_metric_kwargs
@@ -237,29 +241,41 @@ class SegmentationTask(BasicTask):
237
241
  # Metric name can't contain "." so change to ",".
238
242
  suffix = "_" + str(thresholds[0]).replace(".", ",")
239
243
 
244
+ # Create one metric per type - it returns a dict with "avg" and optionally per-class keys
240
245
  metrics["F1" + suffix] = SegmentationMetric(
241
- F1Metric(num_classes=self.num_classes, score_thresholds=thresholds)
246
+ F1Metric(
247
+ num_classes=self.num_classes,
248
+ score_thresholds=thresholds,
249
+ report_per_class=self.report_metric_per_class,
250
+ ),
242
251
  )
243
252
  metrics["precision" + suffix] = SegmentationMetric(
244
253
  F1Metric(
245
254
  num_classes=self.num_classes,
246
255
  score_thresholds=thresholds,
247
256
  metric_mode="precision",
248
- )
257
+ report_per_class=self.report_metric_per_class,
258
+ ),
249
259
  )
250
260
  metrics["recall" + suffix] = SegmentationMetric(
251
261
  F1Metric(
252
262
  num_classes=self.num_classes,
253
263
  score_thresholds=thresholds,
254
264
  metric_mode="recall",
255
- )
265
+ report_per_class=self.report_metric_per_class,
266
+ ),
256
267
  )
257
268
 
258
269
  if self.enable_miou_metric:
259
- miou_metric_kwargs: dict[str, Any] = dict(num_classes=self.num_classes)
270
+ miou_metric_kwargs: dict[str, Any] = dict(
271
+ num_classes=self.num_classes,
272
+ report_per_class=self.report_metric_per_class,
273
+ )
260
274
  if self.nodata_value is not None:
261
275
  miou_metric_kwargs["nodata_value"] = self.nodata_value
262
276
  miou_metric_kwargs.update(self.miou_metric_kwargs)
277
+
278
+ # Create one metric - it returns a dict with "avg" and optionally per-class keys
263
279
  metrics["mean_iou"] = SegmentationMetric(
264
280
  MeanIoUMetric(**miou_metric_kwargs),
265
281
  pass_probabilities=False,
@@ -274,6 +290,20 @@ class SegmentationTask(BasicTask):
274
290
  class SegmentationHead(Predictor):
275
291
  """Head for segmentation task."""
276
292
 
293
+ def __init__(self, weights: list[float] | None = None, dice_loss: bool = False):
294
+ """Initialize a new SegmentationTask.
295
+
296
+ Args:
297
+ weights: weights for cross entropy loss (Tensor of size C)
298
+ dice_loss: weather to add dice loss to cross entropy
299
+ """
300
+ super().__init__()
301
+ if weights is not None:
302
+ self.register_buffer("weights", torch.Tensor(weights))
303
+ else:
304
+ self.weights = None
305
+ self.dice_loss = dice_loss
306
+
277
307
  def forward(
278
308
  self,
279
309
  intermediates: Any,
@@ -308,7 +338,7 @@ class SegmentationHead(Predictor):
308
338
  labels = torch.stack([target["classes"] for target in targets], dim=0)
309
339
  mask = torch.stack([target["valid"] for target in targets], dim=0)
310
340
  per_pixel_loss = torch.nn.functional.cross_entropy(
311
- logits, labels, reduction="none"
341
+ logits, labels, weight=self.weights, reduction="none"
312
342
  )
313
343
  mask_sum = torch.sum(mask)
314
344
  if mask_sum > 0:
@@ -318,6 +348,9 @@ class SegmentationHead(Predictor):
318
348
  # If there are no valid pixels, we avoid dividing by zero and just let
319
349
  # the summed mask loss be zero.
320
350
  losses["cls"] = torch.sum(per_pixel_loss * mask)
351
+ if self.dice_loss:
352
+ dice_loss = DiceLoss()(outputs, labels, mask)
353
+ losses["dice"] = dice_loss
321
354
 
322
355
  return ModelOutput(
323
356
  outputs=outputs,
@@ -333,6 +366,7 @@ class SegmentationMetric(Metric):
333
366
  metric: Metric,
334
367
  pass_probabilities: bool = True,
335
368
  class_idx: int | None = None,
369
+ output_key: str | None = None,
336
370
  ):
337
371
  """Initialize a new SegmentationMetric.
338
372
 
@@ -341,12 +375,19 @@ class SegmentationMetric(Metric):
341
375
  classes from the targets and masking out invalid pixels.
342
376
  pass_probabilities: whether to pass predicted probabilities to the metric.
343
377
  If False, argmax is applied to pass the predicted classes instead.
344
- class_idx: if metric returns value for multiple classes, select this class.
378
+ class_idx: if set, return only this class index's value. For backward
379
+ compatibility with configs using standard torchmetrics. Internally
380
+ converted to output_key="cls_{class_idx}".
381
+ output_key: if the wrapped metric returns a dict (or a tensor that gets
382
+ converted to a dict), return only this key's value. For standard
383
+ torchmetrics with average=None, tensors are converted to dicts with
384
+ keys "cls_0", "cls_1", etc. If None, the full dict is returned.
345
385
  """
346
386
  super().__init__()
347
387
  self.metric = metric
348
388
  self.pass_probablities = pass_probabilities
349
389
  self.class_idx = class_idx
390
+ self.output_key = output_key
350
391
 
351
392
  def update(
352
393
  self, preds: list[Any] | torch.Tensor, targets: list[dict[str, Any]]
@@ -376,10 +417,32 @@ class SegmentationMetric(Metric):
376
417
  self.metric.update(preds, labels)
377
418
 
378
419
  def compute(self) -> Any:
379
- """Returns the computed metric."""
420
+ """Returns the computed metric.
421
+
422
+ If the wrapped metric returns a multi-element tensor (e.g., standard torchmetrics
423
+ with average=None), it is converted to a dict with keys like "cls_0", "cls_1", etc.
424
+ This allows uniform handling via output_key for both standard torchmetrics and
425
+ custom dict-returning metrics.
426
+ """
380
427
  result = self.metric.compute()
428
+
429
+ # Convert multi-element tensors to dict for uniform handling.
430
+ # This supports standard torchmetrics with average=None which return per-class tensors.
431
+ if isinstance(result, torch.Tensor) and result.ndim >= 1:
432
+ result = {f"cls_{i}": result[i] for i in range(len(result))}
433
+
434
+ if self.output_key is not None:
435
+ if not isinstance(result, dict):
436
+ raise TypeError(
437
+ f"output_key is set to '{self.output_key}' but metric returned "
438
+ f"{type(result).__name__} instead of dict"
439
+ )
440
+ return result[self.output_key]
381
441
  if self.class_idx is not None:
382
- result = result[self.class_idx]
442
+ # For backward compatibility: class_idx can index into the converted dict
443
+ if isinstance(result, dict):
444
+ return result[f"cls_{self.class_idx}"]
445
+ return result[self.class_idx]
383
446
  return result
384
447
 
385
448
  def reset(self) -> None:
@@ -404,6 +467,7 @@ class F1Metric(Metric):
404
467
  num_classes: int,
405
468
  score_thresholds: list[float],
406
469
  metric_mode: str = "f1",
470
+ report_per_class: bool = False,
407
471
  ):
408
472
  """Create a new F1Metric.
409
473
 
@@ -413,11 +477,14 @@ class F1Metric(Metric):
413
477
  metric is the best F1 across score thresholds.
414
478
  metric_mode: set to "precision" or "recall" to return that instead of F1
415
479
  (default "f1")
480
+ report_per_class: whether to include per-class scores in the output dict.
481
+ If False, only returns the "avg" key.
416
482
  """
417
483
  super().__init__()
418
484
  self.num_classes = num_classes
419
485
  self.score_thresholds = score_thresholds
420
486
  self.metric_mode = metric_mode
487
+ self.report_per_class = report_per_class
421
488
 
422
489
  assert self.metric_mode in ["f1", "precision", "recall"]
423
490
 
@@ -462,9 +529,10 @@ class F1Metric(Metric):
462
529
  """Compute metric.
463
530
 
464
531
  Returns:
465
- the best F1 score across score thresholds and classes.
532
+ dict with "avg" key containing mean score across classes.
533
+ If report_per_class is True, also includes "cls_N" keys for each class N.
466
534
  """
467
- best_scores = []
535
+ cls_best_scores = {}
468
536
 
469
537
  for cls_idx in range(self.num_classes):
470
538
  best_score = None
@@ -501,9 +569,12 @@ class F1Metric(Metric):
501
569
  if best_score is None or score > best_score:
502
570
  best_score = score
503
571
 
504
- best_scores.append(best_score)
572
+ cls_best_scores[f"cls_{cls_idx}"] = best_score
505
573
 
506
- return torch.mean(torch.stack(best_scores))
574
+ report_scores = {"avg": torch.mean(torch.stack(list(cls_best_scores.values())))}
575
+ if self.report_per_class:
576
+ report_scores.update(cls_best_scores)
577
+ return report_scores
507
578
 
508
579
 
509
580
  class MeanIoUMetric(Metric):
@@ -523,7 +594,7 @@ class MeanIoUMetric(Metric):
523
594
  num_classes: int,
524
595
  nodata_value: int | None = None,
525
596
  ignore_missing_classes: bool = False,
526
- class_idx: int | None = None,
597
+ report_per_class: bool = False,
527
598
  ):
528
599
  """Create a new MeanIoUMetric.
529
600
 
@@ -535,15 +606,14 @@ class MeanIoUMetric(Metric):
535
606
  ignore_missing_classes: whether to ignore classes that don't appear in
536
607
  either the predictions or the ground truth. If false, the IoU for a
537
608
  missing class will be 0.
538
- class_idx: only compute and return the IoU for this class. This option is
539
- provided so the user can get per-class IoU results, since Lightning
540
- only supports scalar return values from metrics.
609
+ report_per_class: whether to include per-class IoU scores in the output dict.
610
+ If False, only returns the "avg" key.
541
611
  """
542
612
  super().__init__()
543
613
  self.num_classes = num_classes
544
614
  self.nodata_value = nodata_value
545
615
  self.ignore_missing_classes = ignore_missing_classes
546
- self.class_idx = class_idx
616
+ self.report_per_class = report_per_class
547
617
 
548
618
  self.add_state(
549
619
  "intersections", default=torch.zeros(self.num_classes), dist_reduce_fx="sum"
@@ -584,9 +654,11 @@ class MeanIoUMetric(Metric):
584
654
  """Compute metric.
585
655
 
586
656
  Returns:
587
- the mean IoU across classes.
657
+ dict with "avg" containing the mean IoU across classes.
658
+ If report_per_class is True, also includes "cls_N" keys for each valid class N.
588
659
  """
589
- per_class_scores = []
660
+ cls_scores = {}
661
+ valid_scores = []
590
662
 
591
663
  for cls_idx in range(self.num_classes):
592
664
  # Check if nodata_value is set and is one of the classes
@@ -599,6 +671,56 @@ class MeanIoUMetric(Metric):
599
671
  if union == 0 and self.ignore_missing_classes:
600
672
  continue
601
673
 
602
- per_class_scores.append(intersection / union)
674
+ score = intersection / union
675
+ cls_scores[f"cls_{cls_idx}"] = score
676
+ valid_scores.append(score)
677
+
678
+ report_scores = {"avg": torch.mean(torch.stack(valid_scores))}
679
+ if self.report_per_class:
680
+ report_scores.update(cls_scores)
681
+ return report_scores
682
+
683
+
684
+ class DiceLoss(torch.nn.Module):
685
+ """Mean Dice Loss for segmentation.
686
+
687
+ This is the mean of the per-class dice loss (1 - 2*intersection / union scores).
688
+ The per-class intersection is the number of pixels across all examples where
689
+ the predicted label and ground truth label are both that class, and the per-class
690
+ union is defined similarly.
691
+ """
692
+
693
+ def __init__(self, smooth: float = 1e-7):
694
+ """Initialize a new DiceLoss."""
695
+ super().__init__()
696
+ self.smooth = smooth
697
+
698
+ def forward(
699
+ self, inputs: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
700
+ ) -> torch.Tensor:
701
+ """Compute Dice Loss.
702
+
703
+ Returns:
704
+ the mean Dicen Loss across classes
705
+ """
706
+ num_classes = inputs.shape[1]
707
+ targets_one_hot = (
708
+ torch.nn.functional.one_hot(targets, num_classes)
709
+ .permute(0, 3, 1, 2)
710
+ .float()
711
+ )
712
+
713
+ # Expand mask to [B, C, H, W]
714
+ mask = mask.unsqueeze(1).expand_as(inputs)
715
+
716
+ dice_per_class = []
717
+ for c in range(num_classes):
718
+ pred_c = inputs[:, c] * mask[:, c]
719
+ target_c = targets_one_hot[:, c] * mask[:, c]
720
+
721
+ intersection = (pred_c * target_c).sum()
722
+ union = pred_c.sum() + target_c.sum()
723
+ dice_c = (2.0 * intersection + self.smooth) / (union + self.smooth)
724
+ dice_per_class.append(dice_c)
603
725
 
604
- return torch.mean(torch.stack(per_class_scores))
726
+ return 1 - torch.stack(dice_per_class).mean()
@@ -108,8 +108,10 @@ class BasicTask(Task):
108
108
  Returns:
109
109
  a dictionary mapping image name to visualization image
110
110
  """
111
- image = input_dict["image"].cpu()
112
- image = image[self.image_bands, :, :]
111
+ raster_image = input_dict["image"]
112
+ assert isinstance(raster_image, RasterImage)
113
+ # We don't really handle time series here, just use the first timestep.
114
+ image = raster_image.image.cpu()[self.image_bands, 0, :, :]
113
115
  if self.remap_values:
114
116
  factor = (self.remap_values[1][1] - self.remap_values[1][0]) / (
115
117
  self.remap_values[0][1] - self.remap_values[0][0]
rslearn/utils/geometry.py CHANGED
@@ -153,8 +153,8 @@ class ResolutionFactor:
153
153
  else:
154
154
  return Projection(
155
155
  projection.crs,
156
- projection.x_resolution // self.numerator,
157
- projection.y_resolution // self.numerator,
156
+ projection.x_resolution / self.numerator,
157
+ projection.y_resolution / self.numerator,
158
158
  )
159
159
 
160
160
  def multiply_bounds(self, bounds: PixelBounds) -> PixelBounds:
rslearn/utils/stac.py ADDED
@@ -0,0 +1,173 @@
1
+ """STAC API client."""
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from datetime import datetime
6
+ from typing import Any
7
+
8
+ import requests
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ Bbox = tuple[float, float, float, float]
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class StacAsset:
17
+ """A STAC asset."""
18
+
19
+ href: str
20
+ title: str | None
21
+ type: str | None
22
+ roles: list[str] | None
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class StacItem:
27
+ """A STAC item."""
28
+
29
+ id: str
30
+ properties: dict[str, Any]
31
+ collection: str | None
32
+ bbox: Bbox | None
33
+ geometry: dict[str, Any] | None
34
+ assets: dict[str, StacAsset] | None
35
+ time_range: tuple[datetime, datetime] | None
36
+
37
+ @classmethod
38
+ def from_dict(cls, item: dict[str, Any]) -> "StacItem":
39
+ """Create a STAC item from the item dict returned from API."""
40
+ properties = item.get("properties", {})
41
+
42
+ # Parse bbox.
43
+ bbox: Bbox | None = None
44
+ if "bbox" in item:
45
+ if len(item["bbox"]) != 4:
46
+ raise NotImplementedError(
47
+ f"got bbox with {len(item['bbox'])} coordinates but only 4 coordinates is implemented"
48
+ )
49
+ bbox = tuple(item["bbox"])
50
+
51
+ # Parse assets.
52
+ assets: dict[str, StacAsset] = {}
53
+ for name, asset in item.get("assets", {}).items():
54
+ assets[name] = StacAsset(
55
+ href=asset["href"],
56
+ title=asset.get("title"),
57
+ type=asset.get("type"),
58
+ roles=asset.get("roles"),
59
+ )
60
+
61
+ # Parse time range.
62
+ time_range: tuple[datetime, datetime] | None = None
63
+ if "start_datetime" in properties and "end_datetime" in properties:
64
+ time_range = (
65
+ datetime.fromisoformat(properties["start_datetime"]),
66
+ datetime.fromisoformat(properties["end_datetime"]),
67
+ )
68
+ elif "datetime" in properties:
69
+ ts = datetime.fromisoformat(properties["datetime"])
70
+ time_range = (ts, ts)
71
+
72
+ return cls(
73
+ id=item["id"],
74
+ properties=properties,
75
+ collection=item.get("collection"),
76
+ bbox=bbox,
77
+ geometry=item.get("geometry"),
78
+ assets=assets,
79
+ time_range=time_range,
80
+ )
81
+
82
+
83
+ class StacClient:
84
+ """Limited functionality client for STAC APIs."""
85
+
86
+ def __init__(self, endpoint: str):
87
+ """Create a new StacClient.
88
+
89
+ Args:
90
+ endpoint: the STAC endpoint (base URL)
91
+ """
92
+ self.endpoint = endpoint
93
+ self.session = requests.Session()
94
+
95
+ def search(
96
+ self,
97
+ collections: list[str] | None = None,
98
+ bbox: Bbox | None = None,
99
+ intersects: dict[str, Any] | None = None,
100
+ date_time: datetime | tuple[datetime, datetime] | None = None,
101
+ ids: list[str] | None = None,
102
+ limit: int | None = None,
103
+ query: dict[str, Any] | None = None,
104
+ ) -> list[StacItem]:
105
+ """Execute a STAC item search.
106
+
107
+ We use the JSON POST API. Pagination is handled so the returned items are
108
+ concatenated across all available pages.
109
+
110
+ Args:
111
+ collections: only search within the provided collection(s).
112
+ bbox: only return features intersecting the provided bounding box.
113
+ intersects: only return features intersecting this GeoJSON geometry.
114
+ date_time: only return features that have a temporal property intersecting
115
+ the provided time range or timestamp.
116
+ ids: only return the provided item IDs.
117
+ limit: number of items per page. We will read all the pages.
118
+ query: query dict, if STAC query extension is supported by this API. See
119
+ https://github.com/stac-api-extensions/query.
120
+
121
+ Returns:
122
+ list of matching STAC items.
123
+ """
124
+ # Build JSON request data.
125
+ request_data: dict[str, Any] = {}
126
+ if collections is not None:
127
+ request_data["collections"] = collections
128
+ if bbox is not None:
129
+ request_data["bbox"] = bbox
130
+ if intersects is not None:
131
+ request_data["intersects"] = intersects
132
+ if date_time is not None:
133
+ if isinstance(date_time, tuple):
134
+ start_time = date_time[0].isoformat().replace("+00:00", "Z")
135
+ end_time = date_time[1].isoformat().replace("+00:00", "Z")
136
+ request_data["datetime"] = f"{start_time}/{end_time}"
137
+ else:
138
+ request_data["datetime"] = date_time.isoformat().replace("+00:00", "Z")
139
+ if ids is not None:
140
+ request_data["ids"] = ids
141
+ if limit is not None:
142
+ request_data["limit"] = limit
143
+ if query is not None:
144
+ request_data["query"] = query
145
+
146
+ # Handle pagination.
147
+ cur_url = self.endpoint + "/search"
148
+ items: list[StacItem] = []
149
+ while True:
150
+ logger.debug("Reading STAC items from %s", cur_url)
151
+ response = self.session.post(url=cur_url, json=request_data)
152
+ response.raise_for_status()
153
+ data = response.json()
154
+ for item_dict in data["features"]:
155
+ items.append(StacItem.from_dict(item_dict))
156
+
157
+ next_link = None
158
+ next_request_data: dict[str, Any] = {}
159
+ for link in data.get("links", []):
160
+ if "rel" not in link or link["rel"] != "next":
161
+ continue
162
+ assert link["method"] == "POST"
163
+ next_link = link["href"]
164
+ next_request_data = link["body"]
165
+ break
166
+
167
+ if next_link is None:
168
+ break
169
+
170
+ cur_url = next_link
171
+ request_data = next_request_data
172
+
173
+ return items
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.21
3
+ Version: 0.0.23
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -218,11 +218,13 @@ Requires-Dist: fsspec>=2025.10.0
218
218
  Requires-Dist: jsonargparse>=4.35.0
219
219
  Requires-Dist: lightning>=2.5.1.post0
220
220
  Requires-Dist: Pillow>=11.3
221
+ Requires-Dist: pydantic>=2
221
222
  Requires-Dist: pyproj>=3.7
222
223
  Requires-Dist: python-dateutil>=2.9
223
224
  Requires-Dist: pytimeparse>=1.1
224
225
  Requires-Dist: rasterio>=1.4
225
226
  Requires-Dist: shapely>=2.1
227
+ Requires-Dist: soilgrids>=0.1.4
226
228
  Requires-Dist: torch>=2.7.0
227
229
  Requires-Dist: torchvision>=0.22.0
228
230
  Requires-Dist: tqdm>=4.67
@@ -317,6 +319,7 @@ for how to setup these data sources.
317
319
  - Xyz (Slippy) Tiles (e.g., Mapbox tiles)
318
320
  - Planet Labs (PlanetScope, SkySat)
319
321
  - ESA WorldCover 2021
322
+ - ISRIC SoilGrids (WCS)
320
323
 
321
324
  rslearn can also be used to easily mosaic, crop, and re-project any sets of local
322
325
  raster and vector files you may have.