rslearn 0.0.26__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/data_sources/__init__.py +2 -0
- rslearn/data_sources/aws_landsat.py +44 -161
- rslearn/data_sources/aws_open_data.py +2 -4
- rslearn/data_sources/aws_sentinel1.py +1 -3
- rslearn/data_sources/aws_sentinel2_element84.py +54 -165
- rslearn/data_sources/climate_data_store.py +1 -3
- rslearn/data_sources/copernicus.py +1 -2
- rslearn/data_sources/data_source.py +1 -1
- rslearn/data_sources/direct_materialize_data_source.py +336 -0
- rslearn/data_sources/earthdaily.py +52 -155
- rslearn/data_sources/earthdatahub.py +425 -0
- rslearn/data_sources/eurocrops.py +1 -2
- rslearn/data_sources/gcp_public_data.py +1 -2
- rslearn/data_sources/google_earth_engine.py +1 -2
- rslearn/data_sources/hf_srtm.py +595 -0
- rslearn/data_sources/local_files.py +1 -1
- rslearn/data_sources/openstreetmap.py +1 -1
- rslearn/data_sources/planet.py +1 -2
- rslearn/data_sources/planet_basemap.py +1 -2
- rslearn/data_sources/planetary_computer.py +183 -186
- rslearn/data_sources/soilgrids.py +3 -3
- rslearn/data_sources/stac.py +1 -2
- rslearn/data_sources/usda_cdl.py +1 -3
- rslearn/data_sources/usgs_landsat.py +7 -254
- rslearn/data_sources/worldcereal.py +1 -1
- rslearn/data_sources/worldcover.py +1 -1
- rslearn/data_sources/worldpop.py +1 -1
- rslearn/data_sources/xyz_tiles.py +5 -9
- rslearn/models/concatenate_features.py +6 -1
- rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
- rslearn/train/data_module.py +27 -27
- rslearn/train/dataset.py +109 -62
- rslearn/train/lightning_module.py +1 -1
- rslearn/train/model_context.py +3 -3
- rslearn/train/prediction_writer.py +69 -41
- rslearn/train/tasks/classification.py +1 -1
- rslearn/train/tasks/detection.py +5 -5
- rslearn/train/tasks/regression.py +1 -1
- rslearn/utils/__init__.py +2 -0
- rslearn/utils/geometry.py +21 -0
- rslearn/utils/m2m_api.py +251 -0
- rslearn/utils/retry_session.py +43 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
- {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/RECORD +49 -45
- rslearn/data_sources/earthdata_srtm.py +0 -282
- {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.26.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
rslearn/train/dataset.py
CHANGED
|
@@ -8,6 +8,7 @@ import random
|
|
|
8
8
|
import tempfile
|
|
9
9
|
import time
|
|
10
10
|
import uuid
|
|
11
|
+
import warnings
|
|
11
12
|
from datetime import datetime
|
|
12
13
|
from enum import StrEnum
|
|
13
14
|
from typing import Any
|
|
@@ -456,11 +457,15 @@ class SplitConfig:
|
|
|
456
457
|
num_patches: int | None = None,
|
|
457
458
|
transforms: list[torch.nn.Module] | None = None,
|
|
458
459
|
sampler: SamplerFactory | None = None,
|
|
460
|
+
crop_size: int | tuple[int, int] | None = None,
|
|
461
|
+
overlap_pixels: int | None = None,
|
|
462
|
+
load_all_crops: bool | None = None,
|
|
463
|
+
skip_targets: bool | None = None,
|
|
464
|
+
output_layer_name_skip_inference_if_exists: str | None = None,
|
|
465
|
+
# Deprecated parameters (for backwards compatibility)
|
|
459
466
|
patch_size: int | tuple[int, int] | None = None,
|
|
460
467
|
overlap_ratio: float | None = None,
|
|
461
468
|
load_all_patches: bool | None = None,
|
|
462
|
-
skip_targets: bool | None = None,
|
|
463
|
-
output_layer_name_skip_inference_if_exists: str | None = None,
|
|
464
469
|
) -> None:
|
|
465
470
|
"""Initialize a new SplitConfig.
|
|
466
471
|
|
|
@@ -475,19 +480,69 @@ class SplitConfig:
|
|
|
475
480
|
num_patches: limit this split to this many patches
|
|
476
481
|
transforms: transforms to apply
|
|
477
482
|
sampler: SamplerFactory for this split
|
|
478
|
-
|
|
483
|
+
crop_size: an optional square size or (width, height) tuple. If set, read
|
|
479
484
|
crops of this size rather than entire windows.
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
for each window, read all
|
|
485
|
+
overlap_pixels: the number of pixels shared between adjacent crops during
|
|
486
|
+
sliding window inference.
|
|
487
|
+
load_all_crops: with crop_size set, rather than sampling a random crop
|
|
488
|
+
for each window, read all crops as separate sequential items in the
|
|
484
489
|
dataset.
|
|
485
490
|
skip_targets: whether to skip targets when loading inputs
|
|
486
491
|
output_layer_name_skip_inference_if_exists: optional name of the output layer used during prediction.
|
|
487
492
|
If set, windows that already
|
|
488
493
|
have this layer completed will be skipped (useful for resuming
|
|
489
494
|
partial inference runs).
|
|
495
|
+
patch_size: deprecated, use crop_size instead
|
|
496
|
+
overlap_ratio: deprecated, use overlap_pixels instead
|
|
497
|
+
load_all_patches: deprecated, use load_all_crops instead
|
|
490
498
|
"""
|
|
499
|
+
# Handle deprecated load_all_patches parameter
|
|
500
|
+
if load_all_patches is not None:
|
|
501
|
+
warnings.warn(
|
|
502
|
+
"load_all_patches is deprecated, use load_all_crops instead",
|
|
503
|
+
FutureWarning,
|
|
504
|
+
stacklevel=2,
|
|
505
|
+
)
|
|
506
|
+
if load_all_crops is not None:
|
|
507
|
+
raise ValueError(
|
|
508
|
+
"Cannot specify both load_all_patches and load_all_crops"
|
|
509
|
+
)
|
|
510
|
+
load_all_crops = load_all_patches
|
|
511
|
+
# Handle deprecated patch_size parameter
|
|
512
|
+
if patch_size is not None:
|
|
513
|
+
warnings.warn(
|
|
514
|
+
"patch_size is deprecated, use crop_size instead",
|
|
515
|
+
FutureWarning,
|
|
516
|
+
stacklevel=2,
|
|
517
|
+
)
|
|
518
|
+
if crop_size is not None:
|
|
519
|
+
raise ValueError("Cannot specify both patch_size and crop_size")
|
|
520
|
+
crop_size = patch_size
|
|
521
|
+
|
|
522
|
+
# Normalize crop_size to tuple[int, int] | None
|
|
523
|
+
self.crop_size: tuple[int, int] | None = None
|
|
524
|
+
if crop_size is not None:
|
|
525
|
+
if isinstance(crop_size, int):
|
|
526
|
+
self.crop_size = (crop_size, crop_size)
|
|
527
|
+
else:
|
|
528
|
+
self.crop_size = crop_size
|
|
529
|
+
|
|
530
|
+
# Handle deprecated overlap_ratio parameter
|
|
531
|
+
if overlap_ratio is not None:
|
|
532
|
+
warnings.warn(
|
|
533
|
+
"overlap_ratio is deprecated, use overlap_pixels instead",
|
|
534
|
+
FutureWarning,
|
|
535
|
+
stacklevel=2,
|
|
536
|
+
)
|
|
537
|
+
if overlap_pixels is not None:
|
|
538
|
+
raise ValueError("Cannot specify both overlap_ratio and overlap_pixels")
|
|
539
|
+
if self.crop_size is None:
|
|
540
|
+
raise ValueError("overlap_ratio requires crop_size to be set")
|
|
541
|
+
overlap_pixels = round(self.crop_size[0] * overlap_ratio)
|
|
542
|
+
|
|
543
|
+
if overlap_pixels is not None and overlap_pixels < 0:
|
|
544
|
+
raise ValueError("overlap_pixels must be non-negative")
|
|
545
|
+
|
|
491
546
|
self.groups = groups
|
|
492
547
|
self.names = names
|
|
493
548
|
self.tags = tags
|
|
@@ -495,19 +550,15 @@ class SplitConfig:
|
|
|
495
550
|
self.num_patches = num_patches
|
|
496
551
|
self.transforms = transforms
|
|
497
552
|
self.sampler = sampler
|
|
498
|
-
self.patch_size = patch_size
|
|
499
553
|
self.skip_targets = skip_targets
|
|
500
554
|
self.output_layer_name_skip_inference_if_exists = (
|
|
501
555
|
output_layer_name_skip_inference_if_exists
|
|
502
556
|
)
|
|
503
557
|
|
|
504
|
-
# Note that
|
|
505
|
-
#
|
|
506
|
-
self.
|
|
507
|
-
self.
|
|
508
|
-
|
|
509
|
-
if self.overlap_ratio is not None and not (0 < self.overlap_ratio < 1):
|
|
510
|
-
raise ValueError("overlap_ratio must be between 0 and 1 (exclusive)")
|
|
558
|
+
# Note that load_all_crops is handled by the RslearnDataModule rather than the
|
|
559
|
+
# ModelDataset.
|
|
560
|
+
self.load_all_crops = load_all_crops
|
|
561
|
+
self.overlap_pixels = overlap_pixels
|
|
511
562
|
|
|
512
563
|
def update(self, other: "SplitConfig") -> "SplitConfig":
|
|
513
564
|
"""Override settings in this SplitConfig with those in another.
|
|
@@ -523,9 +574,9 @@ class SplitConfig:
|
|
|
523
574
|
num_patches=self.num_patches,
|
|
524
575
|
transforms=self.transforms,
|
|
525
576
|
sampler=self.sampler,
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
577
|
+
crop_size=self.crop_size,
|
|
578
|
+
overlap_pixels=self.overlap_pixels,
|
|
579
|
+
load_all_crops=self.load_all_crops,
|
|
529
580
|
skip_targets=self.skip_targets,
|
|
530
581
|
output_layer_name_skip_inference_if_exists=self.output_layer_name_skip_inference_if_exists,
|
|
531
582
|
)
|
|
@@ -543,12 +594,12 @@ class SplitConfig:
|
|
|
543
594
|
result.transforms = other.transforms
|
|
544
595
|
if other.sampler:
|
|
545
596
|
result.sampler = other.sampler
|
|
546
|
-
if other.
|
|
547
|
-
result.
|
|
548
|
-
if other.
|
|
549
|
-
result.
|
|
550
|
-
if other.
|
|
551
|
-
result.
|
|
597
|
+
if other.crop_size:
|
|
598
|
+
result.crop_size = other.crop_size
|
|
599
|
+
if other.overlap_pixels is not None:
|
|
600
|
+
result.overlap_pixels = other.overlap_pixels
|
|
601
|
+
if other.load_all_crops is not None:
|
|
602
|
+
result.load_all_crops = other.load_all_crops
|
|
552
603
|
if other.skip_targets is not None:
|
|
553
604
|
result.skip_targets = other.skip_targets
|
|
554
605
|
if other.output_layer_name_skip_inference_if_exists is not None:
|
|
@@ -557,21 +608,17 @@ class SplitConfig:
|
|
|
557
608
|
)
|
|
558
609
|
return result
|
|
559
610
|
|
|
560
|
-
def
|
|
561
|
-
"""Get
|
|
562
|
-
|
|
563
|
-
return None
|
|
564
|
-
if isinstance(self.patch_size, int):
|
|
565
|
-
return (self.patch_size, self.patch_size)
|
|
566
|
-
return self.patch_size
|
|
611
|
+
def get_crop_size(self) -> tuple[int, int] | None:
|
|
612
|
+
"""Get crop size as tuple."""
|
|
613
|
+
return self.crop_size
|
|
567
614
|
|
|
568
|
-
def
|
|
569
|
-
"""Get the overlap
|
|
570
|
-
return self.
|
|
615
|
+
def get_overlap_pixels(self) -> int:
|
|
616
|
+
"""Get the overlap pixels (default 0)."""
|
|
617
|
+
return self.overlap_pixels if self.overlap_pixels is not None else 0
|
|
571
618
|
|
|
572
|
-
def
|
|
619
|
+
def get_load_all_crops(self) -> bool:
|
|
573
620
|
"""Returns whether loading all patches is enabled (default False)."""
|
|
574
|
-
return True if self.
|
|
621
|
+
return True if self.load_all_crops is True else False
|
|
575
622
|
|
|
576
623
|
def get_skip_targets(self) -> bool:
|
|
577
624
|
"""Returns whether skip_targets is enabled (default False)."""
|
|
@@ -677,13 +724,13 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
677
724
|
else:
|
|
678
725
|
self.transforms = rslearn.train.transforms.transform.Identity()
|
|
679
726
|
|
|
680
|
-
# Get normalized
|
|
681
|
-
# But if load all patches is enabled, this is handled by
|
|
727
|
+
# Get normalized crop size from the SplitConfig.
|
|
728
|
+
# But if load all patches is enabled, this is handled by AllCropsDataset, so
|
|
682
729
|
# here we instead load the entire windows.
|
|
683
|
-
if split_config.
|
|
684
|
-
self.
|
|
730
|
+
if split_config.get_load_all_crops():
|
|
731
|
+
self.crop_size = None
|
|
685
732
|
else:
|
|
686
|
-
self.
|
|
733
|
+
self.crop_size = split_config.get_crop_size()
|
|
687
734
|
|
|
688
735
|
# If targets are not needed, remove them from the inputs.
|
|
689
736
|
if split_config.get_skip_targets():
|
|
@@ -904,8 +951,8 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
904
951
|
def get_dataset_examples(self) -> list[Window]:
|
|
905
952
|
"""Get a list of examples in the dataset.
|
|
906
953
|
|
|
907
|
-
If
|
|
908
|
-
list of (window,
|
|
954
|
+
If load_all_crops is False, this is a list of Windows. Otherwise, this is a
|
|
955
|
+
list of (window, crop_bounds, (crop_idx, # patches)) tuples.
|
|
909
956
|
"""
|
|
910
957
|
if self.dataset_examples is None:
|
|
911
958
|
logger.debug(
|
|
@@ -941,34 +988,34 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
941
988
|
rng = random.Random(idx if self.fix_patch_pick else None)
|
|
942
989
|
|
|
943
990
|
# Select bounds to read.
|
|
944
|
-
if self.
|
|
991
|
+
if self.crop_size:
|
|
945
992
|
window = example
|
|
946
993
|
|
|
947
|
-
def
|
|
948
|
-
if
|
|
994
|
+
def get_crop_range(n_crop: int, n_window: int) -> list[int]:
|
|
995
|
+
if n_crop > n_window:
|
|
949
996
|
# Select arbitrary range containing the entire window.
|
|
950
|
-
# Basically arbitrarily padding the window to get to
|
|
951
|
-
start = rng.randint(n_window -
|
|
952
|
-
return [start, start +
|
|
997
|
+
# Basically arbitrarily padding the window to get to crop size.
|
|
998
|
+
start = rng.randint(n_window - n_crop, 0)
|
|
999
|
+
return [start, start + n_crop]
|
|
953
1000
|
|
|
954
1001
|
else:
|
|
955
|
-
# Select arbitrary
|
|
956
|
-
start = rng.randint(0, n_window -
|
|
957
|
-
return [start, start +
|
|
1002
|
+
# Select arbitrary crop within the window.
|
|
1003
|
+
start = rng.randint(0, n_window - n_crop)
|
|
1004
|
+
return [start, start + n_crop]
|
|
958
1005
|
|
|
959
1006
|
window_size = (
|
|
960
1007
|
window.bounds[2] - window.bounds[0],
|
|
961
1008
|
window.bounds[3] - window.bounds[1],
|
|
962
1009
|
)
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
1010
|
+
crop_ranges = [
|
|
1011
|
+
get_crop_range(self.crop_size[0], window_size[0]),
|
|
1012
|
+
get_crop_range(self.crop_size[1], window_size[1]),
|
|
966
1013
|
]
|
|
967
1014
|
bounds = (
|
|
968
|
-
window.bounds[0] +
|
|
969
|
-
window.bounds[1] +
|
|
970
|
-
window.bounds[0] +
|
|
971
|
-
window.bounds[1] +
|
|
1015
|
+
window.bounds[0] + crop_ranges[0][0],
|
|
1016
|
+
window.bounds[1] + crop_ranges[1][0],
|
|
1017
|
+
window.bounds[0] + crop_ranges[0][1],
|
|
1018
|
+
window.bounds[1] + crop_ranges[1][1],
|
|
972
1019
|
)
|
|
973
1020
|
|
|
974
1021
|
else:
|
|
@@ -990,9 +1037,9 @@ class ModelDataset(torch.utils.data.Dataset):
|
|
|
990
1037
|
window_group=window.group,
|
|
991
1038
|
window_name=window.name,
|
|
992
1039
|
window_bounds=window.bounds,
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
1040
|
+
crop_bounds=bounds,
|
|
1041
|
+
crop_idx=0,
|
|
1042
|
+
num_crops_in_window=1,
|
|
996
1043
|
time_range=window.time_range,
|
|
997
1044
|
projection=window.projection,
|
|
998
1045
|
dataset_source=self.name,
|
|
@@ -365,7 +365,7 @@ class RslearnLightningModule(L.LightningModule):
|
|
|
365
365
|
for image_suffix, image in images.items():
|
|
366
366
|
out_fname = os.path.join(
|
|
367
367
|
self.visualize_dir,
|
|
368
|
-
f"{metadata.window_name}_{metadata.
|
|
368
|
+
f"{metadata.window_name}_{metadata.crop_bounds[0]}_{metadata.crop_bounds[1]}_{image_suffix}.png",
|
|
369
369
|
)
|
|
370
370
|
Image.fromarray(image).save(out_fname)
|
|
371
371
|
|
rslearn/train/model_context.py
CHANGED
|
@@ -67,9 +67,9 @@ class SampleMetadata:
|
|
|
67
67
|
window_group: str
|
|
68
68
|
window_name: str
|
|
69
69
|
window_bounds: PixelBounds
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
70
|
+
crop_bounds: PixelBounds
|
|
71
|
+
crop_idx: int
|
|
72
|
+
num_crops_in_window: int
|
|
73
73
|
time_range: tuple[datetime, datetime] | None
|
|
74
74
|
projection: Projection
|
|
75
75
|
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""rslearn PredictionWriter implementation."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import warnings
|
|
4
5
|
from collections.abc import Iterable, Sequence
|
|
5
6
|
from dataclasses import dataclass
|
|
6
7
|
from pathlib import Path
|
|
@@ -39,20 +40,20 @@ logger = get_logger(__name__)
|
|
|
39
40
|
|
|
40
41
|
|
|
41
42
|
@dataclass
|
|
42
|
-
class
|
|
43
|
-
"""A
|
|
43
|
+
class PendingCropOutput:
|
|
44
|
+
"""A crop output that hasn't been merged yet."""
|
|
44
45
|
|
|
45
46
|
bounds: PixelBounds
|
|
46
47
|
output: Any
|
|
47
48
|
|
|
48
49
|
|
|
49
|
-
class
|
|
50
|
-
"""Base class for merging predictions from multiple
|
|
50
|
+
class CropPredictionMerger:
|
|
51
|
+
"""Base class for merging predictions from multiple crops."""
|
|
51
52
|
|
|
52
53
|
def merge(
|
|
53
54
|
self,
|
|
54
55
|
window: Window,
|
|
55
|
-
outputs: Sequence[
|
|
56
|
+
outputs: Sequence[PendingCropOutput],
|
|
56
57
|
layer_config: LayerConfig,
|
|
57
58
|
) -> Any:
|
|
58
59
|
"""Merge the outputs.
|
|
@@ -68,39 +69,60 @@ class PatchPredictionMerger:
|
|
|
68
69
|
raise NotImplementedError
|
|
69
70
|
|
|
70
71
|
|
|
71
|
-
class VectorMerger(
|
|
72
|
+
class VectorMerger(CropPredictionMerger):
|
|
72
73
|
"""Merger for vector data that simply concatenates the features."""
|
|
73
74
|
|
|
74
75
|
def merge(
|
|
75
76
|
self,
|
|
76
77
|
window: Window,
|
|
77
|
-
outputs: Sequence[
|
|
78
|
+
outputs: Sequence[PendingCropOutput],
|
|
78
79
|
layer_config: LayerConfig,
|
|
79
80
|
) -> list[Feature]:
|
|
80
81
|
"""Concatenate the vector features."""
|
|
81
82
|
return [feat for output in outputs for feat in output.output]
|
|
82
83
|
|
|
83
84
|
|
|
84
|
-
class RasterMerger(
|
|
85
|
+
class RasterMerger(CropPredictionMerger):
|
|
85
86
|
"""Merger for raster data that copies the rasters to the output."""
|
|
86
87
|
|
|
87
|
-
def __init__(
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
overlap_pixels: int | None = None,
|
|
91
|
+
downsample_factor: int = 1,
|
|
92
|
+
# Deprecated parameter (for backwards compatibility)
|
|
93
|
+
padding: int | None = None,
|
|
94
|
+
):
|
|
88
95
|
"""Create a new RasterMerger.
|
|
89
96
|
|
|
90
97
|
Args:
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
98
|
+
overlap_pixels: the number of pixels shared between adjacent crops during
|
|
99
|
+
sliding window inference. Half of this overlap is removed from each
|
|
100
|
+
crop during merging (except at window boundaries where the full crop
|
|
101
|
+
is retained).
|
|
94
102
|
downsample_factor: the factor by which the rasters output by the task are
|
|
95
103
|
lower in resolution relative to the window resolution.
|
|
104
|
+
padding: deprecated, use overlap_pixels instead. The old padding value
|
|
105
|
+
equals overlap_pixels // 2.
|
|
96
106
|
"""
|
|
97
|
-
|
|
107
|
+
# Handle deprecated padding parameter
|
|
108
|
+
if padding is not None:
|
|
109
|
+
warnings.warn(
|
|
110
|
+
"padding is deprecated, use overlap_pixels instead. "
|
|
111
|
+
"Note: overlap_pixels = padding * 2",
|
|
112
|
+
FutureWarning,
|
|
113
|
+
stacklevel=2,
|
|
114
|
+
)
|
|
115
|
+
if overlap_pixels is not None:
|
|
116
|
+
raise ValueError("Cannot specify both padding and overlap_pixels")
|
|
117
|
+
overlap_pixels = padding * 2
|
|
118
|
+
|
|
119
|
+
self.overlap_pixels = overlap_pixels
|
|
98
120
|
self.downsample_factor = downsample_factor
|
|
99
121
|
|
|
100
122
|
def merge(
|
|
101
123
|
self,
|
|
102
124
|
window: Window,
|
|
103
|
-
outputs: Sequence[
|
|
125
|
+
outputs: Sequence[PendingCropOutput],
|
|
104
126
|
layer_config: LayerConfig,
|
|
105
127
|
) -> npt.NDArray:
|
|
106
128
|
"""Merge the raster outputs."""
|
|
@@ -114,6 +136,12 @@ class RasterMerger(PatchPredictionMerger):
|
|
|
114
136
|
dtype=layer_config.band_sets[0].dtype.get_numpy_dtype(),
|
|
115
137
|
)
|
|
116
138
|
|
|
139
|
+
# Compute how many pixels to trim from each side.
|
|
140
|
+
# We remove half of the overlap from each side (not at window boundaries).
|
|
141
|
+
trim_pixels = (
|
|
142
|
+
self.overlap_pixels // 2 if self.overlap_pixels is not None else None
|
|
143
|
+
)
|
|
144
|
+
|
|
117
145
|
# Ensure the outputs are sorted by height then width.
|
|
118
146
|
# This way when we merge we can be sure that outputs that are lower or further
|
|
119
147
|
# to the right will overwrite earlier outputs.
|
|
@@ -123,18 +151,18 @@ class RasterMerger(PatchPredictionMerger):
|
|
|
123
151
|
for output in sorted_outputs:
|
|
124
152
|
# So now we just need to compute the src_offset to copy.
|
|
125
153
|
# If the output is not on the left or top boundary, then we should apply
|
|
126
|
-
# the
|
|
154
|
+
# the trim (if set).
|
|
127
155
|
src = output.output
|
|
128
156
|
src_offset = (
|
|
129
157
|
output.bounds[0] // self.downsample_factor,
|
|
130
158
|
output.bounds[1] // self.downsample_factor,
|
|
131
159
|
)
|
|
132
|
-
if
|
|
133
|
-
src = src[:, :,
|
|
134
|
-
src_offset = (src_offset[0] +
|
|
135
|
-
if
|
|
136
|
-
src = src[:,
|
|
137
|
-
src_offset = (src_offset[0], src_offset[1] +
|
|
160
|
+
if trim_pixels is not None and output.bounds[0] != window.bounds[0]:
|
|
161
|
+
src = src[:, :, trim_pixels:]
|
|
162
|
+
src_offset = (src_offset[0] + trim_pixels, src_offset[1])
|
|
163
|
+
if trim_pixels is not None and output.bounds[1] != window.bounds[1]:
|
|
164
|
+
src = src[:, trim_pixels:, :]
|
|
165
|
+
src_offset = (src_offset[0], src_offset[1] + trim_pixels)
|
|
138
166
|
|
|
139
167
|
copy_spatial_array(
|
|
140
168
|
src=src,
|
|
@@ -162,7 +190,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
162
190
|
output_layer: str,
|
|
163
191
|
path_options: dict[str, Any] | None = None,
|
|
164
192
|
selector: list[str] | None = None,
|
|
165
|
-
merger:
|
|
193
|
+
merger: CropPredictionMerger | None = None,
|
|
166
194
|
output_path: str | Path | None = None,
|
|
167
195
|
layer_config: LayerConfig | None = None,
|
|
168
196
|
storage_config: StorageConfig | None = None,
|
|
@@ -175,7 +203,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
175
203
|
path_options: additional options for path to pass to fsspec
|
|
176
204
|
selector: keys to access the desired output in the output dict if needed.
|
|
177
205
|
e.g ["key1", "key2"] gets output["key1"]["key2"]
|
|
178
|
-
merger: merger to use to merge outputs from overlapped
|
|
206
|
+
merger: merger to use to merge outputs from overlapped crops.
|
|
179
207
|
output_path: optional custom path for writing predictions. If provided,
|
|
180
208
|
predictions will be written to this path instead of deriving from dataset path.
|
|
181
209
|
layer_config: optional layer configuration. If provided, this config will be
|
|
@@ -217,9 +245,9 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
217
245
|
self.merger = VectorMerger()
|
|
218
246
|
|
|
219
247
|
# Map from window name to pending data to write.
|
|
220
|
-
# This is used when windows are split up into
|
|
221
|
-
#
|
|
222
|
-
self.pending_outputs: dict[str, list[
|
|
248
|
+
# This is used when windows are split up into crops, so the data from all the
|
|
249
|
+
# crops of each window need to be reconstituted.
|
|
250
|
+
self.pending_outputs: dict[str, list[PendingCropOutput]] = {}
|
|
223
251
|
|
|
224
252
|
def _get_layer_config_and_dataset_storage(
|
|
225
253
|
self,
|
|
@@ -327,7 +355,7 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
327
355
|
will be processed by the task to obtain a vector (list[Feature]) or
|
|
328
356
|
raster (npt.NDArray) output.
|
|
329
357
|
metadatas: corresponding list of metadatas from the batch describing the
|
|
330
|
-
|
|
358
|
+
crops that were processed.
|
|
331
359
|
"""
|
|
332
360
|
# Process the predictions into outputs that can be written.
|
|
333
361
|
outputs: list = [
|
|
@@ -349,17 +377,17 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
349
377
|
)
|
|
350
378
|
self.process_output(
|
|
351
379
|
window,
|
|
352
|
-
metadata.
|
|
353
|
-
metadata.
|
|
354
|
-
metadata.
|
|
380
|
+
metadata.crop_idx,
|
|
381
|
+
metadata.num_crops_in_window,
|
|
382
|
+
metadata.crop_bounds,
|
|
355
383
|
output,
|
|
356
384
|
)
|
|
357
385
|
|
|
358
386
|
def process_output(
|
|
359
387
|
self,
|
|
360
388
|
window: Window,
|
|
361
|
-
|
|
362
|
-
|
|
389
|
+
crop_idx: int,
|
|
390
|
+
num_crops: int,
|
|
363
391
|
cur_bounds: PixelBounds,
|
|
364
392
|
output: npt.NDArray | list[Feature],
|
|
365
393
|
) -> None:
|
|
@@ -367,28 +395,28 @@ class RslearnWriter(BasePredictionWriter):
|
|
|
367
395
|
|
|
368
396
|
Args:
|
|
369
397
|
window: the window that the output pertains to.
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
cur_bounds: the bounds of the current
|
|
398
|
+
crop_idx: the index of this crop for the window.
|
|
399
|
+
num_crops: the total number of crops to be processed for the window.
|
|
400
|
+
cur_bounds: the bounds of the current crop.
|
|
373
401
|
output: the output data.
|
|
374
402
|
"""
|
|
375
|
-
# Incorporate the output into our list of pending
|
|
403
|
+
# Incorporate the output into our list of pending crop outputs.
|
|
376
404
|
if window.name not in self.pending_outputs:
|
|
377
405
|
self.pending_outputs[window.name] = []
|
|
378
|
-
self.pending_outputs[window.name].append(
|
|
406
|
+
self.pending_outputs[window.name].append(PendingCropOutput(cur_bounds, output))
|
|
379
407
|
logger.debug(
|
|
380
|
-
f"Stored
|
|
408
|
+
f"Stored PendingCropOutput for crop #{crop_idx}/{num_crops} at window {window.name}"
|
|
381
409
|
)
|
|
382
410
|
|
|
383
|
-
if
|
|
411
|
+
if crop_idx < num_crops - 1:
|
|
384
412
|
return
|
|
385
413
|
|
|
386
|
-
# This is the last
|
|
414
|
+
# This is the last crop so it's time to write it.
|
|
387
415
|
# First get the pending output and clear it.
|
|
388
416
|
pending_output = self.pending_outputs[window.name]
|
|
389
417
|
del self.pending_outputs[window.name]
|
|
390
418
|
|
|
391
|
-
# Merge outputs from overlapped
|
|
419
|
+
# Merge outputs from overlapped crops if merger is set.
|
|
392
420
|
logger.debug(f"Merging and writing for window {window.name}")
|
|
393
421
|
merged_output = self.merger.merge(window, pending_output, self.layer_config)
|
|
394
422
|
|
|
@@ -201,7 +201,7 @@ class ClassificationTask(BasicTask):
|
|
|
201
201
|
feature = Feature(
|
|
202
202
|
STGeometry(
|
|
203
203
|
metadata.projection,
|
|
204
|
-
shapely.Point(metadata.
|
|
204
|
+
shapely.Point(metadata.crop_bounds[0], metadata.crop_bounds[1]),
|
|
205
205
|
None,
|
|
206
206
|
),
|
|
207
207
|
{
|
rslearn/train/tasks/detection.py
CHANGED
|
@@ -128,7 +128,7 @@ class DetectionTask(BasicTask):
|
|
|
128
128
|
if not load_targets:
|
|
129
129
|
return {}, {}
|
|
130
130
|
|
|
131
|
-
bounds = metadata.
|
|
131
|
+
bounds = metadata.crop_bounds
|
|
132
132
|
|
|
133
133
|
boxes = []
|
|
134
134
|
class_labels = []
|
|
@@ -244,10 +244,10 @@ class DetectionTask(BasicTask):
|
|
|
244
244
|
features = []
|
|
245
245
|
for box, class_id, score in zip(boxes, class_ids, scores):
|
|
246
246
|
shp = shapely.box(
|
|
247
|
-
metadata.
|
|
248
|
-
metadata.
|
|
249
|
-
metadata.
|
|
250
|
-
metadata.
|
|
247
|
+
metadata.crop_bounds[0] + float(box[0]),
|
|
248
|
+
metadata.crop_bounds[1] + float(box[1]),
|
|
249
|
+
metadata.crop_bounds[0] + float(box[2]),
|
|
250
|
+
metadata.crop_bounds[1] + float(box[3]),
|
|
251
251
|
)
|
|
252
252
|
geom = STGeometry(metadata.projection, shp, None)
|
|
253
253
|
properties: dict[str, Any] = {
|
|
@@ -130,7 +130,7 @@ class RegressionTask(BasicTask):
|
|
|
130
130
|
feature = Feature(
|
|
131
131
|
STGeometry(
|
|
132
132
|
metadata.projection,
|
|
133
|
-
shapely.Point(metadata.
|
|
133
|
+
shapely.Point(metadata.crop_bounds[0], metadata.crop_bounds[1]),
|
|
134
134
|
None,
|
|
135
135
|
),
|
|
136
136
|
{
|
rslearn/utils/__init__.py
CHANGED
|
@@ -7,6 +7,7 @@ from .geometry import (
|
|
|
7
7
|
PixelBounds,
|
|
8
8
|
Projection,
|
|
9
9
|
STGeometry,
|
|
10
|
+
get_global_raster_bounds,
|
|
10
11
|
is_same_resolution,
|
|
11
12
|
shp_intersects,
|
|
12
13
|
)
|
|
@@ -23,6 +24,7 @@ __all__ = (
|
|
|
23
24
|
"Projection",
|
|
24
25
|
"STGeometry",
|
|
25
26
|
"daterange",
|
|
27
|
+
"get_global_raster_bounds",
|
|
26
28
|
"get_utm_ups_crs",
|
|
27
29
|
"is_same_resolution",
|
|
28
30
|
"logger",
|
rslearn/utils/geometry.py
CHANGED
|
@@ -116,6 +116,27 @@ class Projection:
|
|
|
116
116
|
WGS84_PROJECTION = Projection(CRS.from_epsg(WGS84_EPSG), 1, 1)
|
|
117
117
|
|
|
118
118
|
|
|
119
|
+
def get_global_raster_bounds(projection: Projection) -> PixelBounds:
|
|
120
|
+
"""Get very large pixel bounds for a global raster in the given projection.
|
|
121
|
+
|
|
122
|
+
This is useful for data sources that cover the entire world and don't want to
|
|
123
|
+
compute exact bounds in arbitrary projections (which can fail for projections
|
|
124
|
+
like UTM that only cover part of the world).
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
projection: the projection to get bounds in.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Pixel bounds that will intersect with any reasonable window. We assume that the
|
|
131
|
+
absolute value of CRS coordinates is at most 2^32, and adjust it based on the
|
|
132
|
+
resolution in the Projection in case very fine-grained resolutions are used.
|
|
133
|
+
"""
|
|
134
|
+
crs_bound = 2**32
|
|
135
|
+
pixel_bound_x = int(crs_bound / abs(projection.x_resolution))
|
|
136
|
+
pixel_bound_y = int(crs_bound / abs(projection.y_resolution))
|
|
137
|
+
return (-pixel_bound_x, -pixel_bound_y, pixel_bound_x, pixel_bound_y)
|
|
138
|
+
|
|
139
|
+
|
|
119
140
|
class ResolutionFactor:
|
|
120
141
|
"""Multiplier for the resolution in a Projection.
|
|
121
142
|
|