rslearn 0.0.26__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. rslearn/data_sources/__init__.py +2 -0
  2. rslearn/data_sources/aws_landsat.py +44 -161
  3. rslearn/data_sources/aws_open_data.py +2 -4
  4. rslearn/data_sources/aws_sentinel1.py +1 -3
  5. rslearn/data_sources/aws_sentinel2_element84.py +54 -165
  6. rslearn/data_sources/climate_data_store.py +1 -3
  7. rslearn/data_sources/copernicus.py +1 -2
  8. rslearn/data_sources/data_source.py +1 -1
  9. rslearn/data_sources/direct_materialize_data_source.py +336 -0
  10. rslearn/data_sources/earthdaily.py +52 -155
  11. rslearn/data_sources/earthdatahub.py +425 -0
  12. rslearn/data_sources/eurocrops.py +1 -2
  13. rslearn/data_sources/gcp_public_data.py +1 -2
  14. rslearn/data_sources/google_earth_engine.py +1 -2
  15. rslearn/data_sources/hf_srtm.py +595 -0
  16. rslearn/data_sources/local_files.py +1 -1
  17. rslearn/data_sources/openstreetmap.py +1 -1
  18. rslearn/data_sources/planet.py +1 -2
  19. rslearn/data_sources/planet_basemap.py +1 -2
  20. rslearn/data_sources/planetary_computer.py +183 -186
  21. rslearn/data_sources/soilgrids.py +3 -3
  22. rslearn/data_sources/stac.py +1 -2
  23. rslearn/data_sources/usda_cdl.py +1 -3
  24. rslearn/data_sources/usgs_landsat.py +7 -254
  25. rslearn/data_sources/worldcereal.py +1 -1
  26. rslearn/data_sources/worldcover.py +1 -1
  27. rslearn/data_sources/worldpop.py +1 -1
  28. rslearn/data_sources/xyz_tiles.py +5 -9
  29. rslearn/models/concatenate_features.py +6 -1
  30. rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
  31. rslearn/train/data_module.py +27 -27
  32. rslearn/train/dataset.py +109 -62
  33. rslearn/train/lightning_module.py +1 -1
  34. rslearn/train/model_context.py +3 -3
  35. rslearn/train/prediction_writer.py +69 -41
  36. rslearn/train/tasks/classification.py +1 -1
  37. rslearn/train/tasks/detection.py +5 -5
  38. rslearn/train/tasks/regression.py +1 -1
  39. rslearn/utils/__init__.py +2 -0
  40. rslearn/utils/geometry.py +21 -0
  41. rslearn/utils/m2m_api.py +251 -0
  42. rslearn/utils/retry_session.py +43 -0
  43. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
  44. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/RECORD +49 -45
  45. rslearn/data_sources/earthdata_srtm.py +0 -282
  46. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
  47. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
  48. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
  49. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
  50. {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
rslearn/train/dataset.py CHANGED
@@ -8,6 +8,7 @@ import random
8
8
  import tempfile
9
9
  import time
10
10
  import uuid
11
+ import warnings
11
12
  from datetime import datetime
12
13
  from enum import StrEnum
13
14
  from typing import Any
@@ -456,11 +457,15 @@ class SplitConfig:
456
457
  num_patches: int | None = None,
457
458
  transforms: list[torch.nn.Module] | None = None,
458
459
  sampler: SamplerFactory | None = None,
460
+ crop_size: int | tuple[int, int] | None = None,
461
+ overlap_pixels: int | None = None,
462
+ load_all_crops: bool | None = None,
463
+ skip_targets: bool | None = None,
464
+ output_layer_name_skip_inference_if_exists: str | None = None,
465
+ # Deprecated parameters (for backwards compatibility)
459
466
  patch_size: int | tuple[int, int] | None = None,
460
467
  overlap_ratio: float | None = None,
461
468
  load_all_patches: bool | None = None,
462
- skip_targets: bool | None = None,
463
- output_layer_name_skip_inference_if_exists: str | None = None,
464
469
  ) -> None:
