rslearn 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/data_sources/local_files.py +20 -3
- rslearn/data_sources/planetary_computer.py +79 -14
- rslearn/dataset/handler_summaries.py +130 -0
- rslearn/dataset/manage.py +159 -24
- rslearn/dataset/materialize.py +21 -2
- rslearn/dataset/remap.py +29 -4
- rslearn/main.py +60 -8
- rslearn/models/clay/clay.py +29 -14
- rslearn/models/copernicusfm.py +37 -25
- rslearn/models/dinov3.py +166 -0
- rslearn/models/galileo/galileo.py +58 -12
- rslearn/models/galileo/single_file_galileo.py +7 -1
- rslearn/models/presto/presto.py +11 -0
- rslearn/models/prithvi.py +139 -52
- rslearn/models/registry.py +19 -2
- rslearn/models/resize_features.py +45 -0
- rslearn/models/simple_time_series.py +65 -10
- rslearn/models/upsample.py +2 -2
- rslearn/tile_stores/default.py +34 -7
- rslearn/train/transforms/normalize.py +34 -5
- rslearn/train/transforms/select_bands.py +67 -0
- rslearn/train/transforms/sentinel1.py +60 -0
- rslearn/train/transforms/transform.py +23 -6
- rslearn/utils/raster_format.py +44 -5
- rslearn/utils/vector_format.py +35 -4
- {rslearn-0.0.7.dist-info → rslearn-0.0.9.dist-info}/METADATA +3 -4
- {rslearn-0.0.7.dist-info → rslearn-0.0.9.dist-info}/RECORD +31 -26
- {rslearn-0.0.7.dist-info → rslearn-0.0.9.dist-info}/WHEEL +0 -0
- {rslearn-0.0.7.dist-info → rslearn-0.0.9.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.7.dist-info → rslearn-0.0.9.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.7.dist-info → rslearn-0.0.9.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import math
|
|
4
4
|
import tempfile
|
|
5
|
+
from contextlib import nullcontext
|
|
5
6
|
from enum import StrEnum
|
|
6
7
|
from typing import Any, cast
|
|
7
8
|
|
|
@@ -63,6 +64,11 @@ pretrained_weights: dict[GalileoSize, str] = {
|
|
|
63
64
|
|
|
64
65
|
DEFAULT_NORMALIZER = Normalizer()
|
|
65
66
|
|
|
67
|
+
AUTOCAST_DTYPE_MAP = {
|
|
68
|
+
"bfloat16": torch.bfloat16,
|
|
69
|
+
"float32": torch.float32,
|
|
70
|
+
}
|
|
71
|
+
|
|
66
72
|
|
|
67
73
|
class GalileoModel(nn.Module):
|
|
68
74
|
"""Galileo backbones."""
|
|
@@ -85,6 +91,7 @@ class GalileoModel(nn.Module):
|
|
|
85
91
|
size: GalileoSize,
|
|
86
92
|
patch_size: int = 4,
|
|
87
93
|
pretrained_path: str | UPath | None = None,
|
|
94
|
+
autocast_dtype: str | None = "bfloat16",
|
|
88
95
|
) -> None:
|
|
89
96
|
"""Initialize the Galileo model.
|
|
90
97
|
|
|
@@ -93,6 +100,7 @@ class GalileoModel(nn.Module):
|
|
|
93
100
|
patch_size: The patch size to use.
|
|
94
101
|
pretrained_path: the local path to the pretrained weights. Otherwise it is
|
|
95
102
|
downloaded and cached in temp directory.
|
|
103
|
+
autocast_dtype: which dtype to use for autocasting, or set None to disable.
|
|
96
104
|
"""
|
|
97
105
|
super().__init__()
|
|
98
106
|
if pretrained_path is None:
|
|
@@ -128,8 +136,14 @@ class GalileoModel(nn.Module):
|
|
|
128
136
|
idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S1" in key
|
|
129
137
|
]
|
|
130
138
|
|
|
139
|
+
self.size = size
|
|
131
140
|
self.patch_size = patch_size
|
|
132
141
|
|
|
142
|
+
if autocast_dtype is not None:
|
|
143
|
+
self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
|
|
144
|
+
else:
|
|
145
|
+
self.autocast_dtype = None
|
|
146
|
+
|
|
133
147
|
@staticmethod
|
|
134
148
|
def to_cartesian(
|
|
135
149
|
lat: float | np.ndarray | torch.Tensor, lon: float | np.ndarray | torch.Tensor
|
|
@@ -484,18 +498,31 @@ class GalileoModel(nn.Module):
|
|
|
484
498
|
patch_size = h
|
|
485
499
|
else:
|
|
486
500
|
patch_size = self.patch_size
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
501
|
+
|
|
502
|
+
# Decide context based on self.autocast_dtype.
|
|
503
|
+
device = galileo_input.s_t_x.device
|
|
504
|
+
if self.autocast_dtype is None:
|
|
505
|
+
context = nullcontext()
|
|
506
|
+
else:
|
|
507
|
+
assert device is not None
|
|
508
|
+
context = torch.amp.autocast(
|
|
509
|
+
device_type=device.type, dtype=self.autocast_dtype
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
with context:
|
|
513
|
+
outputs = self.model(
|
|
514
|
+
s_t_x=galileo_input.s_t_x,
|
|
515
|
+
s_t_m=galileo_input.s_t_m,
|
|
516
|
+
sp_x=galileo_input.sp_x,
|
|
517
|
+
sp_m=galileo_input.sp_m,
|
|
518
|
+
t_x=galileo_input.t_x,
|
|
519
|
+
t_m=galileo_input.t_m,
|
|
520
|
+
st_x=galileo_input.st_x,
|
|
521
|
+
st_m=galileo_input.st_m,
|
|
522
|
+
months=galileo_input.months,
|
|
523
|
+
patch_size=patch_size,
|
|
524
|
+
)
|
|
525
|
+
|
|
499
526
|
if h == patch_size:
|
|
500
527
|
# only one spatial patch, so we can just take an average
|
|
501
528
|
# of all the tokens to output b c_g 1 1
|
|
@@ -515,3 +542,22 @@ class GalileoModel(nn.Module):
|
|
|
515
542
|
"b h w c_g d -> b c_g d h w",
|
|
516
543
|
).mean(dim=1)
|
|
517
544
|
]
|
|
545
|
+
|
|
546
|
+
def get_backbone_channels(self) -> list:
|
|
547
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
548
|
+
|
|
549
|
+
The output channels is a list of (patch_size, depth) that corresponds
|
|
550
|
+
to the feature maps that the backbone returns.
|
|
551
|
+
|
|
552
|
+
Returns:
|
|
553
|
+
the output channels of the backbone as a list of (patch_size, depth) tuples.
|
|
554
|
+
"""
|
|
555
|
+
if self.size == GalileoSize.BASE:
|
|
556
|
+
depth = 768
|
|
557
|
+
elif self.model_size == GalileoSize.TINY:
|
|
558
|
+
depth = 192
|
|
559
|
+
elif self.model_size == GalileoSize.NANO:
|
|
560
|
+
depth = 128
|
|
561
|
+
else:
|
|
562
|
+
raise ValueError(f"Invalid model size: {self.size}")
|
|
563
|
+
return [(self.patch_size, depth)]
|
|
@@ -1469,7 +1469,13 @@ class Encoder(GalileoBase):
|
|
|
1469
1469
|
# we take the inverse of the mask because a value
|
|
1470
1470
|
# of True indicates the value *should* take part in
|
|
1471
1471
|
# attention
|
|
1472
|
-
|
|
1472
|
+
temp_mask = ~new_m.bool()
|
|
1473
|
+
if temp_mask.all():
|
|
1474
|
+
# if all the tokens are used in attention we can pass a None mask
|
|
1475
|
+
# to the attention block
|
|
1476
|
+
temp_mask = None
|
|
1477
|
+
|
|
1478
|
+
x = blk(x=x, y=None, attn_mask=temp_mask)
|
|
1473
1479
|
|
|
1474
1480
|
if exit_ids_seq is not None:
|
|
1475
1481
|
assert exited_tokens is not None
|
rslearn/models/presto/presto.py
CHANGED
|
@@ -248,3 +248,14 @@ class Presto(nn.Module):
|
|
|
248
248
|
output_features[batch_idx : batch_idx + self.pixel_batch_size] = output_b
|
|
249
249
|
|
|
250
250
|
return [rearrange(output_features, "(b h w) d -> b d h w", h=h, w=w, b=b)]
|
|
251
|
+
|
|
252
|
+
def get_backbone_channels(self) -> list:
|
|
253
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
254
|
+
|
|
255
|
+
The output channels is a list of (patch_size, depth) that corresponds
|
|
256
|
+
to the feature maps that the backbone returns.
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
the output channels of the backbone as a list of (patch_size, depth) tuples.
|
|
260
|
+
"""
|
|
261
|
+
return [(1, 128)]
|
rslearn/models/prithvi.py
CHANGED
|
@@ -1,57 +1,91 @@
|
|
|
1
1
|
"""Prithvi V2."""
|
|
2
2
|
|
|
3
|
+
import json
|
|
3
4
|
import logging
|
|
4
5
|
import tempfile
|
|
5
6
|
import warnings
|
|
7
|
+
from enum import StrEnum
|
|
8
|
+
from pathlib import Path
|
|
6
9
|
from typing import Any
|
|
7
10
|
|
|
8
11
|
import numpy as np
|
|
9
12
|
import torch
|
|
10
13
|
import torch.nn as nn
|
|
11
|
-
import yaml
|
|
12
14
|
from einops import rearrange
|
|
13
15
|
from huggingface_hub import hf_hub_download
|
|
14
16
|
from timm.layers import to_2tuple
|
|
15
17
|
from timm.models.vision_transformer import Block
|
|
16
18
|
from torch.nn import functional as F
|
|
17
|
-
|
|
19
|
+
|
|
20
|
+
from rslearn.train.transforms.normalize import Normalize
|
|
21
|
+
from rslearn.train.transforms.transform import Transform
|
|
18
22
|
|
|
19
23
|
logger = logging.getLogger(__name__)
|
|
20
24
|
|
|
21
25
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
26
|
+
class PrithviV2Models(StrEnum):
|
|
27
|
+
"""Names for different Prithvi models on torch hub."""
|
|
28
|
+
|
|
29
|
+
VIT_300 = "VIT_300"
|
|
30
|
+
VIT_600 = "VIT_600"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
MODEL_TO_HF_INFO = {
|
|
34
|
+
PrithviV2Models.VIT_300: {
|
|
35
|
+
"hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-300M",
|
|
36
|
+
"weights": "Prithvi_EO_V2_300M.pt",
|
|
37
|
+
"revision": "b2f2520ab889f42a25c5361ba18761fcb4ea44ad",
|
|
38
|
+
},
|
|
39
|
+
PrithviV2Models.VIT_600: {
|
|
40
|
+
"hf_hub_id": "ibm-nasa-geospatial/Prithvi-EO-2.0-600M",
|
|
41
|
+
"weights": "Prithvi_EO_V2_600M.pt",
|
|
42
|
+
"revision": "87f15784813828dc37aa3197a143cd4689e4d080",
|
|
43
|
+
},
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
HF_HUB_CONFIG_FNAME = "config.json"
|
|
48
|
+
DEFAULT_CACHE_DIR = Path(tempfile.gettempdir(), "rslearn_cache", "prithvi_v2")
|
|
39
49
|
|
|
40
50
|
|
|
41
|
-
|
|
51
|
+
def get_config(cache_dir: Path, hf_hub_id: str, hf_hub_revision: str) -> dict[str, Any]:
|
|
52
|
+
"""Get the JSON config dict.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
cache_dir: the directory to cache the config.json file, which will be
|
|
56
|
+
downloaded from HF Hub.
|
|
57
|
+
hf_hub_id: the HF Hub ID from which to download the config.
|
|
58
|
+
hf_hub_revision: The revision (commit) to download the config from.
|
|
59
|
+
"""
|
|
60
|
+
cache_fname = cache_dir / HF_HUB_CONFIG_FNAME
|
|
61
|
+
if not cache_fname.exists():
|
|
62
|
+
_ = hf_hub_download(
|
|
63
|
+
local_dir=cache_dir,
|
|
64
|
+
repo_id=hf_hub_id,
|
|
65
|
+
filename=HF_HUB_CONFIG_FNAME,
|
|
66
|
+
revision=hf_hub_revision,
|
|
67
|
+
) # nosec
|
|
68
|
+
with cache_fname.open() as f:
|
|
69
|
+
return json.load(f)["pretrained_cfg"]
|
|
42
70
|
|
|
43
71
|
|
|
44
72
|
class PrithviV2(nn.Module):
|
|
45
73
|
"""An Rslearn wrapper for Prithvi 2.0."""
|
|
46
74
|
|
|
47
|
-
|
|
75
|
+
INPUT_KEY = "image"
|
|
48
76
|
|
|
49
|
-
def __init__(
|
|
50
|
-
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
cache_dir: str | Path | None = None,
|
|
80
|
+
size: PrithviV2Models = PrithviV2Models.VIT_300,
|
|
81
|
+
num_frames: int = 1,
|
|
82
|
+
):
|
|
83
|
+
"""Create a new PrithviV2.
|
|
51
84
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
85
|
+
Args:
|
|
86
|
+
cache_dir: The local folder in which to download the prithvi config and
|
|
87
|
+
weights. If None, it downloads to a temporary folder.
|
|
88
|
+
size: the model size, see class for various models.
|
|
55
89
|
num_frames: The number of input frames (timesteps). The model was trained on 3,
|
|
56
90
|
but if there is just one timestamp examples use 1 (e.g.
|
|
57
91
|
https://github.com/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/
|
|
@@ -59,35 +93,28 @@ class PrithviV2(nn.Module):
|
|
|
59
93
|
|
|
60
94
|
"""
|
|
61
95
|
super().__init__()
|
|
62
|
-
if
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
)
|
|
96
|
+
if cache_dir is None:
|
|
97
|
+
cache_dir = DEFAULT_CACHE_DIR
|
|
98
|
+
cache_dir = Path(cache_dir)
|
|
66
99
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
repo_id=HF_HUB_ID,
|
|
71
|
-
filename="config.json",
|
|
72
|
-
revision="b2f2520ab889f42a25c5361ba18761fcb4ea44ad",
|
|
73
|
-
)
|
|
74
|
-
with (UPath(pretrained_path) / "config.json").open("r") as f:
|
|
75
|
-
config = yaml.safe_load(f)["pretrained_cfg"]
|
|
100
|
+
hub_id = MODEL_TO_HF_INFO[size]["hf_hub_id"]
|
|
101
|
+
revision = MODEL_TO_HF_INFO[size]["revision"]
|
|
102
|
+
checkpoint_fname = MODEL_TO_HF_INFO[size]["weights"]
|
|
76
103
|
|
|
104
|
+
config = get_config(cache_dir, hub_id, revision)
|
|
77
105
|
config["num_frames"] = num_frames
|
|
78
|
-
|
|
79
106
|
self.model = PrithviMAE(**config)
|
|
80
107
|
|
|
81
|
-
if not (
|
|
108
|
+
if not (cache_dir / checkpoint_fname).exists():
|
|
82
109
|
_ = hf_hub_download(
|
|
83
|
-
local_dir=
|
|
84
|
-
repo_id=
|
|
85
|
-
filename=
|
|
86
|
-
revision=
|
|
87
|
-
)
|
|
110
|
+
local_dir=cache_dir,
|
|
111
|
+
repo_id=hub_id,
|
|
112
|
+
filename=checkpoint_fname,
|
|
113
|
+
revision=revision,
|
|
114
|
+
) # nosec
|
|
88
115
|
|
|
89
116
|
state_dict = torch.load(
|
|
90
|
-
|
|
117
|
+
cache_dir / checkpoint_fname,
|
|
91
118
|
map_location="cpu",
|
|
92
119
|
weights_only=True,
|
|
93
120
|
)
|
|
@@ -125,16 +152,15 @@ class PrithviV2(nn.Module):
|
|
|
125
152
|
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
126
153
|
"""Compute feature maps from the Prithvi V2 backbone.
|
|
127
154
|
|
|
128
|
-
|
|
129
|
-
inputs: input dicts that must include "
|
|
130
|
-
|
|
131
|
-
this naming keeps the model consistent with other rslearn models.
|
|
155
|
+
Args:
|
|
156
|
+
inputs: input dicts that must include "image" key containing HLS
|
|
157
|
+
(Harmonized Landsat-Sentinel) data.
|
|
132
158
|
|
|
133
159
|
Returns:
|
|
134
160
|
11 feature maps (one per transformer block in the Prithvi model),
|
|
135
161
|
of shape [B, H/p_s, W/p_s, D=1024] where p_s=16 is the patch size.
|
|
136
162
|
"""
|
|
137
|
-
x = torch.stack([inp[
|
|
163
|
+
x = torch.stack([inp[self.INPUT_KEY] for inp in inputs], dim=0)
|
|
138
164
|
x = self._resize_data(x)
|
|
139
165
|
num_timesteps = x.shape[1] // len(self.bands)
|
|
140
166
|
x = rearrange(x, "b (t c) h w -> b c t h w", t=num_timesteps)
|
|
@@ -147,6 +173,67 @@ class PrithviV2(nn.Module):
|
|
|
147
173
|
features, num_timesteps
|
|
148
174
|
)
|
|
149
175
|
|
|
176
|
+
def get_backbone_channels(self) -> list:
|
|
177
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
178
|
+
|
|
179
|
+
The output channels is a list of (patch_size, depth) that corresponds
|
|
180
|
+
to the feature maps that the backbone returns.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
the output channels of the backbone as a list of (patch_size, depth) tuples.
|
|
184
|
+
"""
|
|
185
|
+
return [(1, 1024)]
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class PrithviNormalize(Transform):
|
|
189
|
+
"""Normalize inputs using Prithvi normalization.
|
|
190
|
+
|
|
191
|
+
Similar to the model, the input should be an image time series under the key
|
|
192
|
+
"image".
|
|
193
|
+
"""
|
|
194
|
+
|
|
195
|
+
def __init__(
|
|
196
|
+
self,
|
|
197
|
+
cache_dir: str | Path | None = None,
|
|
198
|
+
size: PrithviV2Models = PrithviV2Models.VIT_300,
|
|
199
|
+
) -> None:
|
|
200
|
+
"""Initialize a new PrithviNormalize.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
cache_dir: the local directory to cache the config.json which contains the
|
|
204
|
+
means and standard deviations used in the normalization.
|
|
205
|
+
size: the model size, see class for various models. In this case (and
|
|
206
|
+
for the current hf revision), the config values (mean and std) are the
|
|
207
|
+
same for both the 300M and 600M model, so its safe to not set this.
|
|
208
|
+
"""
|
|
209
|
+
super().__init__()
|
|
210
|
+
hub_id = MODEL_TO_HF_INFO[size]["hf_hub_id"]
|
|
211
|
+
revision = MODEL_TO_HF_INFO[size]["revision"]
|
|
212
|
+
if cache_dir is None:
|
|
213
|
+
cache_dir = DEFAULT_CACHE_DIR
|
|
214
|
+
cache_dir = Path(cache_dir)
|
|
215
|
+
config = get_config(cache_dir, hub_id, revision)
|
|
216
|
+
self.normalizer = Normalize(
|
|
217
|
+
mean=config["mean"],
|
|
218
|
+
std=config["std"],
|
|
219
|
+
num_bands=len(config["mean"]),
|
|
220
|
+
selectors=[PrithviV2.INPUT_KEY],
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
def forward(
|
|
224
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
225
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
226
|
+
"""Apply Prithvi normalization on the image.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
input_dict: the input, which must contain the "image" key.
|
|
230
|
+
target_dict: the target
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
normalized (input_dicts, target_dicts) tuple
|
|
234
|
+
"""
|
|
235
|
+
return self.normalizer(input_dict, target_dict)
|
|
236
|
+
|
|
150
237
|
|
|
151
238
|
# Copyright (c) IBM Corp. 2024. All rights reserved.
|
|
152
239
|
#
|
rslearn/models/registry.py
CHANGED
|
@@ -1,5 +1,22 @@
|
|
|
1
1
|
"""Model registry."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any, TypeVar
|
|
4
5
|
|
|
5
|
-
|
|
6
|
+
_ModelT = TypeVar("_ModelT")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class _ModelRegistry(dict[str, type[Any]]):
|
|
10
|
+
"""Registry for Model classes."""
|
|
11
|
+
|
|
12
|
+
def register(self, name: str) -> Callable[[type[_ModelT]], type[_ModelT]]:
|
|
13
|
+
"""Decorator to register a model class."""
|
|
14
|
+
|
|
15
|
+
def decorator(cls: type[_ModelT]) -> type[_ModelT]:
|
|
16
|
+
self[name] = cls
|
|
17
|
+
return cls
|
|
18
|
+
|
|
19
|
+
return decorator
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
Models = _ModelRegistry()
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""The ResizeFeatures module."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ResizeFeatures(torch.nn.Module):
|
|
7
|
+
"""Resize input features to new sizes."""
|
|
8
|
+
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
out_sizes: list[tuple[int, int]],
|
|
12
|
+
mode: str = "bilinear",
|
|
13
|
+
):
|
|
14
|
+
"""Initialize a ResizeFeatures.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
out_sizes: the output sizes of the feature maps. There must be one entry
|
|
18
|
+
for each input feature map.
|
|
19
|
+
mode: mode to pass to torch.nn.Upsample, e.g. "bilinear" (default) or
|
|
20
|
+
"nearest".
|
|
21
|
+
"""
|
|
22
|
+
super().__init__()
|
|
23
|
+
layers = []
|
|
24
|
+
for size in out_sizes:
|
|
25
|
+
layers.append(
|
|
26
|
+
torch.nn.Upsample(
|
|
27
|
+
size=size,
|
|
28
|
+
mode=mode,
|
|
29
|
+
)
|
|
30
|
+
)
|
|
31
|
+
self.layers = torch.nn.ModuleList(layers)
|
|
32
|
+
|
|
33
|
+
def forward(
|
|
34
|
+
self, features: list[torch.Tensor], inputs: list[torch.Tensor]
|
|
35
|
+
) -> list[torch.Tensor]:
|
|
36
|
+
"""Resize the input feature maps to new sizes.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
features: list of feature maps at different resolutions.
|
|
40
|
+
inputs: original inputs (ignored).
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
resized feature maps
|
|
44
|
+
"""
|
|
45
|
+
return [self.layers[idx](feat_map) for idx, feat_map in enumerate(features)]
|
|
@@ -20,12 +20,13 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
20
20
|
def __init__(
|
|
21
21
|
self,
|
|
22
22
|
encoder: torch.nn.Module,
|
|
23
|
-
image_channels: int,
|
|
23
|
+
image_channels: int | None = None,
|
|
24
24
|
op: str = "max",
|
|
25
25
|
groups: list[list[int]] | None = None,
|
|
26
26
|
num_layers: int | None = None,
|
|
27
27
|
image_key: str = "image",
|
|
28
28
|
backbone_channels: list[tuple[int, int]] | None = None,
|
|
29
|
+
image_keys: dict[str, int] | None = None,
|
|
29
30
|
) -> None:
|
|
30
31
|
"""Create a new SimpleTimeSeries.
|
|
31
32
|
|
|
@@ -48,13 +49,25 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
48
49
|
image_key: the key to access the images.
|
|
49
50
|
backbone_channels: manually specify the backbone channels. Can be set if
|
|
50
51
|
the encoder does not provide get_backbone_channels function.
|
|
52
|
+
image_keys: as an alternative to setting image_channels, map from the key
|
|
53
|
+
in input dict to the number of channels per timestep for that modality.
|
|
54
|
+
This way SimpleTimeSeries can be used with multimodal inputs. One of
|
|
55
|
+
image_channels or image_keys must be specified.
|
|
51
56
|
"""
|
|
57
|
+
if (image_channels is None and image_keys is None) or (
|
|
58
|
+
image_channels is not None and image_keys is not None
|
|
59
|
+
):
|
|
60
|
+
raise ValueError(
|
|
61
|
+
"exactly one of image_channels and image_keys must be specified"
|
|
62
|
+
)
|
|
63
|
+
|
|
52
64
|
super().__init__()
|
|
53
65
|
self.encoder = encoder
|
|
54
66
|
self.image_channels = image_channels
|
|
55
67
|
self.op = op
|
|
56
68
|
self.groups = groups
|
|
57
69
|
self.image_key = image_key
|
|
70
|
+
self.image_keys = image_keys
|
|
58
71
|
|
|
59
72
|
if backbone_channels is not None:
|
|
60
73
|
out_channels = backbone_channels
|
|
@@ -144,6 +157,26 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
144
157
|
out_channels.append((downsample_factor, depth * self.num_groups))
|
|
145
158
|
return out_channels
|
|
146
159
|
|
|
160
|
+
def _get_batched_images(
|
|
161
|
+
self, input_dicts: list[dict[str, Any]], image_key: str, image_channels: int
|
|
162
|
+
) -> torch.Tensor:
|
|
163
|
+
"""Collect and reshape images across input dicts.
|
|
164
|
+
|
|
165
|
+
The BTCHW image time series are reshaped to (B*T)CHW so they can be passed to
|
|
166
|
+
the forward pass of a per-image (unitemporal) model.
|
|
167
|
+
"""
|
|
168
|
+
images = torch.stack(
|
|
169
|
+
[input_dict[image_key] for input_dict in input_dicts], dim=0
|
|
170
|
+
)
|
|
171
|
+
n_batch = images.shape[0]
|
|
172
|
+
n_images = images.shape[1] // image_channels
|
|
173
|
+
n_height = images.shape[2]
|
|
174
|
+
n_width = images.shape[3]
|
|
175
|
+
batched_images = images.reshape(
|
|
176
|
+
n_batch * n_images, image_channels, n_height, n_width
|
|
177
|
+
)
|
|
178
|
+
return batched_images
|
|
179
|
+
|
|
147
180
|
def forward(
|
|
148
181
|
self,
|
|
149
182
|
inputs: list[dict[str, Any]],
|
|
@@ -156,15 +189,37 @@ class SimpleTimeSeries(torch.nn.Module):
|
|
|
156
189
|
"""
|
|
157
190
|
# First get features of each image.
|
|
158
191
|
# To do so, we need to split up each grouped image into its component images (which have had their channels stacked).
|
|
159
|
-
|
|
160
|
-
n_batch =
|
|
161
|
-
n_images
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
192
|
+
batched_inputs: list[dict[str, Any]] | None = None
|
|
193
|
+
n_batch = len(inputs)
|
|
194
|
+
n_images: int | None = None
|
|
195
|
+
|
|
196
|
+
if self.image_keys is not None:
|
|
197
|
+
for image_key, image_channels in self.image_keys.items():
|
|
198
|
+
batched_images = self._get_batched_images(
|
|
199
|
+
inputs, image_key, image_channels
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
if batched_inputs is None:
|
|
203
|
+
batched_inputs = [{} for _ in batched_images]
|
|
204
|
+
n_images = batched_images.shape[0] // n_batch
|
|
205
|
+
elif n_images != batched_images.shape[0] // n_batch:
|
|
206
|
+
raise ValueError(
|
|
207
|
+
"expected all modalities to have the same number of timesteps"
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
for i, image in enumerate(batched_images):
|
|
211
|
+
batched_inputs[i][image_key] = image
|
|
212
|
+
|
|
213
|
+
else:
|
|
214
|
+
assert self.image_channels is not None
|
|
215
|
+
batched_images = self._get_batched_images(
|
|
216
|
+
inputs, self.image_key, self.image_channels
|
|
217
|
+
)
|
|
218
|
+
batched_inputs = [{self.image_key: image} for image in batched_images]
|
|
219
|
+
n_images = batched_images.shape[0] // n_batch
|
|
220
|
+
|
|
221
|
+
assert n_images is not None
|
|
222
|
+
|
|
168
223
|
all_features = [
|
|
169
224
|
feat_map.reshape(
|
|
170
225
|
n_batch,
|
rslearn/models/upsample.py
CHANGED
|
@@ -23,13 +23,13 @@ class Upsample(torch.nn.Module):
|
|
|
23
23
|
def forward(
|
|
24
24
|
self, features: list[torch.Tensor], inputs: list[torch.Tensor]
|
|
25
25
|
) -> list[torch.Tensor]:
|
|
26
|
-
"""
|
|
26
|
+
"""Upsample each feature map.
|
|
27
27
|
|
|
28
28
|
Args:
|
|
29
29
|
features: list of feature maps at different resolutions.
|
|
30
30
|
inputs: original inputs (ignored).
|
|
31
31
|
|
|
32
32
|
Returns:
|
|
33
|
-
|
|
33
|
+
upsampled feature maps
|
|
34
34
|
"""
|
|
35
35
|
return [self.layer(feat_map) for feat_map in features]
|