rslearn 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +30 -23
- rslearn/data_sources/__init__.py +2 -0
- rslearn/data_sources/aws_landsat.py +44 -161
- rslearn/data_sources/aws_open_data.py +2 -4
- rslearn/data_sources/aws_sentinel1.py +1 -3
- rslearn/data_sources/aws_sentinel2_element84.py +54 -165
- rslearn/data_sources/climate_data_store.py +1 -3
- rslearn/data_sources/copernicus.py +1 -2
- rslearn/data_sources/data_source.py +1 -1
- rslearn/data_sources/direct_materialize_data_source.py +336 -0
- rslearn/data_sources/earthdaily.py +52 -155
- rslearn/data_sources/earthdatahub.py +425 -0
- rslearn/data_sources/eurocrops.py +1 -2
- rslearn/data_sources/gcp_public_data.py +1 -2
- rslearn/data_sources/google_earth_engine.py +1 -2
- rslearn/data_sources/hf_srtm.py +595 -0
- rslearn/data_sources/local_files.py +3 -3
- rslearn/data_sources/openstreetmap.py +1 -1
- rslearn/data_sources/planet.py +1 -2
- rslearn/data_sources/planet_basemap.py +1 -2
- rslearn/data_sources/planetary_computer.py +183 -186
- rslearn/data_sources/soilgrids.py +3 -3
- rslearn/data_sources/stac.py +1 -2
- rslearn/data_sources/usda_cdl.py +1 -3
- rslearn/data_sources/usgs_landsat.py +7 -254
- rslearn/data_sources/utils.py +204 -64
- rslearn/data_sources/worldcereal.py +1 -1
- rslearn/data_sources/worldcover.py +1 -1
- rslearn/data_sources/worldpop.py +1 -1
- rslearn/data_sources/xyz_tiles.py +5 -9
- rslearn/dataset/materialize.py +5 -1
- rslearn/models/clay/clay.py +3 -3
- rslearn/models/concatenate_features.py +6 -1
- rslearn/models/detr/detr.py +4 -1
- rslearn/models/dinov3.py +0 -1
- rslearn/models/olmoearth_pretrain/model.py +3 -1
- rslearn/models/pooling_decoder.py +1 -1
- rslearn/models/prithvi.py +0 -1
- rslearn/models/simple_time_series.py +97 -35
- rslearn/train/{all_patches_dataset.py → all_crops_dataset.py} +120 -117
- rslearn/train/data_module.py +32 -27
- rslearn/train/dataset.py +260 -117
- rslearn/train/dataset_index.py +156 -0
- rslearn/train/lightning_module.py +1 -1
- rslearn/train/model_context.py +19 -3
- rslearn/train/prediction_writer.py +69 -41
- rslearn/train/tasks/classification.py +1 -1
- rslearn/train/tasks/detection.py +5 -5
- rslearn/train/tasks/per_pixel_regression.py +13 -13
- rslearn/train/tasks/regression.py +1 -1
- rslearn/train/tasks/segmentation.py +26 -13
- rslearn/train/transforms/concatenate.py +17 -27
- rslearn/train/transforms/crop.py +8 -19
- rslearn/train/transforms/flip.py +4 -10
- rslearn/train/transforms/mask.py +9 -15
- rslearn/train/transforms/normalize.py +31 -82
- rslearn/train/transforms/pad.py +7 -13
- rslearn/train/transforms/resize.py +5 -22
- rslearn/train/transforms/select_bands.py +16 -36
- rslearn/train/transforms/sentinel1.py +4 -16
- rslearn/utils/__init__.py +2 -0
- rslearn/utils/geometry.py +21 -0
- rslearn/utils/m2m_api.py +251 -0
- rslearn/utils/retry_session.py +43 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/METADATA +6 -3
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/RECORD +71 -66
- rslearn/data_sources/earthdata_srtm.py +0 -282
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/WHEEL +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.25.dist-info → rslearn-0.0.27.dist-info}/top_level.txt +0 -0
rslearn/train/transforms/mask.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""Mask transform."""
|
|
2
2
|
|
|
3
|
-
import torch
|
|
4
|
-
|
|
5
3
|
from rslearn.train.model_context import RasterImage
|
|
6
4
|
from rslearn.train.transforms.transform import Transform, read_selector
|
|
7
5
|
|
|
@@ -32,9 +30,7 @@ class Mask(Transform):
|
|
|
32
30
|
self.mask_selector = mask_selector
|
|
33
31
|
self.mask_value = mask_value
|
|
34
32
|
|
|
35
|
-
def apply_image(
|
|
36
|
-
self, image: torch.Tensor | RasterImage, mask: torch.Tensor | RasterImage
|
|
37
|
-
) -> torch.Tensor | RasterImage:
|
|
33
|
+
def apply_image(self, image: RasterImage, mask: RasterImage) -> RasterImage:
|
|
38
34
|
"""Apply the mask on the image.
|
|
39
35
|
|
|
40
36
|
Args:
|
|
@@ -44,21 +40,19 @@ class Mask(Transform):
|
|
|
44
40
|
Returns:
|
|
45
41
|
masked image
|
|
46
42
|
"""
|
|
47
|
-
#
|
|
48
|
-
|
|
49
|
-
mask = mask.image
|
|
43
|
+
# Extract the mask tensor (CTHW format)
|
|
44
|
+
mask_tensor = mask.image
|
|
50
45
|
|
|
51
|
-
|
|
52
|
-
|
|
46
|
+
# Tile the mask to have same number of bands (C dimension) as the image.
|
|
47
|
+
if image.shape[0] != mask_tensor.shape[0]:
|
|
48
|
+
if mask_tensor.shape[0] != 1:
|
|
53
49
|
raise ValueError(
|
|
54
50
|
"expected mask to either have same bands as image, or one band"
|
|
55
51
|
)
|
|
56
|
-
|
|
52
|
+
# Repeat along C dimension, keep T, H, W the same
|
|
53
|
+
mask_tensor = mask_tensor.repeat(image.shape[0], 1, 1, 1)
|
|
57
54
|
|
|
58
|
-
|
|
59
|
-
image[mask == 0] = self.mask_value
|
|
60
|
-
else:
|
|
61
|
-
image.image[mask == 0] = self.mask_value
|
|
55
|
+
image.image[mask_tensor == 0] = self.mask_value
|
|
62
56
|
return image
|
|
63
57
|
|
|
64
58
|
def forward(self, input_dict: dict, target_dict: dict) -> tuple[dict, dict]:
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Normalization transforms."""
|
|
2
2
|
|
|
3
|
+
import warnings
|
|
3
4
|
from typing import Any
|
|
4
5
|
|
|
5
6
|
import torch
|
|
@@ -35,14 +36,17 @@ class Normalize(Transform):
|
|
|
35
36
|
bands: optionally restrict the normalization to these band indices. If set,
|
|
36
37
|
mean and std must either be one value, or have length equal to the
|
|
37
38
|
number of band indices passed here.
|
|
38
|
-
num_bands:
|
|
39
|
-
in a time series. If set, then the bands list is repeated for each
|
|
40
|
-
image, e.g. if bands=[2] then we apply normalization on images[2],
|
|
41
|
-
images[2+num_bands], images[2+num_bands*2], etc. Or if the bands list
|
|
42
|
-
is not set, then we apply the mean and std on each image in the time
|
|
43
|
-
series.
|
|
39
|
+
num_bands: deprecated, no longer used. Will be removed after 2026-04-01.
|
|
44
40
|
"""
|
|
45
41
|
super().__init__()
|
|
42
|
+
|
|
43
|
+
if num_bands is not None:
|
|
44
|
+
warnings.warn(
|
|
45
|
+
"num_bands is deprecated and no longer used. "
|
|
46
|
+
"It will be removed after 2026-04-01.",
|
|
47
|
+
FutureWarning,
|
|
48
|
+
)
|
|
49
|
+
|
|
46
50
|
self.mean = torch.tensor(mean)
|
|
47
51
|
self.std = torch.tensor(std)
|
|
48
52
|
|
|
@@ -55,92 +59,37 @@ class Normalize(Transform):
|
|
|
55
59
|
|
|
56
60
|
self.selectors = selectors
|
|
57
61
|
self.bands = torch.tensor(bands) if bands is not None else None
|
|
58
|
-
self.num_bands = num_bands
|
|
59
62
|
|
|
60
|
-
def apply_image(
|
|
61
|
-
self, image: torch.Tensor | RasterImage
|
|
62
|
-
) -> torch.Tensor | RasterImage:
|
|
63
|
+
def apply_image(self, image: RasterImage) -> RasterImage:
|
|
63
64
|
"""Normalize the specified image.
|
|
64
65
|
|
|
65
66
|
Args:
|
|
66
67
|
image: the image to transform.
|
|
67
68
|
"""
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
#
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
if len(self.mean.shape) == 0:
|
|
77
|
-
return self.mean, self.std
|
|
78
|
-
if num_bands is None:
|
|
79
|
-
return self.mean, self.std
|
|
80
|
-
num_images = image_channels // num_bands
|
|
81
|
-
if is_raster_image:
|
|
82
|
-
# add an extra T dimension, CTHW
|
|
83
|
-
return self.mean.repeat(num_images)[
|
|
84
|
-
:, None, None, None
|
|
85
|
-
], self.std.repeat(num_images)[:, None, None, None]
|
|
86
|
-
else:
|
|
87
|
-
# add an extra T dimension, CTHW
|
|
88
|
-
return self.mean.repeat(num_images)[:, None, None], self.std.repeat(
|
|
89
|
-
num_images
|
|
90
|
-
)[:, None, None]
|
|
69
|
+
# Get mean/std with singleton dims for broadcasting over CTHW.
|
|
70
|
+
if len(self.mean.shape) == 0:
|
|
71
|
+
# Scalar - broadcasts naturally.
|
|
72
|
+
mean, std = self.mean, self.std
|
|
73
|
+
else:
|
|
74
|
+
# Vector of length C - add singleton dims for T, H, W.
|
|
75
|
+
mean = self.mean[:, None, None, None]
|
|
76
|
+
std = self.std[:, None, None, None]
|
|
91
77
|
|
|
92
78
|
if self.bands is not None:
|
|
93
|
-
#
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
[
|
|
101
|
-
band_indices + image_idx * self.num_bands
|
|
102
|
-
for image_idx in range(num_images)
|
|
103
|
-
],
|
|
104
|
-
dim=0,
|
|
79
|
+
# Normalize only specific band indices.
|
|
80
|
+
image.image[self.bands] = (image.image[self.bands] - mean) / std
|
|
81
|
+
if self.valid_min is not None:
|
|
82
|
+
image.image[self.bands] = torch.clamp(
|
|
83
|
+
image.image[self.bands],
|
|
84
|
+
min=self.valid_min,
|
|
85
|
+
max=self.valid_max,
|
|
105
86
|
)
|
|
106
|
-
|
|
107
|
-
# We use len(self.bands) here because that is how many bands per timestep
|
|
108
|
-
# we are actually processing with the mean/std.
|
|
109
|
-
mean, std = _repeat_mean_and_std(
|
|
110
|
-
image_channels=len(band_indices),
|
|
111
|
-
num_bands=len(self.bands),
|
|
112
|
-
is_raster_image=isinstance(image, RasterImage),
|
|
113
|
-
)
|
|
114
|
-
if isinstance(image, torch.Tensor):
|
|
115
|
-
image[band_indices] = (image[band_indices] - mean) / std
|
|
116
|
-
if self.valid_min is not None:
|
|
117
|
-
image[band_indices] = torch.clamp(
|
|
118
|
-
image[band_indices], min=self.valid_min, max=self.valid_max
|
|
119
|
-
)
|
|
120
|
-
else:
|
|
121
|
-
image.image[band_indices] = (image.image[band_indices] - mean) / std
|
|
122
|
-
if self.valid_min is not None:
|
|
123
|
-
image.image[band_indices] = torch.clamp(
|
|
124
|
-
image.image[band_indices],
|
|
125
|
-
min=self.valid_min,
|
|
126
|
-
max=self.valid_max,
|
|
127
|
-
)
|
|
128
87
|
else:
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
if isinstance(image, torch.Tensor):
|
|
135
|
-
image = (image - mean) / std
|
|
136
|
-
if self.valid_min is not None:
|
|
137
|
-
image = torch.clamp(image, min=self.valid_min, max=self.valid_max)
|
|
138
|
-
else:
|
|
139
|
-
image.image = (image.image - mean) / std
|
|
140
|
-
if self.valid_min is not None:
|
|
141
|
-
image.image = torch.clamp(
|
|
142
|
-
image.image, min=self.valid_min, max=self.valid_max
|
|
143
|
-
)
|
|
88
|
+
image.image = (image.image - mean) / std
|
|
89
|
+
if self.valid_min is not None:
|
|
90
|
+
image.image = torch.clamp(
|
|
91
|
+
image.image, min=self.valid_min, max=self.valid_max
|
|
92
|
+
)
|
|
144
93
|
return image
|
|
145
94
|
|
|
146
95
|
def forward(
|
rslearn/train/transforms/pad.py
CHANGED
|
@@ -50,9 +50,7 @@ class Pad(Transform):
|
|
|
50
50
|
"""
|
|
51
51
|
return {"size": torch.randint(low=self.size[0], high=self.size[1], size=())}
|
|
52
52
|
|
|
53
|
-
def apply_image(
|
|
54
|
-
self, image: RasterImage | torch.Tensor, state: dict[str, bool]
|
|
55
|
-
) -> RasterImage | torch.Tensor:
|
|
53
|
+
def apply_image(self, image: RasterImage, state: dict[str, bool]) -> RasterImage:
|
|
56
54
|
"""Apply the sampled state on the specified image.
|
|
57
55
|
|
|
58
56
|
Args:
|
|
@@ -105,16 +103,12 @@ class Pad(Transform):
|
|
|
105
103
|
horizontal_pad = (horizontal_half, horizontal_extra - horizontal_half)
|
|
106
104
|
vertical_pad = (vertical_half, vertical_extra - vertical_half)
|
|
107
105
|
|
|
108
|
-
|
|
109
|
-
image.image
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
image.image
|
|
113
|
-
|
|
114
|
-
)
|
|
115
|
-
else:
|
|
116
|
-
image = apply_padding(image, True, horizontal_pad[0], horizontal_pad[1])
|
|
117
|
-
image = apply_padding(image, False, vertical_pad[0], vertical_pad[1])
|
|
106
|
+
image.image = apply_padding(
|
|
107
|
+
image.image, True, horizontal_pad[0], horizontal_pad[1]
|
|
108
|
+
)
|
|
109
|
+
image.image = apply_padding(
|
|
110
|
+
image.image, False, vertical_pad[0], vertical_pad[1]
|
|
111
|
+
)
|
|
118
112
|
return image
|
|
119
113
|
|
|
120
114
|
def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
|
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
|
-
import torch
|
|
6
5
|
import torchvision
|
|
7
6
|
from torchvision.transforms import InterpolationMode
|
|
8
7
|
|
|
@@ -40,32 +39,16 @@ class Resize(Transform):
|
|
|
40
39
|
self.selectors = selectors
|
|
41
40
|
self.interpolation = INTERPOLATION_MODES[interpolation]
|
|
42
41
|
|
|
43
|
-
def apply_resize(
|
|
44
|
-
self, image: torch.Tensor | RasterImage
|
|
45
|
-
) -> torch.Tensor | RasterImage:
|
|
42
|
+
def apply_resize(self, image: RasterImage) -> RasterImage:
|
|
46
43
|
"""Apply resizing on the specified image.
|
|
47
44
|
|
|
48
|
-
If the image is 2D, it is unsqueezed to 3D and then squeezed
|
|
49
|
-
back after resizing.
|
|
50
|
-
|
|
51
45
|
Args:
|
|
52
46
|
image: the image to transform.
|
|
53
47
|
"""
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
image, self.target_size, self.interpolation
|
|
59
|
-
)
|
|
60
|
-
return result.squeeze(0) # (1, H, W) -> (H, W)
|
|
61
|
-
return torchvision.transforms.functional.resize(
|
|
62
|
-
image, self.target_size, self.interpolation
|
|
63
|
-
)
|
|
64
|
-
else:
|
|
65
|
-
image.image = torchvision.transforms.functional.resize(
|
|
66
|
-
image.image, self.target_size, self.interpolation
|
|
67
|
-
)
|
|
68
|
-
return image
|
|
48
|
+
image.image = torchvision.transforms.functional.resize(
|
|
49
|
+
image.image, self.target_size, self.interpolation
|
|
50
|
+
)
|
|
51
|
+
return image
|
|
69
52
|
|
|
70
53
|
def forward(
|
|
71
54
|
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
"""The SelectBands transform."""
|
|
2
2
|
|
|
3
|
+
import warnings
|
|
3
4
|
from typing import Any
|
|
4
5
|
|
|
5
|
-
from rslearn.train.model_context import RasterImage
|
|
6
|
-
|
|
7
6
|
from .transform import Transform, read_selector, write_selector
|
|
8
7
|
|
|
9
8
|
|
|
@@ -17,60 +16,41 @@ class SelectBands(Transform):
|
|
|
17
16
|
output_selector: str = "image",
|
|
18
17
|
num_bands_per_timestep: int | None = None,
|
|
19
18
|
):
|
|
20
|
-
"""Initialize a new
|
|
19
|
+
"""Initialize a new SelectBands.
|
|
21
20
|
|
|
22
21
|
Args:
|
|
23
|
-
band_indices: the bands to select.
|
|
22
|
+
band_indices: the bands to select from the channel dimension.
|
|
24
23
|
input_selector: the selector to read the input image.
|
|
25
24
|
output_selector: the output selector under which to save the output image.
|
|
26
|
-
num_bands_per_timestep:
|
|
27
|
-
|
|
28
|
-
band_indices are selected for each image in the time series.
|
|
25
|
+
num_bands_per_timestep: deprecated, no longer used. Will be removed after
|
|
26
|
+
2026-04-01.
|
|
29
27
|
"""
|
|
30
28
|
super().__init__()
|
|
29
|
+
|
|
30
|
+
if num_bands_per_timestep is not None:
|
|
31
|
+
warnings.warn(
|
|
32
|
+
"num_bands_per_timestep is deprecated and no longer used. "
|
|
33
|
+
"It will be removed after 2026-04-01.",
|
|
34
|
+
FutureWarning,
|
|
35
|
+
)
|
|
36
|
+
|
|
31
37
|
self.input_selector = input_selector
|
|
32
38
|
self.output_selector = output_selector
|
|
33
39
|
self.band_indices = band_indices
|
|
34
|
-
self.num_bands_per_timestep = num_bands_per_timestep
|
|
35
40
|
|
|
36
41
|
def forward(
|
|
37
42
|
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
38
43
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
39
|
-
"""Apply
|
|
44
|
+
"""Apply band selection over the inputs and targets.
|
|
40
45
|
|
|
41
46
|
Args:
|
|
42
47
|
input_dict: the input
|
|
43
48
|
target_dict: the target
|
|
44
49
|
|
|
45
50
|
Returns:
|
|
46
|
-
|
|
51
|
+
(input_dicts, target_dicts) tuple with selected bands
|
|
47
52
|
"""
|
|
48
53
|
image = read_selector(input_dict, target_dict, self.input_selector)
|
|
49
|
-
|
|
50
|
-
self.num_bands_per_timestep
|
|
51
|
-
if self.num_bands_per_timestep is not None
|
|
52
|
-
else image.shape[0]
|
|
53
|
-
)
|
|
54
|
-
if isinstance(image, RasterImage):
|
|
55
|
-
assert num_bands_per_timestep == image.shape[0], (
|
|
56
|
-
"Expect a seperate dimension for timesteps in RasterImages."
|
|
57
|
-
)
|
|
58
|
-
|
|
59
|
-
if image.shape[0] % num_bands_per_timestep != 0:
|
|
60
|
-
raise ValueError(
|
|
61
|
-
f"channel dimension {image.shape[0]} is not multiple of bands per timestep {num_bands_per_timestep}"
|
|
62
|
-
)
|
|
63
|
-
|
|
64
|
-
# Copy the band indices for each timestep in the input.
|
|
65
|
-
wanted_bands: list[int] = []
|
|
66
|
-
for start_channel_idx in range(0, image.shape[0], num_bands_per_timestep):
|
|
67
|
-
wanted_bands.extend(
|
|
68
|
-
[(start_channel_idx + band_idx) for band_idx in self.band_indices]
|
|
69
|
-
)
|
|
70
|
-
|
|
71
|
-
if isinstance(image, RasterImage):
|
|
72
|
-
image.image = image.image[wanted_bands]
|
|
73
|
-
else:
|
|
74
|
-
image = image[wanted_bands]
|
|
54
|
+
image.image = image.image[self.band_indices]
|
|
75
55
|
write_selector(input_dict, target_dict, self.output_selector, image)
|
|
76
56
|
return input_dict, target_dict
|
|
@@ -33,31 +33,19 @@ class Sentinel1ToDecibels(Transform):
|
|
|
33
33
|
self.from_decibels = from_decibels
|
|
34
34
|
self.epsilon = epsilon
|
|
35
35
|
|
|
36
|
-
def apply_image(
|
|
37
|
-
self, image: torch.Tensor | RasterImage
|
|
38
|
-
) -> torch.Tensor | RasterImage:
|
|
36
|
+
def apply_image(self, image: RasterImage) -> RasterImage:
|
|
39
37
|
"""Normalize the specified image.
|
|
40
38
|
|
|
41
39
|
Args:
|
|
42
40
|
image: the image to transform.
|
|
43
41
|
"""
|
|
44
|
-
if isinstance(image, torch.Tensor):
|
|
45
|
-
image_to_process = image
|
|
46
|
-
else:
|
|
47
|
-
image_to_process = image.image
|
|
48
42
|
if self.from_decibels:
|
|
49
43
|
# Decibels to linear scale.
|
|
50
|
-
|
|
44
|
+
image.image = torch.pow(10.0, image.image / 10.0)
|
|
51
45
|
else:
|
|
52
46
|
# Linear scale to decibels.
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
)
|
|
56
|
-
if isinstance(image, torch.Tensor):
|
|
57
|
-
return image_to_process
|
|
58
|
-
else:
|
|
59
|
-
image.image = image_to_process
|
|
60
|
-
return image
|
|
47
|
+
image.image = 10 * torch.log10(torch.clamp(image.image, min=self.epsilon))
|
|
48
|
+
return image
|
|
61
49
|
|
|
62
50
|
def forward(
|
|
63
51
|
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
rslearn/utils/__init__.py
CHANGED
|
@@ -7,6 +7,7 @@ from .geometry import (
|
|
|
7
7
|
PixelBounds,
|
|
8
8
|
Projection,
|
|
9
9
|
STGeometry,
|
|
10
|
+
get_global_raster_bounds,
|
|
10
11
|
is_same_resolution,
|
|
11
12
|
shp_intersects,
|
|
12
13
|
)
|
|
@@ -23,6 +24,7 @@ __all__ = (
|
|
|
23
24
|
"Projection",
|
|
24
25
|
"STGeometry",
|
|
25
26
|
"daterange",
|
|
27
|
+
"get_global_raster_bounds",
|
|
26
28
|
"get_utm_ups_crs",
|
|
27
29
|
"is_same_resolution",
|
|
28
30
|
"logger",
|
rslearn/utils/geometry.py
CHANGED
|
@@ -116,6 +116,27 @@ class Projection:
|
|
|
116
116
|
WGS84_PROJECTION = Projection(CRS.from_epsg(WGS84_EPSG), 1, 1)
|
|
117
117
|
|
|
118
118
|
|
|
119
|
+
def get_global_raster_bounds(projection: Projection) -> PixelBounds:
|
|
120
|
+
"""Get very large pixel bounds for a global raster in the given projection.
|
|
121
|
+
|
|
122
|
+
This is useful for data sources that cover the entire world and don't want to
|
|
123
|
+
compute exact bounds in arbitrary projections (which can fail for projections
|
|
124
|
+
like UTM that only cover part of the world).
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
projection: the projection to get bounds in.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Pixel bounds that will intersect with any reasonable window. We assume that the
|
|
131
|
+
absolute value of CRS coordinates is at most 2^32, and adjust it based on the
|
|
132
|
+
resolution in the Projection in case very fine-grained resolutions are used.
|
|
133
|
+
"""
|
|
134
|
+
crs_bound = 2**32
|
|
135
|
+
pixel_bound_x = int(crs_bound / abs(projection.x_resolution))
|
|
136
|
+
pixel_bound_y = int(crs_bound / abs(projection.y_resolution))
|
|
137
|
+
return (-pixel_bound_x, -pixel_bound_y, pixel_bound_x, pixel_bound_y)
|
|
138
|
+
|
|
139
|
+
|
|
119
140
|
class ResolutionFactor:
|
|
120
141
|
"""Multiplier for the resolution in a Projection.
|
|
121
142
|
|