465
470
  """Initialize a new SplitConfig.
466
471
 
@@ -475,19 +480,69 @@ class SplitConfig:
475
480
  num_patches: limit this split to this many patches
476
481
  transforms: transforms to apply
477
482
  sampler: SamplerFactory for this split
478
- patch_size: an optional square size or (width, height) tuple. If set, read
483
+ crop_size: an optional square size or (width, height) tuple. If set, read
479
484
  crops of this size rather than entire windows.
480
- overlap_ratio: an optional float between 0 and 1. If set, read patches with
481
- this ratio of overlap.
482
- load_all_patches: with patch_size set, rather than sampling a random patch
483
- for each window, read all patches as separate sequential items in the
485
+ overlap_pixels: the number of pixels shared between adjacent crops during
486
+ sliding window inference.
487
+ load_all_crops: with crop_size set, rather than sampling a random crop
488
+ for each window, read all crops as separate sequential items in the
484
489
  dataset.
485
490
  skip_targets: whether to skip targets when loading inputs
486
491
  output_layer_name_skip_inference_if_exists: optional name of the output layer used during prediction.
487
492
  If set, windows that already
488
493
  have this layer completed will be skipped (useful for resuming
489
494
  partial inference runs).
495
+ patch_size: deprecated, use crop_size instead
496
+ overlap_ratio: deprecated, use overlap_pixels instead
497
+ load_all_patches: deprecated, use load_all_crops instead
490
498
  """
499
+ # Handle deprecated load_all_patches parameter
500
+ if load_all_patches is not None:
501
+ warnings.warn(
502
+ "load_all_patches is deprecated, use load_all_crops instead",
503
+ FutureWarning,
504
+ stacklevel=2,
505
+ )
506
+ if load_all_crops is not None:
507
+ raise ValueError(
508
+ "Cannot specify both load_all_patches and load_all_crops"
509
+ )
510
+ load_all_crops = load_all_patches
511
+ # Handle deprecated patch_size parameter
512
+ if patch_size is not None:
513
+ warnings.warn(
514
+ "patch_size is deprecated, use crop_size instead",
515
+ FutureWarning,
516
+ stacklevel=2,
517
+ )
518
+ if crop_size is not None:
519
+ raise ValueError("Cannot specify both patch_size and crop_size")
520
+ crop_size = patch_size
521
+
522
+ # Normalize crop_size to tuple[int, int] | None
523
+ self.crop_size: tuple[int, int] | None = None
524
+ if crop_size is not None:
525
+ if isinstance(crop_size, int):
526
+ self.crop_size = (crop_size, crop_size)
527
+ else:
528
+ self.crop_size = crop_size
529
+
530
+ # Handle deprecated overlap_ratio parameter
531
+ if overlap_ratio is not None:
532
+ warnings.warn(
533
+ "overlap_ratio is deprecated, use overlap_pixels instead",
534
+ FutureWarning,
535
+ stacklevel=2,
536
+ )
537
+ if overlap_pixels is not None:
538
+ raise ValueError("Cannot specify both overlap_ratio and overlap_pixels")
539
+ if self.crop_size is None:
540
+ raise ValueError("overlap_ratio requires crop_size to be set")
541
+ overlap_pixels = round(self.crop_size[0] * overlap_ratio)
542
+
543
+ if overlap_pixels is not None and overlap_pixels < 0:
544
+ raise ValueError("overlap_pixels must be non-negative")
545
+
491
546
  self.groups = groups
492
547
  self.names = names
493
548
  self.tags = tags
@@ -495,19 +550,15 @@ class SplitConfig:
495
550
  self.num_patches = num_patches
496
551
  self.transforms = transforms
497
552
  self.sampler = sampler
498
- self.patch_size = patch_size
499
553
  self.skip_targets = skip_targets
500
554
  self.output_layer_name_skip_inference_if_exists = (
501
555
  output_layer_name_skip_inference_if_exists
502
556
  )
503
557
 
