rslearn 0.0.26__py3-none-any.whl → 0.0.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. rslearn/data_sources/__init__.py +2 -0
  2. rslearn/data_sources/aws_landsat.py +44 -161
  3. rslearn/data_sources/aws_open_data.py +2 -4
  4. rslearn/data_sources/aws_sentinel1.py +1 -3
  5. rslearn/data_sources/aws_sentinel2_element84.py +54 -165
  6. rslearn/data_sources/climate_data_store.py +1 -3
  7. rslearn/data_sources/copernicus.py +1 -2
  8. rslearn/data_sources/data_source.py +1 -1
  9. rslearn/data_sources/direct_materialize_data_source.py +336 -0
  10. rslearn/data_sources/earthdaily.py +52 -155
  11. rslearn/data_sources/earthdatahub.py +425 -0
  12. rslearn/data_sources/eurocrops.py +1 -2
  13. rslearn/data_sources/gcp_public_data.py +1 -2
  14. rslearn/data_sources/google_earth_engine.py +1 -2
  15. rslearn/data_sources/hf_srtm.py +595 -0
  16. rslearn/data_sources/local_files.py +1 -1
  17. rslearn/data_sources/openstreetmap.py +1 -1
  18. rslearn/data_sources/planet.py +1 -2
  19. rslearn/data_sources/planet_basemap.py +1 -2
  20. rslearn/data_sources/planetary_computer.py +183 -186
  21. rslearn/data_sources/soilgrids.py +3 -3
  22. rslearn/data_sources/stac.py +1 -2
  23. rslearn/data_sources/usda_cdl.py +1 -3
  24. rslearn/data_sources/usgs_landsat.py +7 -254
  25. rslearn/data_sources/worldcereal.py +1 -1
  26. rslearn/data_sources/worldcover.py +1 -1
  27. rslearn/data_sources/worldpop.py +1 -1
  28. rslearn/data_sources/xyz_tiles.py +5 -9
  29. rslearn/dataset/storage/file.py +16 -12
  30. rslearn/models/concatenate_features.py +6 -1
  31. rslearn/tile_stores/default.py +4 -2
  32. rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
  33. rslearn/train/data_module.py +36 -33
  34. rslearn/train/dataset.py +159 -68
  35. rslearn/train/lightning_module.py +60 -4
  36. rslearn/train/metrics.py +162 -0
  37. rslearn/train/model_context.py +3 -3
  38. rslearn/train/prediction_writer.py +69 -41
  39. rslearn/train/tasks/classification.py +14 -1
  40. rslearn/train/tasks/detection.py +5 -5
  41. rslearn/train/tasks/per_pixel_regression.py +19 -6
  42. rslearn/train/tasks/regression.py +19 -3
  43. rslearn/train/tasks/segmentation.py +17 -0
  44. rslearn/utils/__init__.py +2 -0
  45. rslearn/utils/fsspec.py +51 -1
  46. rslearn/utils/geometry.py +21 -0
  47. rslearn/utils/m2m_api.py +251 -0
  48. rslearn/utils/retry_session.py +43 -0
  49. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/METADATA +6 -3
  50. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/RECORD +55 -50
  51. rslearn/data_sources/earthdata_srtm.py +0 -282
  52. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/WHEEL +0 -0
  53. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/entry_points.txt +0 -0
  54. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/licenses/LICENSE +0 -0
  55. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/licenses/NOTICE +0 -0
  56. {rslearn-0.0.26.dist-info → rslearn-0.0.28.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  """rslearn PredictionWriter implementation."""
2
2
 
3
3
  import json
4
+ import warnings
4
5
  from collections.abc import Iterable, Sequence
5
6
  from dataclasses import dataclass
6
7
  from pathlib import Path
@@ -39,20 +40,20 @@ logger = get_logger(__name__)
39
40
 
40
41
 
41
42
  @dataclass
42
- class PendingPatchOutput:
43
- """A patch output that hasn't been merged yet."""
43
+ class PendingCropOutput:
44
+ """A crop output that hasn't been merged yet."""
44
45
 
45
46
  bounds: PixelBounds
46
47
  output: Any
47
48
 
48
49
 
49
- class PatchPredictionMerger:
50
- """Base class for merging predictions from multiple patches."""
50
+ class CropPredictionMerger:
51
+ """Base class for merging predictions from multiple crops."""
51
52
 
52
53
  def merge(
53
54
  self,
54
55
  window: Window,
55
- outputs: Sequence[PendingPatchOutput],
56
+ outputs: Sequence[PendingCropOutput],
56
57
  layer_config: LayerConfig,
57
58
  ) -> Any:
58
59
  """Merge the outputs.
@@ -68,39 +69,60 @@ class PatchPredictionMerger:
68
69
  raise NotImplementedError
69
70
 
70
71
 
71
- class VectorMerger(PatchPredictionMerger):
72
+ class VectorMerger(CropPredictionMerger):
72
73
  """Merger for vector data that simply concatenates the features."""
73
74
 
74
75
  def merge(
75
76
  self,
76
77
  window: Window,
77
- outputs: Sequence[PendingPatchOutput],
78
+ outputs: Sequence[PendingCropOutput],
78
79
  layer_config: LayerConfig,
79
80
  ) -> list[Feature]:
80
81
  """Concatenate the vector features."""
81
82
  return [feat for output in outputs for feat in output.output]
82
83
 
83
84
 
84
- class RasterMerger(PatchPredictionMerger):
85
+ class RasterMerger(CropPredictionMerger):
85
86
  """Merger for raster data that copies the rasters to the output."""
86
87
 
87
- def __init__(self, padding: int | None = None, downsample_factor: int = 1):
88
+ def __init__(
89
+ self,
90
+ overlap_pixels: int | None = None,
91
+ downsample_factor: int = 1,
92
+ # Deprecated parameter (for backwards compatibility)
93
+ padding: int | None = None,
94
+ ):
88
95
  """Create a new RasterMerger.
89
96
 
90
97
  Args:
91
- padding: the padding around the individual patch outputs to remove. This is
92
- typically used when leveraging overlapping patches. Portions of outputs
93
- at the border of the window will still be retained.
98
+ overlap_pixels: the number of pixels shared between adjacent crops during
99
+ sliding window inference. Half of this overlap is removed from each
100
+ crop during merging (except at window boundaries where the full crop
101
+ is retained).
94
102
  downsample_factor: the factor by which the rasters output by the task are
95
103
  lower in resolution relative to the window resolution.
104
+ padding: deprecated, use overlap_pixels instead. The old padding value
105
+ equals overlap_pixels // 2.
96
106
  """
97
- self.padding = padding
107
+ # Handle deprecated padding parameter
108
+ if padding is not None:
109
+ warnings.warn(
110
+ "padding is deprecated, use overlap_pixels instead. "
111
+ "Note: overlap_pixels = padding * 2",
112
+ FutureWarning,
113
+ stacklevel=2,
114
+ )
115
+ if overlap_pixels is not None:
116
+ raise ValueError("Cannot specify both padding and overlap_pixels")
117
+ overlap_pixels = padding * 2
118
+
119
+ self.overlap_pixels = overlap_pixels
98
120
  self.downsample_factor = downsample_factor
99
121
 
100
122
  def merge(
101
123
  self,
102
124
  window: Window,
103
- outputs: Sequence[PendingPatchOutput],
125
+ outputs: Sequence[PendingCropOutput],
104
126
  layer_config: LayerConfig,
105
127
  ) -> npt.NDArray:
106
128
  """Merge the raster outputs."""
@@ -114,6 +136,12 @@ class RasterMerger(PatchPredictionMerger):
114
136
  dtype=layer_config.band_sets[0].dtype.get_numpy_dtype(),
115
137
  )
116
138
 
139
+ # Compute how many pixels to trim from each side.
140
+ # We remove half of the overlap from each side (not at window boundaries).
141
+ trim_pixels = (
142
+ self.overlap_pixels // 2 if self.overlap_pixels is not None else None
143
+ )
144
+
117
145
  # Ensure the outputs are sorted by height then width.
118
146
  # This way when we merge we can be sure that outputs that are lower or further
119
147
  # to the right will overwrite earlier outputs.
@@ -123,18 +151,18 @@ class RasterMerger(PatchPredictionMerger):
123
151
  for output in sorted_outputs:
124
152
  # So now we just need to compute the src_offset to copy.
125
153
  # If the output is not on the left or top boundary, then we should apply
126
- # the padding (if set).
154
+ # the trim (if set).
127
155
  src = output.output
128
156
  src_offset = (
129
157
  output.bounds[0] // self.downsample_factor,
130
158
  output.bounds[1] // self.downsample_factor,
131
159
  )
132
- if self.padding is not None and output.bounds[0] != window.bounds[0]:
133
- src = src[:, :, self.padding :]
134
- src_offset = (src_offset[0] + self.padding, src_offset[1])
135
- if self.padding is not None and output.bounds[1] != window.bounds[1]:
136
- src = src[:, self.padding :, :]
137
- src_offset = (src_offset[0], src_offset[1] + self.padding)
160
+ if trim_pixels is not None and output.bounds[0] != window.bounds[0]:
161
+ src = src[:, :, trim_pixels:]
162
+ src_offset = (src_offset[0] + trim_pixels, src_offset[1])
163
+ if trim_pixels is not None and output.bounds[1] != window.bounds[1]:
164
+ src = src[:, trim_pixels:, :]
165
+ src_offset = (src_offset[0], src_offset[1] + trim_pixels)
138
166
 
139
167
  copy_spatial_array(
140
168
  src=src,
@@ -162,7 +190,7 @@ class RslearnWriter(BasePredictionWriter):
162
190
  output_layer: str,
163
191
  path_options: dict[str, Any] | None = None,
164
192
  selector: list[str] | None = None,
165
- merger: PatchPredictionMerger | None = None,
193
+ merger: CropPredictionMerger | None = None,
166
194
  output_path: str | Path | None = None,
167
195
  layer_config: LayerConfig | None = None,
168
196
  storage_config: StorageConfig | None = None,
@@ -175,7 +203,7 @@ class RslearnWriter(BasePredictionWriter):
175
203
  path_options: additional options for path to pass to fsspec
176
204
  selector: keys to access the desired output in the output dict if needed.
177
205
  e.g ["key1", "key2"] gets output["key1"]["key2"]
178
- merger: merger to use to merge outputs from overlapped patches.
206
+ merger: merger to use to merge outputs from overlapped crops.
179
207
  output_path: optional custom path for writing predictions. If provided,
180
208
  predictions will be written to this path instead of deriving from dataset path.
181
209
  layer_config: optional layer configuration. If provided, this config will be
@@ -217,9 +245,9 @@ class RslearnWriter(BasePredictionWriter):
217
245
  self.merger = VectorMerger()
218
246
 
219
247
  # Map from window name to pending data to write.
220
- # This is used when windows are split up into patches, so the data from all the
221
- # patches of each window need to be reconstituted.
222
- self.pending_outputs: dict[str, list[PendingPatchOutput]] = {}
248
+ # This is used when windows are split up into crops, so the data from all the
249
+ # crops of each window need to be reconstituted.
250
+ self.pending_outputs: dict[str, list[PendingCropOutput]] = {}
223
251
 
224
252
  def _get_layer_config_and_dataset_storage(
225
253
  self,
@@ -327,7 +355,7 @@ class RslearnWriter(BasePredictionWriter):
327
355
  will be processed by the task to obtain a vector (list[Feature]) or
328
356
  raster (npt.NDArray) output.
329
357
  metadatas: corresponding list of metadatas from the batch describing the
330
- patches that were processed.
358
+ crops that were processed.
331
359
  """
332
360
  # Process the predictions into outputs that can be written.
333
361
  outputs: list = [
@@ -349,17 +377,17 @@ class RslearnWriter(BasePredictionWriter):
349
377
  )
350
378
  self.process_output(
351
379
  window,
352
- metadata.patch_idx,
353
- metadata.num_patches_in_window,
354
- metadata.patch_bounds,
380
+ metadata.crop_idx,
381
+ metadata.num_crops_in_window,
382
+ metadata.crop_bounds,
355
383
  output,
356
384
  )
357
385
 
358
386
  def process_output(
359
387
  self,
360
388
  window: Window,
361
- patch_idx: int,
362
- num_patches: int,
389
+ crop_idx: int,
390
+ num_crops: int,
363
391
  cur_bounds: PixelBounds,
364
392
  output: npt.NDArray | list[Feature],
365
393
  ) -> None:
@@ -367,28 +395,28 @@ class RslearnWriter(BasePredictionWriter):
367
395
 
368
396
  Args:
369
397
  window: the window that the output pertains to.
370
- patch_idx: the index of this patch for the window.
371
- num_patches: the total number of patches to be processed for the window.
372
- cur_bounds: the bounds of the current patch.
398
+ crop_idx: the index of this crop for the window.
399
+ num_crops: the total number of crops to be processed for the window.
400
+ cur_bounds: the bounds of the current crop.
373
401
  output: the output data.
374
402
  """
375
- # Incorporate the output into our list of pending patch outputs.
403
+ # Incorporate the output into our list of pending crop outputs.
376
404
  if window.name not in self.pending_outputs:
377
405
  self.pending_outputs[window.name] = []
378
- self.pending_outputs[window.name].append(PendingPatchOutput(cur_bounds, output))
406
+ self.pending_outputs[window.name].append(PendingCropOutput(cur_bounds, output))
379
407
  logger.debug(
380
- f"Stored PendingPatchOutput for patch #{patch_idx}/{num_patches} at window {window.name}"
408
+ f"Stored PendingCropOutput for crop #{crop_idx}/{num_crops} at window {window.name}"
381
409
  )
382
410
 
383
- if patch_idx < num_patches - 1:
411
+ if crop_idx < num_crops - 1:
384
412
  return
385
413
 
386
- # This is the last patch so it's time to write it.
414
+ # This is the last crop so it's time to write it.
387
415
  # First get the pending output and clear it.
388
416
  pending_output = self.pending_outputs[window.name]
389
417
  del self.pending_outputs[window.name]
390
418
 
391
- # Merge outputs from overlapped patches if merger is set.
419
+ # Merge outputs from overlapped crops if merger is set.
392
420
  logger.debug(f"Merging and writing for window {window.name}")
393
421
  merged_output = self.merger.merge(window, pending_output, self.layer_config)
394
422
 
@@ -16,6 +16,7 @@ from torchmetrics.classification import (
16
16
  )
17
17
 
18
18
  from rslearn.models.component import FeatureVector, Predictor
19
+ from rslearn.train.metrics import ConfusionMatrixMetric
19
20
  from rslearn.train.model_context import (
20
21
  ModelContext,
21
22
  ModelOutput,
@@ -44,6 +45,7 @@ class ClassificationTask(BasicTask):
44
45
  f1_metric_kwargs: dict[str, Any] = {},
45
46
  positive_class: str | None = None,
46
47
  positive_class_threshold: float = 0.5,
48
+ enable_confusion_matrix: bool = False,
47
49
  **kwargs: Any,
48
50
  ):
49
51
  """Initialize a new ClassificationTask.
@@ -69,6 +71,8 @@ class ClassificationTask(BasicTask):
69
71
  positive_class: positive class name.
70
72
  positive_class_threshold: threshold for classifying the positive class in
71
73
  binary classification (default 0.5).
74
+ enable_confusion_matrix: whether to compute confusion matrix (default false).
75
+ If true, it requires wandb to be initialized for logging.
72
76
  kwargs: other arguments to pass to BasicTask
73
77
  """
74
78
  super().__init__(**kwargs)
@@ -84,6 +88,7 @@ class ClassificationTask(BasicTask):
84
88
  self.f1_metric_kwargs = f1_metric_kwargs
85
89
  self.positive_class = positive_class
86
90
  self.positive_class_threshold = positive_class_threshold
91
+ self.enable_confusion_matrix = enable_confusion_matrix
87
92
 
88
93
  if self.positive_class_threshold != 0.5:
89
94
  # Must be binary classification
@@ -201,7 +206,7 @@ class ClassificationTask(BasicTask):
201
206
  feature = Feature(
202
207
  STGeometry(
203
208
  metadata.projection,
204
- shapely.Point(metadata.patch_bounds[0], metadata.patch_bounds[1]),
209
+ shapely.Point(metadata.crop_bounds[0], metadata.crop_bounds[1]),
205
210
  None,
206
211
  ),
207
212
  {
@@ -278,6 +283,14 @@ class ClassificationTask(BasicTask):
278
283
  )
279
284
  metrics["f1"] = ClassificationMetric(MulticlassF1Score(**kwargs))
280
285
 
286
+ if self.enable_confusion_matrix:
287
+ metrics["confusion_matrix"] = ClassificationMetric(
288
+ ConfusionMatrixMetric(
289
+ num_classes=len(self.classes),
290
+ class_names=self.classes,
291
+ ),
292
+ )
293
+
281
294
  return MetricCollection(metrics)
282
295
 
283
296
 
@@ -128,7 +128,7 @@ class DetectionTask(BasicTask):
128
128
  if not load_targets:
129
129
  return {}, {}
130
130
 
131
- bounds = metadata.patch_bounds
131
+ bounds = metadata.crop_bounds
132
132
 
133
133
  boxes = []
134
134
  class_labels = []
@@ -244,10 +244,10 @@ class DetectionTask(BasicTask):
244
244
  features = []
245
245
  for box, class_id, score in zip(boxes, class_ids, scores):
246
246
  shp = shapely.box(
247
- metadata.patch_bounds[0] + float(box[0]),
248
- metadata.patch_bounds[1] + float(box[1]),
249
- metadata.patch_bounds[0] + float(box[2]),
250
- metadata.patch_bounds[1] + float(box[3]),
247
+ metadata.crop_bounds[0] + float(box[0]),
248
+ metadata.crop_bounds[1] + float(box[1]),
249
+ metadata.crop_bounds[0] + float(box[2]),
250
+ metadata.crop_bounds[1] + float(box[3]),
251
251
  )
252
252
  geom = STGeometry(metadata.projection, shp, None)
253
253
  properties: dict[str, Any] = {
@@ -149,22 +149,28 @@ class PerPixelRegressionHead(Predictor):
149
149
  """Head for per-pixel regression task."""
150
150
 
151
151
  def __init__(
152
- self, loss_mode: Literal["mse", "l1"] = "mse", use_sigmoid: bool = False
152
+ self,
153
+ loss_mode: Literal["mse", "l1", "huber"] = "mse",
154
+ use_sigmoid: bool = False,
155
+ huber_delta: float = 1.0,
153
156
  ):
154
- """Initialize a new RegressionHead.
157
+ """Initialize a new PerPixelRegressionHead.
155
158
 
156
159
  Args:
157
- loss_mode: the loss function to use, either "mse" (default) or "l1".
160
+ loss_mode: the loss function to use: "mse" (default), "l1", or "huber".
158
161
  use_sigmoid: whether to apply a sigmoid activation on the output. This
159
162
  requires targets to be between 0-1.
163
+ huber_delta: delta parameter for Huber loss (only used when
164
+ loss_mode="huber").
160
165
  """
161
166
  super().__init__()
162
167
 
163
- if loss_mode not in ["mse", "l1"]:
164
- raise ValueError("invalid loss mode")
168
+ if loss_mode not in ["mse", "l1", "huber"]:
169
+ raise ValueError(f"invalid loss mode {loss_mode}")
165
170
 
166
171
  self.loss_mode = loss_mode
167
172
  self.use_sigmoid = use_sigmoid
173
+ self.huber_delta = huber_delta
168
174
 
169
175
  def forward(
170
176
  self,
@@ -217,8 +223,15 @@ class PerPixelRegressionHead(Predictor):
217
223
  scores = torch.square(outputs - labels)
218
224
  elif self.loss_mode == "l1":
219
225
  scores = torch.abs(outputs - labels)
226
+ elif self.loss_mode == "huber":
227
+ scores = torch.nn.functional.huber_loss(
228
+ outputs,
229
+ labels,
230
+ reduction="none",
231
+ delta=self.huber_delta,
232
+ )
220
233
  else:
221
- assert False
234
+ raise ValueError(f"unknown loss mode {self.loss_mode}")
222
235
 
223
236
  # Compute average but only over valid pixels.
224
237
  mask_total = mask.sum()
@@ -130,7 +130,7 @@ class RegressionTask(BasicTask):
130
130
  feature = Feature(
131
131
  STGeometry(
132
132
  metadata.projection,
133
- shapely.Point(metadata.patch_bounds[0], metadata.patch_bounds[1]),
133
+ shapely.Point(metadata.crop_bounds[0], metadata.crop_bounds[1]),
134
134
  None,
135
135
  ),
136
136
  {
@@ -196,18 +196,24 @@ class RegressionHead(Predictor):
196
196
  """Head for regression task."""
197
197
 
198
198
  def __init__(
199
- self, loss_mode: Literal["mse", "l1"] = "mse", use_sigmoid: bool = False
199
+ self,
200
+ loss_mode: Literal["mse", "l1", "huber"] = "mse",
201
+ use_sigmoid: bool = False,
202
+ huber_delta: float = 1.0,
200
203
  ):
201
204
  """Initialize a new RegressionHead.
202
205
 
203
206
  Args:
204
- loss_mode: the loss function to use, either "mse" (default) or "l1".
207
+ loss_mode: the loss function to use: "mse" (default), "l1", or "huber".
205
208
  use_sigmoid: whether to apply a sigmoid activation on the output. This
206
209
  requires targets to be between 0-1.
210
+ huber_delta: delta parameter for Huber loss (only used when
211
+ loss_mode="huber").
207
212
  """
208
213
  super().__init__()
209
214
  self.loss_mode = loss_mode
210
215
  self.use_sigmoid = use_sigmoid
216
+ self.huber_delta = huber_delta
211
217
 
212
218
  def forward(
213
219
  self,
@@ -251,6 +257,16 @@ class RegressionHead(Predictor):
251
257
  losses["regress"] = torch.mean(torch.square(outputs - labels) * mask)
252
258
  elif self.loss_mode == "l1":
253
259
  losses["regress"] = torch.mean(torch.abs(outputs - labels) * mask)
260
+ elif self.loss_mode == "huber":
261
+ losses["regress"] = torch.mean(
262
+ torch.nn.functional.huber_loss(
263
+ outputs,
264
+ labels,
265
+ reduction="none",
266
+ delta=self.huber_delta,
267
+ )
268
+ * mask
269
+ )
254
270
  else:
255
271
  raise ValueError(f"unknown loss mode {self.loss_mode}")
256
272
 
@@ -10,6 +10,7 @@ import torchmetrics.classification
10
10
  from torchmetrics import Metric, MetricCollection
11
11
 
12
12
  from rslearn.models.component import FeatureMaps, Predictor
13
+ from rslearn.train.metrics import ConfusionMatrixMetric
13
14
  from rslearn.train.model_context import (
14
15
  ModelContext,
15
16
  ModelOutput,
@@ -43,6 +44,8 @@ class SegmentationTask(BasicTask):
43
44
  other_metrics: dict[str, Metric] = {},
44
45
  output_probs: bool = False,
45
46
  output_class_idx: int | None = None,
47
+ enable_confusion_matrix: bool = False,
48
+ class_names: list[str] | None = None,
46
49
  **kwargs: Any,
47
50
  ) -> None:
48
51
  """Initialize a new SegmentationTask.
@@ -80,6 +83,10 @@ class SegmentationTask(BasicTask):
80
83
  during prediction.
81
84
  output_class_idx: if set along with output_probs, only output the probability
82
85
  for this specific class index (single-channel output).
86
+ enable_confusion_matrix: whether to compute confusion matrix (default false).
87
+ If true, it requires wandb to be initialized for logging.
88
+ class_names: optional list of class names for labeling confusion matrix axes.
89
+ If not provided, classes will be labeled as "class_0", "class_1", etc.
83
90
  kwargs: additional arguments to pass to BasicTask
84
91
  """
85
92
  super().__init__(**kwargs)
@@ -106,6 +113,8 @@ class SegmentationTask(BasicTask):
106
113
  self.other_metrics = other_metrics
107
114
  self.output_probs = output_probs
108
115
  self.output_class_idx = output_class_idx
116
+ self.enable_confusion_matrix = enable_confusion_matrix
117
+ self.class_names = class_names
109
118
 
110
119
  def process_inputs(
111
120
  self,
@@ -285,6 +294,14 @@ class SegmentationTask(BasicTask):
285
294
  if self.other_metrics:
286
295
  metrics.update(self.other_metrics)
287
296
 
297
+ if self.enable_confusion_matrix:
298
+ metrics["confusion_matrix"] = SegmentationMetric(
299
+ ConfusionMatrixMetric(
300
+ num_classes=self.num_classes,
301
+ class_names=self.class_names,
302
+ ),
303
+ )
304
+
288
305
  return MetricCollection(metrics)
289
306
 
290
307
 
rslearn/utils/__init__.py CHANGED
@@ -7,6 +7,7 @@ from .geometry import (
7
7
  PixelBounds,
8
8
  Projection,
9
9
  STGeometry,
10
+ get_global_raster_bounds,
10
11
  is_same_resolution,
11
12
  shp_intersects,
12
13
  )
@@ -23,6 +24,7 @@ __all__ = (
23
24
  "Projection",
24
25
  "STGeometry",
25
26
  "daterange",
27
+ "get_global_raster_bounds",
26
28
  "get_utm_ups_crs",
27
29
  "is_same_resolution",
28
30
  "logger",
rslearn/utils/fsspec.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  import os
4
4
  import tempfile
5
- from collections.abc import Generator
5
+ from collections.abc import Generator, Iterator
6
6
  from contextlib import contextmanager
7
7
  from typing import Any
8
8
 
@@ -16,6 +16,56 @@ from rslearn.log_utils import get_logger
16
16
  logger = get_logger(__name__)
17
17
 
18
18
 
19
+ def iter_nonhidden(path: UPath) -> Iterator[UPath]:
20
+ """Iterate over non-hidden entries in a directory.
21
+
22
+ Hidden entries are those whose basename starts with "." (e.g. ".DS_Store").
23
+
24
+ Args:
25
+ path: the directory to iterate.
26
+
27
+ Yields:
28
+ non-hidden UPath entries in the directory.
29
+ """
30
+ try:
31
+ it = path.iterdir()
32
+ except (FileNotFoundError, NotADirectoryError):
33
+ return
34
+
35
+ for p in it:
36
+ if p.name.startswith("."):
37
+ continue
38
+ yield p
39
+
40
+
41
+ def iter_nonhidden_subdirs(path: UPath) -> Iterator[UPath]:
42
+ """Iterate over non-hidden subdirectories in a directory.
43
+
44
+ Args:
45
+ path: the directory to iterate.
46
+
47
+ Yields:
48
+ non-hidden subdirectories in the directory.
49
+ """
50
+ for p in iter_nonhidden(path):
51
+ if p.is_dir():
52
+ yield p
53
+
54
+
55
+ def iter_nonhidden_files(path: UPath) -> Iterator[UPath]:
56
+ """Iterate over non-hidden files in a directory.
57
+
58
+ Args:
59
+ path: the directory to iterate.
60
+
61
+ Yields:
62
+ non-hidden files in the directory.
63
+ """
64
+ for p in iter_nonhidden(path):
65
+ if p.is_file():
66
+ yield p
67
+
68
+
19
69
  @contextmanager
20
70
  def get_upath_local(
21
71
  path: UPath, extra_paths: list[UPath] = []
rslearn/utils/geometry.py CHANGED
@@ -116,6 +116,27 @@ class Projection:
116
116
  WGS84_PROJECTION = Projection(CRS.from_epsg(WGS84_EPSG), 1, 1)
117
117
 
118
118
 
119
+ def get_global_raster_bounds(projection: Projection) -> PixelBounds:
120
+ """Get very large pixel bounds for a global raster in the given projection.
121
+
122
+ This is useful for data sources that cover the entire world and don't want to
123
+ compute exact bounds in arbitrary projections (which can fail for projections
124
+ like UTM that only cover part of the world).
125
+
126
+ Args:
127
+ projection: the projection to get bounds in.
128
+
129
+ Returns:
130
+ Pixel bounds that will intersect with any reasonable window. We assume that the
131
+ absolute value of CRS coordinates is at most 2^32, and adjust it based on the
132
+ resolution in the Projection in case very fine-grained resolutions are used.
133
+ """
134
+ crs_bound = 2**32
135
+ pixel_bound_x = int(crs_bound / abs(projection.x_resolution))
136
+ pixel_bound_y = int(crs_bound / abs(projection.y_resolution))
137
+ return (-pixel_bound_x, -pixel_bound_y, pixel_bound_x, pixel_bound_y)
138
+
139
+
119
140
  class ResolutionFactor:
120
141
  """Multiplier for the resolution in a Projection.
121
142