rslearn 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +2 -9
- rslearn/config/dataset.py +15 -16
- rslearn/dataset/dataset.py +28 -22
- rslearn/lightning_cli.py +22 -11
- rslearn/main.py +1 -1
- rslearn/models/anysat.py +35 -33
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clip.py +5 -2
- rslearn/models/component.py +12 -0
- rslearn/models/croma.py +11 -3
- rslearn/models/dinov3.py +2 -1
- rslearn/models/faster_rcnn.py +2 -1
- rslearn/models/galileo/galileo.py +58 -31
- rslearn/models/module_wrapper.py +6 -1
- rslearn/models/molmo.py +4 -2
- rslearn/models/olmoearth_pretrain/model.py +206 -51
- rslearn/models/olmoearth_pretrain/norm.py +5 -3
- rslearn/models/panopticon.py +3 -1
- rslearn/models/presto/presto.py +45 -15
- rslearn/models/prithvi.py +9 -7
- rslearn/models/sam2_enc.py +3 -1
- rslearn/models/satlaspretrain.py +4 -1
- rslearn/models/simple_time_series.py +43 -17
- rslearn/models/ssl4eo_s12.py +19 -14
- rslearn/models/swin.py +3 -1
- rslearn/models/terramind.py +5 -4
- rslearn/train/all_patches_dataset.py +96 -28
- rslearn/train/dataset.py +102 -53
- rslearn/train/model_context.py +35 -1
- rslearn/train/scheduler.py +15 -0
- rslearn/train/tasks/classification.py +8 -2
- rslearn/train/tasks/detection.py +3 -2
- rslearn/train/tasks/multi_task.py +2 -3
- rslearn/train/tasks/per_pixel_regression.py +14 -5
- rslearn/train/tasks/regression.py +8 -2
- rslearn/train/tasks/segmentation.py +13 -4
- rslearn/train/tasks/task.py +2 -2
- rslearn/train/transforms/concatenate.py +45 -5
- rslearn/train/transforms/crop.py +22 -8
- rslearn/train/transforms/flip.py +13 -5
- rslearn/train/transforms/mask.py +11 -2
- rslearn/train/transforms/normalize.py +46 -15
- rslearn/train/transforms/pad.py +15 -3
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +11 -2
- rslearn/train/transforms/sentinel1.py +18 -3
- rslearn/utils/geometry.py +73 -0
- rslearn/utils/jsonargparse.py +66 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/METADATA +1 -1
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/RECORD +55 -53
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/WHEEL +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/top_level.txt +0 -0
rslearn/train/transforms/crop.py
CHANGED
|
@@ -5,6 +5,8 @@ from typing import Any
|
|
|
5
5
|
import torch
|
|
6
6
|
import torchvision
|
|
7
7
|
|
|
8
|
+
from rslearn.train.model_context import RasterImage
|
|
9
|
+
|
|
8
10
|
from .transform import Transform, read_selector
|
|
9
11
|
|
|
10
12
|
|
|
@@ -69,7 +71,9 @@ class Crop(Transform):
|
|
|
69
71
|
"remove_from_top": remove_from_top,
|
|
70
72
|
}
|
|
71
73
|
|
|
72
|
-
def apply_image(
|
|
74
|
+
def apply_image(
|
|
75
|
+
self, image: RasterImage | torch.Tensor, state: dict[str, Any]
|
|
76
|
+
) -> RasterImage | torch.Tensor:
|
|
73
77
|
"""Apply the sampled state on the specified image.
|
|
74
78
|
|
|
75
79
|
Args:
|
|
@@ -80,13 +84,23 @@ class Crop(Transform):
|
|
|
80
84
|
crop_size = state["crop_size"] * image.shape[-1] // image_shape[1]
|
|
81
85
|
remove_from_left = state["remove_from_left"] * image.shape[-1] // image_shape[1]
|
|
82
86
|
remove_from_top = state["remove_from_top"] * image.shape[-2] // image_shape[0]
|
|
83
|
-
|
|
84
|
-
image
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
87
|
+
if isinstance(image, RasterImage):
|
|
88
|
+
image.image = torchvision.transforms.functional.crop(
|
|
89
|
+
image.image,
|
|
90
|
+
top=remove_from_top,
|
|
91
|
+
left=remove_from_left,
|
|
92
|
+
height=crop_size,
|
|
93
|
+
width=crop_size,
|
|
94
|
+
)
|
|
95
|
+
else:
|
|
96
|
+
image = torchvision.transforms.functional.crop(
|
|
97
|
+
image,
|
|
98
|
+
top=remove_from_top,
|
|
99
|
+
left=remove_from_left,
|
|
100
|
+
height=crop_size,
|
|
101
|
+
width=crop_size,
|
|
102
|
+
)
|
|
103
|
+
return image
|
|
90
104
|
|
|
91
105
|
def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
|
|
92
106
|
"""Apply the sampled state on the specified image.
|
rslearn/train/transforms/flip.py
CHANGED
|
@@ -4,6 +4,8 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from rslearn.train.model_context import RasterImage
|
|
8
|
+
|
|
7
9
|
from .transform import Transform
|
|
8
10
|
|
|
9
11
|
|
|
@@ -48,17 +50,23 @@ class Flip(Transform):
|
|
|
48
50
|
"vertical": vertical,
|
|
49
51
|
}
|
|
50
52
|
|
|
51
|
-
def apply_image(self, image:
|
|
53
|
+
def apply_image(self, image: RasterImage, state: dict[str, bool]) -> RasterImage:
|
|
52
54
|
"""Apply the sampled state on the specified image.
|
|
53
55
|
|
|
54
56
|
Args:
|
|
55
57
|
image: the image to transform.
|
|
56
58
|
state: the sampled state.
|
|
57
59
|
"""
|
|
58
|
-
if
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
60
|
+
if isinstance(image, RasterImage):
|
|
61
|
+
if state["horizontal"]:
|
|
62
|
+
image.image = torch.flip(image.image, dims=[-1])
|
|
63
|
+
if state["vertical"]:
|
|
64
|
+
image.image = torch.flip(image.image, dims=[-2])
|
|
65
|
+
elif isinstance(image, torch.Tensor):
|
|
66
|
+
if state["horizontal"]:
|
|
67
|
+
image = torch.flip(image, dims=[-1])
|
|
68
|
+
if state["vertical"]:
|
|
69
|
+
image = torch.flip(image, dims=[-2])
|
|
62
70
|
return image
|
|
63
71
|
|
|
64
72
|
def apply_boxes(
|
rslearn/train/transforms/mask.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
+
from rslearn.train.model_context import RasterImage
|
|
5
6
|
from rslearn.train.transforms.transform import Transform, read_selector
|
|
6
7
|
|
|
7
8
|
|
|
@@ -31,7 +32,9 @@ class Mask(Transform):
|
|
|
31
32
|
self.mask_selector = mask_selector
|
|
32
33
|
self.mask_value = mask_value
|
|
33
34
|
|
|
34
|
-
def apply_image(
|
|
35
|
+
def apply_image(
|
|
36
|
+
self, image: torch.Tensor | RasterImage, mask: torch.Tensor | RasterImage
|
|
37
|
+
) -> torch.Tensor | RasterImage:
|
|
35
38
|
"""Apply the mask on the image.
|
|
36
39
|
|
|
37
40
|
Args:
|
|
@@ -42,6 +45,9 @@ class Mask(Transform):
|
|
|
42
45
|
masked image
|
|
43
46
|
"""
|
|
44
47
|
# Tile the mask to have same number of bands as the image.
|
|
48
|
+
if isinstance(mask, RasterImage):
|
|
49
|
+
mask = mask.image
|
|
50
|
+
|
|
45
51
|
if image.shape[0] != mask.shape[0]:
|
|
46
52
|
if mask.shape[0] != 1:
|
|
47
53
|
raise ValueError(
|
|
@@ -49,7 +55,10 @@ class Mask(Transform):
|
|
|
49
55
|
)
|
|
50
56
|
mask = mask.repeat(image.shape[0], 1, 1)
|
|
51
57
|
|
|
52
|
-
image
|
|
58
|
+
if isinstance(image, torch.Tensor):
|
|
59
|
+
image[mask == 0] = self.mask_value
|
|
60
|
+
else:
|
|
61
|
+
image.image[mask == 0] = self.mask_value
|
|
53
62
|
return image
|
|
54
63
|
|
|
55
64
|
def forward(self, input_dict: dict, target_dict: dict) -> tuple[dict, dict]:
|
|
@@ -4,6 +4,8 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from rslearn.train.model_context import RasterImage
|
|
8
|
+
|
|
7
9
|
from .transform import Transform
|
|
8
10
|
|
|
9
11
|
|
|
@@ -55,7 +57,9 @@ class Normalize(Transform):
|
|
|
55
57
|
self.bands = torch.tensor(bands) if bands is not None else None
|
|
56
58
|
self.num_bands = num_bands
|
|
57
59
|
|
|
58
|
-
def apply_image(
|
|
60
|
+
def apply_image(
|
|
61
|
+
self, image: torch.Tensor | RasterImage
|
|
62
|
+
) -> torch.Tensor | RasterImage:
|
|
59
63
|
"""Normalize the specified image.
|
|
60
64
|
|
|
61
65
|
Args:
|
|
@@ -63,7 +67,7 @@ class Normalize(Transform):
|
|
|
63
67
|
"""
|
|
64
68
|
|
|
65
69
|
def _repeat_mean_and_std(
|
|
66
|
-
image_channels: int, num_bands: int | None
|
|
70
|
+
image_channels: int, num_bands: int | None, is_raster_image: bool
|
|
67
71
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
68
72
|
"""Get mean and std tensor that are suitable for applying on the image."""
|
|
69
73
|
# We only need to repeat the tensor if both of these are true:
|
|
@@ -74,9 +78,16 @@ class Normalize(Transform):
|
|
|
74
78
|
if num_bands is None:
|
|
75
79
|
return self.mean, self.std
|
|
76
80
|
num_images = image_channels // num_bands
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
81
|
+
if is_raster_image:
|
|
82
|
+
# add an extra T dimension, CTHW
|
|
83
|
+
return self.mean.repeat(num_images)[
|
|
84
|
+
:, None, None, None
|
|
85
|
+
], self.std.repeat(num_images)[:, None, None, None]
|
|
86
|
+
else:
|
|
87
|
+
# add an extra T dimension, CTHW
|
|
88
|
+
return self.mean.repeat(num_images)[:, None, None], self.std.repeat(
|
|
89
|
+
num_images
|
|
90
|
+
)[:, None, None]
|
|
80
91
|
|
|
81
92
|
if self.bands is not None:
|
|
82
93
|
# User has provided band indices to normalize.
|
|
@@ -96,20 +107,40 @@ class Normalize(Transform):
|
|
|
96
107
|
# We use len(self.bands) here because that is how many bands per timestep
|
|
97
108
|
# we are actually processing with the mean/std.
|
|
98
109
|
mean, std = _repeat_mean_and_std(
|
|
99
|
-
image_channels=len(band_indices),
|
|
110
|
+
image_channels=len(band_indices),
|
|
111
|
+
num_bands=len(self.bands),
|
|
112
|
+
is_raster_image=isinstance(image, RasterImage),
|
|
100
113
|
)
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
image[band_indices]
|
|
105
|
-
|
|
114
|
+
if isinstance(image, torch.Tensor):
|
|
115
|
+
image[band_indices] = (image[band_indices] - mean) / std
|
|
116
|
+
if self.valid_min is not None:
|
|
117
|
+
image[band_indices] = torch.clamp(
|
|
118
|
+
image[band_indices], min=self.valid_min, max=self.valid_max
|
|
119
|
+
)
|
|
120
|
+
else:
|
|
121
|
+
image.image[band_indices] = (image.image[band_indices] - mean) / std
|
|
122
|
+
if self.valid_min is not None:
|
|
123
|
+
image.image[band_indices] = torch.clamp(
|
|
124
|
+
image.image[band_indices],
|
|
125
|
+
min=self.valid_min,
|
|
126
|
+
max=self.valid_max,
|
|
127
|
+
)
|
|
106
128
|
else:
|
|
107
129
|
mean, std = _repeat_mean_and_std(
|
|
108
|
-
image_channels=image.shape[0],
|
|
130
|
+
image_channels=image.shape[0],
|
|
131
|
+
num_bands=self.num_bands,
|
|
132
|
+
is_raster_image=isinstance(image, RasterImage),
|
|
109
133
|
)
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
134
|
+
if isinstance(image, torch.Tensor):
|
|
135
|
+
image = (image - mean) / std
|
|
136
|
+
if self.valid_min is not None:
|
|
137
|
+
image = torch.clamp(image, min=self.valid_min, max=self.valid_max)
|
|
138
|
+
else:
|
|
139
|
+
image.image = (image.image - mean) / std
|
|
140
|
+
if self.valid_min is not None:
|
|
141
|
+
image.image = torch.clamp(
|
|
142
|
+
image.image, min=self.valid_min, max=self.valid_max
|
|
143
|
+
)
|
|
113
144
|
return image
|
|
114
145
|
|
|
115
146
|
def forward(
|
rslearn/train/transforms/pad.py
CHANGED
|
@@ -5,6 +5,8 @@ from typing import Any
|
|
|
5
5
|
import torch
|
|
6
6
|
import torchvision
|
|
7
7
|
|
|
8
|
+
from rslearn.train.model_context import RasterImage
|
|
9
|
+
|
|
8
10
|
from .transform import Transform
|
|
9
11
|
|
|
10
12
|
|
|
@@ -48,7 +50,9 @@ class Pad(Transform):
|
|
|
48
50
|
"""
|
|
49
51
|
return {"size": torch.randint(low=self.size[0], high=self.size[1], size=())}
|
|
50
52
|
|
|
51
|
-
def apply_image(
|
|
53
|
+
def apply_image(
|
|
54
|
+
self, image: RasterImage | torch.Tensor, state: dict[str, bool]
|
|
55
|
+
) -> RasterImage | torch.Tensor:
|
|
52
56
|
"""Apply the sampled state on the specified image.
|
|
53
57
|
|
|
54
58
|
Args:
|
|
@@ -101,8 +105,16 @@ class Pad(Transform):
|
|
|
101
105
|
horizontal_pad = (horizontal_half, horizontal_extra - horizontal_half)
|
|
102
106
|
vertical_pad = (vertical_half, vertical_extra - vertical_half)
|
|
103
107
|
|
|
104
|
-
|
|
105
|
-
|
|
108
|
+
if isinstance(image, RasterImage):
|
|
109
|
+
image.image = apply_padding(
|
|
110
|
+
image.image, True, horizontal_pad[0], horizontal_pad[1]
|
|
111
|
+
)
|
|
112
|
+
image.image = apply_padding(
|
|
113
|
+
image.image, False, vertical_pad[0], vertical_pad[1]
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
image = apply_padding(image, True, horizontal_pad[0], horizontal_pad[1])
|
|
117
|
+
image = apply_padding(image, False, vertical_pad[0], vertical_pad[1])
|
|
106
118
|
return image
|
|
107
119
|
|
|
108
120
|
def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Resize transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torchvision
|
|
7
|
+
from torchvision.transforms import InterpolationMode
|
|
8
|
+
|
|
9
|
+
from rslearn.train.model_context import RasterImage
|
|
10
|
+
|
|
11
|
+
from .transform import Transform
|
|
12
|
+
|
|
13
|
+
INTERPOLATION_MODES = {
|
|
14
|
+
"nearest": InterpolationMode.NEAREST,
|
|
15
|
+
"nearest_exact": InterpolationMode.NEAREST_EXACT,
|
|
16
|
+
"bilinear": InterpolationMode.BILINEAR,
|
|
17
|
+
"bicubic": InterpolationMode.BICUBIC,
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Resize(Transform):
|
|
22
|
+
"""Resizes inputs to a target size."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
target_size: tuple[int, int],
|
|
27
|
+
selectors: list[str] = [],
|
|
28
|
+
interpolation: str = "nearest",
|
|
29
|
+
):
|
|
30
|
+
"""Initialize a resize transform.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
target_size: the (height, width) to resize to.
|
|
34
|
+
selectors: items to transform.
|
|
35
|
+
interpolation: the interpolation mode to use for resizing.
|
|
36
|
+
Must be one of "nearest", "nearest_exact", "bilinear", or "bicubic".
|
|
37
|
+
"""
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.target_size = target_size
|
|
40
|
+
self.selectors = selectors
|
|
41
|
+
self.interpolation = INTERPOLATION_MODES[interpolation]
|
|
42
|
+
|
|
43
|
+
def apply_resize(
|
|
44
|
+
self, image: torch.Tensor | RasterImage
|
|
45
|
+
) -> torch.Tensor | RasterImage:
|
|
46
|
+
"""Apply resizing on the specified image.
|
|
47
|
+
|
|
48
|
+
If the image is 2D, it is unsqueezed to 3D and then squeezed
|
|
49
|
+
back after resizing.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
image: the image to transform.
|
|
53
|
+
"""
|
|
54
|
+
if isinstance(image, torch.Tensor):
|
|
55
|
+
if image.dim() == 2:
|
|
56
|
+
image = image.unsqueeze(0) # (H, W) -> (1, H, W)
|
|
57
|
+
result = torchvision.transforms.functional.resize(
|
|
58
|
+
image, self.target_size, self.interpolation
|
|
59
|
+
)
|
|
60
|
+
return result.squeeze(0) # (1, H, W) -> (H, W)
|
|
61
|
+
return torchvision.transforms.functional.resize(
|
|
62
|
+
image, self.target_size, self.interpolation
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
image.image = torchvision.transforms.functional.resize(
|
|
66
|
+
image.image, self.target_size, self.interpolation
|
|
67
|
+
)
|
|
68
|
+
return image
|
|
69
|
+
|
|
70
|
+
def forward(
|
|
71
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
72
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
73
|
+
"""Apply transform over the inputs and targets.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
input_dict: the input
|
|
77
|
+
target_dict: the target
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
transformed (input_dicts, target_dicts) tuple
|
|
81
|
+
"""
|
|
82
|
+
self.apply_fn(self.apply_resize, input_dict, target_dict, self.selectors)
|
|
83
|
+
return input_dict, target_dict
|
|
@@ -2,6 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
|
+
from rslearn.train.model_context import RasterImage
|
|
6
|
+
|
|
5
7
|
from .transform import Transform, read_selector, write_selector
|
|
6
8
|
|
|
7
9
|
|
|
@@ -49,6 +51,10 @@ class SelectBands(Transform):
|
|
|
49
51
|
if self.num_bands_per_timestep is not None
|
|
50
52
|
else image.shape[0]
|
|
51
53
|
)
|
|
54
|
+
if isinstance(image, RasterImage):
|
|
55
|
+
assert num_bands_per_timestep == image.shape[0], (
|
|
56
|
+
"Expect a seperate dimension for timesteps in RasterImages."
|
|
57
|
+
)
|
|
52
58
|
|
|
53
59
|
if image.shape[0] % num_bands_per_timestep != 0:
|
|
54
60
|
raise ValueError(
|
|
@@ -62,6 +68,9 @@ class SelectBands(Transform):
|
|
|
62
68
|
[(start_channel_idx + band_idx) for band_idx in self.band_indices]
|
|
63
69
|
)
|
|
64
70
|
|
|
65
|
-
|
|
66
|
-
|
|
71
|
+
if isinstance(image, RasterImage):
|
|
72
|
+
image.image = image.image[wanted_bands]
|
|
73
|
+
else:
|
|
74
|
+
image = image[wanted_bands]
|
|
75
|
+
write_selector(input_dict, target_dict, self.output_selector, image)
|
|
67
76
|
return input_dict, target_dict
|
|
@@ -4,6 +4,8 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from rslearn.train.model_context import RasterImage
|
|
8
|
+
|
|
7
9
|
from .transform import Transform
|
|
8
10
|
|
|
9
11
|
|
|
@@ -31,18 +33,31 @@ class Sentinel1ToDecibels(Transform):
|
|
|
31
33
|
self.from_decibels = from_decibels
|
|
32
34
|
self.epsilon = epsilon
|
|
33
35
|
|
|
34
|
-
def apply_image(
|
|
36
|
+
def apply_image(
|
|
37
|
+
self, image: torch.Tensor | RasterImage
|
|
38
|
+
) -> torch.Tensor | RasterImage:
|
|
35
39
|
"""Normalize the specified image.
|
|
36
40
|
|
|
37
41
|
Args:
|
|
38
42
|
image: the image to transform.
|
|
39
43
|
"""
|
|
44
|
+
if isinstance(image, torch.Tensor):
|
|
45
|
+
image_to_process = image
|
|
46
|
+
else:
|
|
47
|
+
image_to_process = image.image
|
|
40
48
|
if self.from_decibels:
|
|
41
49
|
# Decibels to linear scale.
|
|
42
|
-
|
|
50
|
+
image_to_process = torch.pow(10.0, image_to_process / 10.0)
|
|
43
51
|
else:
|
|
44
52
|
# Linear scale to decibels.
|
|
45
|
-
|
|
53
|
+
image_to_process = 10 * torch.log10(
|
|
54
|
+
torch.clamp(image_to_process, min=self.epsilon)
|
|
55
|
+
)
|
|
56
|
+
if isinstance(image, torch.Tensor):
|
|
57
|
+
return image_to_process
|
|
58
|
+
else:
|
|
59
|
+
image.image = image_to_process
|
|
60
|
+
return image
|
|
46
61
|
|
|
47
62
|
def forward(
|
|
48
63
|
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
rslearn/utils/geometry.py
CHANGED
|
@@ -116,6 +116,79 @@ class Projection:
|
|
|
116
116
|
WGS84_PROJECTION = Projection(CRS.from_epsg(WGS84_EPSG), 1, 1)
|
|
117
117
|
|
|
118
118
|
|
|
119
|
+
class ResolutionFactor:
|
|
120
|
+
"""Multiplier for the resolution in a Projection.
|
|
121
|
+
|
|
122
|
+
The multiplier is either an integer x, or the inverse of an integer (1/x).
|
|
123
|
+
|
|
124
|
+
Factors greater than 1 increase the projection_units/pixel resolution, increasing
|
|
125
|
+
the resolution (more pixels per projection unit). Factors less than 1 make it coarser
|
|
126
|
+
(less pixels).
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def __init__(self, numerator: int = 1, denominator: int = 1):
|
|
130
|
+
"""Create a new ResolutionFactor.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
numerator: the numerator of the fraction.
|
|
134
|
+
denominator: the denominator of the fraction. If set, numerator must be 1.
|
|
135
|
+
"""
|
|
136
|
+
if numerator != 1 and denominator != 1:
|
|
137
|
+
raise ValueError("one of numerator or denominator must be 1")
|
|
138
|
+
if not isinstance(numerator, int) or not isinstance(denominator, int):
|
|
139
|
+
raise ValueError("numerator and denominator must be integers")
|
|
140
|
+
if numerator < 1 or denominator < 1:
|
|
141
|
+
raise ValueError("numerator and denominator must be >= 1")
|
|
142
|
+
self.numerator = numerator
|
|
143
|
+
self.denominator = denominator
|
|
144
|
+
|
|
145
|
+
def multiply_projection(self, projection: Projection) -> Projection:
|
|
146
|
+
"""Multiply the projection by this factor."""
|
|
147
|
+
if self.denominator > 1:
|
|
148
|
+
return Projection(
|
|
149
|
+
projection.crs,
|
|
150
|
+
projection.x_resolution * self.denominator,
|
|
151
|
+
projection.y_resolution * self.denominator,
|
|
152
|
+
)
|
|
153
|
+
else:
|
|
154
|
+
return Projection(
|
|
155
|
+
projection.crs,
|
|
156
|
+
projection.x_resolution // self.numerator,
|
|
157
|
+
projection.y_resolution // self.numerator,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
def multiply_bounds(self, bounds: PixelBounds) -> PixelBounds:
|
|
161
|
+
"""Multiply the bounds by this factor.
|
|
162
|
+
|
|
163
|
+
When coarsening, the width and height of the given bounds must be a multiple of
|
|
164
|
+
the denominator.
|
|
165
|
+
"""
|
|
166
|
+
if self.denominator > 1:
|
|
167
|
+
# Verify the width and height are multiples of the denominator.
|
|
168
|
+
# Otherwise the new width and height is not an integer.
|
|
169
|
+
width = bounds[2] - bounds[0]
|
|
170
|
+
height = bounds[3] - bounds[1]
|
|
171
|
+
if width % self.denominator != 0 or height % self.denominator != 0:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"width {width} or height {height} is not a multiple of the resolution factor {self.denominator}"
|
|
174
|
+
)
|
|
175
|
+
# TODO: an offset could be introduced by bounds not being a multiple
|
|
176
|
+
# of the denominator -> will need to decide how to handle that.
|
|
177
|
+
return (
|
|
178
|
+
bounds[0] // self.denominator,
|
|
179
|
+
bounds[1] // self.denominator,
|
|
180
|
+
bounds[2] // self.denominator,
|
|
181
|
+
bounds[3] // self.denominator,
|
|
182
|
+
)
|
|
183
|
+
else:
|
|
184
|
+
return (
|
|
185
|
+
bounds[0] * self.numerator,
|
|
186
|
+
bounds[1] * self.numerator,
|
|
187
|
+
bounds[2] * self.numerator,
|
|
188
|
+
bounds[3] * self.numerator,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
|
|
119
192
|
class STGeometry:
|
|
120
193
|
"""A spatiotemporal geometry.
|
|
121
194
|
|
rslearn/utils/jsonargparse.py
CHANGED
|
@@ -8,6 +8,7 @@ from rasterio.crs import CRS
|
|
|
8
8
|
from upath import UPath
|
|
9
9
|
|
|
10
10
|
from rslearn.config.dataset import LayerConfig
|
|
11
|
+
from rslearn.utils.geometry import ResolutionFactor
|
|
11
12
|
|
|
12
13
|
if TYPE_CHECKING:
|
|
13
14
|
from rslearn.data_sources.data_source import DataSourceContext
|
|
@@ -91,6 +92,68 @@ def data_source_context_deserializer(v: dict[str, Any]) -> "DataSourceContext":
|
|
|
91
92
|
)
|
|
92
93
|
|
|
93
94
|
|
|
95
|
+
def resolution_factor_serializer(v: ResolutionFactor) -> str:
|
|
96
|
+
"""Serialize ResolutionFactor for jsonargparse.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
v: the ResolutionFactor object.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
the ResolutionFactor encoded to string
|
|
103
|
+
"""
|
|
104
|
+
if hasattr(v, "init_args"):
|
|
105
|
+
init_args = v.init_args
|
|
106
|
+
return f"{init_args.numerator}/{init_args.denominator}"
|
|
107
|
+
|
|
108
|
+
return f"{v.numerator}/{v.denominator}"
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def resolution_factor_deserializer(v: int | str | dict) -> ResolutionFactor:
|
|
112
|
+
"""Deserialize ResolutionFactor for jsonargparse.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
v: the encoded ResolutionFactor.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
the decoded ResolutionFactor object
|
|
119
|
+
"""
|
|
120
|
+
# Handle already-instantiated ResolutionFactor
|
|
121
|
+
if isinstance(v, ResolutionFactor):
|
|
122
|
+
return v
|
|
123
|
+
|
|
124
|
+
# Handle Namespace from class_path syntax (used during config save/validation)
|
|
125
|
+
if hasattr(v, "init_args"):
|
|
126
|
+
init_args = v.init_args
|
|
127
|
+
return ResolutionFactor(
|
|
128
|
+
numerator=init_args.numerator,
|
|
129
|
+
denominator=init_args.denominator,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Handle dict from class_path syntax in YAML config
|
|
133
|
+
if isinstance(v, dict) and "init_args" in v:
|
|
134
|
+
init_args = v["init_args"]
|
|
135
|
+
return ResolutionFactor(
|
|
136
|
+
numerator=init_args.get("numerator", 1),
|
|
137
|
+
denominator=init_args.get("denominator", 1),
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
if isinstance(v, int):
|
|
141
|
+
return ResolutionFactor(numerator=v)
|
|
142
|
+
elif isinstance(v, str):
|
|
143
|
+
parts = v.split("/")
|
|
144
|
+
if len(parts) == 1:
|
|
145
|
+
return ResolutionFactor(numerator=int(parts[0]))
|
|
146
|
+
elif len(parts) == 2:
|
|
147
|
+
return ResolutionFactor(
|
|
148
|
+
numerator=int(parts[0]),
|
|
149
|
+
denominator=int(parts[1]),
|
|
150
|
+
)
|
|
151
|
+
else:
|
|
152
|
+
raise ValueError("expected resolution factor to be of the form x or 1/x")
|
|
153
|
+
else:
|
|
154
|
+
raise ValueError("expected resolution factor to be str or int")
|
|
155
|
+
|
|
156
|
+
|
|
94
157
|
def init_jsonargparse() -> None:
|
|
95
158
|
"""Initialize custom jsonargparse serializers."""
|
|
96
159
|
global INITIALIZED
|
|
@@ -100,6 +163,9 @@ def init_jsonargparse() -> None:
|
|
|
100
163
|
jsonargparse.typing.register_type(
|
|
101
164
|
datetime, datetime_serializer, datetime_deserializer
|
|
102
165
|
)
|
|
166
|
+
jsonargparse.typing.register_type(
|
|
167
|
+
ResolutionFactor, resolution_factor_serializer, resolution_factor_deserializer
|
|
168
|
+
)
|
|
103
169
|
|
|
104
170
|
from rslearn.data_sources.data_source import DataSourceContext
|
|
105
171
|
|