504
- # Note that load_all_patches are handled by the RslearnDataModule rather than
505
- # the ModelDataset.
506
- self.load_all_patches = load_all_patches
507
- self.overlap_ratio = overlap_ratio
508
-
509
- if self.overlap_ratio is not None and not (0 < self.overlap_ratio < 1):
510
- raise ValueError("overlap_ratio must be between 0 and 1 (exclusive)")
558
+ # Note that load_all_crops is handled by the RslearnDataModule rather than the
559
+ # ModelDataset.
560
+ self.load_all_crops = load_all_crops
561
+ self.overlap_pixels = overlap_pixels
511
562
 
512
563
  def update(self, other: "SplitConfig") -> "SplitConfig":
513
564
  """Override settings in this SplitConfig with those in another.
@@ -523,9 +574,9 @@ class SplitConfig:
523
574
  num_patches=self.num_patches,
524
575
  transforms=self.transforms,
525
576
  sampler=self.sampler,
526
- patch_size=self.patch_size,
527
- overlap_ratio=self.overlap_ratio,
528
- load_all_patches=self.load_all_patches,
577
+ crop_size=self.crop_size,
578
+ overlap_pixels=self.overlap_pixels,
579
+ load_all_crops=self.load_all_crops,
529
580
  skip_targets=self.skip_targets,
530
581
  output_layer_name_skip_inference_if_exists=self.output_layer_name_skip_inference_if_exists,
531
582
  )
@@ -543,12 +594,12 @@ class SplitConfig:
543
594
  result.transforms = other.transforms
544
595
  if other.sampler:
545
596
  result.sampler = other.sampler
546
- if other.patch_size:
547
- result.patch_size = other.patch_size
548
- if other.overlap_ratio is not None:
549
- result.overlap_ratio = other.overlap_ratio
550
- if other.load_all_patches is not None:
551
- result.load_all_patches = other.load_all_patches
597
+ if other.crop_size:
598
+ result.crop_size = other.crop_size
599
+ if other.overlap_pixels is not None:
600
+ result.overlap_pixels = other.overlap_pixels
601
+ if other.load_all_crops is not None:
602
+ result.load_all_crops = other.load_all_crops
552
603
  if other.skip_targets is not None:
553
604
  result.skip_targets = other.skip_targets
554
605
  if other.output_layer_name_skip_inference_if_exists is not None:
@@ -557,21 +608,17 @@ class SplitConfig:
557
608
  )
558
609
  return result
559
610
 
560
- def get_patch_size(self) -> tuple[int, int] | None:
561
- """Get patch size normalized to int tuple."""
562
- if self.patch_size is None:
563
- return None
564
- if isinstance(self.patch_size, int):
565
- return (self.patch_size, self.patch_size)
566
- return self.patch_size
611
+ def get_crop_size(self) -> tuple[int, int] | None:
612
+ """Get crop size as tuple."""
613
+ return self.crop_size
567
614
 
568
- def get_overlap_ratio(self) -> float:
569
- """Get the overlap ratio (default 0)."""
570
- return self.overlap_ratio if self.overlap_ratio is not None else 0.0
615
+ def get_overlap_pixels(self) -> int:
616
+ """Get the overlap pixels (default 0)."""
617
+ return self.overlap_pixels if self.overlap_pixels is not None else 0
571
618
 
572
- def get_load_all_patches(self) -> bool:
619
+ def get_load_all_crops(self) -> bool:
573
620
  """Returns whether loading all patches is enabled (default False)."""
574
- return True if self.load_all_patches is True else False
621
+ return True if self.load_all_crops is True else False
575
622
 
576
623
  def get_skip_targets(self) -> bool:
577
624
  """Returns whether skip_targets is enabled (default False)."""
@@ -677,13 +724,13 @@ class ModelDataset(torch.utils.data.Dataset):
677
724
  else:
678
725
  self.transforms = rslearn.train.transforms.transform.Identity()
679
726
 
680
- # Get normalized patch size from the SplitConfig.
681
- # But if load all patches is enabled, this is handled by AllPatchesDataset, so
727
+ # Get normalized crop size from the SplitConfig.
728
+ # But if load all patches is enabled, this is handled by AllCropsDataset, so
682
729
  # here we instead load the entire windows.
683
- if split_config.get_load_all_patches():
684
- self.patch_size = None
730
+ if split_config.get_load_all_crops():
731
+ self.crop_size = None
685
732
  else:
686
- self.patch_size = split_config.get_patch_size()
733
+ self.crop_size = split_config.get_crop_size()
687
734
 
688
735
  # If targets are not needed, remove them from the inputs.
689
736
  if split_config.get_skip_targets():
@@ -904,8 +951,8 @@ class ModelDataset(torch.utils.data.Dataset):
904
951
  def get_dataset_examples(self) -> list[Window]:
905
952
  """Get a list of examples in the dataset.
