rslearn 0.0.19__py3-none-any.whl → 0.0.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/models/anysat.py +35 -33
- rslearn/models/clip.py +5 -2
- rslearn/models/croma.py +11 -3
- rslearn/models/dinov3.py +2 -1
- rslearn/models/faster_rcnn.py +2 -1
- rslearn/models/galileo/galileo.py +58 -31
- rslearn/models/module_wrapper.py +6 -1
- rslearn/models/molmo.py +4 -2
- rslearn/models/olmoearth_pretrain/model.py +93 -29
- rslearn/models/olmoearth_pretrain/norm.py +5 -3
- rslearn/models/panopticon.py +3 -1
- rslearn/models/presto/presto.py +45 -15
- rslearn/models/prithvi.py +9 -7
- rslearn/models/sam2_enc.py +3 -1
- rslearn/models/satlaspretrain.py +4 -1
- rslearn/models/simple_time_series.py +36 -16
- rslearn/models/ssl4eo_s12.py +19 -14
- rslearn/models/swin.py +3 -1
- rslearn/models/terramind.py +5 -4
- rslearn/train/all_patches_dataset.py +34 -14
- rslearn/train/dataset.py +66 -10
- rslearn/train/model_context.py +35 -1
- rslearn/train/tasks/classification.py +8 -2
- rslearn/train/tasks/detection.py +3 -2
- rslearn/train/tasks/multi_task.py +2 -3
- rslearn/train/tasks/per_pixel_regression.py +14 -5
- rslearn/train/tasks/regression.py +8 -2
- rslearn/train/tasks/segmentation.py +13 -4
- rslearn/train/tasks/task.py +2 -2
- rslearn/train/transforms/concatenate.py +45 -5
- rslearn/train/transforms/crop.py +22 -8
- rslearn/train/transforms/flip.py +13 -5
- rslearn/train/transforms/mask.py +11 -2
- rslearn/train/transforms/normalize.py +46 -15
- rslearn/train/transforms/pad.py +15 -3
- rslearn/train/transforms/resize.py +18 -9
- rslearn/train/transforms/select_bands.py +11 -2
- rslearn/train/transforms/sentinel1.py +18 -3
- {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/METADATA +1 -1
- {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/RECORD +45 -45
- {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/WHEEL +0 -0
- {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.19.dist-info → rslearn-0.0.20.dist-info}/top_level.txt +0 -0
rslearn/models/presto/presto.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import tempfile
|
|
5
|
+
from datetime import datetime
|
|
5
6
|
|
|
6
7
|
import torch
|
|
7
8
|
from einops import rearrange, repeat
|
|
@@ -118,21 +119,21 @@ class Presto(FeatureExtractor):
|
|
|
118
119
|
of each timestep for that pixel
|
|
119
120
|
"""
|
|
120
121
|
bs = [x.shape[0] for x in [s1, s2, era5, srtm] if x is not None]
|
|
121
|
-
|
|
122
|
-
|
|
122
|
+
ts = [x.shape[2] for x in [s1, s2, era5, srtm] if x is not None]
|
|
123
|
+
hs = [x.shape[3] for x in [s1, s2, era5, srtm] if x is not None]
|
|
124
|
+
ws = [x.shape[4] for x in [s1, s2, era5, srtm] if x is not None]
|
|
123
125
|
devices = [x.device for x in [s1, s2, era5, srtm] if x is not None]
|
|
124
126
|
|
|
125
127
|
assert len(set(bs)) == 1
|
|
126
128
|
assert len(set(hs)) == 1
|
|
127
129
|
assert len(set(ws)) == 1
|
|
128
130
|
assert len(set(devices)) == 1
|
|
129
|
-
|
|
130
|
-
|
|
131
|
+
assert len(set(ts)) == 1
|
|
132
|
+
b, h, w, t, device = bs[0], hs[0], ws[0], ts[0], devices[0]
|
|
131
133
|
# these values will be initialized as
|
|
132
134
|
# we iterate through the data
|
|
133
135
|
x: torch.Tensor | None = None
|
|
134
136
|
mask: torch.Tensor | None = None
|
|
135
|
-
t: int | None = None
|
|
136
137
|
|
|
137
138
|
for band_group in [
|
|
138
139
|
(s1, s1_bands),
|
|
@@ -146,14 +147,7 @@ class Presto(FeatureExtractor):
|
|
|
146
147
|
else:
|
|
147
148
|
continue
|
|
148
149
|
|
|
149
|
-
|
|
150
|
-
if t is None:
|
|
151
|
-
t = m_t
|
|
152
|
-
else:
|
|
153
|
-
if t != m_t:
|
|
154
|
-
raise ValueError("inconsistent values for t")
|
|
155
|
-
|
|
156
|
-
data = rearrange(data, "b (t c) h w -> b t h w c", t=m_t)
|
|
150
|
+
data = rearrange(data, "b c t h w -> b t h w c")
|
|
157
151
|
if x is None:
|
|
158
152
|
x = torch.zeros(b, t, h, w, len(INPUT_PRESTO_BANDS), device=device)
|
|
159
153
|
if mask is None:
|
|
@@ -184,6 +178,23 @@ class Presto(FeatureExtractor):
|
|
|
184
178
|
x = (x + PRESTO_ADD_BY.to(device=device)) / PRESTO_DIV_BY.to(device=device)
|
|
185
179
|
return x, mask, dynamic_world.long(), months.long()
|
|
186
180
|
|
|
181
|
+
@staticmethod
|
|
182
|
+
def time_ranges_to_timestamps(
|
|
183
|
+
time_ranges: list[tuple[datetime, datetime]],
|
|
184
|
+
device: torch.device,
|
|
185
|
+
) -> torch.Tensor:
|
|
186
|
+
"""Turn the time ranges stored in a RasterImage to timestamps accepted by Presto.
|
|
187
|
+
|
|
188
|
+
Presto only uses the month associated with each timestamp, so we take the midpoint
|
|
189
|
+
the time range. For some inputs (e.g. Sentinel 2) we take an image from a specific
|
|
190
|
+
time so that start_time == end_time == mid_time.
|
|
191
|
+
"""
|
|
192
|
+
mid_ranges = [t[0] + ((t[1] - t[0]) / 2) for t in time_ranges]
|
|
193
|
+
# months are indexed 0-11
|
|
194
|
+
return torch.tensor(
|
|
195
|
+
[d.month - 1 for d in mid_ranges], dtype=torch.int32, device=device
|
|
196
|
+
)
|
|
197
|
+
|
|
187
198
|
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
188
199
|
"""Compute feature maps from the Presto backbone.
|
|
189
200
|
|
|
@@ -194,17 +205,36 @@ class Presto(FeatureExtractor):
|
|
|
194
205
|
a FeatureMaps with one feature map that is at the same resolution as the
|
|
195
206
|
input (since Presto operates per-pixel).
|
|
196
207
|
"""
|
|
208
|
+
time_modalities = ["s1", "s2", "era5"]
|
|
197
209
|
stacked_inputs = {}
|
|
198
210
|
latlons: torch.Tensor | None = None
|
|
211
|
+
months: torch.Tensor | None = None
|
|
199
212
|
for key in context.inputs[0].keys():
|
|
200
213
|
# assume all the keys in an input are consistent
|
|
201
214
|
if key in self.input_keys:
|
|
202
215
|
if key == "latlon":
|
|
203
|
-
latlons = torch.stack(
|
|
216
|
+
latlons = torch.stack(
|
|
217
|
+
[inp[key].image for inp in context.inputs], dim=0
|
|
218
|
+
)
|
|
204
219
|
else:
|
|
205
220
|
stacked_inputs[key] = torch.stack(
|
|
206
|
-
[inp[key] for inp in context.inputs], dim=0
|
|
221
|
+
[inp[key].image for inp in context.inputs], dim=0
|
|
207
222
|
)
|
|
223
|
+
if key in time_modalities:
|
|
224
|
+
if months is None:
|
|
225
|
+
if context.inputs[0][key].timestamps is not None:
|
|
226
|
+
months = torch.stack(
|
|
227
|
+
[
|
|
228
|
+
self.time_ranges_to_timestamps(
|
|
229
|
+
inp[key].timestamps, # type: ignore
|
|
230
|
+
device=stacked_inputs[key].device,
|
|
231
|
+
)
|
|
232
|
+
for inp in context.inputs
|
|
233
|
+
],
|
|
234
|
+
dim=0,
|
|
235
|
+
)
|
|
236
|
+
if months is not None:
|
|
237
|
+
stacked_inputs["months"] = months
|
|
208
238
|
|
|
209
239
|
(
|
|
210
240
|
x,
|
rslearn/models/prithvi.py
CHANGED
|
@@ -144,13 +144,15 @@ class PrithviV2(FeatureExtractor):
|
|
|
144
144
|
"""Process individual modality data.
|
|
145
145
|
|
|
146
146
|
Args:
|
|
147
|
-
data: Input tensor of shape [B, C, H, W]
|
|
147
|
+
data: Input tensor of shape [B, C, T, H, W]
|
|
148
148
|
|
|
149
149
|
Returns:
|
|
150
|
-
list of tensors of shape [B, C, H, W]
|
|
150
|
+
list of tensors of shape [B, C, T, H, W]
|
|
151
151
|
"""
|
|
152
152
|
# Get original dimensions
|
|
153
|
-
|
|
153
|
+
B, C, T, H, W = data.shape
|
|
154
|
+
data = rearrange(data, "b c t h w -> b (c t) h w")
|
|
155
|
+
original_height = H
|
|
154
156
|
new_height = self.patch_size if original_height == 1 else self.image_resolution
|
|
155
157
|
data = F.interpolate(
|
|
156
158
|
data,
|
|
@@ -158,6 +160,7 @@ class PrithviV2(FeatureExtractor):
|
|
|
158
160
|
mode="bilinear",
|
|
159
161
|
align_corners=False,
|
|
160
162
|
)
|
|
163
|
+
data = rearrange(data, "b (c t) h w -> b c t h w", c=C, t=T)
|
|
161
164
|
return data
|
|
162
165
|
|
|
163
166
|
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
@@ -171,17 +174,16 @@ class PrithviV2(FeatureExtractor):
|
|
|
171
174
|
a FeatureMaps with one map of shape [B, H/p_s, W/p_s, 11*1024] that contains stacked
|
|
172
175
|
feature maps across the 11 transformer blocks.
|
|
173
176
|
"""
|
|
174
|
-
x
|
|
177
|
+
# x has shape BCTHW
|
|
178
|
+
x = torch.stack([inp[self.INPUT_KEY].image for inp in context.inputs], dim=0)
|
|
175
179
|
x = self._resize_data(x)
|
|
176
|
-
num_timesteps = x.shape[1] // len(self.bands)
|
|
177
|
-
x = rearrange(x, "b (t c) h w -> b c t h w", t=num_timesteps)
|
|
178
180
|
features = self.model.encoder.forward_features(x)
|
|
179
181
|
# prepare_features_for_image_model was slightly modified since we already
|
|
180
182
|
# know the number of timesteps and don't need to recompute it.
|
|
181
183
|
# in addition we average along the time dimension (instead of concatenating)
|
|
182
184
|
# to keep the embeddings reasonably sized.
|
|
183
185
|
result = self.model.encoder.prepare_features_for_image_model(
|
|
184
|
-
features,
|
|
186
|
+
features, x.shape[2]
|
|
185
187
|
)
|
|
186
188
|
return FeatureMaps([torch.cat(result, dim=1)])
|
|
187
189
|
|
rslearn/models/sam2_enc.py
CHANGED
|
@@ -95,7 +95,9 @@ class SAM2Encoder(FeatureExtractor):
|
|
|
95
95
|
Returns:
|
|
96
96
|
feature maps from the encoder.
|
|
97
97
|
"""
|
|
98
|
-
images = torch.stack(
|
|
98
|
+
images = torch.stack(
|
|
99
|
+
[inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
|
|
100
|
+
)
|
|
99
101
|
features = self.encoder(images)
|
|
100
102
|
return FeatureMaps(features)
|
|
101
103
|
|
rslearn/models/satlaspretrain.py
CHANGED
|
@@ -76,7 +76,10 @@ class SatlasPretrain(FeatureExtractor):
|
|
|
76
76
|
Returns:
|
|
77
77
|
multi-resolution feature maps computed by the model.
|
|
78
78
|
"""
|
|
79
|
-
|
|
79
|
+
# take the first (assumed to be only) timestep
|
|
80
|
+
images = torch.stack(
|
|
81
|
+
[inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
|
|
82
|
+
)
|
|
80
83
|
feature_maps = self.model(self.maybe_resize(images))
|
|
81
84
|
return FeatureMaps(feature_maps)
|
|
82
85
|
|
|
@@ -3,8 +3,9 @@
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
+
from einops import rearrange
|
|
6
7
|
|
|
7
|
-
from rslearn.train.model_context import ModelContext
|
|
8
|
+
from rslearn.train.model_context import ModelContext, RasterImage
|
|
8
9
|
|
|
9
10
|
from .component import FeatureExtractor, FeatureMaps
|
|
10
11
|
|
|
@@ -163,23 +164,44 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
163
164
|
|
|
164
165
|
def _get_batched_images(
|
|
165
166
|
self, input_dicts: list[dict[str, Any]], image_key: str, image_channels: int
|
|
166
|
-
) ->
|
|
167
|
+
) -> list[RasterImage]:
|
|
167
168
|
"""Collect and reshape images across input dicts.
|
|
168
169
|
|
|
169
170
|
The BTCHW image time series are reshaped to (B*T)CHW so they can be passed to
|
|
170
171
|
the forward pass of a per-image (unitemporal) model.
|
|
171
172
|
"""
|
|
172
173
|
images = torch.stack(
|
|
173
|
-
[input_dict[image_key] for input_dict in input_dicts], dim=0
|
|
174
|
+
[input_dict[image_key].image for input_dict in input_dicts], dim=0
|
|
175
|
+
) # B, C, T, H, W
|
|
176
|
+
timestamps = [input_dict[image_key].timestamps for input_dict in input_dicts]
|
|
177
|
+
# if image channels is not equal to the actual number of channels, then
|
|
178
|
+
# then every N images should be batched together. For example, if the
|
|
179
|
+
# number of input channels c == 2, and image_channels == 4, then we
|
|
180
|
+
# want to pass 2 timesteps to the model.
|
|
181
|
+
# TODO is probably to make this behaviour clearer but lets leave it like
|
|
182
|
+
# this for now to not break things.
|
|
183
|
+
num_timesteps = images.shape[1] // image_channels
|
|
184
|
+
batched_timesteps = images.shape[2] // num_timesteps
|
|
185
|
+
images = rearrange(
|
|
186
|
+
images,
|
|
187
|
+
"b c (b_t k_t) h w -> (b b_t) c k_t h w",
|
|
188
|
+
b_t=batched_timesteps,
|
|
189
|
+
k_t=num_timesteps,
|
|
174
190
|
)
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
191
|
+
if timestamps[0] is None:
|
|
192
|
+
new_timestamps = [None] * images.shape[0]
|
|
193
|
+
else:
|
|
194
|
+
# we also need to split the timestamps
|
|
195
|
+
new_timestamps = []
|
|
196
|
+
for t in timestamps:
|
|
197
|
+
for i in range(batched_timesteps):
|
|
198
|
+
new_timestamps.append(
|
|
199
|
+
t[i * num_timesteps : (i + 1) * num_timesteps]
|
|
200
|
+
)
|
|
201
|
+
return [
|
|
202
|
+
RasterImage(image=image, timestamps=timestamps)
|
|
203
|
+
for image, timestamps in zip(images, new_timestamps)
|
|
204
|
+
] # C, T, H, W
|
|
183
205
|
|
|
184
206
|
def forward(
|
|
185
207
|
self,
|
|
@@ -208,8 +230,8 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
208
230
|
|
|
209
231
|
if batched_inputs is None:
|
|
210
232
|
batched_inputs = [{} for _ in batched_images]
|
|
211
|
-
n_images = batched_images
|
|
212
|
-
elif n_images != batched_images
|
|
233
|
+
n_images = len(batched_images) // n_batch
|
|
234
|
+
elif n_images != len(batched_images) // n_batch:
|
|
213
235
|
raise ValueError(
|
|
214
236
|
"expected all modalities to have the same number of timesteps"
|
|
215
237
|
)
|
|
@@ -223,10 +245,9 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
223
245
|
context.inputs, self.image_key, self.image_channels
|
|
224
246
|
)
|
|
225
247
|
batched_inputs = [{self.image_key: image} for image in batched_images]
|
|
226
|
-
n_images = batched_images
|
|
248
|
+
n_images = len(batched_images) // n_batch
|
|
227
249
|
|
|
228
250
|
assert n_images is not None
|
|
229
|
-
|
|
230
251
|
# Now we can apply the underlying FeatureExtractor.
|
|
231
252
|
# Its output must be a FeatureMaps.
|
|
232
253
|
assert batched_inputs is not None
|
|
@@ -250,7 +271,6 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
250
271
|
)
|
|
251
272
|
for feat_map in encoder_output.feature_maps
|
|
252
273
|
]
|
|
253
|
-
|
|
254
274
|
# Groups defaults to flattening all the feature maps.
|
|
255
275
|
groups = self.groups
|
|
256
276
|
if not groups:
|
rslearn/models/ssl4eo_s12.py
CHANGED
|
@@ -13,7 +13,7 @@ class Ssl4eoS12(FeatureExtractor):
|
|
|
13
13
|
|
|
14
14
|
def __init__(
|
|
15
15
|
self,
|
|
16
|
-
backbone_ckpt_path: str,
|
|
16
|
+
backbone_ckpt_path: str | None,
|
|
17
17
|
arch: str = "resnet50",
|
|
18
18
|
output_layers: list[int] = [0, 1, 2, 3],
|
|
19
19
|
) -> None:
|
|
@@ -39,19 +39,22 @@ class Ssl4eoS12(FeatureExtractor):
|
|
|
39
39
|
else:
|
|
40
40
|
raise ValueError(f"unknown SSL4EO-S12 architecture {arch}")
|
|
41
41
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
f"warning: got missing_keys={missing_keys}, unexpected_keys={unexpected_keys} when loading SSL4EO-S12 state dict"
|
|
42
|
+
if backbone_ckpt_path is not None:
|
|
43
|
+
state_dict = torch.load(backbone_ckpt_path, weights_only=True)
|
|
44
|
+
state_dict = state_dict["teacher"]
|
|
45
|
+
prefix = "module.backbone."
|
|
46
|
+
state_dict = {
|
|
47
|
+
k[len(prefix) :]: v
|
|
48
|
+
for k, v in state_dict.items()
|
|
49
|
+
if k.startswith(prefix)
|
|
50
|
+
}
|
|
51
|
+
missing_keys, unexpected_keys = self.model.load_state_dict(
|
|
52
|
+
state_dict, strict=False
|
|
54
53
|
)
|
|
54
|
+
if missing_keys or unexpected_keys:
|
|
55
|
+
print(
|
|
56
|
+
f"warning: got missing_keys={missing_keys}, unexpected_keys={unexpected_keys} when loading SSL4EO-S12 state dict"
|
|
57
|
+
)
|
|
55
58
|
|
|
56
59
|
def get_backbone_channels(self) -> list[tuple[int, int]]:
|
|
57
60
|
"""Returns the output channels of this model when used as a backbone.
|
|
@@ -91,7 +94,9 @@ class Ssl4eoS12(FeatureExtractor):
|
|
|
91
94
|
Returns:
|
|
92
95
|
feature maps computed by the pre-trained model.
|
|
93
96
|
"""
|
|
94
|
-
x = torch.stack(
|
|
97
|
+
x = torch.stack(
|
|
98
|
+
[inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
|
|
99
|
+
)
|
|
95
100
|
x = self.model.conv1(x)
|
|
96
101
|
x = self.model.bn1(x)
|
|
97
102
|
x = self.model.relu(x)
|
rslearn/models/swin.py
CHANGED
|
@@ -151,7 +151,9 @@ class Swin(FeatureExtractor):
|
|
|
151
151
|
a FeatureVector if the configured output_layers is None, or a FeatureMaps
|
|
152
152
|
otherwise containing one feature map per configured output layer.
|
|
153
153
|
"""
|
|
154
|
-
images = torch.stack(
|
|
154
|
+
images = torch.stack(
|
|
155
|
+
[inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
|
|
156
|
+
)
|
|
155
157
|
|
|
156
158
|
if self.output_layers:
|
|
157
159
|
layer_features = []
|
rslearn/models/terramind.py
CHANGED
|
@@ -143,7 +143,8 @@ class Terramind(FeatureExtractor):
|
|
|
143
143
|
if modality not in context.inputs[0]:
|
|
144
144
|
continue
|
|
145
145
|
cur = torch.stack(
|
|
146
|
-
[inp[modality] for inp in context.inputs],
|
|
146
|
+
[inp[modality].single_ts_to_chw_tensor() for inp in context.inputs],
|
|
147
|
+
dim=0,
|
|
147
148
|
) # (B, C, H, W)
|
|
148
149
|
if self.do_resizing and (
|
|
149
150
|
cur.shape[2] != IMAGE_SIZE or cur.shape[3] != IMAGE_SIZE
|
|
@@ -219,7 +220,7 @@ class TerramindNormalize(Transform):
|
|
|
219
220
|
Returns:
|
|
220
221
|
The normalized image.
|
|
221
222
|
"""
|
|
222
|
-
images = image.float() # (C, H, W)
|
|
223
|
+
images = image.float() # (C, 1, H, W)
|
|
223
224
|
if images.shape[0] % len(means) != 0:
|
|
224
225
|
raise ValueError(
|
|
225
226
|
f"the number of image channels {images.shape[0]} is not multiple of expected number of bands {len(means)}"
|
|
@@ -247,8 +248,8 @@ class TerramindNormalize(Transform):
|
|
|
247
248
|
band_info = PRETRAINED_BANDS[modality]
|
|
248
249
|
means = [band_info[band][0] for band in band_info]
|
|
249
250
|
stds = [band_info[band][1] for band in band_info]
|
|
250
|
-
input_dict[modality] = self.apply_image(
|
|
251
|
-
input_dict[modality],
|
|
251
|
+
input_dict[modality].image = self.apply_image(
|
|
252
|
+
input_dict[modality].image,
|
|
252
253
|
means,
|
|
253
254
|
stds,
|
|
254
255
|
)
|
|
@@ -10,7 +10,7 @@ import torch
|
|
|
10
10
|
|
|
11
11
|
from rslearn.dataset import Window
|
|
12
12
|
from rslearn.train.dataset import DataInput, ModelDataset
|
|
13
|
-
from rslearn.train.model_context import SampleMetadata
|
|
13
|
+
from rslearn.train.model_context import RasterImage, SampleMetadata
|
|
14
14
|
from rslearn.utils.geometry import PixelBounds, STGeometry
|
|
15
15
|
|
|
16
16
|
|
|
@@ -99,6 +99,30 @@ def pad_slice_protect(
|
|
|
99
99
|
return raw_inputs, passthrough_inputs
|
|
100
100
|
|
|
101
101
|
|
|
102
|
+
def crop_tensor_or_rasterimage(
|
|
103
|
+
x: torch.Tensor | RasterImage, start: tuple[int, int], end: tuple[int, int]
|
|
104
|
+
) -> torch.Tensor | RasterImage:
|
|
105
|
+
"""Crop a tensor or a RasterImage."""
|
|
106
|
+
if isinstance(x, torch.Tensor):
|
|
107
|
+
# Crop the CHW tensor with scaled coordinates.
|
|
108
|
+
return x[
|
|
109
|
+
:,
|
|
110
|
+
start[1] : end[1],
|
|
111
|
+
start[0] : end[0],
|
|
112
|
+
].clone()
|
|
113
|
+
else:
|
|
114
|
+
# Crop the CTHW tensor with scaled coordinates.
|
|
115
|
+
return RasterImage(
|
|
116
|
+
x.image[
|
|
117
|
+
:,
|
|
118
|
+
:,
|
|
119
|
+
start[1] : end[1],
|
|
120
|
+
start[0] : end[0],
|
|
121
|
+
].clone(),
|
|
122
|
+
x.timestamps,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
|
|
102
126
|
class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
103
127
|
"""This wraps a ModelDataset to iterate over all patches in that dataset.
|
|
104
128
|
|
|
@@ -281,7 +305,7 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
281
305
|
def crop_input_dict(d: dict[str, Any]) -> dict[str, Any]:
|
|
282
306
|
cropped = {}
|
|
283
307
|
for input_name, value in d.items():
|
|
284
|
-
if isinstance(value, torch.Tensor):
|
|
308
|
+
if isinstance(value, torch.Tensor | RasterImage):
|
|
285
309
|
# Get resolution scale for this input
|
|
286
310
|
rf = self.inputs[input_name].resolution_factor
|
|
287
311
|
scale = rf.numerator / rf.denominator
|
|
@@ -294,12 +318,9 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
294
318
|
int(end_offset[0] * scale),
|
|
295
319
|
int(end_offset[1] * scale),
|
|
296
320
|
)
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
scaled_start[1] : scaled_end[1],
|
|
301
|
-
scaled_start[0] : scaled_end[0],
|
|
302
|
-
].clone()
|
|
321
|
+
cropped[input_name] = crop_tensor_or_rasterimage(
|
|
322
|
+
value, scaled_start, scaled_end
|
|
323
|
+
)
|
|
303
324
|
elif isinstance(value, list):
|
|
304
325
|
cropped[input_name] = [
|
|
305
326
|
feat
|
|
@@ -429,7 +450,7 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
|
429
450
|
"""
|
|
430
451
|
cropped = {}
|
|
431
452
|
for input_name, value in d.items():
|
|
432
|
-
if isinstance(value, torch.Tensor):
|
|
453
|
+
if isinstance(value, torch.Tensor | RasterImage):
|
|
433
454
|
# Get resolution scale for this input
|
|
434
455
|
rf = self.inputs[input_name].resolution_factor
|
|
435
456
|
scale = rf.numerator / rf.denominator
|
|
@@ -442,11 +463,10 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
|
442
463
|
int(end_offset[0] * scale),
|
|
443
464
|
int(end_offset[1] * scale),
|
|
444
465
|
)
|
|
445
|
-
cropped[input_name] =
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
].clone()
|
|
466
|
+
cropped[input_name] = crop_tensor_or_rasterimage(
|
|
467
|
+
value, scaled_start, scaled_end
|
|
468
|
+
)
|
|
469
|
+
|
|
450
470
|
elif isinstance(value, list):
|
|
451
471
|
cropped[input_name] = [
|
|
452
472
|
feat for feat in value if cur_geom.intersects(feat.geometry)
|
rslearn/train/dataset.py
CHANGED
|
@@ -8,6 +8,7 @@ import random
|
|
|
8
8
|
import tempfile
|
|
9
9
|
import time
|
|
10
10
|
import uuid
|
|
11
|
+
from datetime import datetime
|
|
11
12
|
from typing import Any
|
|
12
13
|
|
|
13
14
|
import torch
|
|
@@ -19,10 +20,16 @@ from rslearn.config import (
|
|
|
19
20
|
DType,
|
|
20
21
|
LayerConfig,
|
|
21
22
|
)
|
|
23
|
+
from rslearn.data_sources.data_source import Item
|
|
22
24
|
from rslearn.dataset.dataset import Dataset
|
|
23
25
|
from rslearn.dataset.storage.file import FileWindowStorage
|
|
24
|
-
from rslearn.dataset.window import
|
|
26
|
+
from rslearn.dataset.window import (
|
|
27
|
+
Window,
|
|
28
|
+
WindowLayerData,
|
|
29
|
+
get_layer_and_group_from_dir_name,
|
|
30
|
+
)
|
|
25
31
|
from rslearn.log_utils import get_logger
|
|
32
|
+
from rslearn.train.model_context import RasterImage
|
|
26
33
|
from rslearn.utils.feature import Feature
|
|
27
34
|
from rslearn.utils.geometry import PixelBounds, ResolutionFactor
|
|
28
35
|
from rslearn.utils.mp import star_imap_unordered
|
|
@@ -198,7 +205,8 @@ def read_raster_layer_for_data_input(
|
|
|
198
205
|
group_idx: int,
|
|
199
206
|
layer_config: LayerConfig,
|
|
200
207
|
data_input: DataInput,
|
|
201
|
-
|
|
208
|
+
layer_data: WindowLayerData | None,
|
|
209
|
+
) -> tuple[torch.Tensor, tuple[datetime, datetime] | None]:
|
|
202
210
|
"""Read a raster layer for a DataInput.
|
|
203
211
|
|
|
204
212
|
This scans the available rasters for the layer at the window to determine which
|
|
@@ -211,9 +219,11 @@ def read_raster_layer_for_data_input(
|
|
|
211
219
|
group_idx: the item group.
|
|
212
220
|
layer_config: the layer configuration.
|
|
213
221
|
data_input: the DataInput that specifies the bands and dtype.
|
|
222
|
+
layer_data: the WindowLayerData associated with this layer and window.
|
|
214
223
|
|
|
215
224
|
Returns:
|
|
216
|
-
|
|
225
|
+
RasterImage containing raster data and the timestamp associated
|
|
226
|
+
with that data.
|
|
217
227
|
"""
|
|
218
228
|
# See what different sets of bands we need to read to get all the
|
|
219
229
|
# configured bands.
|
|
@@ -284,7 +294,34 @@ def read_raster_layer_for_data_input(
|
|
|
284
294
|
src[src_indexes, :, :].astype(data_input.dtype.get_numpy_dtype())
|
|
285
295
|
)
|
|
286
296
|
|
|
287
|
-
|
|
297
|
+
# add the timestamp. this is a tuple defining the start and end of the time range.
|
|
298
|
+
time_range = None
|
|
299
|
+
if layer_data is not None:
|
|
300
|
+
item = Item.deserialize(layer_data.serialized_item_groups[group_idx][0])
|
|
301
|
+
if item.geometry.time_range is not None:
|
|
302
|
+
# we assume if one layer data has a geometry & time range, all of them do
|
|
303
|
+
time_ranges = [
|
|
304
|
+
(
|
|
305
|
+
datetime.fromisoformat(
|
|
306
|
+
Item.deserialize(
|
|
307
|
+
layer_data.serialized_item_groups[group_idx][idx]
|
|
308
|
+
).geometry.time_range[0] # type: ignore
|
|
309
|
+
),
|
|
310
|
+
datetime.fromisoformat(
|
|
311
|
+
Item.deserialize(
|
|
312
|
+
layer_data.serialized_item_groups[group_idx][idx]
|
|
313
|
+
).geometry.time_range[1] # type: ignore
|
|
314
|
+
),
|
|
315
|
+
)
|
|
316
|
+
for idx in range(len(layer_data.serialized_item_groups[group_idx]))
|
|
317
|
+
]
|
|
318
|
+
# take the min and max
|
|
319
|
+
time_range = (
|
|
320
|
+
min([t[0] for t in time_ranges]),
|
|
321
|
+
max([t[1] for t in time_ranges]),
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
return image, time_range
|
|
288
325
|
|
|
289
326
|
|
|
290
327
|
def read_data_input(
|
|
@@ -293,7 +330,7 @@ def read_data_input(
|
|
|
293
330
|
bounds: PixelBounds,
|
|
294
331
|
data_input: DataInput,
|
|
295
332
|
rng: random.Random,
|
|
296
|
-
) ->
|
|
333
|
+
) -> RasterImage | list[Feature]:
|
|
297
334
|
"""Read the data specified by the DataInput from the window.
|
|
298
335
|
|
|
299
336
|
Args:
|
|
@@ -335,15 +372,34 @@ def read_data_input(
|
|
|
335
372
|
layers_to_read = [rng.choice(layer_options)]
|
|
336
373
|
|
|
337
374
|
if data_input.data_type == "raster":
|
|
375
|
+
# load it once here
|
|
376
|
+
layer_datas = window.load_layer_datas()
|
|
338
377
|
images: list[torch.Tensor] = []
|
|
378
|
+
time_ranges: list[tuple[datetime, datetime] | None] = []
|
|
339
379
|
for layer_name, group_idx in layers_to_read:
|
|
340
380
|
layer_config = dataset.layers[layer_name]
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
381
|
+
image, time_range = read_raster_layer_for_data_input(
|
|
382
|
+
window,
|
|
383
|
+
bounds,
|
|
384
|
+
layer_name,
|
|
385
|
+
group_idx,
|
|
386
|
+
layer_config,
|
|
387
|
+
data_input,
|
|
388
|
+
# some layers (e.g. "label_raster") won't have associated
|
|
389
|
+
# layer datas
|
|
390
|
+
layer_datas[layer_name] if layer_name in layer_datas else None,
|
|
345
391
|
)
|
|
346
|
-
|
|
392
|
+
if len(time_ranges) > 0:
|
|
393
|
+
if type(time_ranges[-1]) is not type(time_range):
|
|
394
|
+
raise ValueError(
|
|
395
|
+
f"All time ranges should be datetime tuples or None. Got {type(time_range)} amd {type(time_ranges[-1])}"
|
|
396
|
+
)
|
|
397
|
+
images.append(image)
|
|
398
|
+
time_ranges.append(time_range)
|
|
399
|
+
return RasterImage(
|
|
400
|
+
torch.stack(images, dim=1),
|
|
401
|
+
time_ranges if time_ranges[0] is not None else None, # type: ignore
|
|
402
|
+
)
|
|
347
403
|
|
|
348
404
|
elif data_input.data_type == "vector":
|
|
349
405
|
# We don't really support time series for vector data currently, we just
|
rslearn/train/model_context.py
CHANGED
|
@@ -10,6 +10,40 @@ import torch
|
|
|
10
10
|
from rslearn.utils.geometry import PixelBounds, Projection
|
|
11
11
|
|
|
12
12
|
|
|
13
|
+
@dataclass
|
|
14
|
+
class RasterImage:
|
|
15
|
+
"""A raster image is a torch.tensor containing the images and their associated timestamps."""
|
|
16
|
+
|
|
17
|
+
# image is a 4D CTHW tensor
|
|
18
|
+
image: torch.Tensor
|
|
19
|
+
# if timestamps is not None, len(timestamps) must match the T dimension of the tensor
|
|
20
|
+
timestamps: list[tuple[datetime, datetime]] | None = None
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def shape(self) -> torch.Size:
|
|
24
|
+
"""The shape of the image."""
|
|
25
|
+
return self.image.shape
|
|
26
|
+
|
|
27
|
+
def dim(self) -> int:
|
|
28
|
+
"""The dim of the image."""
|
|
29
|
+
return self.image.dim()
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def dtype(self) -> torch.dtype:
|
|
33
|
+
"""The image dtype."""
|
|
34
|
+
return self.image.dtype
|
|
35
|
+
|
|
36
|
+
def single_ts_to_chw_tensor(self) -> torch.Tensor:
|
|
37
|
+
"""Single timestep models expect single timestep inputs.
|
|
38
|
+
|
|
39
|
+
This function (1) checks this raster image only has 1 timestep and
|
|
40
|
+
(2) returns the tensor for that (single) timestep (going from CTHW to CHW).
|
|
41
|
+
"""
|
|
42
|
+
if self.image.shape[1] != 1:
|
|
43
|
+
raise ValueError(f"Expected a single timestep, got {self.image.shape[1]}")
|
|
44
|
+
return self.image[:, 0]
|
|
45
|
+
|
|
46
|
+
|
|
13
47
|
@dataclass
|
|
14
48
|
class SampleMetadata:
|
|
15
49
|
"""Metadata pertaining to an example."""
|
|
@@ -32,7 +66,7 @@ class ModelContext:
|
|
|
32
66
|
"""Context to pass to all model components."""
|
|
33
67
|
|
|
34
68
|
# One input dict per example in the batch.
|
|
35
|
-
inputs: list[dict[str, torch.Tensor]]
|
|
69
|
+
inputs: list[dict[str, torch.Tensor | RasterImage]]
|
|
36
70
|
# One SampleMetadata per example in the batch.
|
|
37
71
|
metadatas: list[SampleMetadata]
|
|
38
72
|
# Arbitrary dict that components can add to.
|