rslearn 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +2 -9
- rslearn/config/__init__.py +2 -0
- rslearn/config/dataset.py +64 -20
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +34 -84
- rslearn/dataset/materialize.py +5 -5
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +26 -80
- rslearn/lightning_cli.py +22 -11
- rslearn/main.py +12 -37
- rslearn/models/anysat.py +11 -9
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +8 -9
- rslearn/models/clip.py +18 -15
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +21 -11
- rslearn/models/conv.py +15 -8
- rslearn/models/croma.py +13 -8
- rslearn/models/detr/detr.py +25 -14
- rslearn/models/dinov3.py +11 -6
- rslearn/models/faster_rcnn.py +19 -9
- rslearn/models/feature_center_crop.py +12 -9
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/galileo.py +23 -18
- rslearn/models/module_wrapper.py +26 -57
- rslearn/models/molmo.py +16 -14
- rslearn/models/multitask.py +102 -73
- rslearn/models/olmoearth_pretrain/model.py +135 -38
- rslearn/models/panopticon.py +8 -7
- rslearn/models/pick_features.py +18 -24
- rslearn/models/pooling_decoder.py +22 -14
- rslearn/models/presto/presto.py +16 -10
- rslearn/models/presto/single_file_presto.py +4 -10
- rslearn/models/prithvi.py +12 -8
- rslearn/models/resize_features.py +21 -7
- rslearn/models/sam2_enc.py +11 -9
- rslearn/models/satlaspretrain.py +15 -9
- rslearn/models/simple_time_series.py +37 -17
- rslearn/models/singletask.py +24 -17
- rslearn/models/ssl4eo_s12.py +15 -10
- rslearn/models/swin.py +22 -13
- rslearn/models/terramind.py +24 -7
- rslearn/models/trunk.py +6 -3
- rslearn/models/unet.py +18 -9
- rslearn/models/upsample.py +22 -9
- rslearn/train/all_patches_dataset.py +89 -37
- rslearn/train/dataset.py +105 -97
- rslearn/train/lightning_module.py +51 -32
- rslearn/train/model_context.py +54 -0
- rslearn/train/prediction_writer.py +111 -41
- rslearn/train/scheduler.py +15 -0
- rslearn/train/tasks/classification.py +34 -15
- rslearn/train/tasks/detection.py +24 -31
- rslearn/train/tasks/embedding.py +33 -29
- rslearn/train/tasks/multi_task.py +7 -7
- rslearn/train/tasks/per_pixel_regression.py +41 -19
- rslearn/train/tasks/regression.py +38 -21
- rslearn/train/tasks/segmentation.py +33 -15
- rslearn/train/tasks/task.py +3 -2
- rslearn/train/transforms/resize.py +74 -0
- rslearn/utils/geometry.py +73 -0
- rslearn/utils/jsonargparse.py +66 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/RECORD +71 -66
- rslearn/dataset/index.py +0 -173
- rslearn/models/registry.py +0 -22
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
rslearn/models/ssl4eo_s12.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
"""SSL4EO-S12 models."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
3
|
import torch
|
|
6
4
|
import torchvision
|
|
7
5
|
|
|
6
|
+
from rslearn.train.model_context import ModelContext
|
|
7
|
+
|
|
8
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
9
|
+
|
|
8
10
|
|
|
9
|
-
class Ssl4eoS12(
|
|
11
|
+
class Ssl4eoS12(FeatureExtractor):
|
|
10
12
|
"""The SSL4EO-S12 family of pretrained models."""
|
|
11
13
|
|
|
12
14
|
def __init__(
|
|
@@ -74,19 +76,22 @@ class Ssl4eoS12(torch.nn.Module):
|
|
|
74
76
|
|
|
75
77
|
def forward(
|
|
76
78
|
self,
|
|
77
|
-
|
|
78
|
-
) ->
|
|
79
|
+
context: ModelContext,
|
|
80
|
+
) -> FeatureMaps:
|
|
79
81
|
"""Compute outputs from the backbone.
|
|
80
82
|
|
|
81
83
|
If output_layers is set, then the outputs are multi-scale feature maps;
|
|
82
84
|
otherwise, the model is being used for classification so the outputs are class
|
|
83
85
|
probabilities and the loss.
|
|
84
86
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
process.
|
|
87
|
+
Args:
|
|
88
|
+
context: the model context. Input dicts must include "image" key containing
|
|
89
|
+
the images to process.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
feature maps computed by the pre-trained model.
|
|
88
93
|
"""
|
|
89
|
-
x = torch.stack([inp["image"] for inp in inputs], dim=0)
|
|
94
|
+
x = torch.stack([inp["image"] for inp in context.inputs], dim=0)
|
|
90
95
|
x = self.model.conv1(x)
|
|
91
96
|
x = self.model.bn1(x)
|
|
92
97
|
x = self.model.relu(x)
|
|
@@ -97,4 +102,4 @@ class Ssl4eoS12(torch.nn.Module):
|
|
|
97
102
|
layer3 = self.model.layer3(layer2)
|
|
98
103
|
layer4 = self.model.layer4(layer3)
|
|
99
104
|
all_features = [layer1, layer2, layer3, layer4]
|
|
100
|
-
return [all_features[idx] for idx in self.output_layers]
|
|
105
|
+
return FeatureMaps([all_features[idx] for idx in self.output_layers])
|
rslearn/models/swin.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""Swin Transformer."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
3
|
import torch
|
|
6
4
|
import torchvision
|
|
7
5
|
from torchvision.models.swin_transformer import (
|
|
@@ -13,8 +11,12 @@ from torchvision.models.swin_transformer import (
|
|
|
13
11
|
Swin_V2_T_Weights,
|
|
14
12
|
)
|
|
15
13
|
|
|
14
|
+
from rslearn.train.model_context import ModelContext
|
|
15
|
+
|
|
16
|
+
from .component import FeatureExtractor, FeatureMaps, FeatureVector
|
|
17
|
+
|
|
16
18
|
|
|
17
|
-
class Swin(
|
|
19
|
+
class Swin(FeatureExtractor):
|
|
18
20
|
"""A Swin Transformer model.
|
|
19
21
|
|
|
20
22
|
It can either be used stand-alone for classification, or as a feature extractor in
|
|
@@ -34,9 +36,12 @@ class Swin(torch.nn.Module):
|
|
|
34
36
|
Args:
|
|
35
37
|
arch: the architecture, e.g. "swin_v2_b" (default) or "swin_t"
|
|
36
38
|
pretrained: set True to use ImageNet pre-trained weights
|
|
37
|
-
input_channels: number of input channels (default 3)
|
|
39
|
+
input_channels: number of input channels (default 3). If not 3, the first
|
|
40
|
+
layer is updated and will be randomly initialized even if pretrained is
|
|
41
|
+
set.
|
|
38
42
|
output_layers: list of layers to output, default use as classification
|
|
39
|
-
model. For feature extraction, [1, 3, 5, 7] is
|
|
43
|
+
model (output FeatureVector). For feature extraction, [1, 3, 5, 7] is
|
|
44
|
+
recommended.
|
|
40
45
|
num_outputs: number of output logits, defaults to 1000 which matches the
|
|
41
46
|
pretrained models.
|
|
42
47
|
"""
|
|
@@ -130,19 +135,23 @@ class Swin(torch.nn.Module):
|
|
|
130
135
|
|
|
131
136
|
def forward(
|
|
132
137
|
self,
|
|
133
|
-
|
|
134
|
-
) ->
|
|
138
|
+
context: ModelContext,
|
|
139
|
+
) -> FeatureVector | FeatureMaps:
|
|
135
140
|
"""Compute outputs from the backbone.
|
|
136
141
|
|
|
137
142
|
If output_layers is set, then the outputs are multi-scale feature maps;
|
|
138
143
|
otherwise, the model is being used for classification so the outputs are class
|
|
139
144
|
probabilities and the loss.
|
|
140
145
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
process.
|
|
146
|
+
Args:
|
|
147
|
+
context: the model context. Input dicts must include "image" key containing
|
|
148
|
+
the image to process.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
a FeatureVector if the configured output_layers is None, or a FeatureMaps
|
|
152
|
+
otherwise containing one feature map per configured output layer.
|
|
144
153
|
"""
|
|
145
|
-
images = torch.stack([inp["image"] for inp in inputs], dim=0)
|
|
154
|
+
images = torch.stack([inp["image"] for inp in context.inputs], dim=0)
|
|
146
155
|
|
|
147
156
|
if self.output_layers:
|
|
148
157
|
layer_features = []
|
|
@@ -150,7 +159,7 @@ class Swin(torch.nn.Module):
|
|
|
150
159
|
for layer in self.model.features:
|
|
151
160
|
x = layer(x)
|
|
152
161
|
layer_features.append(x.permute(0, 3, 1, 2))
|
|
153
|
-
return [layer_features[idx] for idx in self.output_layers]
|
|
162
|
+
return FeatureMaps([layer_features[idx] for idx in self.output_layers])
|
|
154
163
|
|
|
155
164
|
else:
|
|
156
|
-
return self.model(images)
|
|
165
|
+
return FeatureVector(self.model(images))
|
rslearn/models/terramind.py
CHANGED
|
@@ -8,8 +8,11 @@ import torch.nn.functional as F
|
|
|
8
8
|
from einops import rearrange
|
|
9
9
|
from terratorch.registry import BACKBONE_REGISTRY
|
|
10
10
|
|
|
11
|
+
from rslearn.train.model_context import ModelContext
|
|
11
12
|
from rslearn.train.transforms.transform import Transform
|
|
12
13
|
|
|
14
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
15
|
+
|
|
13
16
|
|
|
14
17
|
# TerraMind v1 provides two sizes: base and large
|
|
15
18
|
class TerramindSize(str, Enum):
|
|
@@ -85,7 +88,7 @@ PRETRAINED_BANDS = {
|
|
|
85
88
|
}
|
|
86
89
|
|
|
87
90
|
|
|
88
|
-
class Terramind(
|
|
91
|
+
class Terramind(FeatureExtractor):
|
|
89
92
|
"""Terramind backbones."""
|
|
90
93
|
|
|
91
94
|
def __init__(
|
|
@@ -123,21 +126,25 @@ class Terramind(torch.nn.Module):
|
|
|
123
126
|
self.modalities = modalities
|
|
124
127
|
self.do_resizing = do_resizing
|
|
125
128
|
|
|
126
|
-
def forward(self,
|
|
129
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
127
130
|
"""Forward pass for the Terramind model.
|
|
128
131
|
|
|
129
132
|
Args:
|
|
130
|
-
|
|
133
|
+
context: the model context. Input dicts must include modalities as keys
|
|
134
|
+
which are defined in the self.modalities list.
|
|
131
135
|
|
|
132
136
|
Returns:
|
|
133
|
-
|
|
137
|
+
a FeatureMaps with one feature map from the encoder, at 1/16 of the input
|
|
138
|
+
resolution.
|
|
134
139
|
"""
|
|
135
140
|
model_inputs = {}
|
|
136
141
|
for modality in self.modalities:
|
|
137
142
|
# We assume the all the inputs include the same modalities
|
|
138
|
-
if modality not in inputs[0]:
|
|
143
|
+
if modality not in context.inputs[0]:
|
|
139
144
|
continue
|
|
140
|
-
cur = torch.stack(
|
|
145
|
+
cur = torch.stack(
|
|
146
|
+
[inp[modality] for inp in context.inputs], dim=0
|
|
147
|
+
) # (B, C, H, W)
|
|
141
148
|
if self.do_resizing and (
|
|
142
149
|
cur.shape[2] != IMAGE_SIZE or cur.shape[3] != IMAGE_SIZE
|
|
143
150
|
):
|
|
@@ -159,7 +166,17 @@ class Terramind(torch.nn.Module):
|
|
|
159
166
|
image_features = self.model(model_inputs)[-1]
|
|
160
167
|
batch_size, num_patches, _ = image_features.shape
|
|
161
168
|
height, width = int(num_patches**0.5), int(num_patches**0.5)
|
|
162
|
-
return
|
|
169
|
+
return FeatureMaps(
|
|
170
|
+
[
|
|
171
|
+
rearrange(
|
|
172
|
+
image_features,
|
|
173
|
+
"b (h w) d -> b d h w",
|
|
174
|
+
b=batch_size,
|
|
175
|
+
h=height,
|
|
176
|
+
w=width,
|
|
177
|
+
)
|
|
178
|
+
]
|
|
179
|
+
)
|
|
163
180
|
|
|
164
181
|
def get_backbone_channels(self) -> list:
|
|
165
182
|
"""Returns the output channels of this model when used as a backbone.
|
rslearn/models/trunk.py
CHANGED
|
@@ -7,6 +7,7 @@ import torch
|
|
|
7
7
|
|
|
8
8
|
from rslearn.log_utils import get_logger
|
|
9
9
|
from rslearn.models.task_embedding import BaseTaskEmbedding
|
|
10
|
+
from rslearn.train.model_context import ModelOutput
|
|
10
11
|
|
|
11
12
|
logger = get_logger(__name__)
|
|
12
13
|
|
|
@@ -32,10 +33,11 @@ class DecoderTrunkLayer(torch.nn.Module, ABC):
|
|
|
32
33
|
dict with key "outputs" (output tensor of shape (batch_size, seq_len, dim))
|
|
33
34
|
and optionally other keys.
|
|
34
35
|
"""
|
|
36
|
+
raise NotImplementedError
|
|
35
37
|
|
|
36
38
|
@abstractmethod
|
|
37
39
|
def apply_auxiliary_losses(
|
|
38
|
-
self, trunk_out: dict[str, Any], outs:
|
|
40
|
+
self, trunk_out: dict[str, Any], outs: ModelOutput
|
|
39
41
|
) -> None:
|
|
40
42
|
"""Apply auxiliary losses in-place.
|
|
41
43
|
|
|
@@ -43,6 +45,7 @@ class DecoderTrunkLayer(torch.nn.Module, ABC):
|
|
|
43
45
|
trunk_out: The output of the trunk.
|
|
44
46
|
outs: The output of the decoders, with key "loss_dict" containing the losses.
|
|
45
47
|
"""
|
|
48
|
+
raise NotImplementedError
|
|
46
49
|
|
|
47
50
|
|
|
48
51
|
class DecoderTrunk(torch.nn.Module):
|
|
@@ -122,7 +125,7 @@ class DecoderTrunk(torch.nn.Module):
|
|
|
122
125
|
return out
|
|
123
126
|
|
|
124
127
|
def apply_auxiliary_losses(
|
|
125
|
-
self, trunk_out: dict[str, Any], outs:
|
|
128
|
+
self, trunk_out: dict[str, Any], outs: ModelOutput
|
|
126
129
|
) -> None:
|
|
127
130
|
"""Apply auxiliary losses in-place.
|
|
128
131
|
|
|
@@ -130,7 +133,7 @@ class DecoderTrunk(torch.nn.Module):
|
|
|
130
133
|
|
|
131
134
|
Args:
|
|
132
135
|
trunk_out: The output of the trunk.
|
|
133
|
-
outs: The output of the decoders
|
|
136
|
+
outs: The output of the decoders.
|
|
134
137
|
"""
|
|
135
138
|
for layer in self.layers:
|
|
136
139
|
layer.apply_auxiliary_losses(trunk_out, outs)
|
rslearn/models/unet.py
CHANGED
|
@@ -5,8 +5,15 @@ from typing import Any
|
|
|
5
5
|
import torch
|
|
6
6
|
import torch.nn.functional as F
|
|
7
7
|
|
|
8
|
+
from rslearn.train.model_context import ModelContext
|
|
8
9
|
|
|
9
|
-
|
|
10
|
+
from .component import (
|
|
11
|
+
FeatureMaps,
|
|
12
|
+
IntermediateComponent,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class UNetDecoder(IntermediateComponent):
|
|
10
17
|
"""UNet-style decoder.
|
|
11
18
|
|
|
12
19
|
It inputs multi-scale features. Starting from last (lowest resolution) feature map,
|
|
@@ -143,23 +150,25 @@ class UNetDecoder(torch.nn.Module):
|
|
|
143
150
|
align_corners=False,
|
|
144
151
|
)
|
|
145
152
|
|
|
146
|
-
def forward(
|
|
147
|
-
self, in_features: list[torch.Tensor], inputs: list[dict[str, Any]]
|
|
148
|
-
) -> torch.Tensor:
|
|
153
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
149
154
|
"""Compute output from multi-scale feature map.
|
|
150
155
|
|
|
151
156
|
Args:
|
|
152
|
-
|
|
153
|
-
|
|
157
|
+
intermediates: the output from the previous model component, which must be a FeatureMaps.
|
|
158
|
+
context: the model context.
|
|
154
159
|
|
|
155
160
|
Returns:
|
|
156
|
-
output
|
|
161
|
+
output FeatureMaps consisting of one map. The embedding size is equal to the
|
|
162
|
+
configured out_channels.
|
|
157
163
|
"""
|
|
164
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
165
|
+
raise ValueError("input to UNetDecoder must be a FeatureMaps")
|
|
166
|
+
|
|
158
167
|
# Reverse the features since we will pass them in from lowest resolution to highest.
|
|
159
|
-
in_features = list(reversed(
|
|
168
|
+
in_features = list(reversed(intermediates.feature_maps))
|
|
160
169
|
cur_features = self.layers[0](in_features[0])
|
|
161
170
|
for in_feat, layer in zip(in_features[1:], self.layers[1:]):
|
|
162
171
|
cur_features = layer(torch.cat([cur_features, in_feat], dim=1))
|
|
163
172
|
if self.original_size_to_interpolate is not None:
|
|
164
173
|
cur_features = self._resize(cur_features)
|
|
165
|
-
return cur_features
|
|
174
|
+
return FeatureMaps([cur_features])
|
rslearn/models/upsample.py
CHANGED
|
@@ -1,9 +1,18 @@
|
|
|
1
1
|
"""An upsampling layer."""
|
|
2
2
|
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
3
5
|
import torch
|
|
4
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
8
|
+
|
|
9
|
+
from .component import (
|
|
10
|
+
FeatureMaps,
|
|
11
|
+
IntermediateComponent,
|
|
12
|
+
)
|
|
5
13
|
|
|
6
|
-
|
|
14
|
+
|
|
15
|
+
class Upsample(IntermediateComponent):
|
|
7
16
|
"""Upsamples each input feature map by the same factor."""
|
|
8
17
|
|
|
9
18
|
def __init__(
|
|
@@ -20,16 +29,20 @@ class Upsample(torch.nn.Module):
|
|
|
20
29
|
super().__init__()
|
|
21
30
|
self.layer = torch.nn.Upsample(scale_factor=scale_factor, mode=mode)
|
|
22
31
|
|
|
23
|
-
def forward(
|
|
24
|
-
|
|
25
|
-
) -> list[torch.Tensor]:
|
|
26
|
-
"""Upsample each feature map.
|
|
32
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
33
|
+
"""Upsample each feature map by scale_factor.
|
|
27
34
|
|
|
28
35
|
Args:
|
|
29
|
-
|
|
30
|
-
|
|
36
|
+
intermediates: the output from the previous component, which must be a FeatureMaps.
|
|
37
|
+
context: the model context.
|
|
31
38
|
|
|
32
39
|
Returns:
|
|
33
|
-
upsampled feature maps
|
|
40
|
+
upsampled feature maps.
|
|
34
41
|
"""
|
|
35
|
-
|
|
42
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
43
|
+
raise ValueError("input to Upsample must be a FeatureMaps")
|
|
44
|
+
|
|
45
|
+
upsampled_feat_maps = [
|
|
46
|
+
self.layer(feat_map) for feat_map in intermediates.feature_maps
|
|
47
|
+
]
|
|
48
|
+
return FeatureMaps(upsampled_feat_maps)
|
|
@@ -2,13 +2,15 @@
|
|
|
2
2
|
|
|
3
3
|
import itertools
|
|
4
4
|
from collections.abc import Iterable, Iterator
|
|
5
|
+
from dataclasses import replace
|
|
5
6
|
from typing import Any
|
|
6
7
|
|
|
7
8
|
import shapely
|
|
8
9
|
import torch
|
|
9
10
|
|
|
10
11
|
from rslearn.dataset import Window
|
|
11
|
-
from rslearn.train.dataset import ModelDataset
|
|
12
|
+
from rslearn.train.dataset import DataInput, ModelDataset
|
|
13
|
+
from rslearn.train.model_context import SampleMetadata
|
|
12
14
|
from rslearn.utils.geometry import PixelBounds, STGeometry
|
|
13
15
|
|
|
14
16
|
|
|
@@ -32,22 +34,28 @@ def get_window_patch_options(
|
|
|
32
34
|
bottommost patches may extend beyond the provided bounds.
|
|
33
35
|
"""
|
|
34
36
|
# We stride the patches by patch_size - overlap_size until the last patch.
|
|
37
|
+
# We handle the first patch with a special case to ensure it is always used.
|
|
35
38
|
# We handle the last patch with a special case to ensure it does not exceed the
|
|
36
39
|
# window bounds. Instead, it may overlap the previous patch.
|
|
37
|
-
cols = list(
|
|
40
|
+
cols = [bounds[0]] + list(
|
|
38
41
|
range(
|
|
39
|
-
bounds[0],
|
|
42
|
+
bounds[0] + patch_size[0],
|
|
40
43
|
bounds[2] - patch_size[0],
|
|
41
44
|
patch_size[0] - overlap_size[0],
|
|
42
45
|
)
|
|
43
|
-
)
|
|
44
|
-
rows = list(
|
|
46
|
+
)
|
|
47
|
+
rows = [bounds[1]] + list(
|
|
45
48
|
range(
|
|
46
|
-
bounds[1],
|
|
49
|
+
bounds[1] + patch_size[1],
|
|
47
50
|
bounds[3] - patch_size[1],
|
|
48
51
|
patch_size[1] - overlap_size[1],
|
|
49
52
|
)
|
|
50
|
-
)
|
|
53
|
+
)
|
|
54
|
+
# Add last patches only if the input is larger than one patch.
|
|
55
|
+
if bounds[2] - patch_size[0] > bounds[0]:
|
|
56
|
+
cols.append(bounds[2] - patch_size[0])
|
|
57
|
+
if bounds[3] - patch_size[1] > bounds[1]:
|
|
58
|
+
rows.append(bounds[3] - patch_size[1])
|
|
51
59
|
|
|
52
60
|
patch_bounds: list[PixelBounds] = []
|
|
53
61
|
for col in cols:
|
|
@@ -60,13 +68,17 @@ def pad_slice_protect(
|
|
|
60
68
|
raw_inputs: dict[str, Any],
|
|
61
69
|
passthrough_inputs: dict[str, Any],
|
|
62
70
|
patch_size: tuple[int, int],
|
|
71
|
+
inputs: dict[str, DataInput],
|
|
63
72
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
64
73
|
"""Pad tensors in-place by patch size to protect slicing near right/bottom edges.
|
|
65
74
|
|
|
75
|
+
The padding is scaled based on each input's resolution_factor.
|
|
76
|
+
|
|
66
77
|
Args:
|
|
67
78
|
raw_inputs: the raw inputs to pad.
|
|
68
79
|
passthrough_inputs: the passthrough inputs to pad.
|
|
69
|
-
patch_size: the size of the patches to extract.
|
|
80
|
+
patch_size: the size of the patches to extract (at window resolution).
|
|
81
|
+
inputs: the DataInput definitions, used to get resolution_factor per input.
|
|
70
82
|
|
|
71
83
|
Returns:
|
|
72
84
|
a tuple of (raw_inputs, passthrough_inputs).
|
|
@@ -75,8 +87,14 @@ def pad_slice_protect(
|
|
|
75
87
|
for input_name, value in list(d.items()):
|
|
76
88
|
if not isinstance(value, torch.Tensor):
|
|
77
89
|
continue
|
|
90
|
+
# Get resolution scale for this input
|
|
91
|
+
rf = inputs[input_name].resolution_factor
|
|
92
|
+
scale = rf.numerator / rf.denominator
|
|
93
|
+
# Scale the padding amount
|
|
94
|
+
scaled_pad_x = int(patch_size[0] * scale)
|
|
95
|
+
scaled_pad_y = int(patch_size[1] * scale)
|
|
78
96
|
d[input_name] = torch.nn.functional.pad(
|
|
79
|
-
value, pad=(0,
|
|
97
|
+
value, pad=(0, scaled_pad_x, 0, scaled_pad_y)
|
|
80
98
|
)
|
|
81
99
|
return raw_inputs, passthrough_inputs
|
|
82
100
|
|
|
@@ -121,6 +139,7 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
121
139
|
self.rank = rank
|
|
122
140
|
self.world_size = world_size
|
|
123
141
|
self.windows = self.dataset.get_dataset_examples()
|
|
142
|
+
self.inputs = dataset.inputs
|
|
124
143
|
|
|
125
144
|
def set_name(self, name: str) -> None:
|
|
126
145
|
"""Sets dataset name.
|
|
@@ -218,7 +237,7 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
218
237
|
|
|
219
238
|
def __iter__(
|
|
220
239
|
self,
|
|
221
|
-
) -> Iterator[tuple[dict[str, Any], dict[str, Any],
|
|
240
|
+
) -> Iterator[tuple[dict[str, Any], dict[str, Any], SampleMetadata]]:
|
|
222
241
|
"""Iterate over all patches in each element of the underlying ModelDataset."""
|
|
223
242
|
# Iterate over the window IDs until we have returned enough samples.
|
|
224
243
|
window_ids, num_samples_needed = self._get_worker_iteration_data()
|
|
@@ -229,12 +248,14 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
229
248
|
raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(
|
|
230
249
|
window_id
|
|
231
250
|
)
|
|
232
|
-
bounds = metadata
|
|
251
|
+
bounds = metadata.patch_bounds
|
|
233
252
|
|
|
234
253
|
# For simplicity, pad tensors by patch size to ensure that any patch bounds
|
|
235
254
|
# extending outside the window bounds will not have issues when we slice
|
|
236
|
-
# the tensors later.
|
|
237
|
-
pad_slice_protect(
|
|
255
|
+
# the tensors later. Padding is scaled per-input based on resolution_factor.
|
|
256
|
+
pad_slice_protect(
|
|
257
|
+
raw_inputs, passthrough_inputs, self.patch_size, self.inputs
|
|
258
|
+
)
|
|
238
259
|
|
|
239
260
|
# Now iterate over the patches and extract/yield the crops.
|
|
240
261
|
# Note that, in case user is leveraging RslearnWriter, it is important that
|
|
@@ -244,7 +265,7 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
244
265
|
)
|
|
245
266
|
for patch_idx, patch_bounds in enumerate(patches):
|
|
246
267
|
cur_geom = STGeometry(
|
|
247
|
-
metadata
|
|
268
|
+
metadata.projection, shapely.box(*patch_bounds), None
|
|
248
269
|
)
|
|
249
270
|
start_offset = (
|
|
250
271
|
patch_bounds[0] - bounds[0],
|
|
@@ -256,15 +277,28 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
256
277
|
)
|
|
257
278
|
|
|
258
279
|
# Define a helper function to handle each input dict.
|
|
280
|
+
# Crop coordinates are scaled based on each input's resolution_factor.
|
|
259
281
|
def crop_input_dict(d: dict[str, Any]) -> dict[str, Any]:
|
|
260
282
|
cropped = {}
|
|
261
283
|
for input_name, value in d.items():
|
|
262
284
|
if isinstance(value, torch.Tensor):
|
|
263
|
-
#
|
|
285
|
+
# Get resolution scale for this input
|
|
286
|
+
rf = self.inputs[input_name].resolution_factor
|
|
287
|
+
scale = rf.numerator / rf.denominator
|
|
288
|
+
# Scale the crop coordinates
|
|
289
|
+
scaled_start = (
|
|
290
|
+
int(start_offset[0] * scale),
|
|
291
|
+
int(start_offset[1] * scale),
|
|
292
|
+
)
|
|
293
|
+
scaled_end = (
|
|
294
|
+
int(end_offset[0] * scale),
|
|
295
|
+
int(end_offset[1] * scale),
|
|
296
|
+
)
|
|
297
|
+
# Crop the CHW tensor with scaled coordinates.
|
|
264
298
|
cropped[input_name] = value[
|
|
265
299
|
:,
|
|
266
|
-
|
|
267
|
-
|
|
300
|
+
scaled_start[1] : scaled_end[1],
|
|
301
|
+
scaled_start[0] : scaled_end[0],
|
|
268
302
|
].clone()
|
|
269
303
|
elif isinstance(value, list):
|
|
270
304
|
cropped[input_name] = [
|
|
@@ -282,10 +316,12 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
282
316
|
cur_passthrough_inputs = crop_input_dict(passthrough_inputs)
|
|
283
317
|
|
|
284
318
|
# Adjust the metadata as well.
|
|
285
|
-
cur_metadata =
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
319
|
+
cur_metadata = replace(
|
|
320
|
+
metadata,
|
|
321
|
+
patch_bounds=patch_bounds,
|
|
322
|
+
patch_idx=patch_idx,
|
|
323
|
+
num_patches_in_window=len(patches),
|
|
324
|
+
)
|
|
289
325
|
|
|
290
326
|
# Now we can compute input and target dicts via the task.
|
|
291
327
|
input_dict, target_dict = self.dataset.task.process_inputs(
|
|
@@ -297,7 +333,6 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
297
333
|
input_dict, target_dict = self.dataset.transforms(
|
|
298
334
|
input_dict, target_dict
|
|
299
335
|
)
|
|
300
|
-
input_dict["dataset_source"] = self.dataset.name
|
|
301
336
|
|
|
302
337
|
if num_samples_returned < num_samples_needed:
|
|
303
338
|
yield input_dict, target_dict, cur_metadata
|
|
@@ -345,8 +380,9 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
|
345
380
|
round(self.patch_size[1] * overlap_ratio),
|
|
346
381
|
)
|
|
347
382
|
self.windows = self.dataset.get_dataset_examples()
|
|
383
|
+
self.inputs = dataset.inputs
|
|
348
384
|
self.window_cache: dict[
|
|
349
|
-
int, tuple[dict[str, Any], dict[str, Any],
|
|
385
|
+
int, tuple[dict[str, Any], dict[str, Any], SampleMetadata]
|
|
350
386
|
] = {}
|
|
351
387
|
|
|
352
388
|
# Precompute the batch boundaries for each window
|
|
@@ -360,7 +396,7 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
|
360
396
|
|
|
361
397
|
def get_raw_inputs(
|
|
362
398
|
self, index: int
|
|
363
|
-
) -> tuple[dict[str, Any], dict[str, Any],
|
|
399
|
+
) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
|
|
364
400
|
"""Get the raw inputs for a single patch. Retrieve from cache if possible.
|
|
365
401
|
|
|
366
402
|
Also crops/pads the tensors by patch size to protect slicing near right/bottom edges.
|
|
@@ -375,26 +411,41 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
|
375
411
|
return self.window_cache[index]
|
|
376
412
|
|
|
377
413
|
raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(index)
|
|
378
|
-
pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size)
|
|
414
|
+
pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size, self.inputs)
|
|
379
415
|
|
|
380
416
|
self.window_cache[index] = (raw_inputs, passthrough_inputs, metadata)
|
|
381
417
|
return self.window_cache[index]
|
|
382
418
|
|
|
383
|
-
@staticmethod
|
|
384
419
|
def _crop_input_dict(
|
|
420
|
+
self,
|
|
385
421
|
d: dict[str, Any],
|
|
386
422
|
start_offset: tuple[int, int],
|
|
387
423
|
end_offset: tuple[int, int],
|
|
388
424
|
cur_geom: STGeometry,
|
|
389
425
|
) -> dict[str, Any]:
|
|
390
|
-
"""Crop a dictionary of inputs to the given bounds.
|
|
426
|
+
"""Crop a dictionary of inputs to the given bounds.
|
|
427
|
+
|
|
428
|
+
Crop coordinates are scaled based on each input's resolution_factor.
|
|
429
|
+
"""
|
|
391
430
|
cropped = {}
|
|
392
431
|
for input_name, value in d.items():
|
|
393
432
|
if isinstance(value, torch.Tensor):
|
|
433
|
+
# Get resolution scale for this input
|
|
434
|
+
rf = self.inputs[input_name].resolution_factor
|
|
435
|
+
scale = rf.numerator / rf.denominator
|
|
436
|
+
# Scale the crop coordinates
|
|
437
|
+
scaled_start = (
|
|
438
|
+
int(start_offset[0] * scale),
|
|
439
|
+
int(start_offset[1] * scale),
|
|
440
|
+
)
|
|
441
|
+
scaled_end = (
|
|
442
|
+
int(end_offset[0] * scale),
|
|
443
|
+
int(end_offset[1] * scale),
|
|
444
|
+
)
|
|
394
445
|
cropped[input_name] = value[
|
|
395
446
|
:,
|
|
396
|
-
|
|
397
|
-
|
|
447
|
+
scaled_start[1] : scaled_end[1],
|
|
448
|
+
scaled_start[0] : scaled_end[0],
|
|
398
449
|
].clone()
|
|
399
450
|
elif isinstance(value, list):
|
|
400
451
|
cropped[input_name] = [
|
|
@@ -410,13 +461,13 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
|
410
461
|
|
|
411
462
|
def __getitem__(
|
|
412
463
|
self, index: int
|
|
413
|
-
) -> tuple[dict[str, Any], dict[str, Any],
|
|
464
|
+
) -> tuple[dict[str, Any], dict[str, Any], SampleMetadata]:
|
|
414
465
|
"""Return (input_dict, target_dict, metadata) for a single flattened patch."""
|
|
415
466
|
(window_id, patch_bounds, (patch_idx, num_patches)) = self.patches[index]
|
|
416
467
|
raw_inputs, passthrough_inputs, metadata = self.get_raw_inputs(window_id)
|
|
417
|
-
bounds = metadata
|
|
468
|
+
bounds = metadata.patch_bounds
|
|
418
469
|
|
|
419
|
-
cur_geom = STGeometry(metadata
|
|
470
|
+
cur_geom = STGeometry(metadata.projection, shapely.box(*patch_bounds), None)
|
|
420
471
|
start_offset = (patch_bounds[0] - bounds[0], patch_bounds[1] - bounds[1])
|
|
421
472
|
end_offset = (patch_bounds[2] - bounds[0], patch_bounds[3] - bounds[1])
|
|
422
473
|
|
|
@@ -428,10 +479,12 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
|
428
479
|
)
|
|
429
480
|
|
|
430
481
|
# Adjust the metadata as well.
|
|
431
|
-
cur_metadata =
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
482
|
+
cur_metadata = replace(
|
|
483
|
+
metadata,
|
|
484
|
+
patch_bounds=patch_bounds,
|
|
485
|
+
patch_idx=patch_idx,
|
|
486
|
+
num_patches_in_window=num_patches,
|
|
487
|
+
)
|
|
435
488
|
|
|
436
489
|
# Now we can compute input and target dicts via the task.
|
|
437
490
|
input_dict, target_dict = self.dataset.task.process_inputs(
|
|
@@ -441,7 +494,6 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
|
441
494
|
)
|
|
442
495
|
input_dict.update(cur_passthrough_inputs)
|
|
443
496
|
input_dict, target_dict = self.dataset.transforms(input_dict, target_dict)
|
|
444
|
-
input_dict["dataset_source"] = self.dataset.name
|
|
445
497
|
|
|
446
498
|
return input_dict, target_dict, cur_metadata
|
|
447
499
|
|