906
953
 
907
- If load_all_patches is False, this is a list of Windows. Otherwise, this is a
908
- list of (window, patch_bounds, (patch_idx, # patches)) tuples.
954
+ If load_all_crops is False, this is a list of Windows. Otherwise, this is a
955
+ list of (window, crop_bounds, (crop_idx, # patches)) tuples.
909
956
  """
910
957
  if self.dataset_examples is None:
911
958
  logger.debug(
@@ -941,34 +988,34 @@ class ModelDataset(torch.utils.data.Dataset):
941
988
  rng = random.Random(idx if self.fix_patch_pick else None)
942
989
 
943
990
  # Select bounds to read.
944
- if self.patch_size:
991
+ if self.crop_size:
945
992
  window = example
946
993
 
947
- def get_patch_range(n_patch: int, n_window: int) -> list[int]:
948
- if n_patch > n_window:
994
+ def get_crop_range(n_crop: int, n_window: int) -> list[int]:
995
+ if n_crop > n_window:
949
996
  # Select arbitrary range containing the entire window.
950
- # Basically arbitrarily padding the window to get to patch size.
951
- start = rng.randint(n_window - n_patch, 0)
952
- return [start, start + n_patch]
997
+ # Basically arbitrarily padding the window to get to crop size.
998
+ start = rng.randint(n_window - n_crop, 0)
999
+ return [start, start + n_crop]
953
1000
 
954
1001
  else:
955
- # Select arbitrary patch within the window.
956
- start = rng.randint(0, n_window - n_patch)
957
- return [start, start + n_patch]
1002
+ # Select arbitrary crop within the window.
1003
+ start = rng.randint(0, n_window - n_crop)
1004
+ return [start, start + n_crop]
958
1005
 
959
1006
  window_size = (
960
1007
  window.bounds[2] - window.bounds[0],
961
1008
  window.bounds[3] - window.bounds[1],
962
1009
  )
963
- patch_ranges = [
964
- get_patch_range(self.patch_size[0], window_size[0]),
965
- get_patch_range(self.patch_size[1], window_size[1]),
1010
+ crop_ranges = [
1011
+ get_crop_range(self.crop_size[0], window_size[0]),
1012
+ get_crop_range(self.crop_size[1], window_size[1]),
966
1013
  ]
967
1014
  bounds = (
968
- window.bounds[0] + patch_ranges[0][0],
969
- window.bounds[1] + patch_ranges[1][0],
970
- window.bounds[0] + patch_ranges[0][1],
971
- window.bounds[1] + patch_ranges[1][1],
1015
+ window.bounds[0] + crop_ranges[0][0],
1016
+ window.bounds[1] + crop_ranges[1][0],
1017
+ window.bounds[0] + crop_ranges[0][1],
1018
+ window.bounds[1] + crop_ranges[1][1],
972
1019
  )
973
1020
 
974
1021
  else:
@@ -990,9 +1037,9 @@ class ModelDataset(torch.utils.data.Dataset):
990
1037
  window_group=window.group,
991
1038
  window_name=window.name,
992
1039
  window_bounds=window.bounds,
993
- patch_bounds=bounds,
994
- patch_idx=0,
995
- num_patches_in_window=1,
1040
+ crop_bounds=bounds,
1041
+ crop_idx=0,
1042
+ num_crops_in_window=1,
996
1043
  time_range=window.time_range,
997
1044
  projection=window.projection,
998
1045
  dataset_source=self.name,
@@ -365,7 +365,7 @@ class RslearnLightningModule(L.LightningModule):
365
365
  for image_suffix, image in images.items():
366
366
  out_fname = os.path.join(
367
367
  self.visualize_dir,
368
- f"{metadata.window_name}_{metadata.patch_bounds[0]}_{metadata.patch_bounds[1]}_{image_suffix}.png",
368
+ f"{metadata.window_name}_{metadata.crop_bounds[0]}_{metadata.crop_bounds[1]}_{image_suffix}.png",
369
369
  )
370
370
  Image.fromarray(image).save(out_fname)
371
371
 
@@ -67,9 +67,9 @@ class SampleMetadata:
67
67
  window_group: str
68
68
  window_name: str
69
69
  window_bounds: PixelBounds
70
- patch_bounds: PixelBounds
71
- patch_idx: int
72
- num_patches_in_window: int
70
+ crop_bounds: PixelBounds
71
+ crop_idx: int
72
+ num_crops_in_window: int
73
73
  time_range: tuple[datetime, datetime] | None
74
74
  projection: Projection
75
75
 
@@ -1,6 +1,7 @@
1
1
  """rslearn PredictionWriter implementation."""
2
2
 
3
3
  import json
4
+ import warnings
4
5
  from collections.abc import Iterable, Sequence
5
6
  from dataclasses import dataclass
6
7
  from pathlib import Path
@@ -39,20 +40,20 @@ logger = get_logger(__name__)
39
40
 
40
41
 
41
42
  @dataclass
42
- class PendingPatchOutput:
43
- """A patch output that hasn't been merged yet."""
43
+ class PendingCropOutput:
44
+ """A crop output that hasn't been merged yet."""
44
45
 
45
46
  bounds: PixelBounds
46
47
  output: Any
47
48
 
48
49
 
49
- class PatchPredictionMerger:
50
- """Base class for merging predictions from multiple patches."""
50
+ class CropPredictionMerger:
51
+ """Base class for merging predictions from multiple crops."""
51
52
 
52
53
  def merge(
53
54
  self,
54
55
  window: Window,
55
- outputs: Sequence[PendingPatchOutput],
56
+ outputs: Sequence[PendingCropOutput],
56
57
  layer_config: LayerConfig,
57
58
  ) -> Any:
58
59
  """Merge the outputs.
@@ -68,39 +69,60 @@ class PatchPredictionMerger:
68
69
  raise NotImplementedError
69
70
 
70
71
 
71
- class VectorMerger(PatchPredictionMerger):
72
+ class VectorMerger(CropPredictionMerger):
72
73
  """Merger for vector data that simply concatenates the features."""
73
74
 
74
75
  def merge(
75
76
  self,
76
77
  window: Window,
77
- outputs: Sequence[PendingPatchOutput],
78
+ outputs: Sequence[PendingCropOutput],
78
79
  layer_config: LayerConfig,
79
80
  ) -> list[Feature]:
80
81
  """Concatenate the vector features."""
81
82
  return [feat for output in outputs for feat in output.output]
82
83
 
83
84
 
84
- class RasterMerger(PatchPredictionMerger):
85
+ class RasterMerger(CropPredictionMerger):
85
86
  """Merger for raster data that copies the rasters to the output."""
86
87
 
87
- def __init__(self, padding: int | None = None, downsample_factor: int = 1):
88
+ def __init__(
89
+ self,
90
+ overlap_pixels: int | None = None,
91
+ downsample_factor: int = 1,
92
+ # Deprecated parameter (for backwards compatibility)
93
+ padding: int | None = None,
94
+ ):
88
95
  """Create a new RasterMerger.
89
96
 
90
97
  Args:
91
- padding: the padding around the individual patch outputs to remove. This is
92
- typically used when leveraging overlapping patches. Portions of outputs
93
- at the border of the window will still be retained.
98
+ overlap_pixels: the number of pixels shared between adjacent crops during
99
+ sliding window inference. Half of this overlap is removed from each
100
+ crop during merging (except at window boundaries where the full crop
101
+ is retained).
94
102
  downsample_factor: the factor by which the rasters output by the task are
95
103
  lower in resolution relative to the window resolution.
104
+ padding: deprecated, use overlap_pixels instead. The old padding value
105
+ equals overlap_pixels // 2.
96
106
  """
97
- self.padding = padding
107
+ # Handle deprecated padding parameter
108
+ if padding is not None:
109
+ warnings.warn(
110
+ "padding is deprecated, use overlap_pixels instead. "
111
+ "Note: overlap_pixels = padding * 2",
112
+ FutureWarning,
113
+ stacklevel=2,
114
+ )
115
+ if overlap_pixels is not None:
116
+ raise ValueError("Cannot specify both padding and overlap_pixels")
117
+ overlap_pixels = padding * 2
118
+
119
+ self.overlap_pixels = overlap_pixels
98
120
  self.downsample_factor = downsample_factor
99
121
 
100
122
  def merge(
101
123
  self,
102
124
  window: Window,
103
- outputs: Sequence[PendingPatchOutput],
125
+ outputs: Sequence[PendingCropOutput],
104
126
  layer_config: LayerConfig,
105
127
  ) -> npt.NDArray:
106
128
  """Merge the raster outputs."""
@@ -114,6 +136,12 @@ class RasterMerger(PatchPredictionMerger):
114
136
  dtype=layer_config.band_sets[0].dtype.get_numpy_dtype(),
115
137
  )
116
138
 
139
+ # Compute how many pixels to trim from each side.
140
+ # We remove half of the overlap from each side (not at window boundaries).
141
+ trim_pixels = (
142
+ self.overlap_pixels // 2 if self.overlap_pixels is not None else None
143
+ )
144
+
117
145
  # Ensure the outputs are sorted by height then width.
118
146
  # This way when we merge we can be sure that outputs that are lower or further
119
147
  # to the right will overwrite earlier outputs.
@@ -123,18 +151,18 @@ class RasterMerger(PatchPredictionMerger):
123
151
  for output in sorted_outputs:
124
152
  # So now we just need to compute the src_offset to copy.
125
153
  # If the output is not on the left or top boundary, then we should apply
126
- # the padding (if set).
154
+ # the trim (if set).
127
155
  src = output.output
128
156
  src_offset = (
129
157
  output.bounds[0] // self.downsample_factor,
130
158
  output.bounds[1] // self.downsample_factor,
131
159
  )
132
- if self.padding is not None and output.bounds[0] != window.bounds[0]:
133
- src = src[:, :, self.padding :]
134
- src_offset = (src_offset[0] + self.padding, src_offset[1])
135
- if self.padding is not None and output.bounds[1] != window.bounds[1]:
136
- src = src[:, self.padding :, :]
137
- src_offset = (src_offset[0], src_offset[1] + self.padding)
160
+ if trim_pixels is not None and output.bounds[0] != window.bounds[0]:
161
+ src = src[:, :, trim_pixels:]
162
+ src_offset = (src_offset[0] + trim_pixels, src_offset[1])
163
+ if trim_pixels is not None and output.bounds[1] != window.bounds[1]:
164
+ src = src[:, trim_pixels:, :]
165
+ src_offset = (src_offset[0], src_offset[1] + trim_pixels)
138
166
 
139
167
  copy_spatial_array(
140
168
  src=src,
@@ -162,7 +190,7 @@ class RslearnWriter(BasePredictionWriter):
162
190
  output_layer: str,
163
191
  path_options: dict[str, Any] | None = None,
164
192
  selector: list[str] | None = None,
165
- merger: PatchPredictionMerger | None = None,
193
+ merger: CropPredictionMerger | None = None,
166
194
  output_path: str | Path | None = None,
167
195
  layer_config: LayerConfig | None = None,
168
196
  storage_config: StorageConfig | None = None,
@@ -175,7 +203,7 @@ class RslearnWriter(BasePredictionWriter):
175
203
  path_options: additional options for path to pass to fsspec
176
204
  selector: keys to access the desired output in the output dict if needed.
177
205
  e.g ["key1", "key2"] gets output["key1"]["key2"]
178
- merger: merger to use to merge outputs from overlapped patches.
206
+ merger: merger to use to merge outputs from overlapped crops.
179
207
  output_path: optional custom path for writing predictions. If provided,
180
208
  predictions will be written to this path instead of deriving from dataset path.
181
209
  layer_config: optional layer configuration. If provided, this config will be
@@ -217,9 +245,9 @@ class RslearnWriter(BasePredictionWriter):
217
245
  self.merger = VectorMerger()
218
246
 
219
247
  # Map from window name to pending data to write.
220
- # This is used when windows are split up into patches, so the data from all the
221
- # patches of each window need to be reconstituted.
222
- self.pending_outputs: dict[str, list[PendingPatchOutput]] = {}
248
+ # This is used when windows are split up into crops, so the data from all the
249
+ # crops of each window need to be reconstituted.
250
+ self.pending_outputs: dict[str, list[PendingCropOutput]] = {}
223
251
 
224
252
  def _get_layer_config_and_dataset_storage(
225
253
  self,
@@ -327,7 +355,7 @@ class RslearnWriter(BasePredictionWriter):
327
355
  will be processed by the task to obtain a vector (list[Feature]) or
328
356
  raster (npt.NDArray) output.
329
357
  metadatas: corresponding list of metadatas from the batch describing the
330
- patches that were processed.
358
+ crops that were processed.
331
359
  """
332
360
  # Process the predictions into outputs that can be written.
333
361
  outputs: list = [
@@ -349,17 +377,17 @@ class RslearnWriter(BasePredictionWriter):
349
377
  )
350
378
  self.process_output(
351
379
  window,
352
- metadata.patch_idx,
353
- metadata.num_patches_in_window,
354
- metadata.patch_bounds,
380
+ metadata.crop_idx,
381
+ metadata.num_crops_in_window,
382
+ metadata.crop_bounds,
355
383
  output,
356
384
  )
357
385
 
358
386
  def process_output(
359
387
  self,
360
388
  window: Window,
361
- patch_idx: int,
362
- num_patches: int,
389
+ crop_idx: int,
390
+ num_crops: int,
363
391
  cur_bounds: PixelBounds,
364
392
  output: npt.NDArray | list[Feature],
365
393
  ) -> None:
@@ -367,28 +395,28 @@ class RslearnWriter(BasePredictionWriter):
367
395
 
368
396
  Args:
369
397
  window: the window that the output pertains to.
370
- patch_idx: the index of this patch for the window.
371
- num_patches: the total number of patches to be processed for the window.
372
- cur_bounds: the bounds of the current patch.
398
+ crop_idx: the index of this crop for the window.
399
+ num_crops: the total number of crops to be processed for the window.
400
+ cur_bounds: the bounds of the current crop.
373
401
  output: the output data.
374
402
  """
375
- # Incorporate the output into our list of pending patch outputs.
403
+ # Incorporate the output into our list of pending crop outputs.
376
404
  if window.name not in self.pending_outputs:
377
405
  self.pending_outputs[window.name] = []
378
- self.pending_outputs[window.name].append(PendingPatchOutput(cur_bounds, output))
406
+ self.pending_outputs[window.name].append(PendingCropOutput(cur_bounds, output))
379
407
  logger.debug(
380
- f"Stored PendingPatchOutput for patch #{patch_idx}/{num_patches} at window {window.name}"
408
+ f"Stored PendingCropOutput for crop #{crop_idx}/{num_crops} at window {window.name}"
381
409
  )
382
410
 
383
- if patch_idx < num_patches - 1:
411
+ if crop_idx < num_crops - 1:
384
412
  return
385
413
 
386
- # This is the last patch so it's time to write it.
414
+ # This is the last crop so it's time to write it.
387
415
  # First get the pending output and clear it.
388
416
  pending_output = self.pending_outputs[window.name]
389
417
  del self.pending_outputs[window.name]
390
418
 
391
- # Merge outputs from overlapped patches if merger is set.
419
+ # Merge outputs from overlapped crops if merger is set.
392
420
  logger.debug(f"Merging and writing for window {window.name}")
393
421
  merged_output = self.merger.merge(window, pending_output, self.layer_config)
394
422
 
@@ -201,7 +201,7 @@ class ClassificationTask(BasicTask):
201
201
  feature = Feature(
202
202
  STGeometry(
203
203
  metadata.projection,
204
- shapely.Point(metadata.patch_bounds[0], metadata.patch_bounds[1]),
204
+ shapely.Point(metadata.crop_bounds[0], metadata.crop_bounds[1]),
205
205
  None,
206
206
  ),
207
207
  {
@@ -128,7 +128,7 @@ class DetectionTask(BasicTask):
128
128
  if not load_targets:
129
129
  return {}, {}
130
130
 
131
- bounds = metadata.patch_bounds
131
+ bounds = metadata.crop_bounds
132
132
 
133
133
  boxes = []
134
134
  class_labels = []
@@ -244,10 +244,10 @@ class DetectionTask(BasicTask):
244
244
  features = []
245
245
  for box, class_id, score in zip(boxes, class_ids, scores):
246
246
  shp = shapely.box(
247
- metadata.patch_bounds[0] + float(box[0]),
248
- metadata.patch_bounds[1] + float(box[1]),
249
- metadata.patch_bounds[0] + float(box[2]),
250
- metadata.patch_bounds[1] + float(box[3]),
247
+ metadata.crop_bounds[0] + float(box[0]),
248
+ metadata.crop_bounds[1] + float(box[1]),
249
+ metadata.crop_bounds[0] + float(box[2]),
250
+ metadata.crop_bounds[1] + float(box[3]),
251
251
  )
252
252
  geom = STGeometry(metadata.projection, shp, None)
253
253
  properties: dict[str, Any] = {
@@ -130,7 +130,7 @@ class RegressionTask(BasicTask):
130
130
  feature = Feature(
131
131
  STGeometry(
132
132
  metadata.projection,
133
- shapely.Point(metadata.patch_bounds[0], metadata.patch_bounds[1]),
133
+ shapely.Point(metadata.crop_bounds[0], metadata.crop_bounds[1]),
134
134
  None,
135
135
  ),
136
136
  {
rslearn/utils/__init__.py CHANGED
@@ -7,6 +7,7 @@ from .geometry import (
7
7
  PixelBounds,
8
8
  Projection,
9
9
  STGeometry,
10
+ get_global_raster_bounds,
10
11
  is_same_resolution,
11
12
  shp_intersects,
12
13
  )
@@ -23,6 +24,7 @@ __all__ = (
23
24
  "Projection",
24
25
  "STGeometry",
25
26
  "daterange",
27
+ "get_global_raster_bounds",
26
28
  "get_utm_ups_crs",
27
29
  "is_same_resolution",
28
30
  "logger",
rslearn/utils/geometry.py CHANGED
@@ -116,6 +116,27 @@ class Projection:
116
116
  WGS84_PROJECTION = Projection(CRS.from_epsg(WGS84_EPSG), 1, 1)
117
117
 
118
118
 
119
+ def get_global_raster_bounds(projection: Projection) -> PixelBounds:
120
+ """Get very large pixel bounds for a global raster in the given projection.
121
+
122
+ This is useful for data sources that cover the entire world and don't want to
123
+ compute exact bounds in arbitrary projections (which can fail for projections
124
+ like UTM that only cover part of the world).
125
+
126
+ Args:
127
+ projection: the projection to get bounds in.
128
+
129
+ Returns:
130
+ Pixel bounds that will intersect with any reasonable window. We assume that the
131
+ absolute value of CRS coordinates is at most 2^32, and adjust it based on the
132
+ resolution in the Projection in case very fine-grained resolutions are used.
133
+ """
134
+ crs_bound = 2**32
135
+ pixel_bound_x = int(crs_bound / abs(projection.x_resolution))
136
+ pixel_bound_y = int(crs_bound / abs(projection.y_resolution))
137
+ return (-pixel_bound_x, -pixel_bound_y, pixel_bound_x, pixel_bound_y)
138
+
139
+
119
140
  class ResolutionFactor:
120
141
  """Multiplier for the resolution in a Projection.
121
142