rslearn 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +2 -9
- rslearn/config/dataset.py +15 -16
- rslearn/dataset/dataset.py +28 -22
- rslearn/lightning_cli.py +22 -11
- rslearn/main.py +1 -1
- rslearn/models/anysat.py +35 -33
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clip.py +5 -2
- rslearn/models/component.py +12 -0
- rslearn/models/croma.py +11 -3
- rslearn/models/dinov3.py +2 -1
- rslearn/models/faster_rcnn.py +2 -1
- rslearn/models/galileo/galileo.py +58 -31
- rslearn/models/module_wrapper.py +6 -1
- rslearn/models/molmo.py +4 -2
- rslearn/models/olmoearth_pretrain/model.py +206 -51
- rslearn/models/olmoearth_pretrain/norm.py +5 -3
- rslearn/models/panopticon.py +3 -1
- rslearn/models/presto/presto.py +45 -15
- rslearn/models/prithvi.py +9 -7
- rslearn/models/sam2_enc.py +3 -1
- rslearn/models/satlaspretrain.py +4 -1
- rslearn/models/simple_time_series.py +43 -17
- rslearn/models/ssl4eo_s12.py +19 -14
- rslearn/models/swin.py +3 -1
- rslearn/models/terramind.py +5 -4
- rslearn/train/all_patches_dataset.py +96 -28
- rslearn/train/dataset.py +102 -53
- rslearn/train/model_context.py +35 -1
- rslearn/train/scheduler.py +15 -0
- rslearn/train/tasks/classification.py +8 -2
- rslearn/train/tasks/detection.py +3 -2
- rslearn/train/tasks/multi_task.py +2 -3
- rslearn/train/tasks/per_pixel_regression.py +14 -5
- rslearn/train/tasks/regression.py +8 -2
- rslearn/train/tasks/segmentation.py +13 -4
- rslearn/train/tasks/task.py +2 -2
- rslearn/train/transforms/concatenate.py +45 -5
- rslearn/train/transforms/crop.py +22 -8
- rslearn/train/transforms/flip.py +13 -5
- rslearn/train/transforms/mask.py +11 -2
- rslearn/train/transforms/normalize.py +46 -15
- rslearn/train/transforms/pad.py +15 -3
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +11 -2
- rslearn/train/transforms/sentinel1.py +18 -3
- rslearn/utils/geometry.py +73 -0
- rslearn/utils/jsonargparse.py +66 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/METADATA +1 -1
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/RECORD +55 -53
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/WHEEL +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/top_level.txt +0 -0
rslearn/models/presto/presto.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import tempfile
|
|
5
|
+
from datetime import datetime
|
|
5
6
|
|
|
6
7
|
import torch
|
|
7
8
|
from einops import rearrange, repeat
|
|
@@ -118,21 +119,21 @@ class Presto(FeatureExtractor):
|
|
|
118
119
|
of each timestep for that pixel
|
|
119
120
|
"""
|
|
120
121
|
bs = [x.shape[0] for x in [s1, s2, era5, srtm] if x is not None]
|
|
121
|
-
|
|
122
|
-
|
|
122
|
+
ts = [x.shape[2] for x in [s1, s2, era5, srtm] if x is not None]
|
|
123
|
+
hs = [x.shape[3] for x in [s1, s2, era5, srtm] if x is not None]
|
|
124
|
+
ws = [x.shape[4] for x in [s1, s2, era5, srtm] if x is not None]
|
|
123
125
|
devices = [x.device for x in [s1, s2, era5, srtm] if x is not None]
|
|
124
126
|
|
|
125
127
|
assert len(set(bs)) == 1
|
|
126
128
|
assert len(set(hs)) == 1
|
|
127
129
|
assert len(set(ws)) == 1
|
|
128
130
|
assert len(set(devices)) == 1
|
|
129
|
-
|
|
130
|
-
|
|
131
|
+
assert len(set(ts)) == 1
|
|
132
|
+
b, h, w, t, device = bs[0], hs[0], ws[0], ts[0], devices[0]
|
|
131
133
|
# these values will be initialized as
|
|
132
134
|
# we iterate through the data
|
|
133
135
|
x: torch.Tensor | None = None
|
|
134
136
|
mask: torch.Tensor | None = None
|
|
135
|
-
t: int | None = None
|
|
136
137
|
|
|
137
138
|
for band_group in [
|
|
138
139
|
(s1, s1_bands),
|
|
@@ -146,14 +147,7 @@ class Presto(FeatureExtractor):
|
|
|
146
147
|
else:
|
|
147
148
|
continue
|
|
148
149
|
|
|
149
|
-
|
|
150
|
-
if t is None:
|
|
151
|
-
t = m_t
|
|
152
|
-
else:
|
|
153
|
-
if t != m_t:
|
|
154
|
-
raise ValueError("inconsistent values for t")
|
|
155
|
-
|
|
156
|
-
data = rearrange(data, "b (t c) h w -> b t h w c", t=m_t)
|
|
150
|
+
data = rearrange(data, "b c t h w -> b t h w c")
|
|
157
151
|
if x is None:
|
|
158
152
|
x = torch.zeros(b, t, h, w, len(INPUT_PRESTO_BANDS), device=device)
|
|
159
153
|
if mask is None:
|
|
@@ -184,6 +178,23 @@ class Presto(FeatureExtractor):
|
|
|
184
178
|
x = (x + PRESTO_ADD_BY.to(device=device)) / PRESTO_DIV_BY.to(device=device)
|
|
185
179
|
return x, mask, dynamic_world.long(), months.long()
|
|
186
180
|
|
|
181
|
+
@staticmethod
|
|
182
|
+
def time_ranges_to_timestamps(
|
|
183
|
+
time_ranges: list[tuple[datetime, datetime]],
|
|
184
|
+
device: torch.device,
|
|
185
|
+
) -> torch.Tensor:
|
|
186
|
+
"""Turn the time ranges stored in a RasterImage to timestamps accepted by Presto.
|
|
187
|
+
|
|
188
|
+
Presto only uses the month associated with each timestamp, so we take the midpoint
|
|
189
|
+
the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
|
|
190
|
+
time so that start_time == end_time == mid_time.
|
|
191
|
+
"""
|
|
192
|
+
mid_ranges = [t[0] + ((t[1] - t[0]) / 2) for t in time_ranges]
|
|
193
|
+
# months are indexed 0-11
|
|
194
|
+
return torch.tensor(
|
|
195
|
+
[d.month - 1 for d in mid_ranges], dtype=torch.int32, device=device
|
|
196
|
+
)
|
|
197
|
+
|
|
187
198
|
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
188
199
|
"""Compute feature maps from the Presto backbone.
|
|
189
200
|
|
|
@@ -194,17 +205,36 @@ class Presto(FeatureExtractor):
|
|
|
194
205
|
a FeatureMaps with one feature map that is at the same resolution as the
|
|
195
206
|
input (since Presto operates per-pixel).
|
|
196
207
|
"""
|
|
208
|
+
time_modalities = ["s1", "s2", "era5"]
|
|
197
209
|
stacked_inputs = {}
|
|
198
210
|
latlons: torch.Tensor | None = None
|
|
211
|
+
months: torch.Tensor | None = None
|
|
199
212
|
for key in context.inputs[0].keys():
|
|
200
213
|
# assume all the keys in an input are consistent
|
|
201
214
|
if key in self.input_keys:
|
|
202
215
|
if key == "latlon":
|
|
203
|
-
latlons = torch.stack(
|
|
216
|
+
latlons = torch.stack(
|
|
217
|
+
[inp[key].image for inp in context.inputs], dim=0
|
|
218
|
+
)
|
|
204
219
|
else:
|
|
205
220
|
stacked_inputs[key] = torch.stack(
|
|
206
|
-
[inp[key] for inp in context.inputs], dim=0
|
|
221
|
+
[inp[key].image for inp in context.inputs], dim=0
|
|
207
222
|
)
|
|
223
|
+
if key in time_modalities:
|
|
224
|
+
if months is None:
|
|
225
|
+
if context.inputs[0][key].timestamps is not None:
|
|
226
|
+
months = torch.stack(
|
|
227
|
+
[
|
|
228
|
+
self.time_ranges_to_timestamps(
|
|
229
|
+
inp[key].timestamps, # type: ignore
|
|
230
|
+
device=stacked_inputs[key].device,
|
|
231
|
+
)
|
|
232
|
+
for inp in context.inputs
|
|
233
|
+
],
|
|
234
|
+
dim=0,
|
|
235
|
+
)
|
|
236
|
+
if months is not None:
|
|
237
|
+
stacked_inputs["months"] = months
|
|
208
238
|
|
|
209
239
|
(
|
|
210
240
|
x,
|
rslearn/models/prithvi.py
CHANGED
|
@@ -144,13 +144,15 @@ class PrithviV2(FeatureExtractor):
|
|
|
144
144
|
"""Process individual modality data.
|
|
145
145
|
|
|
146
146
|
Args:
|
|
147
|
-
data: Input tensor of shape [B, C, H, W]
|
|
147
|
+
data: Input tensor of shape [B, C, T, H, W]
|
|
148
148
|
|
|
149
149
|
Returns:
|
|
150
|
-
list of tensors of shape [B, C, H, W]
|
|
150
|
+
list of tensors of shape [B, C, T, H, W]
|
|
151
151
|
"""
|
|
152
152
|
# Get original dimensions
|
|
153
|
-
|
|
153
|
+
B, C, T, H, W = data.shape
|
|
154
|
+
data = rearrange(data, "b c t h w -> b (c t) h w")
|
|
155
|
+
original_height = H
|
|
154
156
|
new_height = self.patch_size if original_height == 1 else self.image_resolution
|
|
155
157
|
data = F.interpolate(
|
|
156
158
|
data,
|
|
@@ -158,6 +160,7 @@ class PrithviV2(FeatureExtractor):
|
|
|
158
160
|
mode="bilinear",
|
|
159
161
|
align_corners=False,
|
|
160
162
|
)
|
|
163
|
+
data = rearrange(data, "b (c t) h w -> b c t h w", c=C, t=T)
|
|
161
164
|
return data
|
|
162
165
|
|
|
163
166
|
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
@@ -171,17 +174,16 @@ class PrithviV2(FeatureExtractor):
|
|
|
171
174
|
a FeatureMaps with one map of shape [B, H/p_s, W/p_s, 11*1024] that contains stacked
|
|
172
175
|
feature maps across the 11 transformer blocks.
|
|
173
176
|
"""
|
|
174
|
-
x
|
|
177
|
+
# x has shape BCTHW
|
|
178
|
+
x = torch.stack([inp[self.INPUT_KEY].image for inp in context.inputs], dim=0)
|
|
175
179
|
x = self._resize_data(x)
|
|
176
|
-
num_timesteps = x.shape[1] // len(self.bands)
|
|
177
|
-
x = rearrange(x, "b (t c) h w -> b c t h w", t=num_timesteps)
|
|
178
180
|
features = self.model.encoder.forward_features(x)
|
|
179
181
|
# prepare_features_for_image_model was slightly modified since we already
|
|
180
182
|
# know the number of timesteps and don't need to recompute it.
|
|
181
183
|
# in addition we average along the time dimension (instead of concatenating)
|
|
182
184
|
# to keep the embeddings reasonably sized.
|
|
183
185
|
result = self.model.encoder.prepare_features_for_image_model(
|
|
184
|
-
features,
|
|
186
|
+
features, x.shape[2]
|
|
185
187
|
)
|
|
186
188
|
return FeatureMaps([torch.cat(result, dim=1)])
|
|
187
189
|
|
rslearn/models/sam2_enc.py
CHANGED
|
@@ -95,7 +95,9 @@ class SAM2Encoder(FeatureExtractor):
|
|
|
95
95
|
Returns:
|
|
96
96
|
feature maps from the encoder.
|
|
97
97
|
"""
|
|
98
|
-
images = torch.stack(
|
|
98
|
+
images = torch.stack(
|
|
99
|
+
[inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
|
|
100
|
+
)
|
|
99
101
|
features = self.encoder(images)
|
|
100
102
|
return FeatureMaps(features)
|
|
101
103
|
|
rslearn/models/satlaspretrain.py
CHANGED
|
@@ -76,7 +76,10 @@ class SatlasPretrain(FeatureExtractor):
|
|
|
76
76
|
Returns:
|
|
77
77
|
multi-resolution feature maps computed by the model.
|
|
78
78
|
"""
|
|
79
|
-
|
|
79
|
+
# take the first (assumed to be only) timestep
|
|
80
|
+
images = torch.stack(
|
|
81
|
+
[inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
|
|
82
|
+
)
|
|
80
83
|
feature_maps = self.model(self.maybe_resize(images))
|
|
81
84
|
return FeatureMaps(feature_maps)
|
|
82
85
|
|
|
@@ -3,8 +3,9 @@
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
+
from einops import rearrange
|
|
6
7
|
|
|
7
|
-
from rslearn.train.model_context import ModelContext
|
|
8
|
+
from rslearn.train.model_context import ModelContext, RasterImage
|
|
8
9
|
|
|
9
10
|
from .component import FeatureExtractor, FeatureMaps
|
|
10
11
|
|
|
@@ -163,23 +164,44 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
163
164
|
|
|
164
165
|
def _get_batched_images(
|
|
165
166
|
self, input_dicts: list[dict[str, Any]], image_key: str, image_channels: int
|
|
166
|
-
) ->
|
|
167
|
+
) -> list[RasterImage]:
|
|
167
168
|
"""Collect and reshape images across input dicts.
|
|
168
169
|
|
|
169
170
|
The BTCHW image time series are reshaped to (B*T)CHW so they can be passed to
|
|
170
171
|
the forward pass of a per-image (unitemporal) model.
|
|
171
172
|
"""
|
|
172
173
|
images = torch.stack(
|
|
173
|
-
[input_dict[image_key] for input_dict in input_dicts], dim=0
|
|
174
|
+
[input_dict[image_key].image for input_dict in input_dicts], dim=0
|
|
175
|
+
) # B, C, T, H, W
|
|
176
|
+
timestamps = [input_dict[image_key].timestamps for input_dict in input_dicts]
|
|
177
|
+
# if image channels is not equal to the actual number of channels, then
|
|
178
|
+
# then every N images should be batched together. For example, if the
|
|
179
|
+
# number of input channels c == 2, and image_channels == 4, then we
|
|
180
|
+
# want to pass 2 timesteps to the model.
|
|
181
|
+
# TODO is probably to make this behaviour clearer but lets leave it like
|
|
182
|
+
# this for now to not break things.
|
|
183
|
+
num_timesteps = images.shape[1] // image_channels
|
|
184
|
+
batched_timesteps = images.shape[2] // num_timesteps
|
|
185
|
+
images = rearrange(
|
|
186
|
+
images,
|
|
187
|
+
"b c (b_t k_t) h w -> (b b_t) c k_t h w",
|
|
188
|
+
b_t=batched_timesteps,
|
|
189
|
+
k_t=num_timesteps,
|
|
174
190
|
)
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
191
|
+
if timestamps[0] is None:
|
|
192
|
+
new_timestamps = [None] * images.shape[0]
|
|
193
|
+
else:
|
|
194
|
+
# we also need to split the timestamps
|
|
195
|
+
new_timestamps = []
|
|
196
|
+
for t in timestamps:
|
|
197
|
+
for i in range(batched_timesteps):
|
|
198
|
+
new_timestamps.append(
|
|
199
|
+
t[i * num_timesteps : (i + 1) * num_timesteps]
|
|
200
|
+
)
|
|
201
|
+
return [
|
|
202
|
+
RasterImage(image=image, timestamps=timestamps)
|
|
203
|
+
for image, timestamps in zip(images, new_timestamps)
|
|
204
|
+
] # C, T, H, W
|
|
183
205
|
|
|
184
206
|
def forward(
|
|
185
207
|
self,
|
|
@@ -208,8 +230,8 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
208
230
|
|
|
209
231
|
if batched_inputs is None:
|
|
210
232
|
batched_inputs = [{} for _ in batched_images]
|
|
211
|
-
n_images = batched_images
|
|
212
|
-
elif n_images != batched_images
|
|
233
|
+
n_images = len(batched_images) // n_batch
|
|
234
|
+
elif n_images != len(batched_images) // n_batch:
|
|
213
235
|
raise ValueError(
|
|
214
236
|
"expected all modalities to have the same number of timesteps"
|
|
215
237
|
)
|
|
@@ -223,13 +245,18 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
223
245
|
context.inputs, self.image_key, self.image_channels
|
|
224
246
|
)
|
|
225
247
|
batched_inputs = [{self.image_key: image} for image in batched_images]
|
|
226
|
-
n_images = batched_images
|
|
248
|
+
n_images = len(batched_images) // n_batch
|
|
227
249
|
|
|
228
250
|
assert n_images is not None
|
|
229
|
-
|
|
230
251
|
# Now we can apply the underlying FeatureExtractor.
|
|
231
252
|
# Its output must be a FeatureMaps.
|
|
232
|
-
|
|
253
|
+
assert batched_inputs is not None
|
|
254
|
+
encoder_output = self.encoder(
|
|
255
|
+
ModelContext(
|
|
256
|
+
inputs=batched_inputs,
|
|
257
|
+
metadatas=context.metadatas,
|
|
258
|
+
)
|
|
259
|
+
)
|
|
233
260
|
if not isinstance(encoder_output, FeatureMaps):
|
|
234
261
|
raise ValueError(
|
|
235
262
|
"output of underlying FeatureExtractor in SimpleTimeSeries must be a FeatureMaps"
|
|
@@ -244,7 +271,6 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
244
271
|
)
|
|
245
272
|
for feat_map in encoder_output.feature_maps
|
|
246
273
|
]
|
|
247
|
-
|
|
248
274
|
# Groups defaults to flattening all the feature maps.
|
|
249
275
|
groups = self.groups
|
|
250
276
|
if not groups:
|
rslearn/models/ssl4eo_s12.py
CHANGED
|
@@ -13,7 +13,7 @@ class Ssl4eoS12(FeatureExtractor):
|
|
|
13
13
|
|
|
14
14
|
def __init__(
|
|
15
15
|
self,
|
|
16
|
-
backbone_ckpt_path: str,
|
|
16
|
+
backbone_ckpt_path: str | None,
|
|
17
17
|
arch: str = "resnet50",
|
|
18
18
|
output_layers: list[int] = [0, 1, 2, 3],
|
|
19
19
|
) -> None:
|
|
@@ -39,19 +39,22 @@ class Ssl4eoS12(FeatureExtractor):
|
|
|
39
39
|
else:
|
|
40
40
|
raise ValueError(f"unknown SSL4EO-S12 architecture {arch}")
|
|
41
41
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
f"warning: got missing_keys={missing_keys}, unexpected_keys={unexpected_keys} when loading SSL4EO-S12 state dict"
|
|
42
|
+
if backbone_ckpt_path is not None:
|
|
43
|
+
state_dict = torch.load(backbone_ckpt_path, weights_only=True)
|
|
44
|
+
state_dict = state_dict["teacher"]
|
|
45
|
+
prefix = "module.backbone."
|
|
46
|
+
state_dict = {
|
|
47
|
+
k[len(prefix) :]: v
|
|
48
|
+
for k, v in state_dict.items()
|
|
49
|
+
if k.startswith(prefix)
|
|
50
|
+
}
|
|
51
|
+
missing_keys, unexpected_keys = self.model.load_state_dict(
|
|
52
|
+
state_dict, strict=False
|
|
54
53
|
)
|
|
54
|
+
if missing_keys or unexpected_keys:
|
|
55
|
+
print(
|
|
56
|
+
f"warning: got missing_keys={missing_keys}, unexpected_keys={unexpected_keys} when loading SSL4EO-S12 state dict"
|
|
57
|
+
)
|
|
55
58
|
|
|
56
59
|
def get_backbone_channels(self) -> list[tuple[int, int]]:
|
|
57
60
|
"""Returns the output channels of this model when used as a backbone.
|
|
@@ -91,7 +94,9 @@ class Ssl4eoS12(FeatureExtractor):
|
|
|
91
94
|
Returns:
|
|
92
95
|
feature maps computed by the pre-trained model.
|
|
93
96
|
"""
|
|
94
|
-
x = torch.stack(
|
|
97
|
+
x = torch.stack(
|
|
98
|
+
[inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
|
|
99
|
+
)
|
|
95
100
|
x = self.model.conv1(x)
|
|
96
101
|
x = self.model.bn1(x)
|
|
97
102
|
x = self.model.relu(x)
|
rslearn/models/swin.py
CHANGED
|
@@ -151,7 +151,9 @@ class Swin(FeatureExtractor):
|
|
|
151
151
|
a FeatureVector if the configured output_layers is None, or a FeatureMaps
|
|
152
152
|
otherwise containing one feature map per configured output layer.
|
|
153
153
|
"""
|
|
154
|
-
images = torch.stack(
|
|
154
|
+
images = torch.stack(
|
|
155
|
+
[inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
|
|
156
|
+
)
|
|
155
157
|
|
|
156
158
|
if self.output_layers:
|
|
157
159
|
layer_features = []
|
rslearn/models/terramind.py
CHANGED
|
@@ -143,7 +143,8 @@ class Terramind(FeatureExtractor):
|
|
|
143
143
|
if modality not in context.inputs[0]:
|
|
144
144
|
continue
|
|
145
145
|
cur = torch.stack(
|
|
146
|
-
[inp[modality] for inp in context.inputs],
|
|
146
|
+
[inp[modality].single_ts_to_chw_tensor() for inp in context.inputs],
|
|
147
|
+
dim=0,
|
|
147
148
|
) # (B, C, H, W)
|
|
148
149
|
if self.do_resizing and (
|
|
149
150
|
cur.shape[2] != IMAGE_SIZE or cur.shape[3] != IMAGE_SIZE
|
|
@@ -219,7 +220,7 @@ class TerramindNormalize(Transform):
|
|
|
219
220
|
Returns:
|
|
220
221
|
The normalized image.
|
|
221
222
|
"""
|
|
222
|
-
images = image.float() # (C, H, W)
|
|
223
|
+
images = image.float() # (C, 1, H, W)
|
|
223
224
|
if images.shape[0] % len(means) != 0:
|
|
224
225
|
raise ValueError(
|
|
225
226
|
f"the number of image channels {images.shape[0]} is not multiple of expected number of bands {len(means)}"
|
|
@@ -247,8 +248,8 @@ class TerramindNormalize(Transform):
|
|
|
247
248
|
band_info = PRETRAINED_BANDS[modality]
|
|
248
249
|
means = [band_info[band][0] for band in band_info]
|
|
249
250
|
stds = [band_info[band][1] for band in band_info]
|
|
250
|
-
input_dict[modality] = self.apply_image(
|
|
251
|
-
input_dict[modality],
|
|
251
|
+
input_dict[modality].image = self.apply_image(
|
|
252
|
+
input_dict[modality].image,
|
|
252
253
|
means,
|
|
253
254
|
stds,
|
|
254
255
|
)
|
|
@@ -9,8 +9,8 @@ import shapely
|
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
11
|
from rslearn.dataset import Window
|
|
12
|
-
from rslearn.train.dataset import ModelDataset
|
|
13
|
-
from rslearn.train.model_context import SampleMetadata
|
|
12
|
+
from rslearn.train.dataset import DataInput, ModelDataset
|
|
13
|
+
from rslearn.train.model_context import RasterImage, SampleMetadata
|
|
14
14
|
from rslearn.utils.geometry import PixelBounds, STGeometry
|
|
15
15
|
|
|
16
16
|
|
|
@@ -34,22 +34,28 @@ def get_window_patch_options(
|
|
|
34
34
|
bottommost patches may extend beyond the provided bounds.
|
|
35
35
|
"""
|
|
36
36
|
# We stride the patches by patch_size - overlap_size until the last patch.
|
|
37
|
+
# We handle the first patch with a special case to ensure it is always used.
|
|
37
38
|
# We handle the last patch with a special case to ensure it does not exceed the
|
|
38
39
|
# window bounds. Instead, it may overlap the previous patch.
|
|
39
|
-
cols = list(
|
|
40
|
+
cols = [bounds[0]] + list(
|
|
40
41
|
range(
|
|
41
|
-
bounds[0],
|
|
42
|
+
bounds[0] + patch_size[0],
|
|
42
43
|
bounds[2] - patch_size[0],
|
|
43
44
|
patch_size[0] - overlap_size[0],
|
|
44
45
|
)
|
|
45
|
-
)
|
|
46
|
-
rows = list(
|
|
46
|
+
)
|
|
47
|
+
rows = [bounds[1]] + list(
|
|
47
48
|
range(
|
|
48
|
-
bounds[1],
|
|
49
|
+
bounds[1] + patch_size[1],
|
|
49
50
|
bounds[3] - patch_size[1],
|
|
50
51
|
patch_size[1] - overlap_size[1],
|
|
51
52
|
)
|
|
52
|
-
)
|
|
53
|
+
)
|
|
54
|
+
# Add last patches only if the input is larger than one patch.
|
|
55
|
+
if bounds[2] - patch_size[0] > bounds[0]:
|
|
56
|
+
cols.append(bounds[2] - patch_size[0])
|
|
57
|
+
if bounds[3] - patch_size[1] > bounds[1]:
|
|
58
|
+
rows.append(bounds[3] - patch_size[1])
|
|
53
59
|
|
|
54
60
|
patch_bounds: list[PixelBounds] = []
|
|
55
61
|
for col in cols:
|
|
@@ -62,13 +68,17 @@ def pad_slice_protect(
|
|
|
62
68
|
raw_inputs: dict[str, Any],
|
|
63
69
|
passthrough_inputs: dict[str, Any],
|
|
64
70
|
patch_size: tuple[int, int],
|
|
71
|
+
inputs: dict[str, DataInput],
|
|
65
72
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
66
73
|
"""Pad tensors in-place by patch size to protect slicing near right/bottom edges.
|
|
67
74
|
|
|
75
|
+
The padding is scaled based on each input's resolution_factor.
|
|
76
|
+
|
|
68
77
|
Args:
|
|
69
78
|
raw_inputs: the raw inputs to pad.
|
|
70
79
|
passthrough_inputs: the passthrough inputs to pad.
|
|
71
|
-
patch_size: the size of the patches to extract.
|
|
80
|
+
patch_size: the size of the patches to extract (at window resolution).
|
|
81
|
+
inputs: the DataInput definitions, used to get resolution_factor per input.
|
|
72
82
|
|
|
73
83
|
Returns:
|
|
74
84
|
a tuple of (raw_inputs, passthrough_inputs).
|
|
@@ -77,12 +87,42 @@ def pad_slice_protect(
|
|
|
77
87
|
for input_name, value in list(d.items()):
|
|
78
88
|
if not isinstance(value, torch.Tensor):
|
|
79
89
|
continue
|
|
90
|
+
# Get resolution scale for this input
|
|
91
|
+
rf = inputs[input_name].resolution_factor
|
|
92
|
+
scale = rf.numerator / rf.denominator
|
|
93
|
+
# Scale the padding amount
|
|
94
|
+
scaled_pad_x = int(patch_size[0] * scale)
|
|
95
|
+
scaled_pad_y = int(patch_size[1] * scale)
|
|
80
96
|
d[input_name] = torch.nn.functional.pad(
|
|
81
|
-
value, pad=(0,
|
|
97
|
+
value, pad=(0, scaled_pad_x, 0, scaled_pad_y)
|
|
82
98
|
)
|
|
83
99
|
return raw_inputs, passthrough_inputs
|
|
84
100
|
|
|
85
101
|
|
|
102
|
+
def crop_tensor_or_rasterimage(
|
|
103
|
+
x: torch.Tensor | RasterImage, start: tuple[int, int], end: tuple[int, int]
|
|
104
|
+
) -> torch.Tensor | RasterImage:
|
|
105
|
+
"""Crop a tensor or a RasterImage."""
|
|
106
|
+
if isinstance(x, torch.Tensor):
|
|
107
|
+
# Crop the CHW tensor with scaled coordinates.
|
|
108
|
+
return x[
|
|
109
|
+
:,
|
|
110
|
+
start[1] : end[1],
|
|
111
|
+
start[0] : end[0],
|
|
112
|
+
].clone()
|
|
113
|
+
else:
|
|
114
|
+
# Crop the CTHW tensor with scaled coordinates.
|
|
115
|
+
return RasterImage(
|
|
116
|
+
x.image[
|
|
117
|
+
:,
|
|
118
|
+
:,
|
|
119
|
+
start[1] : end[1],
|
|
120
|
+
start[0] : end[0],
|
|
121
|
+
].clone(),
|
|
122
|
+
x.timestamps,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
|
|
86
126
|
class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
87
127
|
"""This wraps a ModelDataset to iterate over all patches in that dataset.
|
|
88
128
|
|
|
@@ -123,6 +163,7 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
123
163
|
self.rank = rank
|
|
124
164
|
self.world_size = world_size
|
|
125
165
|
self.windows = self.dataset.get_dataset_examples()
|
|
166
|
+
self.inputs = dataset.inputs
|
|
126
167
|
|
|
127
168
|
def set_name(self, name: str) -> None:
|
|
128
169
|
"""Sets dataset name.
|
|
@@ -235,8 +276,10 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
235
276
|
|
|
236
277
|
# For simplicity, pad tensors by patch size to ensure that any patch bounds
|
|
237
278
|
# extending outside the window bounds will not have issues when we slice
|
|
238
|
-
# the tensors later.
|
|
239
|
-
pad_slice_protect(
|
|
279
|
+
# the tensors later. Padding is scaled per-input based on resolution_factor.
|
|
280
|
+
pad_slice_protect(
|
|
281
|
+
raw_inputs, passthrough_inputs, self.patch_size, self.inputs
|
|
282
|
+
)
|
|
240
283
|
|
|
241
284
|
# Now iterate over the patches and extract/yield the crops.
|
|
242
285
|
# Note that, in case user is leveraging RslearnWriter, it is important that
|
|
@@ -258,16 +301,26 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
258
301
|
)
|
|
259
302
|
|
|
260
303
|
# Define a helper function to handle each input dict.
|
|
304
|
+
# Crop coordinates are scaled based on each input's resolution_factor.
|
|
261
305
|
def crop_input_dict(d: dict[str, Any]) -> dict[str, Any]:
|
|
262
306
|
cropped = {}
|
|
263
307
|
for input_name, value in d.items():
|
|
264
|
-
if isinstance(value, torch.Tensor):
|
|
265
|
-
#
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
308
|
+
if isinstance(value, torch.Tensor | RasterImage):
|
|
309
|
+
# Get resolution scale for this input
|
|
310
|
+
rf = self.inputs[input_name].resolution_factor
|
|
311
|
+
scale = rf.numerator / rf.denominator
|
|
312
|
+
# Scale the crop coordinates
|
|
313
|
+
scaled_start = (
|
|
314
|
+
int(start_offset[0] * scale),
|
|
315
|
+
int(start_offset[1] * scale),
|
|
316
|
+
)
|
|
317
|
+
scaled_end = (
|
|
318
|
+
int(end_offset[0] * scale),
|
|
319
|
+
int(end_offset[1] * scale),
|
|
320
|
+
)
|
|
321
|
+
cropped[input_name] = crop_tensor_or_rasterimage(
|
|
322
|
+
value, scaled_start, scaled_end
|
|
323
|
+
)
|
|
271
324
|
elif isinstance(value, list):
|
|
272
325
|
cropped[input_name] = [
|
|
273
326
|
feat
|
|
@@ -348,6 +401,7 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
|
348
401
|
round(self.patch_size[1] * overlap_ratio),
|
|
349
402
|
)
|
|
350
403
|
self.windows = self.dataset.get_dataset_examples()
|
|
404
|
+
self.inputs = dataset.inputs
|
|
351
405
|
self.window_cache: dict[
|
|
352
406
|
int, tuple[dict[str, Any], dict[str, Any], SampleMetadata]
|
|
353
407
|
] = {}
|
|
@@ -378,27 +432,41 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
|
378
432
|
return self.window_cache[index]
|
|
379
433
|
|
|
380
434
|
raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(index)
|
|
381
|
-
pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size)
|
|
435
|
+
pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size, self.inputs)
|
|
382
436
|
|
|
383
437
|
self.window_cache[index] = (raw_inputs, passthrough_inputs, metadata)
|
|
384
438
|
return self.window_cache[index]
|
|
385
439
|
|
|
386
|
-
@staticmethod
|
|
387
440
|
def _crop_input_dict(
|
|
441
|
+
self,
|
|
388
442
|
d: dict[str, Any],
|
|
389
443
|
start_offset: tuple[int, int],
|
|
390
444
|
end_offset: tuple[int, int],
|
|
391
445
|
cur_geom: STGeometry,
|
|
392
446
|
) -> dict[str, Any]:
|
|
393
|
-
"""Crop a dictionary of inputs to the given bounds.
|
|
447
|
+
"""Crop a dictionary of inputs to the given bounds.
|
|
448
|
+
|
|
449
|
+
Crop coordinates are scaled based on each input's resolution_factor.
|
|
450
|
+
"""
|
|
394
451
|
cropped = {}
|
|
395
452
|
for input_name, value in d.items():
|
|
396
|
-
if isinstance(value, torch.Tensor):
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
453
|
+
if isinstance(value, torch.Tensor | RasterImage):
|
|
454
|
+
# Get resolution scale for this input
|
|
455
|
+
rf = self.inputs[input_name].resolution_factor
|
|
456
|
+
scale = rf.numerator / rf.denominator
|
|
457
|
+
# Scale the crop coordinates
|
|
458
|
+
scaled_start = (
|
|
459
|
+
int(start_offset[0] * scale),
|
|
460
|
+
int(start_offset[1] * scale),
|
|
461
|
+
)
|
|
462
|
+
scaled_end = (
|
|
463
|
+
int(end_offset[0] * scale),
|
|
464
|
+
int(end_offset[1] * scale),
|
|
465
|
+
)
|
|
466
|
+
cropped[input_name] = crop_tensor_or_rasterimage(
|
|
467
|
+
value, scaled_start, scaled_end
|
|
468
|
+
)
|
|
469
|
+
|
|
402
470
|
elif isinstance(value, list):
|
|
403
471
|
cropped[input_name] = [
|
|
404
472
|
feat for feat in value if cur_geom.intersects(feat.geometry)
|