rslearn 0.0.19__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/models/anysat.py +35 -33
- rslearn/models/clip.py +5 -2
- rslearn/models/croma.py +11 -3
- rslearn/models/dinov3.py +2 -1
- rslearn/models/faster_rcnn.py +2 -1
- rslearn/models/galileo/galileo.py +58 -31
- rslearn/models/module_wrapper.py +6 -1
- rslearn/models/molmo.py +4 -2
- rslearn/models/olmoearth_pretrain/model.py +93 -29
- rslearn/models/olmoearth_pretrain/norm.py +5 -3
- rslearn/models/panopticon.py +3 -1
- rslearn/models/presto/presto.py +45 -15
- rslearn/models/prithvi.py +9 -7
- rslearn/models/sam2_enc.py +3 -1
- rslearn/models/satlaspretrain.py +4 -1
- rslearn/models/simple_time_series.py +36 -16
- rslearn/models/ssl4eo_s12.py +19 -14
- rslearn/models/swin.py +3 -1
- rslearn/models/terramind.py +5 -4
- rslearn/train/all_patches_dataset.py +34 -14
- rslearn/train/dataset.py +66 -10
- rslearn/train/model_context.py +35 -1
- rslearn/train/tasks/classification.py +8 -2
- rslearn/train/tasks/detection.py +3 -2
- rslearn/train/tasks/multi_task.py +2 -3
- rslearn/train/tasks/per_pixel_regression.py +14 -5
- rslearn/train/tasks/regression.py +8 -2
- rslearn/train/tasks/segmentation.py +13 -4
- rslearn/train/tasks/task.py +2 -2
- rslearn/train/transforms/concatenate.py +45 -5
- rslearn/train/transforms/crop.py +22 -8
- rslearn/train/transforms/flip.py +13 -5
- rslearn/train/transforms/mask.py +11 -2
- rslearn/train/transforms/normalize.py +46 -15
- rslearn/train/transforms/pad.py +15 -3
- rslearn/train/transforms/resize.py +18 -9
- rslearn/train/transforms/select_bands.py +11 -2
- rslearn/train/transforms/sentinel1.py +18 -3
- {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/METADATA +1 -1
- {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/RECORD +45 -45
- {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/WHEEL +0 -0
- {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/top_level.txt +0 -0
|
@@ -16,7 +16,12 @@ from torchmetrics.classification import (
|
|
|
16
16
|
)
|
|
17
17
|
|
|
18
18
|
from rslearn.models.component import FeatureVector, Predictor
|
|
19
|
-
from rslearn.train.model_context import
|
|
19
|
+
from rslearn.train.model_context import (
|
|
20
|
+
ModelContext,
|
|
21
|
+
ModelOutput,
|
|
22
|
+
RasterImage,
|
|
23
|
+
SampleMetadata,
|
|
24
|
+
)
|
|
20
25
|
from rslearn.utils import Feature, STGeometry
|
|
21
26
|
|
|
22
27
|
from .task import BasicTask
|
|
@@ -99,7 +104,7 @@ class ClassificationTask(BasicTask):
|
|
|
99
104
|
|
|
100
105
|
def process_inputs(
|
|
101
106
|
self,
|
|
102
|
-
raw_inputs: dict[str,
|
|
107
|
+
raw_inputs: dict[str, RasterImage | list[Feature]],
|
|
103
108
|
metadata: SampleMetadata,
|
|
104
109
|
load_targets: bool = True,
|
|
105
110
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
@@ -118,6 +123,7 @@ class ClassificationTask(BasicTask):
|
|
|
118
123
|
return {}, {}
|
|
119
124
|
|
|
120
125
|
data = raw_inputs["targets"]
|
|
126
|
+
assert isinstance(data, list)
|
|
121
127
|
for feat in data:
|
|
122
128
|
if feat.properties is None:
|
|
123
129
|
continue
|
rslearn/train/tasks/detection.py
CHANGED
|
@@ -12,7 +12,7 @@ import torchmetrics.classification
|
|
|
12
12
|
import torchvision
|
|
13
13
|
from torchmetrics import Metric, MetricCollection
|
|
14
14
|
|
|
15
|
-
from rslearn.train.model_context import SampleMetadata
|
|
15
|
+
from rslearn.train.model_context import RasterImage, SampleMetadata
|
|
16
16
|
from rslearn.utils import Feature, STGeometry
|
|
17
17
|
|
|
18
18
|
from .task import BasicTask
|
|
@@ -127,7 +127,7 @@ class DetectionTask(BasicTask):
|
|
|
127
127
|
|
|
128
128
|
def process_inputs(
|
|
129
129
|
self,
|
|
130
|
-
raw_inputs: dict[str,
|
|
130
|
+
raw_inputs: dict[str, RasterImage | list[Feature]],
|
|
131
131
|
metadata: SampleMetadata,
|
|
132
132
|
load_targets: bool = True,
|
|
133
133
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
@@ -152,6 +152,7 @@ class DetectionTask(BasicTask):
|
|
|
152
152
|
valid = 1
|
|
153
153
|
|
|
154
154
|
data = raw_inputs["targets"]
|
|
155
|
+
assert isinstance(data, list)
|
|
155
156
|
for feat in data:
|
|
156
157
|
if feat.properties is None:
|
|
157
158
|
continue
|
|
@@ -3,10 +3,9 @@
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
5
|
import numpy.typing as npt
|
|
6
|
-
import torch
|
|
7
6
|
from torchmetrics import Metric, MetricCollection
|
|
8
7
|
|
|
9
|
-
from rslearn.train.model_context import SampleMetadata
|
|
8
|
+
from rslearn.train.model_context import RasterImage, SampleMetadata
|
|
10
9
|
from rslearn.utils import Feature
|
|
11
10
|
|
|
12
11
|
from .task import Task
|
|
@@ -30,7 +29,7 @@ class MultiTask(Task):
|
|
|
30
29
|
|
|
31
30
|
def process_inputs(
|
|
32
31
|
self,
|
|
33
|
-
raw_inputs: dict[str,
|
|
32
|
+
raw_inputs: dict[str, RasterImage | list[Feature]],
|
|
34
33
|
metadata: SampleMetadata,
|
|
35
34
|
load_targets: bool = True,
|
|
36
35
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
@@ -9,7 +9,12 @@ import torchmetrics
|
|
|
9
9
|
from torchmetrics import Metric, MetricCollection
|
|
10
10
|
|
|
11
11
|
from rslearn.models.component import FeatureMaps, Predictor
|
|
12
|
-
from rslearn.train.model_context import
|
|
12
|
+
from rslearn.train.model_context import (
|
|
13
|
+
ModelContext,
|
|
14
|
+
ModelOutput,
|
|
15
|
+
RasterImage,
|
|
16
|
+
SampleMetadata,
|
|
17
|
+
)
|
|
13
18
|
from rslearn.utils.feature import Feature
|
|
14
19
|
|
|
15
20
|
from .task import BasicTask
|
|
@@ -42,7 +47,7 @@ class PerPixelRegressionTask(BasicTask):
|
|
|
42
47
|
|
|
43
48
|
def process_inputs(
|
|
44
49
|
self,
|
|
45
|
-
raw_inputs: dict[str,
|
|
50
|
+
raw_inputs: dict[str, RasterImage | list[Feature]],
|
|
46
51
|
metadata: SampleMetadata,
|
|
47
52
|
load_targets: bool = True,
|
|
48
53
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
@@ -60,11 +65,15 @@ class PerPixelRegressionTask(BasicTask):
|
|
|
60
65
|
if not load_targets:
|
|
61
66
|
return {}, {}
|
|
62
67
|
|
|
63
|
-
assert raw_inputs["targets"]
|
|
64
|
-
|
|
68
|
+
assert isinstance(raw_inputs["targets"], RasterImage)
|
|
69
|
+
assert raw_inputs["targets"].image.shape[0] == 1
|
|
70
|
+
assert raw_inputs["targets"].image.shape[1] == 1
|
|
71
|
+
labels = raw_inputs["targets"].image[0, 0, :, :].float() * self.scale_factor
|
|
65
72
|
|
|
66
73
|
if self.nodata_value is not None:
|
|
67
|
-
valid = (
|
|
74
|
+
valid = (
|
|
75
|
+
raw_inputs["targets"].image[0, 0, :, :] != self.nodata_value
|
|
76
|
+
).float()
|
|
68
77
|
else:
|
|
69
78
|
valid = torch.ones(labels.shape, dtype=torch.float32)
|
|
70
79
|
|
|
@@ -11,7 +11,12 @@ from PIL import Image, ImageDraw
|
|
|
11
11
|
from torchmetrics import Metric, MetricCollection
|
|
12
12
|
|
|
13
13
|
from rslearn.models.component import FeatureVector, Predictor
|
|
14
|
-
from rslearn.train.model_context import
|
|
14
|
+
from rslearn.train.model_context import (
|
|
15
|
+
ModelContext,
|
|
16
|
+
ModelOutput,
|
|
17
|
+
RasterImage,
|
|
18
|
+
SampleMetadata,
|
|
19
|
+
)
|
|
15
20
|
from rslearn.utils.feature import Feature
|
|
16
21
|
from rslearn.utils.geometry import STGeometry
|
|
17
22
|
|
|
@@ -63,7 +68,7 @@ class RegressionTask(BasicTask):
|
|
|
63
68
|
|
|
64
69
|
def process_inputs(
|
|
65
70
|
self,
|
|
66
|
-
raw_inputs: dict[str,
|
|
71
|
+
raw_inputs: dict[str, RasterImage | list[Feature]],
|
|
67
72
|
metadata: SampleMetadata,
|
|
68
73
|
load_targets: bool = True,
|
|
69
74
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
@@ -82,6 +87,7 @@ class RegressionTask(BasicTask):
|
|
|
82
87
|
return {}, {}
|
|
83
88
|
|
|
84
89
|
data = raw_inputs["targets"]
|
|
90
|
+
assert isinstance(data, list)
|
|
85
91
|
for feat in data:
|
|
86
92
|
if feat.properties is None or self.filters is None:
|
|
87
93
|
continue
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Segmentation task."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Mapping
|
|
3
4
|
from typing import Any
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
@@ -9,7 +10,13 @@ import torchmetrics.classification
|
|
|
9
10
|
from torchmetrics import Metric, MetricCollection
|
|
10
11
|
|
|
11
12
|
from rslearn.models.component import FeatureMaps, Predictor
|
|
12
|
-
from rslearn.train.model_context import
|
|
13
|
+
from rslearn.train.model_context import (
|
|
14
|
+
ModelContext,
|
|
15
|
+
ModelOutput,
|
|
16
|
+
RasterImage,
|
|
17
|
+
SampleMetadata,
|
|
18
|
+
)
|
|
19
|
+
from rslearn.utils import Feature
|
|
13
20
|
|
|
14
21
|
from .task import BasicTask
|
|
15
22
|
|
|
@@ -108,7 +115,7 @@ class SegmentationTask(BasicTask):
|
|
|
108
115
|
|
|
109
116
|
def process_inputs(
|
|
110
117
|
self,
|
|
111
|
-
raw_inputs:
|
|
118
|
+
raw_inputs: Mapping[str, RasterImage | list[Feature]],
|
|
112
119
|
metadata: SampleMetadata,
|
|
113
120
|
load_targets: bool = True,
|
|
114
121
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
@@ -126,8 +133,10 @@ class SegmentationTask(BasicTask):
|
|
|
126
133
|
if not load_targets:
|
|
127
134
|
return {}, {}
|
|
128
135
|
|
|
129
|
-
assert raw_inputs["targets"]
|
|
130
|
-
|
|
136
|
+
assert isinstance(raw_inputs["targets"], RasterImage)
|
|
137
|
+
assert raw_inputs["targets"].image.shape[0] == 1
|
|
138
|
+
assert raw_inputs["targets"].image.shape[1] == 1
|
|
139
|
+
labels = raw_inputs["targets"].image[0, 0, :, :].long()
|
|
131
140
|
|
|
132
141
|
if self.class_id_mapping is not None:
|
|
133
142
|
new_labels = labels.clone()
|
rslearn/train/tasks/task.py
CHANGED
|
@@ -7,7 +7,7 @@ import numpy.typing as npt
|
|
|
7
7
|
import torch
|
|
8
8
|
from torchmetrics import MetricCollection
|
|
9
9
|
|
|
10
|
-
from rslearn.train.model_context import SampleMetadata
|
|
10
|
+
from rslearn.train.model_context import RasterImage, SampleMetadata
|
|
11
11
|
from rslearn.utils import Feature
|
|
12
12
|
|
|
13
13
|
|
|
@@ -21,7 +21,7 @@ class Task:
|
|
|
21
21
|
|
|
22
22
|
def process_inputs(
|
|
23
23
|
self,
|
|
24
|
-
raw_inputs: dict[str,
|
|
24
|
+
raw_inputs: dict[str, RasterImage | list[Feature]],
|
|
25
25
|
metadata: SampleMetadata,
|
|
26
26
|
load_targets: bool = True,
|
|
27
27
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
@@ -1,12 +1,23 @@
|
|
|
1
1
|
"""Concatenate bands across multiple image inputs."""
|
|
2
2
|
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from enum import Enum
|
|
3
5
|
from typing import Any
|
|
4
6
|
|
|
5
7
|
import torch
|
|
6
8
|
|
|
9
|
+
from rslearn.train.model_context import RasterImage
|
|
10
|
+
|
|
7
11
|
from .transform import Transform, read_selector, write_selector
|
|
8
12
|
|
|
9
13
|
|
|
14
|
+
class ConcatenateDim(Enum):
|
|
15
|
+
"""Enum for concatenation dimensions."""
|
|
16
|
+
|
|
17
|
+
CHANNEL = 0
|
|
18
|
+
TIME = 1
|
|
19
|
+
|
|
20
|
+
|
|
10
21
|
class Concatenate(Transform):
|
|
11
22
|
"""Concatenate bands across multiple image inputs."""
|
|
12
23
|
|
|
@@ -14,6 +25,7 @@ class Concatenate(Transform):
|
|
|
14
25
|
self,
|
|
15
26
|
selections: dict[str, list[int]],
|
|
16
27
|
output_selector: str,
|
|
28
|
+
concatenate_dim: ConcatenateDim | int = ConcatenateDim.TIME,
|
|
17
29
|
):
|
|
18
30
|
"""Initialize a new Concatenate.
|
|
19
31
|
|
|
@@ -21,10 +33,16 @@ class Concatenate(Transform):
|
|
|
21
33
|
selections: map from selector to list of band indices in that input to
|
|
22
34
|
retain, or empty list to use all bands.
|
|
23
35
|
output_selector: the output selector under which to save the concatenate image.
|
|
36
|
+
concatenate_dim: the dimension against which to concatenate the inputs
|
|
24
37
|
"""
|
|
25
38
|
super().__init__()
|
|
26
39
|
self.selections = selections
|
|
27
40
|
self.output_selector = output_selector
|
|
41
|
+
self.concatenate_dim = (
|
|
42
|
+
concatenate_dim.value
|
|
43
|
+
if isinstance(concatenate_dim, ConcatenateDim)
|
|
44
|
+
else concatenate_dim
|
|
45
|
+
)
|
|
28
46
|
|
|
29
47
|
def forward(
|
|
30
48
|
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
@@ -36,14 +54,36 @@ class Concatenate(Transform):
|
|
|
36
54
|
target_dict: the target
|
|
37
55
|
|
|
38
56
|
Returns:
|
|
39
|
-
|
|
57
|
+
concatenated (input_dicts, target_dicts) tuple. If one of the
|
|
58
|
+
specified inputs is a RasterImage, a RasterImage will be returned.
|
|
59
|
+
Otherwise it will be a torch.Tensor.
|
|
40
60
|
"""
|
|
41
61
|
images = []
|
|
62
|
+
return_raster_image: bool = False
|
|
63
|
+
timestamps: list[tuple[datetime, datetime]] | None = None
|
|
42
64
|
for selector, wanted_bands in self.selections.items():
|
|
43
65
|
image = read_selector(input_dict, target_dict, selector)
|
|
44
|
-
if
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
66
|
+
if isinstance(image, torch.Tensor):
|
|
67
|
+
if wanted_bands:
|
|
68
|
+
image = image[wanted_bands, :, :]
|
|
69
|
+
images.append(image)
|
|
70
|
+
elif isinstance(image, RasterImage):
|
|
71
|
+
return_raster_image = True
|
|
72
|
+
if wanted_bands:
|
|
73
|
+
images.append(image.image[wanted_bands, :, :])
|
|
74
|
+
else:
|
|
75
|
+
images.append(image.image)
|
|
76
|
+
if timestamps is None:
|
|
77
|
+
if image.timestamps is not None:
|
|
78
|
+
# assume all concatenated modalities have the same
|
|
79
|
+
# number of timestamps
|
|
80
|
+
timestamps = image.timestamps
|
|
81
|
+
if return_raster_image:
|
|
82
|
+
result = RasterImage(
|
|
83
|
+
torch.concatenate(images, dim=self.concatenate_dim),
|
|
84
|
+
timestamps=timestamps,
|
|
85
|
+
)
|
|
86
|
+
else:
|
|
87
|
+
result = torch.concatenate(images, dim=self.concatenate_dim)
|
|
48
88
|
write_selector(input_dict, target_dict, self.output_selector, result)
|
|
49
89
|
return input_dict, target_dict
|
rslearn/train/transforms/crop.py
CHANGED
|
@@ -5,6 +5,8 @@ from typing import Any
|
|
|
5
5
|
import torch
|
|
6
6
|
import torchvision
|
|
7
7
|
|
|
8
|
+
from rslearn.train.model_context import RasterImage
|
|
9
|
+
|
|
8
10
|
from .transform import Transform, read_selector
|
|
9
11
|
|
|
10
12
|
|
|
@@ -69,7 +71,9 @@ class Crop(Transform):
|
|
|
69
71
|
"remove_from_top": remove_from_top,
|
|
70
72
|
}
|
|
71
73
|
|
|
72
|
-
def apply_image(
|
|
74
|
+
def apply_image(
|
|
75
|
+
self, image: RasterImage | torch.Tensor, state: dict[str, Any]
|
|
76
|
+
) -> RasterImage | torch.Tensor:
|
|
73
77
|
"""Apply the sampled state on the specified image.
|
|
74
78
|
|
|
75
79
|
Args:
|
|
@@ -80,13 +84,23 @@ class Crop(Transform):
|
|
|
80
84
|
crop_size = state["crop_size"] * image.shape[-1] // image_shape[1]
|
|
81
85
|
remove_from_left = state["remove_from_left"] * image.shape[-1] // image_shape[1]
|
|
82
86
|
remove_from_top = state["remove_from_top"] * image.shape[-2] // image_shape[0]
|
|
83
|
-
|
|
84
|
-
image
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
87
|
+
if isinstance(image, RasterImage):
|
|
88
|
+
image.image = torchvision.transforms.functional.crop(
|
|
89
|
+
image.image,
|
|
90
|
+
top=remove_from_top,
|
|
91
|
+
left=remove_from_left,
|
|
92
|
+
height=crop_size,
|
|
93
|
+
width=crop_size,
|
|
94
|
+
)
|
|
95
|
+
else:
|
|
96
|
+
image = torchvision.transforms.functional.crop(
|
|
97
|
+
image,
|
|
98
|
+
top=remove_from_top,
|
|
99
|
+
left=remove_from_left,
|
|
100
|
+
height=crop_size,
|
|
101
|
+
width=crop_size,
|
|
102
|
+
)
|
|
103
|
+
return image
|
|
90
104
|
|
|
91
105
|
def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
|
|
92
106
|
"""Apply the sampled state on the specified image.
|
rslearn/train/transforms/flip.py
CHANGED
|
@@ -4,6 +4,8 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from rslearn.train.model_context import RasterImage
|
|
8
|
+
|
|
7
9
|
from .transform import Transform
|
|
8
10
|
|
|
9
11
|
|
|
@@ -48,17 +50,23 @@ class Flip(Transform):
|
|
|
48
50
|
"vertical": vertical,
|
|
49
51
|
}
|
|
50
52
|
|
|
51
|
-
def apply_image(self, image:
|
|
53
|
+
def apply_image(self, image: RasterImage, state: dict[str, bool]) -> RasterImage:
|
|
52
54
|
"""Apply the sampled state on the specified image.
|
|
53
55
|
|
|
54
56
|
Args:
|
|
55
57
|
image: the image to transform.
|
|
56
58
|
state: the sampled state.
|
|
57
59
|
"""
|
|
58
|
-
if
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
60
|
+
if isinstance(image, RasterImage):
|
|
61
|
+
if state["horizontal"]:
|
|
62
|
+
image.image = torch.flip(image.image, dims=[-1])
|
|
63
|
+
if state["vertical"]:
|
|
64
|
+
image.image = torch.flip(image.image, dims=[-2])
|
|
65
|
+
elif isinstance(image, torch.Tensor):
|
|
66
|
+
if state["horizontal"]:
|
|
67
|
+
image = torch.flip(image, dims=[-1])
|
|
68
|
+
if state["vertical"]:
|
|
69
|
+
image = torch.flip(image, dims=[-2])
|
|
62
70
|
return image
|
|
63
71
|
|
|
64
72
|
def apply_boxes(
|
rslearn/train/transforms/mask.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
+
from rslearn.train.model_context import RasterImage
|
|
5
6
|
from rslearn.train.transforms.transform import Transform, read_selector
|
|
6
7
|
|
|
7
8
|
|
|
@@ -31,7 +32,9 @@ class Mask(Transform):
|
|
|
31
32
|
self.mask_selector = mask_selector
|
|
32
33
|
self.mask_value = mask_value
|
|
33
34
|
|
|
34
|
-
def apply_image(
|
|
35
|
+
def apply_image(
|
|
36
|
+
self, image: torch.Tensor | RasterImage, mask: torch.Tensor | RasterImage
|
|
37
|
+
) -> torch.Tensor | RasterImage:
|
|
35
38
|
"""Apply the mask on the image.
|
|
36
39
|
|
|
37
40
|
Args:
|
|
@@ -42,6 +45,9 @@ class Mask(Transform):
|
|
|
42
45
|
masked image
|
|
43
46
|
"""
|
|
44
47
|
# Tile the mask to have same number of bands as the image.
|
|
48
|
+
if isinstance(mask, RasterImage):
|
|
49
|
+
mask = mask.image
|
|
50
|
+
|
|
45
51
|
if image.shape[0] != mask.shape[0]:
|
|
46
52
|
if mask.shape[0] != 1:
|
|
47
53
|
raise ValueError(
|
|
@@ -49,7 +55,10 @@ class Mask(Transform):
|
|
|
49
55
|
)
|
|
50
56
|
mask = mask.repeat(image.shape[0], 1, 1)
|
|
51
57
|
|
|
52
|
-
image
|
|
58
|
+
if isinstance(image, torch.Tensor):
|
|
59
|
+
image[mask == 0] = self.mask_value
|
|
60
|
+
else:
|
|
61
|
+
image.image[mask == 0] = self.mask_value
|
|
53
62
|
return image
|
|
54
63
|
|
|
55
64
|
def forward(self, input_dict: dict, target_dict: dict) -> tuple[dict, dict]:
|
|
@@ -4,6 +4,8 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from rslearn.train.model_context import RasterImage
|
|
8
|
+
|
|
7
9
|
from .transform import Transform
|
|
8
10
|
|
|
9
11
|
|
|
@@ -55,7 +57,9 @@ class Normalize(Transform):
|
|
|
55
57
|
self.bands = torch.tensor(bands) if bands is not None else None
|
|
56
58
|
self.num_bands = num_bands
|
|
57
59
|
|
|
58
|
-
def apply_image(
|
|
60
|
+
def apply_image(
|
|
61
|
+
self, image: torch.Tensor | RasterImage
|
|
62
|
+
) -> torch.Tensor | RasterImage:
|
|
59
63
|
"""Normalize the specified image.
|
|
60
64
|
|
|
61
65
|
Args:
|
|
@@ -63,7 +67,7 @@ class Normalize(Transform):
|
|
|
63
67
|
"""
|
|
64
68
|
|
|
65
69
|
def _repeat_mean_and_std(
|
|
66
|
-
image_channels: int, num_bands: int | None
|
|
70
|
+
image_channels: int, num_bands: int | None, is_raster_image: bool
|
|
67
71
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
68
72
|
"""Get mean and std tensor that are suitable for applying on the image."""
|
|
69
73
|
# We only need to repeat the tensor if both of these are true:
|
|
@@ -74,9 +78,16 @@ class Normalize(Transform):
|
|
|
74
78
|
if num_bands is None:
|
|
75
79
|
return self.mean, self.std
|
|
76
80
|
num_images = image_channels // num_bands
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
81
|
+
if is_raster_image:
|
|
82
|
+
# add an extra T dimension, CTHW
|
|
83
|
+
return self.mean.repeat(num_images)[
|
|
84
|
+
:, None, None, None
|
|
85
|
+
], self.std.repeat(num_images)[:, None, None, None]
|
|
86
|
+
else:
|
|
87
|
+
# add an extra T dimension, CTHW
|
|
88
|
+
return self.mean.repeat(num_images)[:, None, None], self.std.repeat(
|
|
89
|
+
num_images
|
|
90
|
+
)[:, None, None]
|
|
80
91
|
|
|
81
92
|
if self.bands is not None:
|
|
82
93
|
# User has provided band indices to normalize.
|
|
@@ -96,20 +107,40 @@ class Normalize(Transform):
|
|
|
96
107
|
# We use len(self.bands) here because that is how many bands per timestep
|
|
97
108
|
# we are actually processing with the mean/std.
|
|
98
109
|
mean, std = _repeat_mean_and_std(
|
|
99
|
-
image_channels=len(band_indices),
|
|
110
|
+
image_channels=len(band_indices),
|
|
111
|
+
num_bands=len(self.bands),
|
|
112
|
+
is_raster_image=isinstance(image, RasterImage),
|
|
100
113
|
)
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
image[band_indices]
|
|
105
|
-
|
|
114
|
+
if isinstance(image, torch.Tensor):
|
|
115
|
+
image[band_indices] = (image[band_indices] - mean) / std
|
|
116
|
+
if self.valid_min is not None:
|
|
117
|
+
image[band_indices] = torch.clamp(
|
|
118
|
+
image[band_indices], min=self.valid_min, max=self.valid_max
|
|
119
|
+
)
|
|
120
|
+
else:
|
|
121
|
+
image.image[band_indices] = (image.image[band_indices] - mean) / std
|
|
122
|
+
if self.valid_min is not None:
|
|
123
|
+
image.image[band_indices] = torch.clamp(
|
|
124
|
+
image.image[band_indices],
|
|
125
|
+
min=self.valid_min,
|
|
126
|
+
max=self.valid_max,
|
|
127
|
+
)
|
|
106
128
|
else:
|
|
107
129
|
mean, std = _repeat_mean_and_std(
|
|
108
|
-
image_channels=image.shape[0],
|
|
130
|
+
image_channels=image.shape[0],
|
|
131
|
+
num_bands=self.num_bands,
|
|
132
|
+
is_raster_image=isinstance(image, RasterImage),
|
|
109
133
|
)
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
134
|
+
if isinstance(image, torch.Tensor):
|
|
135
|
+
image = (image - mean) / std
|
|
136
|
+
if self.valid_min is not None:
|
|
137
|
+
image = torch.clamp(image, min=self.valid_min, max=self.valid_max)
|
|
138
|
+
else:
|
|
139
|
+
image.image = (image.image - mean) / std
|
|
140
|
+
if self.valid_min is not None:
|
|
141
|
+
image.image = torch.clamp(
|
|
142
|
+
image.image, min=self.valid_min, max=self.valid_max
|
|
143
|
+
)
|
|
113
144
|
return image
|
|
114
145
|
|
|
115
146
|
def forward(
|
rslearn/train/transforms/pad.py
CHANGED
|
@@ -5,6 +5,8 @@ from typing import Any
|
|
|
5
5
|
import torch
|
|
6
6
|
import torchvision
|
|
7
7
|
|
|
8
|
+
from rslearn.train.model_context import RasterImage
|
|
9
|
+
|
|
8
10
|
from .transform import Transform
|
|
9
11
|
|
|
10
12
|
|
|
@@ -48,7 +50,9 @@ class Pad(Transform):
|
|
|
48
50
|
"""
|
|
49
51
|
return {"size": torch.randint(low=self.size[0], high=self.size[1], size=())}
|
|
50
52
|
|
|
51
|
-
def apply_image(
|
|
53
|
+
def apply_image(
|
|
54
|
+
self, image: RasterImage | torch.Tensor, state: dict[str, bool]
|
|
55
|
+
) -> RasterImage | torch.Tensor:
|
|
52
56
|
"""Apply the sampled state on the specified image.
|
|
53
57
|
|
|
54
58
|
Args:
|
|
@@ -101,8 +105,16 @@ class Pad(Transform):
|
|
|
101
105
|
horizontal_pad = (horizontal_half, horizontal_extra - horizontal_half)
|
|
102
106
|
vertical_pad = (vertical_half, vertical_extra - vertical_half)
|
|
103
107
|
|
|
104
|
-
|
|
105
|
-
|
|
108
|
+
if isinstance(image, RasterImage):
|
|
109
|
+
image.image = apply_padding(
|
|
110
|
+
image.image, True, horizontal_pad[0], horizontal_pad[1]
|
|
111
|
+
)
|
|
112
|
+
image.image = apply_padding(
|
|
113
|
+
image.image, False, vertical_pad[0], vertical_pad[1]
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
image = apply_padding(image, True, horizontal_pad[0], horizontal_pad[1])
|
|
117
|
+
image = apply_padding(image, False, vertical_pad[0], vertical_pad[1])
|
|
106
118
|
return image
|
|
107
119
|
|
|
108
120
|
def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
|
|
@@ -6,6 +6,8 @@ import torch
|
|
|
6
6
|
import torchvision
|
|
7
7
|
from torchvision.transforms import InterpolationMode
|
|
8
8
|
|
|
9
|
+
from rslearn.train.model_context import RasterImage
|
|
10
|
+
|
|
9
11
|
from .transform import Transform
|
|
10
12
|
|
|
11
13
|
INTERPOLATION_MODES = {
|
|
@@ -38,7 +40,9 @@ class Resize(Transform):
|
|
|
38
40
|
self.selectors = selectors
|
|
39
41
|
self.interpolation = INTERPOLATION_MODES[interpolation]
|
|
40
42
|
|
|
41
|
-
def apply_resize(
|
|
43
|
+
def apply_resize(
|
|
44
|
+
self, image: torch.Tensor | RasterImage
|
|
45
|
+
) -> torch.Tensor | RasterImage:
|
|
42
46
|
"""Apply resizing on the specified image.
|
|
43
47
|
|
|
44
48
|
If the image is 2D, it is unsqueezed to 3D and then squeezed
|
|
@@ -47,16 +51,21 @@ class Resize(Transform):
|
|
|
47
51
|
Args:
|
|
48
52
|
image: the image to transform.
|
|
49
53
|
"""
|
|
50
|
-
if image.
|
|
51
|
-
|
|
52
|
-
|
|
54
|
+
if isinstance(image, torch.Tensor):
|
|
55
|
+
if image.dim() == 2:
|
|
56
|
+
image = image.unsqueeze(0) # (H, W) -> (1, H, W)
|
|
57
|
+
result = torchvision.transforms.functional.resize(
|
|
58
|
+
image, self.target_size, self.interpolation
|
|
59
|
+
)
|
|
60
|
+
return result.squeeze(0) # (1, H, W) -> (H, W)
|
|
61
|
+
return torchvision.transforms.functional.resize(
|
|
53
62
|
image, self.target_size, self.interpolation
|
|
54
63
|
)
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
64
|
+
else:
|
|
65
|
+
image.image = torchvision.transforms.functional.resize(
|
|
66
|
+
image.image, self.target_size, self.interpolation
|
|
67
|
+
)
|
|
68
|
+
return image
|
|
60
69
|
|
|
61
70
|
def forward(
|
|
62
71
|
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
@@ -2,6 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
|
+
from rslearn.train.model_context import RasterImage
|
|
6
|
+
|
|
5
7
|
from .transform import Transform, read_selector, write_selector
|
|
6
8
|
|
|
7
9
|
|
|
@@ -49,6 +51,10 @@ class SelectBands(Transform):
|
|
|
49
51
|
if self.num_bands_per_timestep is not None
|
|
50
52
|
else image.shape[0]
|
|
51
53
|
)
|
|
54
|
+
if isinstance(image, RasterImage):
|
|
55
|
+
assert num_bands_per_timestep == image.shape[0], (
|
|
56
|
+
"Expect a seperate dimension for timesteps in RasterImages."
|
|
57
|
+
)
|
|
52
58
|
|
|
53
59
|
if image.shape[0] % num_bands_per_timestep != 0:
|
|
54
60
|
raise ValueError(
|
|
@@ -62,6 +68,9 @@ class SelectBands(Transform):
|
|
|
62
68
|
[(start_channel_idx + band_idx) for band_idx in self.band_indices]
|
|
63
69
|
)
|
|
64
70
|
|
|
65
|
-
|
|
66
|
-
|
|
71
|
+
if isinstance(image, RasterImage):
|
|
72
|
+
image.image = image.image[wanted_bands]
|
|
73
|
+
else:
|
|
74
|
+
image = image[wanted_bands]
|
|
75
|
+
write_selector(input_dict, target_dict, self.output_selector, image)
|
|
67
76
|
return input_dict, target_dict
|