rslearn 0.0.18__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +2 -9
- rslearn/config/dataset.py +15 -16
- rslearn/dataset/dataset.py +28 -22
- rslearn/lightning_cli.py +22 -11
- rslearn/main.py +1 -1
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/component.py +12 -0
- rslearn/models/olmoearth_pretrain/model.py +125 -34
- rslearn/models/simple_time_series.py +7 -1
- rslearn/train/all_patches_dataset.py +67 -19
- rslearn/train/dataset.py +36 -43
- rslearn/train/scheduler.py +15 -0
- rslearn/train/transforms/resize.py +74 -0
- rslearn/utils/geometry.py +73 -0
- rslearn/utils/jsonargparse.py +66 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
- {rslearn-0.0.18.dist-info → rslearn-0.0.19.dist-info}/RECORD +22 -20
- {rslearn-0.0.18.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.18.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
|
@@ -229,7 +229,13 @@ class SimpleTimeSeries(FeatureExtractor):
|
|
|
229
229
|
|
|
230
230
|
# Now we can apply the underlying FeatureExtractor.
|
|
231
231
|
# Its output must be a FeatureMaps.
|
|
232
|
-
|
|
232
|
+
assert batched_inputs is not None
|
|
233
|
+
encoder_output = self.encoder(
|
|
234
|
+
ModelContext(
|
|
235
|
+
inputs=batched_inputs,
|
|
236
|
+
metadatas=context.metadatas,
|
|
237
|
+
)
|
|
238
|
+
)
|
|
233
239
|
if not isinstance(encoder_output, FeatureMaps):
|
|
234
240
|
raise ValueError(
|
|
235
241
|
"output of underlying FeatureExtractor in SimpleTimeSeries must be a FeatureMaps"
|
|
@@ -9,7 +9,7 @@ import shapely
|
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
11
|
from rslearn.dataset import Window
|
|
12
|
-
from rslearn.train.dataset import ModelDataset
|
|
12
|
+
from rslearn.train.dataset import DataInput, ModelDataset
|
|
13
13
|
from rslearn.train.model_context import SampleMetadata
|
|
14
14
|
from rslearn.utils.geometry import PixelBounds, STGeometry
|
|
15
15
|
|
|
@@ -34,22 +34,28 @@ def get_window_patch_options(
|
|
|
34
34
|
bottommost patches may extend beyond the provided bounds.
|
|
35
35
|
"""
|
|
36
36
|
# We stride the patches by patch_size - overlap_size until the last patch.
|
|
37
|
+
# We handle the first patch with a special case to ensure it is always used.
|
|
37
38
|
# We handle the last patch with a special case to ensure it does not exceed the
|
|
38
39
|
# window bounds. Instead, it may overlap the previous patch.
|
|
39
|
-
cols = list(
|
|
40
|
+
cols = [bounds[0]] + list(
|
|
40
41
|
range(
|
|
41
|
-
bounds[0],
|
|
42
|
+
bounds[0] + patch_size[0],
|
|
42
43
|
bounds[2] - patch_size[0],
|
|
43
44
|
patch_size[0] - overlap_size[0],
|
|
44
45
|
)
|
|
45
|
-
)
|
|
46
|
-
rows = list(
|
|
46
|
+
)
|
|
47
|
+
rows = [bounds[1]] + list(
|
|
47
48
|
range(
|
|
48
|
-
bounds[1],
|
|
49
|
+
bounds[1] + patch_size[1],
|
|
49
50
|
bounds[3] - patch_size[1],
|
|
50
51
|
patch_size[1] - overlap_size[1],
|
|
51
52
|
)
|
|
52
|
-
)
|
|
53
|
+
)
|
|
54
|
+
# Add last patches only if the input is larger than one patch.
|
|
55
|
+
if bounds[2] - patch_size[0] > bounds[0]:
|
|
56
|
+
cols.append(bounds[2] - patch_size[0])
|
|
57
|
+
if bounds[3] - patch_size[1] > bounds[1]:
|
|
58
|
+
rows.append(bounds[3] - patch_size[1])
|
|
53
59
|
|
|
54
60
|
patch_bounds: list[PixelBounds] = []
|
|
55
61
|
for col in cols:
|
|
@@ -62,13 +68,17 @@ def pad_slice_protect(
|
|
|
62
68
|
raw_inputs: dict[str, Any],
|
|
63
69
|
passthrough_inputs: dict[str, Any],
|
|
64
70
|
patch_size: tuple[int, int],
|
|
71
|
+
inputs: dict[str, DataInput],
|
|
65
72
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
66
73
|
"""Pad tensors in-place by patch size to protect slicing near right/bottom edges.
|
|
67
74
|
|
|
75
|
+
The padding is scaled based on each input's resolution_factor.
|
|
76
|
+
|
|
68
77
|
Args:
|
|
69
78
|
raw_inputs: the raw inputs to pad.
|
|
70
79
|
passthrough_inputs: the passthrough inputs to pad.
|
|
71
|
-
patch_size: the size of the patches to extract.
|
|
80
|
+
patch_size: the size of the patches to extract (at window resolution).
|
|
81
|
+
inputs: the DataInput definitions, used to get resolution_factor per input.
|
|
72
82
|
|
|
73
83
|
Returns:
|
|
74
84
|
a tuple of (raw_inputs, passthrough_inputs).
|
|
@@ -77,8 +87,14 @@ def pad_slice_protect(
|
|
|
77
87
|
for input_name, value in list(d.items()):
|
|
78
88
|
if not isinstance(value, torch.Tensor):
|
|
79
89
|
continue
|
|
90
|
+
# Get resolution scale for this input
|
|
91
|
+
rf = inputs[input_name].resolution_factor
|
|
92
|
+
scale = rf.numerator / rf.denominator
|
|
93
|
+
# Scale the padding amount
|
|
94
|
+
scaled_pad_x = int(patch_size[0] * scale)
|
|
95
|
+
scaled_pad_y = int(patch_size[1] * scale)
|
|
80
96
|
d[input_name] = torch.nn.functional.pad(
|
|
81
|
-
value, pad=(0,
|
|
97
|
+
value, pad=(0, scaled_pad_x, 0, scaled_pad_y)
|
|
82
98
|
)
|
|
83
99
|
return raw_inputs, passthrough_inputs
|
|
84
100
|
|
|
@@ -123,6 +139,7 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
123
139
|
self.rank = rank
|
|
124
140
|
self.world_size = world_size
|
|
125
141
|
self.windows = self.dataset.get_dataset_examples()
|
|
142
|
+
self.inputs = dataset.inputs
|
|
126
143
|
|
|
127
144
|
def set_name(self, name: str) -> None:
|
|
128
145
|
"""Sets dataset name.
|
|
@@ -235,8 +252,10 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
235
252
|
|
|
236
253
|
# For simplicity, pad tensors by patch size to ensure that any patch bounds
|
|
237
254
|
# extending outside the window bounds will not have issues when we slice
|
|
238
|
-
# the tensors later.
|
|
239
|
-
pad_slice_protect(
|
|
255
|
+
# the tensors later. Padding is scaled per-input based on resolution_factor.
|
|
256
|
+
pad_slice_protect(
|
|
257
|
+
raw_inputs, passthrough_inputs, self.patch_size, self.inputs
|
|
258
|
+
)
|
|
240
259
|
|
|
241
260
|
# Now iterate over the patches and extract/yield the crops.
|
|
242
261
|
# Note that, in case user is leveraging RslearnWriter, it is important that
|
|
@@ -258,15 +277,28 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
|
|
|
258
277
|
)
|
|
259
278
|
|
|
260
279
|
# Define a helper function to handle each input dict.
|
|
280
|
+
# Crop coordinates are scaled based on each input's resolution_factor.
|
|
261
281
|
def crop_input_dict(d: dict[str, Any]) -> dict[str, Any]:
|
|
262
282
|
cropped = {}
|
|
263
283
|
for input_name, value in d.items():
|
|
264
284
|
if isinstance(value, torch.Tensor):
|
|
265
|
-
#
|
|
285
|
+
# Get resolution scale for this input
|
|
286
|
+
rf = self.inputs[input_name].resolution_factor
|
|
287
|
+
scale = rf.numerator / rf.denominator
|
|
288
|
+
# Scale the crop coordinates
|
|
289
|
+
scaled_start = (
|
|
290
|
+
int(start_offset[0] * scale),
|
|
291
|
+
int(start_offset[1] * scale),
|
|
292
|
+
)
|
|
293
|
+
scaled_end = (
|
|
294
|
+
int(end_offset[0] * scale),
|
|
295
|
+
int(end_offset[1] * scale),
|
|
296
|
+
)
|
|
297
|
+
# Crop the CHW tensor with scaled coordinates.
|
|
266
298
|
cropped[input_name] = value[
|
|
267
299
|
:,
|
|
268
|
-
|
|
269
|
-
|
|
300
|
+
scaled_start[1] : scaled_end[1],
|
|
301
|
+
scaled_start[0] : scaled_end[0],
|
|
270
302
|
].clone()
|
|
271
303
|
elif isinstance(value, list):
|
|
272
304
|
cropped[input_name] = [
|
|
@@ -348,6 +380,7 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
|
348
380
|
round(self.patch_size[1] * overlap_ratio),
|
|
349
381
|
)
|
|
350
382
|
self.windows = self.dataset.get_dataset_examples()
|
|
383
|
+
self.inputs = dataset.inputs
|
|
351
384
|
self.window_cache: dict[
|
|
352
385
|
int, tuple[dict[str, Any], dict[str, Any], SampleMetadata]
|
|
353
386
|
] = {}
|
|
@@ -378,26 +411,41 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
|
|
|
378
411
|
return self.window_cache[index]
|
|
379
412
|
|
|
380
413
|
raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(index)
|
|
381
|
-
pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size)
|
|
414
|
+
pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size, self.inputs)
|
|
382
415
|
|
|
383
416
|
self.window_cache[index] = (raw_inputs, passthrough_inputs, metadata)
|
|
384
417
|
return self.window_cache[index]
|
|
385
418
|
|
|
386
|
-
@staticmethod
|
|
387
419
|
def _crop_input_dict(
|
|
420
|
+
self,
|
|
388
421
|
d: dict[str, Any],
|
|
389
422
|
start_offset: tuple[int, int],
|
|
390
423
|
end_offset: tuple[int, int],
|
|
391
424
|
cur_geom: STGeometry,
|
|
392
425
|
) -> dict[str, Any]:
|
|
393
|
-
"""Crop a dictionary of inputs to the given bounds.
|
|
426
|
+
"""Crop a dictionary of inputs to the given bounds.
|
|
427
|
+
|
|
428
|
+
Crop coordinates are scaled based on each input's resolution_factor.
|
|
429
|
+
"""
|
|
394
430
|
cropped = {}
|
|
395
431
|
for input_name, value in d.items():
|
|
396
432
|
if isinstance(value, torch.Tensor):
|
|
433
|
+
# Get resolution scale for this input
|
|
434
|
+
rf = self.inputs[input_name].resolution_factor
|
|
435
|
+
scale = rf.numerator / rf.denominator
|
|
436
|
+
# Scale the crop coordinates
|
|
437
|
+
scaled_start = (
|
|
438
|
+
int(start_offset[0] * scale),
|
|
439
|
+
int(start_offset[1] * scale),
|
|
440
|
+
)
|
|
441
|
+
scaled_end = (
|
|
442
|
+
int(end_offset[0] * scale),
|
|
443
|
+
int(end_offset[1] * scale),
|
|
444
|
+
)
|
|
397
445
|
cropped[input_name] = value[
|
|
398
446
|
:,
|
|
399
|
-
|
|
400
|
-
|
|
447
|
+
scaled_start[1] : scaled_end[1],
|
|
448
|
+
scaled_start[0] : scaled_end[0],
|
|
401
449
|
].clone()
|
|
402
450
|
elif isinstance(value, list):
|
|
403
451
|
cropped[input_name] = [
|
rslearn/train/dataset.py
CHANGED
|
@@ -24,7 +24,7 @@ from rslearn.dataset.storage.file import FileWindowStorage
|
|
|
24
24
|
from rslearn.dataset.window import Window, get_layer_and_group_from_dir_name
|
|
25
25
|
from rslearn.log_utils import get_logger
|
|
26
26
|
from rslearn.utils.feature import Feature
|
|
27
|
-
from rslearn.utils.geometry import PixelBounds
|
|
27
|
+
from rslearn.utils.geometry import PixelBounds, ResolutionFactor
|
|
28
28
|
from rslearn.utils.mp import star_imap_unordered
|
|
29
29
|
|
|
30
30
|
from .model_context import SampleMetadata
|
|
@@ -130,6 +130,10 @@ class DataInput:
|
|
|
130
130
|
"""Specification of a piece of data from a window that is needed for training.
|
|
131
131
|
|
|
132
132
|
The DataInput includes which layer(s) the data can be obtained from for each window.
|
|
133
|
+
|
|
134
|
+
Note that this class is not a dataclass because jsonargparse does not play well
|
|
135
|
+
with dataclasses without enabling specialized options which we have not validated
|
|
136
|
+
will work with the rest of our code.
|
|
133
137
|
"""
|
|
134
138
|
|
|
135
139
|
def __init__(
|
|
@@ -143,7 +147,9 @@ class DataInput:
|
|
|
143
147
|
dtype: DType = DType.FLOAT32,
|
|
144
148
|
load_all_layers: bool = False,
|
|
145
149
|
load_all_item_groups: bool = False,
|
|
146
|
-
|
|
150
|
+
resolution_factor: ResolutionFactor = ResolutionFactor(),
|
|
151
|
+
resampling: Resampling = Resampling.nearest,
|
|
152
|
+
):
|
|
147
153
|
"""Initialize a new DataInput.
|
|
148
154
|
|
|
149
155
|
Args:
|
|
@@ -166,6 +172,11 @@ class DataInput:
|
|
|
166
172
|
are reading from. By default, we assume the specified layer name is of
|
|
167
173
|
the form "{layer_name}.{group_idx}" and read that item group only. With
|
|
168
174
|
this option enabled, we ignore the group_idx and read all item groups.
|
|
175
|
+
resolution_factor: controls the resolution at which raster data is loaded for training.
|
|
176
|
+
By default (factor=1), data is loaded at the window resolution.
|
|
177
|
+
E.g. for a 64x64 window at 10 m/pixel with resolution_factor=1/2,
|
|
178
|
+
the resulting tensor is 32x32 (covering the same geographic area at 20 m/pixel).
|
|
179
|
+
resampling: resampling method (default nearest neighbor).
|
|
169
180
|
"""
|
|
170
181
|
self.data_type = data_type
|
|
171
182
|
self.layers = layers
|
|
@@ -176,6 +187,8 @@ class DataInput:
|
|
|
176
187
|
self.dtype = dtype
|
|
177
188
|
self.load_all_layers = load_all_layers
|
|
178
189
|
self.load_all_item_groups = load_all_item_groups
|
|
190
|
+
self.resolution_factor = resolution_factor
|
|
191
|
+
self.resampling = resampling
|
|
179
192
|
|
|
180
193
|
|
|
181
194
|
def read_raster_layer_for_data_input(
|
|
@@ -233,15 +246,23 @@ def read_raster_layer_for_data_input(
|
|
|
233
246
|
+ f"window {window.name} layer {layer_name} group {group_idx}"
|
|
234
247
|
)
|
|
235
248
|
|
|
249
|
+
# Get the projection and bounds to read under (multiply window resolution # by
|
|
250
|
+
# the specified resolution factor).
|
|
251
|
+
final_projection = data_input.resolution_factor.multiply_projection(
|
|
252
|
+
window.projection
|
|
253
|
+
)
|
|
254
|
+
final_bounds = data_input.resolution_factor.multiply_bounds(bounds)
|
|
255
|
+
|
|
236
256
|
image = torch.zeros(
|
|
237
|
-
(
|
|
257
|
+
(
|
|
258
|
+
len(needed_bands),
|
|
259
|
+
final_bounds[3] - final_bounds[1],
|
|
260
|
+
final_bounds[2] - final_bounds[0],
|
|
261
|
+
),
|
|
238
262
|
dtype=get_torch_dtype(data_input.dtype),
|
|
239
263
|
)
|
|
240
264
|
|
|
241
265
|
for band_set, src_indexes, dst_indexes in needed_sets_and_indexes:
|
|
242
|
-
final_projection, final_bounds = band_set.get_final_projection_and_bounds(
|
|
243
|
-
window.projection, bounds
|
|
244
|
-
)
|
|
245
266
|
if band_set.format is None:
|
|
246
267
|
raise ValueError(f"No format specified for {layer_name}")
|
|
247
268
|
raster_format = band_set.instantiate_raster_format()
|
|
@@ -249,44 +270,16 @@ def read_raster_layer_for_data_input(
|
|
|
249
270
|
layer_name, band_set.bands, group_idx=group_idx
|
|
250
271
|
)
|
|
251
272
|
|
|
252
|
-
#
|
|
253
|
-
#
|
|
254
|
-
#
|
|
255
|
-
#
|
|
256
|
-
#
|
|
257
|
-
#
|
|
258
|
-
is_bounds_zoomable = True
|
|
259
|
-
if band_set.zoom_offset < 0:
|
|
260
|
-
zoom_factor = 2 ** (-band_set.zoom_offset)
|
|
261
|
-
is_bounds_zoomable = (final_bounds[2] - final_bounds[0]) * zoom_factor == (
|
|
262
|
-
bounds[2] - bounds[0]
|
|
263
|
-
) and (final_bounds[3] - final_bounds[1]) * zoom_factor == (
|
|
264
|
-
bounds[3] - bounds[1]
|
|
265
|
-
)
|
|
266
|
-
|
|
267
|
-
if is_bounds_zoomable:
|
|
268
|
-
src = raster_format.decode_raster(
|
|
269
|
-
raster_dir, final_projection, final_bounds
|
|
270
|
-
)
|
|
271
|
-
|
|
272
|
-
# Resize to patch size if needed.
|
|
273
|
-
# This is for band sets that are stored at a lower resolution.
|
|
274
|
-
# Here we assume that it is a multiple.
|
|
275
|
-
if src.shape[1:3] != image.shape[1:3]:
|
|
276
|
-
if src.shape[1] < image.shape[1]:
|
|
277
|
-
factor = image.shape[1] // src.shape[1]
|
|
278
|
-
src = src.repeat(repeats=factor, axis=1).repeat(
|
|
279
|
-
repeats=factor, axis=2
|
|
280
|
-
)
|
|
281
|
-
else:
|
|
282
|
-
factor = src.shape[1] // image.shape[1]
|
|
283
|
-
src = src[:, ::factor, ::factor]
|
|
284
|
-
|
|
285
|
-
else:
|
|
286
|
-
src = raster_format.decode_raster(
|
|
287
|
-
raster_dir, window.projection, bounds, resampling=Resampling.nearest
|
|
288
|
-
)
|
|
273
|
+
# TODO: previously we try to read based on band_set.zoom_offset when possible,
|
|
274
|
+
# and handle zooming in with torch.repeat (if resampling method is nearest
|
|
275
|
+
# neighbor). However, we have not benchmarked whether this actually improves
|
|
276
|
+
# data loading speed, so for simplicity, for now we let rasterio handle the
|
|
277
|
+
# resampling. If it really is much faster to handle it via torch, then it may
|
|
278
|
+
# make sense to bring back that functionality.
|
|
289
279
|
|
|
280
|
+
src = raster_format.decode_raster(
|
|
281
|
+
raster_dir, final_projection, final_bounds, resampling=Resampling.nearest
|
|
282
|
+
)
|
|
290
283
|
image[dst_indexes, :, :] = torch.as_tensor(
|
|
291
284
|
src[src_indexes, :, :].astype(data_input.dtype.get_numpy_dtype())
|
|
292
285
|
)
|
rslearn/train/scheduler.py
CHANGED
|
@@ -8,6 +8,7 @@ from torch.optim.lr_scheduler import (
|
|
|
8
8
|
CosineAnnealingLR,
|
|
9
9
|
CosineAnnealingWarmRestarts,
|
|
10
10
|
LRScheduler,
|
|
11
|
+
MultiStepLR,
|
|
11
12
|
ReduceLROnPlateau,
|
|
12
13
|
)
|
|
13
14
|
|
|
@@ -50,6 +51,20 @@ class PlateauScheduler(SchedulerFactory):
|
|
|
50
51
|
return ReduceLROnPlateau(optimizer, **self.get_kwargs())
|
|
51
52
|
|
|
52
53
|
|
|
54
|
+
@dataclass
|
|
55
|
+
class MultiStepScheduler(SchedulerFactory):
|
|
56
|
+
"""Step learning rate scheduler."""
|
|
57
|
+
|
|
58
|
+
milestones: list[int]
|
|
59
|
+
gamma: float | None = None
|
|
60
|
+
last_epoch: int | None = None
|
|
61
|
+
|
|
62
|
+
def build(self, optimizer: Optimizer) -> LRScheduler:
|
|
63
|
+
"""Build the ReduceLROnPlateau scheduler."""
|
|
64
|
+
super().build(optimizer)
|
|
65
|
+
return MultiStepLR(optimizer, **self.get_kwargs())
|
|
66
|
+
|
|
67
|
+
|
|
53
68
|
@dataclass
|
|
54
69
|
class CosineAnnealingScheduler(SchedulerFactory):
|
|
55
70
|
"""Cosine annealing learning rate scheduler."""
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""Resize transform."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torchvision
|
|
7
|
+
from torchvision.transforms import InterpolationMode
|
|
8
|
+
|
|
9
|
+
from .transform import Transform
|
|
10
|
+
|
|
11
|
+
INTERPOLATION_MODES = {
|
|
12
|
+
"nearest": InterpolationMode.NEAREST,
|
|
13
|
+
"nearest_exact": InterpolationMode.NEAREST_EXACT,
|
|
14
|
+
"bilinear": InterpolationMode.BILINEAR,
|
|
15
|
+
"bicubic": InterpolationMode.BICUBIC,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Resize(Transform):
|
|
20
|
+
"""Resizes inputs to a target size."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
target_size: tuple[int, int],
|
|
25
|
+
selectors: list[str] = [],
|
|
26
|
+
interpolation: str = "nearest",
|
|
27
|
+
):
|
|
28
|
+
"""Initialize a resize transform.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
target_size: the (height, width) to resize to.
|
|
32
|
+
selectors: items to transform.
|
|
33
|
+
interpolation: the interpolation mode to use for resizing.
|
|
34
|
+
Must be one of "nearest", "nearest_exact", "bilinear", or "bicubic".
|
|
35
|
+
"""
|
|
36
|
+
super().__init__()
|
|
37
|
+
self.target_size = target_size
|
|
38
|
+
self.selectors = selectors
|
|
39
|
+
self.interpolation = INTERPOLATION_MODES[interpolation]
|
|
40
|
+
|
|
41
|
+
def apply_resize(self, image: torch.Tensor) -> torch.Tensor:
|
|
42
|
+
"""Apply resizing on the specified image.
|
|
43
|
+
|
|
44
|
+
If the image is 2D, it is unsqueezed to 3D and then squeezed
|
|
45
|
+
back after resizing.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
image: the image to transform.
|
|
49
|
+
"""
|
|
50
|
+
if image.dim() == 2:
|
|
51
|
+
image = image.unsqueeze(0) # (H, W) -> (1, H, W)
|
|
52
|
+
result = torchvision.transforms.functional.resize(
|
|
53
|
+
image, self.target_size, self.interpolation
|
|
54
|
+
)
|
|
55
|
+
return result.squeeze(0) # (1, H, W) -> (H, W)
|
|
56
|
+
|
|
57
|
+
return torchvision.transforms.functional.resize(
|
|
58
|
+
image, self.target_size, self.interpolation
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def forward(
|
|
62
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
63
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
64
|
+
"""Apply transform over the inputs and targets.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
input_dict: the input
|
|
68
|
+
target_dict: the target
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
transformed (input_dicts, target_dicts) tuple
|
|
72
|
+
"""
|
|
73
|
+
self.apply_fn(self.apply_resize, input_dict, target_dict, self.selectors)
|
|
74
|
+
return input_dict, target_dict
|
rslearn/utils/geometry.py
CHANGED
|
@@ -116,6 +116,79 @@ class Projection:
|
|
|
116
116
|
WGS84_PROJECTION = Projection(CRS.from_epsg(WGS84_EPSG), 1, 1)
|
|
117
117
|
|
|
118
118
|
|
|
119
|
+
class ResolutionFactor:
|
|
120
|
+
"""Multiplier for the resolution in a Projection.
|
|
121
|
+
|
|
122
|
+
The multiplier is either an integer x, or the inverse of an integer (1/x).
|
|
123
|
+
|
|
124
|
+
Factors greater than 1 increase the projection_units/pixel resolution, increasing
|
|
125
|
+
the resolution (more pixels per projection unit). Factors less than 1 make it coarser
|
|
126
|
+
(less pixels).
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def __init__(self, numerator: int = 1, denominator: int = 1):
|
|
130
|
+
"""Create a new ResolutionFactor.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
numerator: the numerator of the fraction.
|
|
134
|
+
denominator: the denominator of the fraction. If set, numerator must be 1.
|
|
135
|
+
"""
|
|
136
|
+
if numerator != 1 and denominator != 1:
|
|
137
|
+
raise ValueError("one of numerator or denominator must be 1")
|
|
138
|
+
if not isinstance(numerator, int) or not isinstance(denominator, int):
|
|
139
|
+
raise ValueError("numerator and denominator must be integers")
|
|
140
|
+
if numerator < 1 or denominator < 1:
|
|
141
|
+
raise ValueError("numerator and denominator must be >= 1")
|
|
142
|
+
self.numerator = numerator
|
|
143
|
+
self.denominator = denominator
|
|
144
|
+
|
|
145
|
+
def multiply_projection(self, projection: Projection) -> Projection:
|
|
146
|
+
"""Multiply the projection by this factor."""
|
|
147
|
+
if self.denominator > 1:
|
|
148
|
+
return Projection(
|
|
149
|
+
projection.crs,
|
|
150
|
+
projection.x_resolution * self.denominator,
|
|
151
|
+
projection.y_resolution * self.denominator,
|
|
152
|
+
)
|
|
153
|
+
else:
|
|
154
|
+
return Projection(
|
|
155
|
+
projection.crs,
|
|
156
|
+
projection.x_resolution // self.numerator,
|
|
157
|
+
projection.y_resolution // self.numerator,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
def multiply_bounds(self, bounds: PixelBounds) -> PixelBounds:
|
|
161
|
+
"""Multiply the bounds by this factor.
|
|
162
|
+
|
|
163
|
+
When coarsening, the width and height of the given bounds must be a multiple of
|
|
164
|
+
the denominator.
|
|
165
|
+
"""
|
|
166
|
+
if self.denominator > 1:
|
|
167
|
+
# Verify the width and height are multiples of the denominator.
|
|
168
|
+
# Otherwise the new width and height is not an integer.
|
|
169
|
+
width = bounds[2] - bounds[0]
|
|
170
|
+
height = bounds[3] - bounds[1]
|
|
171
|
+
if width % self.denominator != 0 or height % self.denominator != 0:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"width {width} or height {height} is not a multiple of the resolution factor {self.denominator}"
|
|
174
|
+
)
|
|
175
|
+
# TODO: an offset could be introduced by bounds not being a multiple
|
|
176
|
+
# of the denominator -> will need to decide how to handle that.
|
|
177
|
+
return (
|
|
178
|
+
bounds[0] // self.denominator,
|
|
179
|
+
bounds[1] // self.denominator,
|
|
180
|
+
bounds[2] // self.denominator,
|
|
181
|
+
bounds[3] // self.denominator,
|
|
182
|
+
)
|
|
183
|
+
else:
|
|
184
|
+
return (
|
|
185
|
+
bounds[0] * self.numerator,
|
|
186
|
+
bounds[1] * self.numerator,
|
|
187
|
+
bounds[2] * self.numerator,
|
|
188
|
+
bounds[3] * self.numerator,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
|
|
119
192
|
class STGeometry:
|
|
120
193
|
"""A spatiotemporal geometry.
|
|
121
194
|
|
rslearn/utils/jsonargparse.py
CHANGED
|
@@ -8,6 +8,7 @@ from rasterio.crs import CRS
|
|
|
8
8
|
from upath import UPath
|
|
9
9
|
|
|
10
10
|
from rslearn.config.dataset import LayerConfig
|
|
11
|
+
from rslearn.utils.geometry import ResolutionFactor
|
|
11
12
|
|
|
12
13
|
if TYPE_CHECKING:
|
|
13
14
|
from rslearn.data_sources.data_source import DataSourceContext
|
|
@@ -91,6 +92,68 @@ def data_source_context_deserializer(v: dict[str, Any]) -> "DataSourceContext":
|
|
|
91
92
|
)
|
|
92
93
|
|
|
93
94
|
|
|
95
|
+
def resolution_factor_serializer(v: ResolutionFactor) -> str:
|
|
96
|
+
"""Serialize ResolutionFactor for jsonargparse.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
v: the ResolutionFactor object.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
the ResolutionFactor encoded to string
|
|
103
|
+
"""
|
|
104
|
+
if hasattr(v, "init_args"):
|
|
105
|
+
init_args = v.init_args
|
|
106
|
+
return f"{init_args.numerator}/{init_args.denominator}"
|
|
107
|
+
|
|
108
|
+
return f"{v.numerator}/{v.denominator}"
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def resolution_factor_deserializer(v: int | str | dict) -> ResolutionFactor:
|
|
112
|
+
"""Deserialize ResolutionFactor for jsonargparse.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
v: the encoded ResolutionFactor.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
the decoded ResolutionFactor object
|
|
119
|
+
"""
|
|
120
|
+
# Handle already-instantiated ResolutionFactor
|
|
121
|
+
if isinstance(v, ResolutionFactor):
|
|
122
|
+
return v
|
|
123
|
+
|
|
124
|
+
# Handle Namespace from class_path syntax (used during config save/validation)
|
|
125
|
+
if hasattr(v, "init_args"):
|
|
126
|
+
init_args = v.init_args
|
|
127
|
+
return ResolutionFactor(
|
|
128
|
+
numerator=init_args.numerator,
|
|
129
|
+
denominator=init_args.denominator,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Handle dict from class_path syntax in YAML config
|
|
133
|
+
if isinstance(v, dict) and "init_args" in v:
|
|
134
|
+
init_args = v["init_args"]
|
|
135
|
+
return ResolutionFactor(
|
|
136
|
+
numerator=init_args.get("numerator", 1),
|
|
137
|
+
denominator=init_args.get("denominator", 1),
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
if isinstance(v, int):
|
|
141
|
+
return ResolutionFactor(numerator=v)
|
|
142
|
+
elif isinstance(v, str):
|
|
143
|
+
parts = v.split("/")
|
|
144
|
+
if len(parts) == 1:
|
|
145
|
+
return ResolutionFactor(numerator=int(parts[0]))
|
|
146
|
+
elif len(parts) == 2:
|
|
147
|
+
return ResolutionFactor(
|
|
148
|
+
numerator=int(parts[0]),
|
|
149
|
+
denominator=int(parts[1]),
|
|
150
|
+
)
|
|
151
|
+
else:
|
|
152
|
+
raise ValueError("expected resolution factor to be of the form x or 1/x")
|
|
153
|
+
else:
|
|
154
|
+
raise ValueError("expected resolution factor to be str or int")
|
|
155
|
+
|
|
156
|
+
|
|
94
157
|
def init_jsonargparse() -> None:
|
|
95
158
|
"""Initialize custom jsonargparse serializers."""
|
|
96
159
|
global INITIALIZED
|
|
@@ -100,6 +163,9 @@ def init_jsonargparse() -> None:
|
|
|
100
163
|
jsonargparse.typing.register_type(
|
|
101
164
|
datetime, datetime_serializer, datetime_deserializer
|
|
102
165
|
)
|
|
166
|
+
jsonargparse.typing.register_type(
|
|
167
|
+
ResolutionFactor, resolution_factor_serializer, resolution_factor_deserializer
|
|
168
|
+
)
|
|
103
169
|
|
|
104
170
|
from rslearn.data_sources.data_source import DataSourceContext
|
|
105
171
|
|