rslearn 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +2 -9
- rslearn/config/__init__.py +2 -0
- rslearn/config/dataset.py +64 -20
- rslearn/dataset/add_windows.py +1 -1
- rslearn/dataset/dataset.py +34 -84
- rslearn/dataset/materialize.py +5 -5
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +26 -80
- rslearn/lightning_cli.py +22 -11
- rslearn/main.py +12 -37
- rslearn/models/anysat.py +11 -9
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +8 -9
- rslearn/models/clip.py +18 -15
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +21 -11
- rslearn/models/conv.py +15 -8
- rslearn/models/croma.py +13 -8
- rslearn/models/detr/detr.py +25 -14
- rslearn/models/dinov3.py +11 -6
- rslearn/models/faster_rcnn.py +19 -9
- rslearn/models/feature_center_crop.py +12 -9
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/galileo.py +23 -18
- rslearn/models/module_wrapper.py +26 -57
- rslearn/models/molmo.py +16 -14
- rslearn/models/multitask.py +102 -73
- rslearn/models/olmoearth_pretrain/model.py +135 -38
- rslearn/models/panopticon.py +8 -7
- rslearn/models/pick_features.py +18 -24
- rslearn/models/pooling_decoder.py +22 -14
- rslearn/models/presto/presto.py +16 -10
- rslearn/models/presto/single_file_presto.py +4 -10
- rslearn/models/prithvi.py +12 -8
- rslearn/models/resize_features.py +21 -7
- rslearn/models/sam2_enc.py +11 -9
- rslearn/models/satlaspretrain.py +15 -9
- rslearn/models/simple_time_series.py +37 -17
- rslearn/models/singletask.py +24 -17
- rslearn/models/ssl4eo_s12.py +15 -10
- rslearn/models/swin.py +22 -13
- rslearn/models/terramind.py +24 -7
- rslearn/models/trunk.py +6 -3
- rslearn/models/unet.py +18 -9
- rslearn/models/upsample.py +22 -9
- rslearn/train/all_patches_dataset.py +89 -37
- rslearn/train/dataset.py +105 -97
- rslearn/train/lightning_module.py +51 -32
- rslearn/train/model_context.py +54 -0
- rslearn/train/prediction_writer.py +111 -41
- rslearn/train/scheduler.py +15 -0
- rslearn/train/tasks/classification.py +34 -15
- rslearn/train/tasks/detection.py +24 -31
- rslearn/train/tasks/embedding.py +33 -29
- rslearn/train/tasks/multi_task.py +7 -7
- rslearn/train/tasks/per_pixel_regression.py +41 -19
- rslearn/train/tasks/regression.py +38 -21
- rslearn/train/tasks/segmentation.py +33 -15
- rslearn/train/tasks/task.py +3 -2
- rslearn/train/transforms/resize.py +74 -0
- rslearn/utils/geometry.py +73 -0
- rslearn/utils/jsonargparse.py +66 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/RECORD +71 -66
- rslearn/dataset/index.py +0 -173
- rslearn/models/registry.py +0 -22
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
- {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""An attention pooling layer."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from einops import rearrange
|
|
9
|
+
from torch import nn
|
|
10
|
+
|
|
11
|
+
from rslearn.models.component import (
|
|
12
|
+
FeatureMaps,
|
|
13
|
+
IntermediateComponent,
|
|
14
|
+
TokenFeatureMaps,
|
|
15
|
+
)
|
|
16
|
+
from rslearn.train.model_context import ModelContext
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SimpleAttentionPool(IntermediateComponent):
|
|
20
|
+
"""Simple Attention Pooling.
|
|
21
|
+
|
|
22
|
+
Given a token feature map of shape BCHWN,
|
|
23
|
+
learn an attention layer which aggregates over
|
|
24
|
+
the N dimension.
|
|
25
|
+
|
|
26
|
+
This is done simply by learning a mapping D->1 which is the weight
|
|
27
|
+
which should be assigned to each token during averaging:
|
|
28
|
+
|
|
29
|
+
output = sum [feat_token * W(feat_token) for feat_token in feat_tokens]
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, in_dim: int, hidden_linear: bool = False) -> None:
|
|
33
|
+
"""Initialize the simple attention pooling layer.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
in_dim: the encoding dimension D
|
|
37
|
+
hidden_linear: whether to apply an additional linear transformation D -> D
|
|
38
|
+
to the feat tokens. If this is True, a ReLU activation is applied
|
|
39
|
+
after the first linear transformation.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__()
|
|
42
|
+
if hidden_linear:
|
|
43
|
+
self.hidden_linear = nn.Linear(in_features=in_dim, out_features=in_dim)
|
|
44
|
+
else:
|
|
45
|
+
self.hidden_linear = None
|
|
46
|
+
self.linear = nn.Linear(in_features=in_dim, out_features=1)
|
|
47
|
+
|
|
48
|
+
def forward_for_map(self, feat_tokens: torch.Tensor) -> torch.Tensor:
|
|
49
|
+
"""Attention pooling for a single feature map (BCHWN tensor)."""
|
|
50
|
+
B, D, H, W, N = feat_tokens.shape
|
|
51
|
+
feat_tokens = rearrange(feat_tokens, "b d h w n -> (b h w) n d")
|
|
52
|
+
if self.hidden_linear is not None:
|
|
53
|
+
feat_tokens = torch.nn.functional.relu(self.hidden_linear(feat_tokens))
|
|
54
|
+
attention_scores = torch.nn.functional.softmax(self.linear(feat_tokens), dim=1)
|
|
55
|
+
feat_tokens = (attention_scores * feat_tokens).sum(dim=1)
|
|
56
|
+
return rearrange(feat_tokens, "(b h w) d -> b d h w", b=B, h=H, w=W)
|
|
57
|
+
|
|
58
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
59
|
+
"""Forward pass for attention pooling linear probe.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
intermediates: the output from the previous component, which must be a TokenFeatureMaps.
|
|
63
|
+
We pool over the final dimension in the TokenFeatureMaps. If multiple maps
|
|
64
|
+
are passed, we apply the same linear layers to all of them.
|
|
65
|
+
context: the model context.
|
|
66
|
+
feat_tokens (torch.Tensor): Input feature tokens of shape (B, C, H, W, N).
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
torch.Tensor:
|
|
70
|
+
- output, attentioned pool over the last dimension (B, C, H, W)
|
|
71
|
+
"""
|
|
72
|
+
if not isinstance(intermediates, TokenFeatureMaps):
|
|
73
|
+
raise ValueError("input to Attention Pool must be a TokenFeatureMaps")
|
|
74
|
+
|
|
75
|
+
features = []
|
|
76
|
+
for feat in intermediates.feature_maps:
|
|
77
|
+
features.append(self.forward_for_map(feat))
|
|
78
|
+
return FeatureMaps(features)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class AttentionPool(IntermediateComponent):
|
|
82
|
+
"""Attention Pooling.
|
|
83
|
+
|
|
84
|
+
Given a feature map of shape BCHWN,
|
|
85
|
+
learn an attention layer which aggregates over
|
|
86
|
+
the N dimension.
|
|
87
|
+
|
|
88
|
+
We do this by learning a query token, and applying a standard
|
|
89
|
+
attention mechanism against this learned query token.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(self, in_dim: int, num_heads: int, linear_on_kv: bool = True) -> None:
|
|
93
|
+
"""Initialize the attention pooling layer.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
in_dim: the encoding dimension D
|
|
97
|
+
num_heads: the number of heads to use
|
|
98
|
+
linear_on_kv: Whether to apply a linear layer on the input tokens
|
|
99
|
+
to create the key and value tokens.
|
|
100
|
+
"""
|
|
101
|
+
super().__init__()
|
|
102
|
+
self.query_token: nn.Parameter = nn.Parameter(torch.empty(in_dim))
|
|
103
|
+
if linear_on_kv:
|
|
104
|
+
self.k_linear = nn.Linear(in_dim, in_dim)
|
|
105
|
+
self.v_linear = nn.Linear(in_dim, in_dim)
|
|
106
|
+
else:
|
|
107
|
+
self.k_linear = None
|
|
108
|
+
self.v_linear = None
|
|
109
|
+
if in_dim % num_heads != 0:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"in_dim must be divisible by num_heads. Got {in_dim} and {num_heads}."
|
|
112
|
+
)
|
|
113
|
+
self.num_heads = num_heads
|
|
114
|
+
self.init_weights()
|
|
115
|
+
|
|
116
|
+
def init_weights(self) -> None:
|
|
117
|
+
"""Initialize weights for the probe."""
|
|
118
|
+
nn.init.trunc_normal_(self.query_token, std=0.02)
|
|
119
|
+
|
|
120
|
+
def forward_for_map(self, feat_tokens: torch.Tensor) -> torch.Tensor:
|
|
121
|
+
"""Attention pooling for a single feature map (BCHWN tensor)."""
|
|
122
|
+
B, D, H, W, N = feat_tokens.shape
|
|
123
|
+
feat_tokens = rearrange(feat_tokens, "b d h w n -> (b h w) n d")
|
|
124
|
+
collapsed_dim = B * H * W
|
|
125
|
+
q = self.query_token.expand(collapsed_dim, 1, -1)
|
|
126
|
+
q = q.reshape(
|
|
127
|
+
collapsed_dim, 1, self.num_heads, D // self.num_heads
|
|
128
|
+
) # [B, 1, head, D_head]
|
|
129
|
+
q = rearrange(q, "b h n d -> b n h d")
|
|
130
|
+
if self.k_linear is not None:
|
|
131
|
+
assert self.v_linear is not None
|
|
132
|
+
k = self.k_linear(feat_tokens).reshape(
|
|
133
|
+
collapsed_dim, N, self.num_heads, D // self.num_heads
|
|
134
|
+
)
|
|
135
|
+
v = self.v_linear(feat_tokens).reshape(
|
|
136
|
+
collapsed_dim, N, self.num_heads, D // self.num_heads
|
|
137
|
+
)
|
|
138
|
+
else:
|
|
139
|
+
k = feat_tokens.reshape(
|
|
140
|
+
collapsed_dim, N, self.num_heads, D // self.num_heads
|
|
141
|
+
)
|
|
142
|
+
v = feat_tokens.reshape(
|
|
143
|
+
collapsed_dim, N, self.num_heads, D // self.num_heads
|
|
144
|
+
)
|
|
145
|
+
k = rearrange(k, "b n h d -> b h n d")
|
|
146
|
+
v = rearrange(v, "b n h d -> b h n d")
|
|
147
|
+
|
|
148
|
+
# Compute attention scores
|
|
149
|
+
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(
|
|
150
|
+
D // self.num_heads
|
|
151
|
+
)
|
|
152
|
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
|
153
|
+
x = torch.matmul(attn_weights, v) # [B, head, 1, D_head]
|
|
154
|
+
return x.reshape(B, D, H, W)
|
|
155
|
+
|
|
156
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
157
|
+
"""Forward pass for attention pooling linear probe.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
intermediates: the output from the previous component, which must be a TokenFeatureMaps.
|
|
161
|
+
We pool over the final dimension in the TokenFeatureMaps. If multiple feature
|
|
162
|
+
maps are passed, we apply the same attention weights (query token and linear k, v layers)
|
|
163
|
+
to all the maps.
|
|
164
|
+
context: the model context.
|
|
165
|
+
feat_tokens (torch.Tensor): Input feature tokens of shape (B, C, H, W, N).
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
torch.Tensor:
|
|
169
|
+
- output, attentioned pool over the last dimension (B, C, H, W)
|
|
170
|
+
"""
|
|
171
|
+
if not isinstance(intermediates, TokenFeatureMaps):
|
|
172
|
+
raise ValueError("input to Attention Pool must be a TokenFeatureMaps")
|
|
173
|
+
|
|
174
|
+
features = []
|
|
175
|
+
for feat in intermediates.feature_maps:
|
|
176
|
+
features.append(self.forward_for_map(feat))
|
|
177
|
+
return FeatureMaps(features)
|
rslearn/models/clay/clay.py
CHANGED
|
@@ -16,6 +16,8 @@ from huggingface_hub import hf_hub_download
|
|
|
16
16
|
# from claymodel.module import ClayMAEModule
|
|
17
17
|
from terratorch.models.backbones.clay_v15.module import ClayMAEModule
|
|
18
18
|
|
|
19
|
+
from rslearn.models.component import FeatureExtractor, FeatureMaps
|
|
20
|
+
from rslearn.train.model_context import ModelContext
|
|
19
21
|
from rslearn.train.transforms.normalize import Normalize
|
|
20
22
|
from rslearn.train.transforms.transform import Transform
|
|
21
23
|
|
|
@@ -42,7 +44,7 @@ def get_clay_checkpoint_path(
|
|
|
42
44
|
return hf_hub_download(repo_id=repo_id, filename=filename) # nosec B615
|
|
43
45
|
|
|
44
46
|
|
|
45
|
-
class Clay(
|
|
47
|
+
class Clay(FeatureExtractor):
|
|
46
48
|
"""Clay backbones."""
|
|
47
49
|
|
|
48
50
|
def __init__(
|
|
@@ -108,23 +110,20 @@ class Clay(torch.nn.Module):
|
|
|
108
110
|
image, size=(new_hw, new_hw), mode="bilinear", align_corners=False
|
|
109
111
|
)
|
|
110
112
|
|
|
111
|
-
def forward(self,
|
|
113
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
112
114
|
"""Forward pass for the Clay model.
|
|
113
115
|
|
|
114
116
|
Args:
|
|
115
|
-
|
|
117
|
+
context: the model context. Input dicts must include `self.modality` as a key
|
|
116
118
|
|
|
117
119
|
Returns:
|
|
118
|
-
|
|
120
|
+
a FeatureMaps consisting of one feature map, computed by Clay.
|
|
119
121
|
"""
|
|
120
|
-
if self.modality not in inputs[0]:
|
|
121
|
-
raise ValueError(f"Missing modality {self.modality} in inputs.")
|
|
122
|
-
|
|
123
122
|
param = next(self.model.parameters())
|
|
124
123
|
device = param.device
|
|
125
124
|
|
|
126
125
|
chips = torch.stack(
|
|
127
|
-
[inp[self.modality] for inp in inputs], dim=0
|
|
126
|
+
[inp[self.modality] for inp in context.inputs], dim=0
|
|
128
127
|
) # (B, C, H, W)
|
|
129
128
|
if self.do_resizing:
|
|
130
129
|
chips = self._resize_image(chips, chips.shape[2])
|
|
@@ -163,7 +162,7 @@ class Clay(torch.nn.Module):
|
|
|
163
162
|
)
|
|
164
163
|
|
|
165
164
|
features = rearrange(spatial, "b (h w) d -> b d h w", h=side, w=side)
|
|
166
|
-
return [features]
|
|
165
|
+
return FeatureMaps([features])
|
|
167
166
|
|
|
168
167
|
def get_backbone_channels(self) -> list:
|
|
169
168
|
"""Return output channels of this model when used as a backbone."""
|
rslearn/models/clip.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
"""OpenAI CLIP models."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
3
|
from transformers import AutoModelForZeroShotImageClassification, AutoProcessor
|
|
7
4
|
|
|
5
|
+
from rslearn.train.model_context import ModelContext
|
|
6
|
+
|
|
7
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
8
|
+
|
|
8
9
|
|
|
9
|
-
class CLIP(
|
|
10
|
+
class CLIP(FeatureExtractor):
|
|
10
11
|
"""CLIP image encoder."""
|
|
11
12
|
|
|
12
13
|
def __init__(
|
|
@@ -31,17 +32,17 @@ class CLIP(torch.nn.Module):
|
|
|
31
32
|
self.height = crop_size["height"] // stride[0]
|
|
32
33
|
self.width = crop_size["width"] // stride[1]
|
|
33
34
|
|
|
34
|
-
def forward(self,
|
|
35
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
35
36
|
"""Compute outputs from the backbone.
|
|
36
37
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
process. The images should have values 0-255.
|
|
38
|
+
Args:
|
|
39
|
+
context: the model context. Input dicts must include "image" key containing
|
|
40
|
+
the image to process. The images should have values 0-255.
|
|
40
41
|
|
|
41
42
|
Returns:
|
|
42
|
-
|
|
43
|
-
contains a single Bx24x24x1024 feature map.
|
|
43
|
+
a FeatureMaps with one feature map from the ViT, which is always Bx24x24x1024.
|
|
44
44
|
"""
|
|
45
|
+
inputs = context.inputs
|
|
45
46
|
device = inputs[0]["image"].device
|
|
46
47
|
clip_inputs = self.processor(
|
|
47
48
|
images=[inp["image"].cpu().numpy().transpose(1, 2, 0) for inp in inputs],
|
|
@@ -55,8 +56,10 @@ class CLIP(torch.nn.Module):
|
|
|
55
56
|
batch_size = image_features.shape[0]
|
|
56
57
|
|
|
57
58
|
# 576x1024 -> HxWxC
|
|
58
|
-
return
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
59
|
+
return FeatureMaps(
|
|
60
|
+
[
|
|
61
|
+
image_features.reshape(
|
|
62
|
+
batch_size, self.height, self.width, self.num_features
|
|
63
|
+
).permute(0, 3, 1, 2)
|
|
64
|
+
]
|
|
65
|
+
)
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""Model component API."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from rslearn.train.model_context import ModelContext, ModelOutput
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class FeatureExtractor(torch.nn.Module, abc.ABC):
|
|
13
|
+
"""A feature extractor that performs initial processing of the inputs.
|
|
14
|
+
|
|
15
|
+
The FeatureExtractor is the first component in the encoders list for
|
|
16
|
+
SingleTaskModel and MultiTaskModel.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
@abc.abstractmethod
|
|
20
|
+
def forward(self, context: ModelContext) -> Any:
|
|
21
|
+
"""Extract an initial intermediate from the model context.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
context: the model context.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
any intermediate to pass to downstream components. Oftentimes this is a
|
|
28
|
+
FeatureMaps.
|
|
29
|
+
"""
|
|
30
|
+
raise NotImplementedError
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class IntermediateComponent(torch.nn.Module, abc.ABC):
|
|
34
|
+
"""An intermediate component in the model.
|
|
35
|
+
|
|
36
|
+
In SingleTaskModel and MultiTaskModel, modules after the first module
|
|
37
|
+
in the encoders list are IntermediateComponents, as are modules before the last
|
|
38
|
+
module in the decoders list(s).
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
@abc.abstractmethod
|
|
42
|
+
def forward(self, intermediates: Any, context: ModelContext) -> Any:
|
|
43
|
+
"""Process the given intermediate into another intermediate.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
intermediates: the output from the previous component (either a
|
|
47
|
+
FeatureExtractor or another IntermediateComponent).
|
|
48
|
+
context: the model context.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
any intermediate to pass to downstream components.
|
|
52
|
+
"""
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class Predictor(torch.nn.Module, abc.ABC):
|
|
57
|
+
"""A predictor that computes task-specific outputs and a loss dict.
|
|
58
|
+
|
|
59
|
+
In SingleTaskModel and MultiTaskModel, the last module(s) in the decoders list(s)
|
|
60
|
+
are Predictors.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
@abc.abstractmethod
|
|
64
|
+
def forward(
|
|
65
|
+
self,
|
|
66
|
+
intermediates: Any,
|
|
67
|
+
context: ModelContext,
|
|
68
|
+
targets: list[dict[str, torch.Tensor]] | None = None,
|
|
69
|
+
) -> ModelOutput:
|
|
70
|
+
"""Compute task-specific outputs and loss dict.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
intermediates: the output from the previous component.
|
|
74
|
+
context: the model context.
|
|
75
|
+
targets: the training targets, or None during prediction.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
a tuple of the task-specific outputs (which should be compatible with the
|
|
79
|
+
configured Task) and loss dict. The loss dict maps from a name for each
|
|
80
|
+
loss to a scalar tensor.
|
|
81
|
+
"""
|
|
82
|
+
raise NotImplementedError
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass
|
|
86
|
+
class FeatureMaps:
|
|
87
|
+
"""An intermediate output type for multi-resolution feature maps."""
|
|
88
|
+
|
|
89
|
+
# List of BxCxHxW feature maps at different scales, ordered from highest resolution
|
|
90
|
+
# (most fine-grained) to lowest resolution (coarsest).
|
|
91
|
+
feature_maps: list[torch.Tensor]
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass
|
|
95
|
+
class TokenFeatureMaps:
|
|
96
|
+
"""An intermediate output type for multi-resolution BCHWN feature maps with a token dimension.
|
|
97
|
+
|
|
98
|
+
Unlike `FeatureMaps`, these include an additional dimension for unpooled tokens.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
# List of BxCxHxWxN feature maps at different scales, ordered from highest resolution
|
|
102
|
+
# (most fine-grained) to lowest resolution (coarsest).
|
|
103
|
+
feature_maps: list[torch.Tensor]
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@dataclass
|
|
107
|
+
class FeatureVector:
|
|
108
|
+
"""An intermediate output type for a flat feature vector."""
|
|
109
|
+
|
|
110
|
+
# Flat BxC feature vector.
|
|
111
|
+
feature_vector: torch.Tensor
|
|
@@ -4,8 +4,12 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
7
8
|
|
|
8
|
-
|
|
9
|
+
from .component import FeatureMaps, IntermediateComponent
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ConcatenateFeatures(IntermediateComponent):
|
|
9
13
|
"""Concatenate feature map with additional raw data inputs."""
|
|
10
14
|
|
|
11
15
|
def __init__(
|
|
@@ -55,26 +59,32 @@ class ConcatenateFeatures(torch.nn.Module):
|
|
|
55
59
|
|
|
56
60
|
self.conv_layers = torch.nn.Sequential(*conv_layers)
|
|
57
61
|
|
|
58
|
-
def forward(
|
|
59
|
-
self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
|
|
60
|
-
) -> list[torch.Tensor]:
|
|
62
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
61
63
|
"""Concatenate the feature map with the raw data inputs.
|
|
62
64
|
|
|
63
65
|
Args:
|
|
64
|
-
|
|
65
|
-
|
|
66
|
+
intermediates: the previous output, which must be a FeatureMaps.
|
|
67
|
+
context: the model context. The input dicts must have a key matching the
|
|
68
|
+
configured key.
|
|
66
69
|
|
|
67
70
|
Returns:
|
|
68
71
|
concatenated feature maps.
|
|
69
72
|
"""
|
|
70
|
-
if
|
|
71
|
-
|
|
73
|
+
if (
|
|
74
|
+
not isinstance(intermediates, FeatureMaps)
|
|
75
|
+
or len(intermediates.feature_maps) == 0
|
|
76
|
+
):
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"Expected input to be FeatureMaps with at least one feature map"
|
|
79
|
+
)
|
|
72
80
|
|
|
73
|
-
add_data = torch.stack(
|
|
81
|
+
add_data = torch.stack(
|
|
82
|
+
[input_data[self.key] for input_data in context.inputs], dim=0
|
|
83
|
+
)
|
|
74
84
|
add_features = self.conv_layers(add_data)
|
|
75
85
|
|
|
76
86
|
new_features: list[torch.Tensor] = []
|
|
77
|
-
for feature_map in
|
|
87
|
+
for feature_map in intermediates.feature_maps:
|
|
78
88
|
# Shape of feature map: BCHW
|
|
79
89
|
feat_h, feat_w = feature_map.shape[2], feature_map.shape[3]
|
|
80
90
|
|
|
@@ -90,4 +100,4 @@ class ConcatenateFeatures(torch.nn.Module):
|
|
|
90
100
|
|
|
91
101
|
new_features.append(torch.cat([feature_map, resized_add_features], dim=1))
|
|
92
102
|
|
|
93
|
-
return new_features
|
|
103
|
+
return FeatureMaps(new_features)
|
rslearn/models/conv.py
CHANGED
|
@@ -4,8 +4,12 @@ from typing import Any
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
7
8
|
|
|
8
|
-
|
|
9
|
+
from .component import FeatureMaps, IntermediateComponent
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Conv(IntermediateComponent):
|
|
9
13
|
"""A single convolutional layer.
|
|
10
14
|
|
|
11
15
|
It inputs a set of feature maps; the conv layer is applied to each feature map
|
|
@@ -38,19 +42,22 @@ class Conv(torch.nn.Module):
|
|
|
38
42
|
)
|
|
39
43
|
self.activation = activation
|
|
40
44
|
|
|
41
|
-
def forward(self,
|
|
42
|
-
"""
|
|
45
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
46
|
+
"""Apply conv layer on each feature map.
|
|
43
47
|
|
|
44
48
|
Args:
|
|
45
|
-
|
|
46
|
-
|
|
49
|
+
intermediates: the previous output, which must be a FeatureMaps.
|
|
50
|
+
context: the model context.
|
|
47
51
|
|
|
48
52
|
Returns:
|
|
49
|
-
|
|
53
|
+
the resulting feature maps after applying the same Conv2d on each one.
|
|
50
54
|
"""
|
|
55
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
56
|
+
raise ValueError("input to Conv must be FeatureMaps")
|
|
57
|
+
|
|
51
58
|
new_features = []
|
|
52
|
-
for feat_map in
|
|
59
|
+
for feat_map in intermediates.feature_maps:
|
|
53
60
|
feat_map = self.layer(feat_map)
|
|
54
61
|
feat_map = self.activation(feat_map)
|
|
55
62
|
new_features.append(feat_map)
|
|
56
|
-
return new_features
|
|
63
|
+
return FeatureMaps(new_features)
|
rslearn/models/croma.py
CHANGED
|
@@ -12,9 +12,11 @@ from einops import rearrange
|
|
|
12
12
|
from upath import UPath
|
|
13
13
|
|
|
14
14
|
from rslearn.log_utils import get_logger
|
|
15
|
+
from rslearn.train.model_context import ModelContext
|
|
15
16
|
from rslearn.train.transforms.transform import Transform
|
|
16
17
|
from rslearn.utils.fsspec import open_atomic
|
|
17
18
|
|
|
19
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
18
20
|
from .use_croma import PretrainedCROMA
|
|
19
21
|
|
|
20
22
|
logger = get_logger(__name__)
|
|
@@ -76,7 +78,7 @@ MODALITY_BANDS = {
|
|
|
76
78
|
}
|
|
77
79
|
|
|
78
80
|
|
|
79
|
-
class Croma(
|
|
81
|
+
class Croma(FeatureExtractor):
|
|
80
82
|
"""CROMA backbones.
|
|
81
83
|
|
|
82
84
|
There are two model sizes, base and large.
|
|
@@ -160,20 +162,23 @@ class Croma(torch.nn.Module):
|
|
|
160
162
|
align_corners=False,
|
|
161
163
|
)
|
|
162
164
|
|
|
163
|
-
def forward(self,
|
|
165
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
164
166
|
"""Compute feature maps from the Croma backbone.
|
|
165
167
|
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
"sentinel1" keys depending on the configured modality.
|
|
168
|
+
Args:
|
|
169
|
+
context: the model context. Input dicts must include either/both of
|
|
170
|
+
"sentinel2" or "sentinel1" keys depending on the configured modality.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
a FeatureMaps with one feature map at 1/8 the input resolution.
|
|
169
174
|
"""
|
|
170
175
|
sentinel1: torch.Tensor | None = None
|
|
171
176
|
sentinel2: torch.Tensor | None = None
|
|
172
177
|
if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL1]:
|
|
173
|
-
sentinel1 = torch.stack([inp["sentinel1"] for inp in inputs], dim=0)
|
|
178
|
+
sentinel1 = torch.stack([inp["sentinel1"] for inp in context.inputs], dim=0)
|
|
174
179
|
sentinel1 = self._resize_image(sentinel1) if self.do_resizing else sentinel1
|
|
175
180
|
if self.modality in [CromaModality.BOTH, CromaModality.SENTINEL2]:
|
|
176
|
-
sentinel2 = torch.stack([inp["sentinel2"] for inp in inputs], dim=0)
|
|
181
|
+
sentinel2 = torch.stack([inp["sentinel2"] for inp in context.inputs], dim=0)
|
|
177
182
|
sentinel2 = self._resize_image(sentinel2) if self.do_resizing else sentinel2
|
|
178
183
|
|
|
179
184
|
outputs = self.model(
|
|
@@ -200,7 +205,7 @@ class Croma(torch.nn.Module):
|
|
|
200
205
|
w=num_patches_per_dim,
|
|
201
206
|
)
|
|
202
207
|
|
|
203
|
-
return [features]
|
|
208
|
+
return FeatureMaps([features])
|
|
204
209
|
|
|
205
210
|
def get_backbone_channels(self) -> list:
|
|
206
211
|
"""Returns the output channels of this model when used as a backbone.
|
rslearn/models/detr/detr.py
CHANGED
|
@@ -13,6 +13,8 @@ import torch.nn.functional as F
|
|
|
13
13
|
from torch import nn
|
|
14
14
|
|
|
15
15
|
import rslearn.models.detr.box_ops as box_ops
|
|
16
|
+
from rslearn.models.component import FeatureMaps, Predictor
|
|
17
|
+
from rslearn.train.model_context import ModelContext, ModelOutput
|
|
16
18
|
|
|
17
19
|
from .matcher import HungarianMatcher
|
|
18
20
|
from .position_encoding import PositionEmbeddingSine
|
|
@@ -405,7 +407,7 @@ class PostProcess(nn.Module):
|
|
|
405
407
|
return results
|
|
406
408
|
|
|
407
409
|
|
|
408
|
-
class Detr(
|
|
410
|
+
class Detr(Predictor):
|
|
409
411
|
"""DETR prediction module.
|
|
410
412
|
|
|
411
413
|
This combines PositionEmbeddingSine, DetrPredictor, SetCriterion, and PostProcess.
|
|
@@ -440,33 +442,39 @@ class Detr(nn.Module):
|
|
|
440
442
|
|
|
441
443
|
def forward(
|
|
442
444
|
self,
|
|
443
|
-
|
|
444
|
-
|
|
445
|
+
intermediates: Any,
|
|
446
|
+
context: ModelContext,
|
|
445
447
|
targets: list[dict[str, Any]] | None = None,
|
|
446
|
-
) ->
|
|
448
|
+
) -> ModelOutput:
|
|
447
449
|
"""Compute the detection outputs and loss from features.
|
|
448
450
|
|
|
449
451
|
DETR will use only the last feature map, which should correspond to the lowest
|
|
450
452
|
resolution one.
|
|
451
453
|
|
|
452
454
|
Args:
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
455
|
+
intermediates: the output from the previous component. It must be a FeatureMaps.
|
|
456
|
+
context: the model context. Input dicts must contain an "image" key which we will
|
|
457
|
+
be used to establish the original image size.
|
|
458
|
+
targets: must contain class key that stores the class label.
|
|
456
459
|
|
|
457
460
|
Returns:
|
|
458
|
-
|
|
461
|
+
the model output.
|
|
459
462
|
"""
|
|
463
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
464
|
+
raise ValueError("input to Detr must be a FeatureMaps")
|
|
465
|
+
|
|
466
|
+
# We only use the last feature map (most fine-grained).
|
|
467
|
+
features = intermediates.feature_maps[-1]
|
|
468
|
+
|
|
460
469
|
# Get image sizes.
|
|
461
470
|
image_sizes = torch.tensor(
|
|
462
|
-
[[inp["image"].shape[2], inp["image"].shape[1]] for inp in inputs],
|
|
471
|
+
[[inp["image"].shape[2], inp["image"].shape[1]] for inp in context.inputs],
|
|
463
472
|
dtype=torch.int32,
|
|
464
|
-
device=features
|
|
473
|
+
device=features.device,
|
|
465
474
|
)
|
|
466
475
|
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
outputs = self.predictor(feat_map, pos_embedding)
|
|
476
|
+
pos_embedding = self.pos_embedding(features)
|
|
477
|
+
outputs = self.predictor(features, pos_embedding)
|
|
470
478
|
|
|
471
479
|
if targets is not None:
|
|
472
480
|
# Convert boxes from [x0, y0, x1, y1] to [cx, cy, w, h].
|
|
@@ -490,4 +498,7 @@ class Detr(nn.Module):
|
|
|
490
498
|
|
|
491
499
|
results = self.postprocess(outputs, image_sizes)
|
|
492
500
|
|
|
493
|
-
return
|
|
501
|
+
return ModelOutput(
|
|
502
|
+
outputs=results,
|
|
503
|
+
loss_dict=losses,
|
|
504
|
+
)
|