rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
rslearn/models/ssl4eo_s12.py
CHANGED
|
@@ -1,20 +1,22 @@
|
|
|
1
1
|
"""SSL4EO-S12 models."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
3
|
import torch
|
|
6
4
|
import torchvision
|
|
7
5
|
|
|
6
|
+
from rslearn.train.model_context import ModelContext
|
|
7
|
+
|
|
8
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
9
|
+
|
|
8
10
|
|
|
9
|
-
class Ssl4eoS12(
|
|
11
|
+
class Ssl4eoS12(FeatureExtractor):
|
|
10
12
|
"""The SSL4EO-S12 family of pretrained models."""
|
|
11
13
|
|
|
12
14
|
def __init__(
|
|
13
15
|
self,
|
|
14
|
-
backbone_ckpt_path: str,
|
|
16
|
+
backbone_ckpt_path: str | None,
|
|
15
17
|
arch: str = "resnet50",
|
|
16
18
|
output_layers: list[int] = [0, 1, 2, 3],
|
|
17
|
-
):
|
|
19
|
+
) -> None:
|
|
18
20
|
"""Instantiate a new Swin instance.
|
|
19
21
|
|
|
20
22
|
Args:
|
|
@@ -37,21 +39,24 @@ class Ssl4eoS12(torch.nn.Module):
|
|
|
37
39
|
else:
|
|
38
40
|
raise ValueError(f"unknown SSL4EO-S12 architecture {arch}")
|
|
39
41
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
f"warning: got missing_keys={missing_keys}, unexpected_keys={unexpected_keys} when loading SSL4EO-S12 state dict"
|
|
42
|
+
if backbone_ckpt_path is not None:
|
|
43
|
+
state_dict = torch.load(backbone_ckpt_path, weights_only=True)
|
|
44
|
+
state_dict = state_dict["teacher"]
|
|
45
|
+
prefix = "module.backbone."
|
|
46
|
+
state_dict = {
|
|
47
|
+
k[len(prefix) :]: v
|
|
48
|
+
for k, v in state_dict.items()
|
|
49
|
+
if k.startswith(prefix)
|
|
50
|
+
}
|
|
51
|
+
missing_keys, unexpected_keys = self.model.load_state_dict(
|
|
52
|
+
state_dict, strict=False
|
|
52
53
|
)
|
|
54
|
+
if missing_keys or unexpected_keys:
|
|
55
|
+
print(
|
|
56
|
+
f"warning: got missing_keys={missing_keys}, unexpected_keys={unexpected_keys} when loading SSL4EO-S12 state dict"
|
|
57
|
+
)
|
|
53
58
|
|
|
54
|
-
def get_backbone_channels(self):
|
|
59
|
+
def get_backbone_channels(self) -> list[tuple[int, int]]:
|
|
55
60
|
"""Returns the output channels of this model when used as a backbone.
|
|
56
61
|
|
|
57
62
|
The output channels is a list of (downsample_factor, depth) that corresponds
|
|
@@ -65,28 +70,33 @@ class Ssl4eoS12(torch.nn.Module):
|
|
|
65
70
|
"""
|
|
66
71
|
if self.arch == "resnet50":
|
|
67
72
|
all_out_channels = [
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
73
|
+
(4, 256),
|
|
74
|
+
(8, 512),
|
|
75
|
+
(16, 1024),
|
|
76
|
+
(32, 2048),
|
|
72
77
|
]
|
|
73
78
|
return [all_out_channels[idx] for idx in self.output_layers]
|
|
74
79
|
|
|
75
80
|
def forward(
|
|
76
|
-
self,
|
|
77
|
-
|
|
81
|
+
self,
|
|
82
|
+
context: ModelContext,
|
|
83
|
+
) -> FeatureMaps:
|
|
78
84
|
"""Compute outputs from the backbone.
|
|
79
85
|
|
|
80
86
|
If output_layers is set, then the outputs are multi-scale feature maps;
|
|
81
87
|
otherwise, the model is being used for classification so the outputs are class
|
|
82
88
|
probabilities and the loss.
|
|
83
89
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
process.
|
|
87
|
-
|
|
90
|
+
Args:
|
|
91
|
+
context: the model context. Input dicts must include "image" key containing
|
|
92
|
+
the images to process.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
feature maps computed by the pre-trained model.
|
|
88
96
|
"""
|
|
89
|
-
x = torch.stack(
|
|
97
|
+
x = torch.stack(
|
|
98
|
+
[inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
|
|
99
|
+
)
|
|
90
100
|
x = self.model.conv1(x)
|
|
91
101
|
x = self.model.bn1(x)
|
|
92
102
|
x = self.model.relu(x)
|
|
@@ -97,4 +107,4 @@ class Ssl4eoS12(torch.nn.Module):
|
|
|
97
107
|
layer3 = self.model.layer3(layer2)
|
|
98
108
|
layer4 = self.model.layer4(layer3)
|
|
99
109
|
all_features = [layer1, layer2, layer3, layer4]
|
|
100
|
-
return [all_features[idx] for idx in self.output_layers]
|
|
110
|
+
return FeatureMaps([all_features[idx] for idx in self.output_layers])
|
rslearn/models/swin.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
"""Swin Transformer."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
3
|
import torch
|
|
6
4
|
import torchvision
|
|
7
5
|
from torchvision.models.swin_transformer import (
|
|
@@ -13,8 +11,12 @@ from torchvision.models.swin_transformer import (
|
|
|
13
11
|
Swin_V2_T_Weights,
|
|
14
12
|
)
|
|
15
13
|
|
|
14
|
+
from rslearn.train.model_context import ModelContext
|
|
15
|
+
|
|
16
|
+
from .component import FeatureExtractor, FeatureMaps, FeatureVector
|
|
16
17
|
|
|
17
|
-
|
|
18
|
+
|
|
19
|
+
class Swin(FeatureExtractor):
|
|
18
20
|
"""A Swin Transformer model.
|
|
19
21
|
|
|
20
22
|
It can either be used stand-alone for classification, or as a feature extractor in
|
|
@@ -28,15 +30,18 @@ class Swin(torch.nn.Module):
|
|
|
28
30
|
input_channels: int = 3,
|
|
29
31
|
output_layers: list[int] | None = None,
|
|
30
32
|
num_outputs: int = 1000,
|
|
31
|
-
):
|
|
33
|
+
) -> None:
|
|
32
34
|
"""Instantiate a new Swin instance.
|
|
33
35
|
|
|
34
36
|
Args:
|
|
35
37
|
arch: the architecture, e.g. "swin_v2_b" (default) or "swin_t"
|
|
36
38
|
pretrained: set True to use ImageNet pre-trained weights
|
|
37
|
-
input_channels: number of input channels (default 3)
|
|
39
|
+
input_channels: number of input channels (default 3). If not 3, the first
|
|
40
|
+
layer is updated and will be randomly initialized even if pretrained is
|
|
41
|
+
set.
|
|
38
42
|
output_layers: list of layers to output, default use as classification
|
|
39
|
-
model. For feature extraction, [1, 3, 5, 7] is
|
|
43
|
+
model (output FeatureVector). For feature extraction, [1, 3, 5, 7] is
|
|
44
|
+
recommended.
|
|
40
45
|
num_outputs: number of output logits, defaults to 1000 which matches the
|
|
41
46
|
pretrained models.
|
|
42
47
|
"""
|
|
@@ -89,7 +94,7 @@ class Swin(torch.nn.Module):
|
|
|
89
94
|
if num_outputs != self.model.head.out_features:
|
|
90
95
|
self.model.head = torch.nn.Linear(self.model.head.in_features, num_outputs)
|
|
91
96
|
|
|
92
|
-
def get_backbone_channels(self):
|
|
97
|
+
def get_backbone_channels(self) -> list[tuple[int, int]]:
|
|
93
98
|
"""Returns the output channels of this model when used as a backbone.
|
|
94
99
|
|
|
95
100
|
The output channels is a list of (downsample_factor, depth) that corresponds
|
|
@@ -105,43 +110,50 @@ class Swin(torch.nn.Module):
|
|
|
105
110
|
|
|
106
111
|
if self.arch in ["swin_b", "swin_v2_b"]:
|
|
107
112
|
all_out_channels = [
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
113
|
+
(4, 128),
|
|
114
|
+
(4, 128),
|
|
115
|
+
(4, 128),
|
|
116
|
+
(8, 256),
|
|
117
|
+
(8, 256),
|
|
118
|
+
(16, 512),
|
|
119
|
+
(16, 512),
|
|
120
|
+
(32, 1024),
|
|
121
|
+
(32, 1024),
|
|
116
122
|
]
|
|
117
123
|
elif self.arch in ["swin_s", "swin_v2_s", "swin_t", "swin_v2_t"]:
|
|
118
124
|
all_out_channels = [
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
125
|
+
(4, 96),
|
|
126
|
+
(4, 96),
|
|
127
|
+
(8, 192),
|
|
128
|
+
(8, 192),
|
|
129
|
+
(16, 384),
|
|
130
|
+
(16, 384),
|
|
131
|
+
(32, 768),
|
|
132
|
+
(32, 768),
|
|
127
133
|
]
|
|
128
134
|
return [all_out_channels[idx] for idx in self.output_layers]
|
|
129
135
|
|
|
130
136
|
def forward(
|
|
131
|
-
self,
|
|
132
|
-
|
|
137
|
+
self,
|
|
138
|
+
context: ModelContext,
|
|
139
|
+
) -> FeatureVector | FeatureMaps:
|
|
133
140
|
"""Compute outputs from the backbone.
|
|
134
141
|
|
|
135
142
|
If output_layers is set, then the outputs are multi-scale feature maps;
|
|
136
143
|
otherwise, the model is being used for classification so the outputs are class
|
|
137
144
|
probabilities and the loss.
|
|
138
145
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
process.
|
|
142
|
-
|
|
146
|
+
Args:
|
|
147
|
+
context: the model context. Input dicts must include "image" key containing
|
|
148
|
+
the image to process.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
a FeatureVector if the configured output_layers is None, or a FeatureMaps
|
|
152
|
+
otherwise containing one feature map per configured output layer.
|
|
143
153
|
"""
|
|
144
|
-
images = torch.stack(
|
|
154
|
+
images = torch.stack(
|
|
155
|
+
[inp["image"].single_ts_to_chw_tensor() for inp in context.inputs], dim=0
|
|
156
|
+
)
|
|
145
157
|
|
|
146
158
|
if self.output_layers:
|
|
147
159
|
layer_features = []
|
|
@@ -149,7 +161,7 @@ class Swin(torch.nn.Module):
|
|
|
149
161
|
for layer in self.model.features:
|
|
150
162
|
x = layer(x)
|
|
151
163
|
layer_features.append(x.permute(0, 3, 1, 2))
|
|
152
|
-
return [layer_features[idx] for idx in self.output_layers]
|
|
164
|
+
return FeatureMaps([layer_features[idx] for idx in self.output_layers])
|
|
153
165
|
|
|
154
166
|
else:
|
|
155
|
-
return self.model(images)
|
|
167
|
+
return FeatureVector(self.model(images))
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
"""Task embedding modules."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import nn
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class PositionalEncoding(nn.Module):
|
|
11
|
+
"""Simple sinusoidal positional encoding for the task embedding. From torch docs."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, d_model: int, dropout: float = 0.0, max_len: int = 1024):
|
|
14
|
+
"""Initialize the positional encoding module.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
d_model: The dimension of the model.
|
|
18
|
+
dropout: The dropout rate.
|
|
19
|
+
max_len: The maximum length of the sequence.
|
|
20
|
+
"""
|
|
21
|
+
super().__init__()
|
|
22
|
+
self.dropout = nn.Dropout(p=dropout)
|
|
23
|
+
|
|
24
|
+
position = torch.arange(max_len).unsqueeze(1)
|
|
25
|
+
div_term = torch.exp(
|
|
26
|
+
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
|
|
27
|
+
)
|
|
28
|
+
pe = torch.zeros(max_len, 1, d_model)
|
|
29
|
+
pe[:, 0, 0::2] = torch.sin(position * div_term)
|
|
30
|
+
pe[:, 0, 1::2] = torch.cos(position * div_term)
|
|
31
|
+
self.register_buffer("pe", pe)
|
|
32
|
+
|
|
33
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
34
|
+
"""Apply positional encoding to the input tensor.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
|
|
38
|
+
"""
|
|
39
|
+
x = x + self.pe[: x.size(0)]
|
|
40
|
+
return self.dropout(x)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class BaseTaskEmbedding(torch.nn.Module):
|
|
44
|
+
"""Base class for task embedding modules."""
|
|
45
|
+
|
|
46
|
+
def __init__(self, encoder_embedding_size: int) -> None:
|
|
47
|
+
"""Initialize the base task embedding module.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
encoder_embedding_size: The size of the encoder embedding.
|
|
51
|
+
"""
|
|
52
|
+
super().__init__()
|
|
53
|
+
self.encoder_embedding_size = encoder_embedding_size
|
|
54
|
+
|
|
55
|
+
def register_tasks(self, task_names: list[str]) -> None:
|
|
56
|
+
"""Register the tasks.
|
|
57
|
+
|
|
58
|
+
This must happen post-init so that we can dynamically determine
|
|
59
|
+
the tasks to use, so it doesn't have to be specified in the config.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
task_names: The names of the tasks.
|
|
63
|
+
"""
|
|
64
|
+
raise NotImplementedError
|
|
65
|
+
|
|
66
|
+
def compute_embeds(
|
|
67
|
+
self,
|
|
68
|
+
features: list[torch.tensor],
|
|
69
|
+
inputs: list[dict[str, Any]],
|
|
70
|
+
) -> torch.Tensor:
|
|
71
|
+
"""Compute the task-specific embeddings.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
features: The encoder features.
|
|
75
|
+
inputs: The inputs to the model.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
The task-specific embeddings.
|
|
79
|
+
"""
|
|
80
|
+
raise NotImplementedError
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class TaskChannelEmbedding(BaseTaskEmbedding):
|
|
84
|
+
"""Registers task-specific 'tokens', i.e. embeddings.
|
|
85
|
+
|
|
86
|
+
Each embedding is learned per-channel and copied over the full spatial dimensions.
|
|
87
|
+
Optionally, add a spatial sinusoidal positional embedding to the task embedding.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
encoder_embedding_size: int,
|
|
93
|
+
default_idx: int = 0,
|
|
94
|
+
add_spatial_embed: bool = False,
|
|
95
|
+
) -> None:
|
|
96
|
+
"""Initialize the task channel embedding module.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
encoder_embedding_size: The size of the encoder embedding.
|
|
100
|
+
default_idx: The index of the default task, useful if loading a merged model.
|
|
101
|
+
add_spatial_embed: if true, add a spatial sinusoidal positional embedding to the task embedding
|
|
102
|
+
"""
|
|
103
|
+
super().__init__(encoder_embedding_size)
|
|
104
|
+
self.default_idx = default_idx
|
|
105
|
+
self.add_spatial_embed = add_spatial_embed
|
|
106
|
+
if add_spatial_embed:
|
|
107
|
+
self.pos_embed = PositionalEncoding(encoder_embedding_size)
|
|
108
|
+
|
|
109
|
+
def register_tasks(self, task_names: list[str]) -> None:
|
|
110
|
+
"""Register the tasks.
|
|
111
|
+
|
|
112
|
+
This must happen post-init so that we can dynamically determine
|
|
113
|
+
the tasks to use, so it doesn't have to be specified in the config.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
task_names: The names of the tasks.
|
|
117
|
+
"""
|
|
118
|
+
self.embed = torch.nn.Embedding(len(task_names), self.encoder_embedding_size)
|
|
119
|
+
self.target_to_embed_idx = {name: i for i, name in enumerate(task_names)}
|
|
120
|
+
|
|
121
|
+
def compute_embeds(
|
|
122
|
+
self,
|
|
123
|
+
features: list[torch.tensor],
|
|
124
|
+
inputs: list[dict[str, Any]],
|
|
125
|
+
) -> torch.Tensor:
|
|
126
|
+
"""Compute the task-specific embeddings.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
inputs: The inputs to the model.
|
|
130
|
+
features: computed encoder features
|
|
131
|
+
device: The device to compute the embeddings on.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
The task-specific embeddings, shape (B, T, C), T = HW
|
|
135
|
+
The embeddings are repeated over the spatial dimensions, and optionally
|
|
136
|
+
a sinusoidal positional embedding is added.
|
|
137
|
+
"""
|
|
138
|
+
try:
|
|
139
|
+
idx = [self.target_to_embed_idx[inp["dataset_source"]] for inp in inputs]
|
|
140
|
+
except KeyError:
|
|
141
|
+
idx = [self.default_idx] * len(inputs)
|
|
142
|
+
embeds = self.embed(torch.tensor(idx).to(features[0].device))
|
|
143
|
+
seq_len = features[0].shape[-1] * features[0].shape[-2] # T = HW
|
|
144
|
+
embeds = embeds.unsqueeze(0).repeat(seq_len, 1, 1) # T x B x C
|
|
145
|
+
if self.add_spatial_embed:
|
|
146
|
+
embeds = self.pos_embed(embeds)
|
|
147
|
+
embeds = torch.einsum("tbc->btc", embeds) # B x T x C
|
|
148
|
+
return embeds
|
|
149
|
+
|
|
150
|
+
def forward(
|
|
151
|
+
self,
|
|
152
|
+
features: list[torch.tensor],
|
|
153
|
+
inputs: list[dict[str, Any]],
|
|
154
|
+
embeds: torch.Tensor | None = None,
|
|
155
|
+
) -> list[torch.tensor]:
|
|
156
|
+
"""Compute and apply task-specific embeddings to encoder features.
|
|
157
|
+
|
|
158
|
+
Optionally, add a spatial sinusoidal positional embedding to the task embedding.
|
|
159
|
+
Otherwise, the task embedding is repeated over the spatial dimensions.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
features: The encoder features, a 1-list of B x C x H x W features.
|
|
163
|
+
inputs: The inputs to the model.
|
|
164
|
+
embeds: Already-computed task embeddings, if provided, skip the computation.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
The encoder features with the task-specific embeddings added.
|
|
168
|
+
"""
|
|
169
|
+
height, width = features[0].shape[-2:]
|
|
170
|
+
assert all(f.shape[-2:] == (height, width) for f in features), (
|
|
171
|
+
"features must have the same spatial dimensions"
|
|
172
|
+
)
|
|
173
|
+
if embeds is None:
|
|
174
|
+
embeds = self.compute_embeds(features, inputs) # B x HW x C
|
|
175
|
+
embeds = embeds.unflatten(dim=1, sizes=(height, width)) # B x H x W x C
|
|
176
|
+
for i in range(len(features)):
|
|
177
|
+
features[i] += torch.einsum("bhwc->bchw", embeds) # B x C x H x W
|
|
178
|
+
return features
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class TaskMHAEmbedding(TaskChannelEmbedding):
|
|
182
|
+
"""Multi-headed cross-attention over the spatial dimensions.
|
|
183
|
+
|
|
184
|
+
The task embedding is the query and the features are the key and value.
|
|
185
|
+
We copy the task embedding over the spatial dimensions, and optionally
|
|
186
|
+
add a sinusoidal positional embedding before the MHA layer.
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
def __init__(
|
|
190
|
+
self,
|
|
191
|
+
encoder_embedding_size: int,
|
|
192
|
+
num_heads: int,
|
|
193
|
+
default_idx: int = 0,
|
|
194
|
+
add_spatial_embed: bool = True,
|
|
195
|
+
) -> None:
|
|
196
|
+
"""Initialize the task MHA embedding module.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
encoder_embedding_size: The size of the encoder embedding.
|
|
200
|
+
num_heads: The number of attention heads.
|
|
201
|
+
default_idx: The index of the default task, useful if loading a merged model.
|
|
202
|
+
add_spatial_embed: if true, add a spatial sinusoidal positional embedding to the task embedding
|
|
203
|
+
"""
|
|
204
|
+
super().__init__(encoder_embedding_size, default_idx, add_spatial_embed)
|
|
205
|
+
self.mha = torch.nn.MultiheadAttention(
|
|
206
|
+
encoder_embedding_size, num_heads, batch_first=True
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
def register_tasks(self, task_names: list[str]) -> None:
|
|
210
|
+
"""Register the tasks.
|
|
211
|
+
|
|
212
|
+
This must happen post-init so that we can dynamically determine
|
|
213
|
+
the tasks to use, so it doesn't have to be specified in the config.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
task_names: The names of the tasks.
|
|
217
|
+
"""
|
|
218
|
+
super().register_tasks(task_names)
|
|
219
|
+
|
|
220
|
+
def forward(
|
|
221
|
+
self,
|
|
222
|
+
features: list[torch.tensor],
|
|
223
|
+
inputs: list[dict[str, Any]],
|
|
224
|
+
embeds: torch.Tensor | None = None,
|
|
225
|
+
) -> list[torch.tensor]:
|
|
226
|
+
"""Compute and apply task-specific embeddings to encoder features.
|
|
227
|
+
|
|
228
|
+
Also apply the MHA layer across the spatial dimension, with the task embedding
|
|
229
|
+
as the query and the features as the key and value.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
features: The encoder features, a 1-list of B x C x H x W features.
|
|
233
|
+
inputs: The inputs to the model.
|
|
234
|
+
embeds: Already-computed task embeddings, if provided, skip the computation.
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
The encoder features with the task-specific embeddings added.
|
|
238
|
+
"""
|
|
239
|
+
assert len(features) == 1, "TaskMHAEmbedding only supports one feature"
|
|
240
|
+
x = torch.flatten(features[0], start_dim=2) # B x C x T, T = HW
|
|
241
|
+
if embeds is None:
|
|
242
|
+
embeds = self.compute_embeds(features, inputs) # B x T x C
|
|
243
|
+
out = self.mha(
|
|
244
|
+
embeds, # B x T x C
|
|
245
|
+
torch.einsum("bct->btc", x),
|
|
246
|
+
torch.einsum("bct->btc", x),
|
|
247
|
+
)[0] # B x T x C
|
|
248
|
+
out = torch.einsum("btc->bct", out)
|
|
249
|
+
out = out.view(*features[0].shape) # B x C x H x W
|
|
250
|
+
return [out]
|