rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,21 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""Concatenate bands across multiple image inputs."""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Any
|
|
2
6
|
|
|
3
7
|
import torch
|
|
4
8
|
|
|
5
|
-
from .
|
|
9
|
+
from rslearn.train.model_context import RasterImage
|
|
10
|
+
|
|
11
|
+
from .transform import Transform, read_selector, write_selector
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ConcatenateDim(Enum):
|
|
15
|
+
"""Enum for concatenation dimensions."""
|
|
16
|
+
|
|
17
|
+
CHANNEL = 0
|
|
18
|
+
TIME = 1
|
|
6
19
|
|
|
7
20
|
|
|
8
21
|
class Concatenate(Transform):
|
|
@@ -12,6 +25,7 @@ class Concatenate(Transform):
|
|
|
12
25
|
self,
|
|
13
26
|
selections: dict[str, list[int]],
|
|
14
27
|
output_selector: str,
|
|
28
|
+
concatenate_dim: ConcatenateDim | int = ConcatenateDim.TIME,
|
|
15
29
|
):
|
|
16
30
|
"""Initialize a new Concatenate.
|
|
17
31
|
|
|
@@ -19,12 +33,20 @@ class Concatenate(Transform):
|
|
|
19
33
|
selections: map from selector to list of band indices in that input to
|
|
20
34
|
retain, or empty list to use all bands.
|
|
21
35
|
output_selector: the output selector under which to save the concatenate image.
|
|
36
|
+
concatenate_dim: the dimension against which to concatenate the inputs
|
|
22
37
|
"""
|
|
23
38
|
super().__init__()
|
|
24
39
|
self.selections = selections
|
|
25
40
|
self.output_selector = output_selector
|
|
41
|
+
self.concatenate_dim = (
|
|
42
|
+
concatenate_dim.value
|
|
43
|
+
if isinstance(concatenate_dim, ConcatenateDim)
|
|
44
|
+
else concatenate_dim
|
|
45
|
+
)
|
|
26
46
|
|
|
27
|
-
def forward(
|
|
47
|
+
def forward(
|
|
48
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
49
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
28
50
|
"""Apply concatenation over the inputs and targets.
|
|
29
51
|
|
|
30
52
|
Args:
|
|
@@ -32,14 +54,36 @@ class Concatenate(Transform):
|
|
|
32
54
|
target_dict: the target
|
|
33
55
|
|
|
34
56
|
Returns:
|
|
35
|
-
|
|
57
|
+
concatenated (input_dicts, target_dicts) tuple. If one of the
|
|
58
|
+
specified inputs is a RasterImage, a RasterImage will be returned.
|
|
59
|
+
Otherwise it will be a torch.Tensor.
|
|
36
60
|
"""
|
|
37
61
|
images = []
|
|
62
|
+
return_raster_image: bool = False
|
|
63
|
+
timestamps: list[tuple[datetime, datetime]] | None = None
|
|
38
64
|
for selector, wanted_bands in self.selections.items():
|
|
39
|
-
image =
|
|
40
|
-
if
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
65
|
+
image = read_selector(input_dict, target_dict, selector)
|
|
66
|
+
if isinstance(image, torch.Tensor):
|
|
67
|
+
if wanted_bands:
|
|
68
|
+
image = image[wanted_bands, :, :]
|
|
69
|
+
images.append(image)
|
|
70
|
+
elif isinstance(image, RasterImage):
|
|
71
|
+
return_raster_image = True
|
|
72
|
+
if wanted_bands:
|
|
73
|
+
images.append(image.image[wanted_bands, :, :])
|
|
74
|
+
else:
|
|
75
|
+
images.append(image.image)
|
|
76
|
+
if timestamps is None:
|
|
77
|
+
if image.timestamps is not None:
|
|
78
|
+
# assume all concatenated modalities have the same
|
|
79
|
+
# number of timestamps
|
|
80
|
+
timestamps = image.timestamps
|
|
81
|
+
if return_raster_image:
|
|
82
|
+
result = RasterImage(
|
|
83
|
+
torch.concatenate(images, dim=self.concatenate_dim),
|
|
84
|
+
timestamps=timestamps,
|
|
85
|
+
)
|
|
86
|
+
else:
|
|
87
|
+
result = torch.concatenate(images, dim=self.concatenate_dim)
|
|
88
|
+
write_selector(input_dict, target_dict, self.output_selector, result)
|
|
45
89
|
return input_dict, target_dict
|
rslearn/train/transforms/crop.py
CHANGED
|
@@ -5,7 +5,9 @@ from typing import Any
|
|
|
5
5
|
import torch
|
|
6
6
|
import torchvision
|
|
7
7
|
|
|
8
|
-
from .
|
|
8
|
+
from rslearn.train.model_context import RasterImage
|
|
9
|
+
|
|
10
|
+
from .transform import Transform, read_selector
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
class Crop(Transform):
|
|
@@ -69,7 +71,9 @@ class Crop(Transform):
|
|
|
69
71
|
"remove_from_top": remove_from_top,
|
|
70
72
|
}
|
|
71
73
|
|
|
72
|
-
def apply_image(
|
|
74
|
+
def apply_image(
|
|
75
|
+
self, image: RasterImage | torch.Tensor, state: dict[str, Any]
|
|
76
|
+
) -> RasterImage | torch.Tensor:
|
|
73
77
|
"""Apply the sampled state on the specified image.
|
|
74
78
|
|
|
75
79
|
Args:
|
|
@@ -80,13 +84,23 @@ class Crop(Transform):
|
|
|
80
84
|
crop_size = state["crop_size"] * image.shape[-1] // image_shape[1]
|
|
81
85
|
remove_from_left = state["remove_from_left"] * image.shape[-1] // image_shape[1]
|
|
82
86
|
remove_from_top = state["remove_from_top"] * image.shape[-2] // image_shape[0]
|
|
83
|
-
|
|
84
|
-
image
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
87
|
+
if isinstance(image, RasterImage):
|
|
88
|
+
image.image = torchvision.transforms.functional.crop(
|
|
89
|
+
image.image,
|
|
90
|
+
top=remove_from_top,
|
|
91
|
+
left=remove_from_left,
|
|
92
|
+
height=crop_size,
|
|
93
|
+
width=crop_size,
|
|
94
|
+
)
|
|
95
|
+
else:
|
|
96
|
+
image = torchvision.transforms.functional.crop(
|
|
97
|
+
image,
|
|
98
|
+
top=remove_from_top,
|
|
99
|
+
left=remove_from_left,
|
|
100
|
+
height=crop_size,
|
|
101
|
+
width=crop_size,
|
|
102
|
+
)
|
|
103
|
+
return image
|
|
90
104
|
|
|
91
105
|
def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
|
|
92
106
|
"""Apply the sampled state on the specified image.
|
|
@@ -97,7 +111,9 @@ class Crop(Transform):
|
|
|
97
111
|
"""
|
|
98
112
|
raise NotImplementedError
|
|
99
113
|
|
|
100
|
-
def forward(
|
|
114
|
+
def forward(
|
|
115
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
116
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
101
117
|
"""Apply transform over the inputs and targets.
|
|
102
118
|
|
|
103
119
|
Args:
|
|
@@ -109,13 +125,15 @@ class Crop(Transform):
|
|
|
109
125
|
"""
|
|
110
126
|
smallest_image_shape = None
|
|
111
127
|
for selector in self.image_selectors:
|
|
112
|
-
image =
|
|
128
|
+
image = read_selector(input_dict, target_dict, selector)
|
|
113
129
|
if (
|
|
114
130
|
smallest_image_shape is None
|
|
115
131
|
or image.shape[-1] < smallest_image_shape[1]
|
|
116
132
|
):
|
|
117
133
|
smallest_image_shape = image.shape[-2:]
|
|
118
134
|
|
|
135
|
+
if smallest_image_shape is None:
|
|
136
|
+
raise ValueError("No image found to crop")
|
|
119
137
|
state = self.sample_state(smallest_image_shape)
|
|
120
138
|
|
|
121
139
|
self.apply_fn(
|
rslearn/train/transforms/flip.py
CHANGED
|
@@ -1,7 +1,11 @@
|
|
|
1
1
|
"""Flip transform."""
|
|
2
2
|
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
3
5
|
import torch
|
|
4
6
|
|
|
7
|
+
from rslearn.train.model_context import RasterImage
|
|
8
|
+
|
|
5
9
|
from .transform import Transform
|
|
6
10
|
|
|
7
11
|
|
|
@@ -46,17 +50,23 @@ class Flip(Transform):
|
|
|
46
50
|
"vertical": vertical,
|
|
47
51
|
}
|
|
48
52
|
|
|
49
|
-
def apply_image(self, image:
|
|
53
|
+
def apply_image(self, image: RasterImage, state: dict[str, bool]) -> RasterImage:
|
|
50
54
|
"""Apply the sampled state on the specified image.
|
|
51
55
|
|
|
52
56
|
Args:
|
|
53
57
|
image: the image to transform.
|
|
54
58
|
state: the sampled state.
|
|
55
59
|
"""
|
|
56
|
-
if
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
+
if isinstance(image, RasterImage):
|
|
61
|
+
if state["horizontal"]:
|
|
62
|
+
image.image = torch.flip(image.image, dims=[-1])
|
|
63
|
+
if state["vertical"]:
|
|
64
|
+
image.image = torch.flip(image.image, dims=[-2])
|
|
65
|
+
elif isinstance(image, torch.Tensor):
|
|
66
|
+
if state["horizontal"]:
|
|
67
|
+
image = torch.flip(image, dims=[-1])
|
|
68
|
+
if state["vertical"]:
|
|
69
|
+
image = torch.flip(image, dims=[-2])
|
|
60
70
|
return image
|
|
61
71
|
|
|
62
72
|
def apply_boxes(
|
|
@@ -90,7 +100,9 @@ class Flip(Transform):
|
|
|
90
100
|
)
|
|
91
101
|
return boxes
|
|
92
102
|
|
|
93
|
-
def forward(
|
|
103
|
+
def forward(
|
|
104
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
105
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
94
106
|
"""Apply transform over the inputs and targets.
|
|
95
107
|
|
|
96
108
|
Args:
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""Mask transform."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from rslearn.train.model_context import RasterImage
|
|
6
|
+
from rslearn.train.transforms.transform import Transform, read_selector
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Mask(Transform):
|
|
10
|
+
"""Apply a mask to one or more images.
|
|
11
|
+
|
|
12
|
+
This uses one (mask) image input to mask another (target) image input. The value of
|
|
13
|
+
the target image is set to the mask value everywhere where the mask image is 0.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
selectors: list[str] = ["image"],
|
|
19
|
+
mask_selector: str = "mask",
|
|
20
|
+
mask_value: int = 0,
|
|
21
|
+
):
|
|
22
|
+
"""Initialize a new Mask.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
selectors: images to mask.
|
|
26
|
+
mask_selector: the selector for the mask image to apply.
|
|
27
|
+
mask_value: set each image in selectors to this value where the image
|
|
28
|
+
corresponding to the mask_selector is 0.
|
|
29
|
+
"""
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.selectors = selectors
|
|
32
|
+
self.mask_selector = mask_selector
|
|
33
|
+
self.mask_value = mask_value
|
|
34
|
+
|
|
35
|
+
def apply_image(
|
|
36
|
+
self, image: torch.Tensor | RasterImage, mask: torch.Tensor | RasterImage
|
|
37
|
+
) -> torch.Tensor | RasterImage:
|
|
38
|
+
"""Apply the mask on the image.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
image: the image
|
|
42
|
+
mask: the mask
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
masked image
|
|
46
|
+
"""
|
|
47
|
+
# Tile the mask to have same number of bands as the image.
|
|
48
|
+
if isinstance(mask, RasterImage):
|
|
49
|
+
mask = mask.image
|
|
50
|
+
|
|
51
|
+
if image.shape[0] != mask.shape[0]:
|
|
52
|
+
if mask.shape[0] != 1:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
"expected mask to either have same bands as image, or one band"
|
|
55
|
+
)
|
|
56
|
+
mask = mask.repeat(image.shape[0], 1, 1)
|
|
57
|
+
|
|
58
|
+
if isinstance(image, torch.Tensor):
|
|
59
|
+
image[mask == 0] = self.mask_value
|
|
60
|
+
else:
|
|
61
|
+
image.image[mask == 0] = self.mask_value
|
|
62
|
+
return image
|
|
63
|
+
|
|
64
|
+
def forward(self, input_dict: dict, target_dict: dict) -> tuple[dict, dict]:
|
|
65
|
+
"""Apply mask.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
input_dict: the input
|
|
69
|
+
target_dict: the target
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
normalized (input_dicts, target_dicts) tuple
|
|
73
|
+
"""
|
|
74
|
+
mask = read_selector(input_dict, target_dict, self.mask_selector)
|
|
75
|
+
self.apply_fn(
|
|
76
|
+
self.apply_image, input_dict, target_dict, self.selectors, mask=mask
|
|
77
|
+
)
|
|
78
|
+
return input_dict, target_dict
|
|
@@ -1,7 +1,11 @@
|
|
|
1
1
|
"""Normalization transforms."""
|
|
2
2
|
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
3
5
|
import torch
|
|
4
6
|
|
|
7
|
+
from rslearn.train.model_context import RasterImage
|
|
8
|
+
|
|
5
9
|
from .transform import Transform
|
|
6
10
|
|
|
7
11
|
|
|
@@ -12,22 +16,31 @@ class Normalize(Transform):
|
|
|
12
16
|
self,
|
|
13
17
|
mean: float | list[float],
|
|
14
18
|
std: float | list[float],
|
|
15
|
-
valid_range:
|
|
16
|
-
|
|
17
|
-
|
|
19
|
+
valid_range: (
|
|
20
|
+
tuple[float, float] | tuple[list[float], list[float]] | None
|
|
21
|
+
) = None,
|
|
18
22
|
selectors: list[str] = ["image"],
|
|
19
23
|
bands: list[int] | None = None,
|
|
20
|
-
|
|
24
|
+
num_bands: int | None = None,
|
|
25
|
+
) -> None:
|
|
21
26
|
"""Initialize a new Normalize.
|
|
22
27
|
|
|
23
28
|
Result will be (input - mean) / std.
|
|
24
29
|
|
|
25
30
|
Args:
|
|
26
31
|
mean: a single value or one mean per channel
|
|
27
|
-
std: a single value or one std per channel
|
|
32
|
+
std: a single value or one std per channel (must match the shape of mean)
|
|
28
33
|
valid_range: optionally clip to a minimum and maximum value
|
|
29
34
|
selectors: image items to transform
|
|
30
|
-
bands: optionally restrict the normalization to these
|
|
35
|
+
bands: optionally restrict the normalization to these band indices. If set,
|
|
36
|
+
mean and std must either be one value, or have length equal to the
|
|
37
|
+
number of band indices passed here.
|
|
38
|
+
num_bands: the number of bands per image, to distinguish different images
|
|
39
|
+
in a time series. If set, then the bands list is repeated for each
|
|
40
|
+
image, e.g. if bands=[2] then we apply normalization on images[2],
|
|
41
|
+
images[2+num_bands], images[2+num_bands*2], etc. Or if the bands list
|
|
42
|
+
is not set, then we apply the mean and std on each image in the time
|
|
43
|
+
series.
|
|
31
44
|
"""
|
|
32
45
|
super().__init__()
|
|
33
46
|
self.mean = torch.tensor(mean)
|
|
@@ -41,27 +54,98 @@ class Normalize(Transform):
|
|
|
41
54
|
self.valid_max = None
|
|
42
55
|
|
|
43
56
|
self.selectors = selectors
|
|
44
|
-
self.bands = bands
|
|
57
|
+
self.bands = torch.tensor(bands) if bands is not None else None
|
|
58
|
+
self.num_bands = num_bands
|
|
45
59
|
|
|
46
|
-
def apply_image(
|
|
60
|
+
def apply_image(
|
|
61
|
+
self, image: torch.Tensor | RasterImage
|
|
62
|
+
) -> torch.Tensor | RasterImage:
|
|
47
63
|
"""Normalize the specified image.
|
|
48
64
|
|
|
49
65
|
Args:
|
|
50
66
|
image: the image to transform.
|
|
51
67
|
"""
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
68
|
+
|
|
69
|
+
def _repeat_mean_and_std(
|
|
70
|
+
image_channels: int, num_bands: int | None, is_raster_image: bool
|
|
71
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
72
|
+
"""Get mean and std tensor that are suitable for applying on the image."""
|
|
73
|
+
# We only need to repeat the tensor if both of these are true:
|
|
74
|
+
# - The mean/std are not just one scalar.
|
|
75
|
+
# - self.num_bands is set, otherwise we treat the input as a single image.
|
|
76
|
+
if len(self.mean.shape) == 0:
|
|
77
|
+
return self.mean, self.std
|
|
78
|
+
if num_bands is None:
|
|
79
|
+
return self.mean, self.std
|
|
80
|
+
num_images = image_channels // num_bands
|
|
81
|
+
if is_raster_image:
|
|
82
|
+
# add an extra T dimension, CTHW
|
|
83
|
+
return self.mean.repeat(num_images)[
|
|
84
|
+
:, None, None, None
|
|
85
|
+
], self.std.repeat(num_images)[:, None, None, None]
|
|
86
|
+
else:
|
|
87
|
+
# add an extra T dimension, CTHW
|
|
88
|
+
return self.mean.repeat(num_images)[:, None, None], self.std.repeat(
|
|
89
|
+
num_images
|
|
90
|
+
)[:, None, None]
|
|
91
|
+
|
|
92
|
+
if self.bands is not None:
|
|
93
|
+
# User has provided band indices to normalize.
|
|
94
|
+
# If num_bands is set, then we repeat these for each image in the input
|
|
95
|
+
# image time series.
|
|
96
|
+
band_indices = self.bands
|
|
97
|
+
if self.num_bands:
|
|
98
|
+
num_images = image.shape[0] // self.num_bands
|
|
99
|
+
band_indices = torch.cat(
|
|
100
|
+
[
|
|
101
|
+
band_indices + image_idx * self.num_bands
|
|
102
|
+
for image_idx in range(num_images)
|
|
103
|
+
],
|
|
104
|
+
dim=0,
|
|
57
105
|
)
|
|
106
|
+
|
|
107
|
+
# We use len(self.bands) here because that is how many bands per timestep
|
|
108
|
+
# we are actually processing with the mean/std.
|
|
109
|
+
mean, std = _repeat_mean_and_std(
|
|
110
|
+
image_channels=len(band_indices),
|
|
111
|
+
num_bands=len(self.bands),
|
|
112
|
+
is_raster_image=isinstance(image, RasterImage),
|
|
113
|
+
)
|
|
114
|
+
if isinstance(image, torch.Tensor):
|
|
115
|
+
image[band_indices] = (image[band_indices] - mean) / std
|
|
116
|
+
if self.valid_min is not None:
|
|
117
|
+
image[band_indices] = torch.clamp(
|
|
118
|
+
image[band_indices], min=self.valid_min, max=self.valid_max
|
|
119
|
+
)
|
|
120
|
+
else:
|
|
121
|
+
image.image[band_indices] = (image.image[band_indices] - mean) / std
|
|
122
|
+
if self.valid_min is not None:
|
|
123
|
+
image.image[band_indices] = torch.clamp(
|
|
124
|
+
image.image[band_indices],
|
|
125
|
+
min=self.valid_min,
|
|
126
|
+
max=self.valid_max,
|
|
127
|
+
)
|
|
58
128
|
else:
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
129
|
+
mean, std = _repeat_mean_and_std(
|
|
130
|
+
image_channels=image.shape[0],
|
|
131
|
+
num_bands=self.num_bands,
|
|
132
|
+
is_raster_image=isinstance(image, RasterImage),
|
|
133
|
+
)
|
|
134
|
+
if isinstance(image, torch.Tensor):
|
|
135
|
+
image = (image - mean) / std
|
|
136
|
+
if self.valid_min is not None:
|
|
137
|
+
image = torch.clamp(image, min=self.valid_min, max=self.valid_max)
|
|
138
|
+
else:
|
|
139
|
+
image.image = (image.image - mean) / std
|
|
140
|
+
if self.valid_min is not None:
|
|
141
|
+
image.image = torch.clamp(
|
|
142
|
+
image.image, min=self.valid_min, max=self.valid_max
|
|
143
|
+
)
|
|
62
144
|
return image
|
|
63
145
|
|
|
64
|
-
def forward(
|
|
146
|
+
def forward(
|
|
147
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
148
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
65
149
|
"""Apply normalization over the inputs and targets.
|
|
66
150
|
|
|
67
151
|
Args:
|
rslearn/train/transforms/pad.py
CHANGED
|
@@ -5,6 +5,8 @@ from typing import Any
|
|
|
5
5
|
import torch
|
|
6
6
|
import torchvision
|
|
7
7
|
|
|
8
|
+
from rslearn.train.model_context import RasterImage
|
|
9
|
+
|
|
8
10
|
from .transform import Transform
|
|
9
11
|
|
|
10
12
|
|
|
@@ -25,8 +27,8 @@ class Pad(Transform):
|
|
|
25
27
|
Args:
|
|
26
28
|
size: the size to pad to, or a min/max range of pad sizes. If the image is
|
|
27
29
|
larger than this size, then it is cropped instead.
|
|
28
|
-
mode: "
|
|
29
|
-
"
|
|
30
|
+
mode: "topleft" (default) to only apply padding on the bottom and right
|
|
31
|
+
sides, or "center" to apply padding equally on all sides.
|
|
30
32
|
image_selectors: image items to transform.
|
|
31
33
|
box_selectors: boxes items to transform.
|
|
32
34
|
"""
|
|
@@ -48,7 +50,9 @@ class Pad(Transform):
|
|
|
48
50
|
"""
|
|
49
51
|
return {"size": torch.randint(low=self.size[0], high=self.size[1], size=())}
|
|
50
52
|
|
|
51
|
-
def apply_image(
|
|
53
|
+
def apply_image(
|
|
54
|
+
self, image: RasterImage | torch.Tensor, state: dict[str, bool]
|
|
55
|
+
) -> RasterImage | torch.Tensor:
|
|
52
56
|
"""Apply the sampled state on the specified image.
|
|
53
57
|
|
|
54
58
|
Args:
|
|
@@ -64,11 +68,11 @@ class Pad(Transform):
|
|
|
64
68
|
) -> torch.Tensor:
|
|
65
69
|
# Before/after must either be both non-negative or both negative.
|
|
66
70
|
# >=0 indicates padding while <0 indicates cropping.
|
|
67
|
-
assert (before < 0 and after
|
|
71
|
+
assert (before < 0 and after <= 0) or (before >= 0 and after >= 0)
|
|
68
72
|
if before > 0:
|
|
69
73
|
# Padding.
|
|
70
74
|
if horizontal:
|
|
71
|
-
padding_tuple = (before, after)
|
|
75
|
+
padding_tuple: tuple = (before, after)
|
|
72
76
|
else:
|
|
73
77
|
padding_tuple = (before, after, 0, 0)
|
|
74
78
|
return torch.nn.functional.pad(im, padding_tuple)
|
|
@@ -101,8 +105,16 @@ class Pad(Transform):
|
|
|
101
105
|
horizontal_pad = (horizontal_half, horizontal_extra - horizontal_half)
|
|
102
106
|
vertical_pad = (vertical_half, vertical_extra - vertical_half)
|
|
103
107
|
|
|
104
|
-
|
|
105
|
-
|
|
108
|
+
if isinstance(image, RasterImage):
|
|
109
|
+
image.image = apply_padding(
|
|
110
|
+
image.image, True, horizontal_pad[0], horizontal_pad[1]
|
|
111
|
+
)
|
|
112
|
+
image.image = apply_padding(
|
|
113
|
+
image.image, False, vertical_pad[0], vertical_pad[1]
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
image = apply_padding(image, True, horizontal_pad[0], horizontal_pad[1])
|
|
117
|
+
image = apply_padding(image, False, vertical_pad[0], vertical_pad[1])
|
|
106
118
|
return image
|
|
107
119
|
|
|
108
120
|
def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Resize transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torchvision
|
|
7
|
+
from torchvision.transforms import InterpolationMode
|
|
8
|
+
|
|
9
|
+
from rslearn.train.model_context import RasterImage
|
|
10
|
+
|
|
11
|
+
from .transform import Transform
|
|
12
|
+
|
|
13
|
+
INTERPOLATION_MODES = {
|
|
14
|
+
"nearest": InterpolationMode.NEAREST,
|
|
15
|
+
"nearest_exact": InterpolationMode.NEAREST_EXACT,
|
|
16
|
+
"bilinear": InterpolationMode.BILINEAR,
|
|
17
|
+
"bicubic": InterpolationMode.BICUBIC,
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Resize(Transform):
|
|
22
|
+
"""Resizes inputs to a target size."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
target_size: tuple[int, int],
|
|
27
|
+
selectors: list[str] = [],
|
|
28
|
+
interpolation: str = "nearest",
|
|
29
|
+
):
|
|
30
|
+
"""Initialize a resize transform.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
target_size: the (height, width) to resize to.
|
|
34
|
+
selectors: items to transform.
|
|
35
|
+
interpolation: the interpolation mode to use for resizing.
|
|
36
|
+
Must be one of "nearest", "nearest_exact", "bilinear", or "bicubic".
|
|
37
|
+
"""
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.target_size = target_size
|
|
40
|
+
self.selectors = selectors
|
|
41
|
+
self.interpolation = INTERPOLATION_MODES[interpolation]
|
|
42
|
+
|
|
43
|
+
def apply_resize(
|
|
44
|
+
self, image: torch.Tensor | RasterImage
|
|
45
|
+
) -> torch.Tensor | RasterImage:
|
|
46
|
+
"""Apply resizing on the specified image.
|
|
47
|
+
|
|
48
|
+
If the image is 2D, it is unsqueezed to 3D and then squeezed
|
|
49
|
+
back after resizing.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
image: the image to transform.
|
|
53
|
+
"""
|
|
54
|
+
if isinstance(image, torch.Tensor):
|
|
55
|
+
if image.dim() == 2:
|
|
56
|
+
image = image.unsqueeze(0) # (H, W) -> (1, H, W)
|
|
57
|
+
result = torchvision.transforms.functional.resize(
|
|
58
|
+
image, self.target_size, self.interpolation
|
|
59
|
+
)
|
|
60
|
+
return result.squeeze(0) # (1, H, W) -> (H, W)
|
|
61
|
+
return torchvision.transforms.functional.resize(
|
|
62
|
+
image, self.target_size, self.interpolation
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
image.image = torchvision.transforms.functional.resize(
|
|
66
|
+
image.image, self.target_size, self.interpolation
|
|
67
|
+
)
|
|
68
|
+
return image
|
|
69
|
+
|
|
70
|
+
def forward(
|
|
71
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
72
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
73
|
+
"""Apply transform over the inputs and targets.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
input_dict: the input
|
|
77
|
+
target_dict: the target
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
transformed (input_dicts, target_dicts) tuple
|
|
81
|
+
"""
|
|
82
|
+
self.apply_fn(self.apply_resize, input_dict, target_dict, self.selectors)
|
|
83
|
+
return input_dict, target_dict
|