rslearn 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +1 -22
- rslearn/data_sources/copernicus.py +6 -4
- rslearn/data_sources/eurocrops.py +246 -0
- rslearn/data_sources/local_files.py +11 -0
- rslearn/data_sources/openstreetmap.py +2 -4
- rslearn/dataset/dataset.py +4 -1
- rslearn/models/copernicusfm.py +216 -0
- rslearn/models/copernicusfm_src/__init__.py +1 -0
- rslearn/models/copernicusfm_src/aurora/area.py +50 -0
- rslearn/models/copernicusfm_src/aurora/fourier.py +134 -0
- rslearn/models/copernicusfm_src/dynamic_hypernetwork.py +523 -0
- rslearn/models/copernicusfm_src/flexivit/patch_embed.py +260 -0
- rslearn/models/copernicusfm_src/flexivit/utils.py +69 -0
- rslearn/models/copernicusfm_src/model_vit.py +348 -0
- rslearn/models/copernicusfm_src/util/pos_embed.py +216 -0
- rslearn/models/panopticon.py +167 -0
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +247 -0
- rslearn/models/presto/single_file_presto.py +932 -0
- rslearn/models/unet.py +15 -0
- rslearn/template_params.py +26 -0
- {rslearn-0.0.4.dist-info → rslearn-0.0.6.dist-info}/METADATA +4 -1
- {rslearn-0.0.4.dist-info → rslearn-0.0.6.dist-info}/RECORD +27 -12
- {rslearn-0.0.4.dist-info → rslearn-0.0.6.dist-info}/WHEEL +0 -0
- {rslearn-0.0.4.dist-info → rslearn-0.0.6.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.4.dist-info → rslearn-0.0.6.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.4.dist-info → rslearn-0.0.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# type: ignore
|
|
2
|
+
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""
|
|
3
|
+
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
|
|
10
|
+
from .area import area, radius_earth
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"FourierExpansion",
|
|
14
|
+
"pos_expansion",
|
|
15
|
+
"scale_expansion",
|
|
16
|
+
"lead_time_expansion",
|
|
17
|
+
"levels_expansion",
|
|
18
|
+
"absolute_time_expansion",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class FourierExpansion(nn.Module):
|
|
23
|
+
"""A Fourier series-style expansion into a high-dimensional space.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
lower (float): Lower wavelength.
|
|
27
|
+
upper (float): Upper wavelength.
|
|
28
|
+
assert_range (bool): Assert that the encoded tensor is within the specified wavelength
|
|
29
|
+
range.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, lower: float, upper: float, assert_range: bool = True) -> None:
|
|
33
|
+
"""Initialise.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
lower (float): Lower wavelength.
|
|
37
|
+
upper (float): Upper wavelength.
|
|
38
|
+
assert_range (bool, optional): Assert that the encoded tensor is within the specified
|
|
39
|
+
wavelength range. Defaults to `True`.
|
|
40
|
+
"""
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.lower = lower
|
|
43
|
+
self.upper = upper
|
|
44
|
+
self.assert_range = assert_range
|
|
45
|
+
|
|
46
|
+
def forward(self, x: torch.Tensor, d: int) -> torch.Tensor:
|
|
47
|
+
"""Perform the expansion.
|
|
48
|
+
|
|
49
|
+
Adds a dimension of length `d` to the end of the shape of `x`.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
x (:class:`torch.Tensor`): Input to expand of shape `(..., n)`. All elements of `x` must
|
|
53
|
+
lie within `[self.lower, self.upper]` if `self.assert_range` is `True`.
|
|
54
|
+
d (int): Dimensionality. Must be a multiple of two.
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
AssertionError: If `self.assert_range` is `True` and not all elements of `x` are not
|
|
58
|
+
within `[self.lower, self.upper]`.
|
|
59
|
+
ValueError: If `d` is not a multiple of two.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
torch.Tensor: Fourier series-style expansion of `x` of shape `(..., n, d)`.
|
|
63
|
+
"""
|
|
64
|
+
# If the input is not within the configured range, the embedding might be ambiguous!
|
|
65
|
+
in_range = torch.logical_and(
|
|
66
|
+
self.lower <= x.abs(), torch.all(x.abs() <= self.upper)
|
|
67
|
+
)
|
|
68
|
+
in_range_or_zero = torch.all(
|
|
69
|
+
torch.logical_or(in_range, x == 0)
|
|
70
|
+
) # Allow zeros to pass through.
|
|
71
|
+
if self.assert_range and not in_range_or_zero:
|
|
72
|
+
raise AssertionError(
|
|
73
|
+
f"The input tensor is not within the configured range"
|
|
74
|
+
f" `[{self.lower}, {self.upper}]`."
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# We will use half of the dimensionality for `sin` and the other half for `cos`.
|
|
78
|
+
if not (d % 2 == 0):
|
|
79
|
+
raise ValueError("The dimensionality must be a multiple of two.")
|
|
80
|
+
|
|
81
|
+
# Always perform the expansion with `float64`s to avoid numerical accuracy shenanigans.
|
|
82
|
+
x = x.double()
|
|
83
|
+
|
|
84
|
+
wavelengths = torch.logspace(
|
|
85
|
+
math.log10(self.lower),
|
|
86
|
+
math.log10(self.upper),
|
|
87
|
+
d // 2,
|
|
88
|
+
base=10,
|
|
89
|
+
device=x.device,
|
|
90
|
+
dtype=x.dtype,
|
|
91
|
+
)
|
|
92
|
+
prod = torch.einsum("...i,j->...ij", x, 2 * np.pi / wavelengths)
|
|
93
|
+
encoding = torch.cat((torch.sin(prod), torch.cos(prod)), dim=-1)
|
|
94
|
+
|
|
95
|
+
return encoding.float() # Cast to `float32` to avoid incompatibilities.
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# Determine a reasonable smallest value for the scale embedding by assuming a smallest delta in
|
|
99
|
+
# latitudes and longitudes.
|
|
100
|
+
_delta = 0.01 # Reasonable smallest delta in latitude and longitude
|
|
101
|
+
_min_patch_area: float = area(
|
|
102
|
+
torch.tensor(
|
|
103
|
+
[
|
|
104
|
+
# The smallest patches will be at the poles. Just use the north pole.
|
|
105
|
+
[90, 0],
|
|
106
|
+
[90, _delta],
|
|
107
|
+
[90 - _delta, _delta],
|
|
108
|
+
[90 - _delta, 0],
|
|
109
|
+
],
|
|
110
|
+
dtype=torch.float64,
|
|
111
|
+
)
|
|
112
|
+
).item()
|
|
113
|
+
_area_earth = 4 * np.pi * radius_earth * radius_earth
|
|
114
|
+
|
|
115
|
+
pos_expansion = FourierExpansion(_delta, 720)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
scale_expansion = FourierExpansion(_min_patch_area, _area_earth)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
lead_time_expansion = FourierExpansion(1 / 60, 24 * 7 * 3)
|
|
122
|
+
|
|
123
|
+
levels_expansion = FourierExpansion(0.01, 1e5)
|
|
124
|
+
|
|
125
|
+
absolute_time_expansion = FourierExpansion(1, 24 * 365.25, assert_range=False)
|
|
126
|
+
|
|
127
|
+
### new for SSL4EO-S ###
|
|
128
|
+
# min wavelength: ultraviolet light (100 nm)
|
|
129
|
+
# max wavelength: radio waves (1 m)
|
|
130
|
+
spectrum_central_expansion = FourierExpansion(1e-7, 1)
|
|
131
|
+
|
|
132
|
+
# min bandwidth: 10nm
|
|
133
|
+
# max bandwidth: 1m
|
|
134
|
+
spectrum_width_expansion = FourierExpansion(1e-7, 1)
|
|
@@ -0,0 +1,523 @@
|
|
|
1
|
+
# mypy: ignore-errors
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
import torch.nn.init as init
|
|
7
|
+
|
|
8
|
+
# CopernicusFM: meta encoding (follow aurora)
|
|
9
|
+
from .aurora.fourier import FourierExpansion
|
|
10
|
+
|
|
11
|
+
# CopernicusFM: dynamic patch size (follow flexivit)
|
|
12
|
+
from .flexivit.patch_embed import pi_resize_patch_embed
|
|
13
|
+
from .util.pos_embed import get_1d_sincos_pos_embed_from_grid_torch
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TransformerWeightGenerator(nn.Module):
|
|
17
|
+
def __init__(self, input_dim, output_dim, embed_dim, num_heads=4, num_layers=1):
|
|
18
|
+
super(TransformerWeightGenerator, self).__init__()
|
|
19
|
+
encoder_layer = nn.TransformerEncoderLayer(
|
|
20
|
+
d_model=input_dim,
|
|
21
|
+
nhead=num_heads,
|
|
22
|
+
activation="gelu",
|
|
23
|
+
norm_first=False,
|
|
24
|
+
batch_first=False,
|
|
25
|
+
dropout=False,
|
|
26
|
+
)
|
|
27
|
+
self.transformer_encoder = nn.TransformerEncoder(
|
|
28
|
+
encoder_layer, num_layers=num_layers, enable_nested_tensor=False
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
# Linear layer to map transformer output to desired weight shape
|
|
32
|
+
self.fc_weight = nn.Linear(input_dim, output_dim)
|
|
33
|
+
self.fc_bias = nn.Linear(input_dim, embed_dim)
|
|
34
|
+
self.wt_num = 128
|
|
35
|
+
self.weight_tokens = nn.Parameter(torch.empty([self.wt_num, input_dim]))
|
|
36
|
+
self.bias_token = nn.Parameter(torch.empty([1, input_dim]))
|
|
37
|
+
|
|
38
|
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
|
39
|
+
torch.nn.init.normal_(self.weight_tokens, std=0.02)
|
|
40
|
+
torch.nn.init.normal_(self.bias_token, std=0.02)
|
|
41
|
+
|
|
42
|
+
def forward(self, x):
|
|
43
|
+
# x should have shape [seq_len, batch, input_dim]
|
|
44
|
+
pos_wave = x
|
|
45
|
+
x = torch.cat([self.weight_tokens, pos_wave], dim=0)
|
|
46
|
+
x = torch.cat([x, self.bias_token], dim=0)
|
|
47
|
+
transformer_output = self.transformer_encoder(x)
|
|
48
|
+
weights = self.fc_weight(transformer_output[self.wt_num : -1] + pos_wave)
|
|
49
|
+
bias = self.fc_bias(
|
|
50
|
+
transformer_output[-1]
|
|
51
|
+
) # Using the last output to generate bias
|
|
52
|
+
return weights, bias
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class GaussianFourierFeatureTransform(torch.nn.Module):
|
|
56
|
+
"""An implementation of Gaussian Fourier feature mapping.
|
|
57
|
+
|
|
58
|
+
"Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains":
|
|
59
|
+
https://arxiv.org/abs/2006.10739
|
|
60
|
+
https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html
|
|
61
|
+
|
|
62
|
+
Given an input of size [batches, num_input_channels, width, height],
|
|
63
|
+
returns a tensor of size [batches, mapping_size*2, width, height].
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(self, num_input_channels, mapping_size=256, scale=10):
|
|
67
|
+
super().__init__()
|
|
68
|
+
|
|
69
|
+
self._num_input_channels = num_input_channels
|
|
70
|
+
self._mapping_size = mapping_size
|
|
71
|
+
torch.manual_seed(42)
|
|
72
|
+
self._B = torch.randn((num_input_channels, mapping_size)) * scale
|
|
73
|
+
|
|
74
|
+
def forward(self, x):
|
|
75
|
+
assert x.dim() == 4, f"Expected 4D input (got {x.dim()}D input)"
|
|
76
|
+
|
|
77
|
+
batches, channels, width, height = x.shape
|
|
78
|
+
|
|
79
|
+
assert (
|
|
80
|
+
channels == self._num_input_channels
|
|
81
|
+
), f"Expected input to have {self._num_input_channels} channels (got {channels} channels)"
|
|
82
|
+
|
|
83
|
+
# Make shape compatible for matmul with _B.
|
|
84
|
+
# From [B, C, W, H] to [(B*W*H), C].
|
|
85
|
+
x = x.permute(0, 2, 3, 1).reshape(batches * width * height, channels)
|
|
86
|
+
|
|
87
|
+
x = x @ self._B.to(x.device)
|
|
88
|
+
|
|
89
|
+
# From [(B*W*H), C] to [B, W, H, C]
|
|
90
|
+
x = x.view(batches, width, height, self._mapping_size)
|
|
91
|
+
# From [B, W, H, C] to [B, C, W, H]
|
|
92
|
+
x = x.permute(0, 3, 1, 2)
|
|
93
|
+
|
|
94
|
+
x = 2 * np.pi * x
|
|
95
|
+
return torch.cat([torch.sin(x), torch.cos(x)], dim=1)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class Basic1d(nn.Module):
|
|
99
|
+
def __init__(self, in_channels, out_channels, bias=True):
|
|
100
|
+
super().__init__()
|
|
101
|
+
conv = nn.Linear(in_channels, out_channels, bias)
|
|
102
|
+
self.conv = nn.Sequential(
|
|
103
|
+
conv,
|
|
104
|
+
)
|
|
105
|
+
if not bias:
|
|
106
|
+
self.conv.add_module("ln", nn.LayerNorm(out_channels))
|
|
107
|
+
self.conv.add_module("relu", nn.ReLU(inplace=True))
|
|
108
|
+
|
|
109
|
+
def forward(self, x):
|
|
110
|
+
out = self.conv(x)
|
|
111
|
+
return out
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class FCResLayer(nn.Module):
|
|
115
|
+
def __init__(self, linear_size=128):
|
|
116
|
+
super(FCResLayer, self).__init__()
|
|
117
|
+
self.l_size = linear_size
|
|
118
|
+
self.nonlin1 = nn.ReLU(inplace=True)
|
|
119
|
+
self.nonlin2 = nn.ReLU(inplace=True)
|
|
120
|
+
# self.dropout1 = nn.Dropout()
|
|
121
|
+
self.w1 = nn.Linear(self.l_size, self.l_size)
|
|
122
|
+
self.w2 = nn.Linear(self.l_size, self.l_size)
|
|
123
|
+
|
|
124
|
+
def forward(self, x):
|
|
125
|
+
y = self.w1(x)
|
|
126
|
+
y = self.nonlin1(y)
|
|
127
|
+
# y = self.dropout1(y)
|
|
128
|
+
y = self.w2(y)
|
|
129
|
+
y = self.nonlin2(y)
|
|
130
|
+
out = x + y
|
|
131
|
+
return out
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class Dynamic_MLP_Decoder(nn.Module):
|
|
135
|
+
def __init__(self, wv_planes, inter_dim=128, kernel_size=16, decoder_embed=512):
|
|
136
|
+
super().__init__()
|
|
137
|
+
self.kernel_size = kernel_size
|
|
138
|
+
self.wv_planes = wv_planes
|
|
139
|
+
self.inter_dim = inter_dim
|
|
140
|
+
self.decoder_embed = decoder_embed
|
|
141
|
+
self._num_kernel = self.kernel_size * self.kernel_size * self.decoder_embed
|
|
142
|
+
|
|
143
|
+
# self.weight_generator = nn.Sequential(Basic1d(wv_planes, self.inter_dim, bias=True),
|
|
144
|
+
# nn.Linear(self.inter_dim, self._num_kernel))
|
|
145
|
+
self.weight_generator = TransformerWeightGenerator(
|
|
146
|
+
wv_planes, self._num_kernel, decoder_embed
|
|
147
|
+
)
|
|
148
|
+
self.scaler = 0.01
|
|
149
|
+
|
|
150
|
+
self._init_weights()
|
|
151
|
+
|
|
152
|
+
def _get_weights(self, waves, batch=True):
|
|
153
|
+
dweights = []
|
|
154
|
+
dynamic_weights = None
|
|
155
|
+
if batch:
|
|
156
|
+
dynamic_weights = self.weight_generator(waves)
|
|
157
|
+
else:
|
|
158
|
+
for i in range(waves.size(0)):
|
|
159
|
+
dweights.append(self.weight_generator(waves[i]))
|
|
160
|
+
dynamic_weights = torch.stack(dweights, dim=0)
|
|
161
|
+
|
|
162
|
+
return dynamic_weights
|
|
163
|
+
|
|
164
|
+
def weight_init(self, m):
|
|
165
|
+
if isinstance(m, nn.Linear):
|
|
166
|
+
init.xavier_uniform_(m.weight)
|
|
167
|
+
m.bias.data.fill_(0.01)
|
|
168
|
+
|
|
169
|
+
def _init_weights(self):
|
|
170
|
+
"""Initialize the base weights and dynamic mlp weights"""
|
|
171
|
+
self.weight_generator.apply(self.weight_init)
|
|
172
|
+
|
|
173
|
+
def forward(self, img_feat, waves, kernel_size=None):
|
|
174
|
+
inplanes = waves.size(0)
|
|
175
|
+
# wv_feats: 9,128 -> 9*16*16,512
|
|
176
|
+
weight, bias = self._get_weights(waves) # 9,16*16*512
|
|
177
|
+
# dynamic_weight = weight.view(
|
|
178
|
+
# inplanes * self.kernel_size * self.kernel_size, self.decoder_embed
|
|
179
|
+
# ) # 9*16*16,512
|
|
180
|
+
|
|
181
|
+
# CopernicusFM: dynamic patch size
|
|
182
|
+
dynamic_weight = weight.view(
|
|
183
|
+
inplanes, self.kernel_size, self.kernel_size, self.decoder_embed
|
|
184
|
+
)
|
|
185
|
+
dynamic_weight = dynamic_weight.permute([3, 0, 1, 2])
|
|
186
|
+
# resize the weight to match different preferred kernel sizes
|
|
187
|
+
if kernel_size != None and self.kernel_size != kernel_size:
|
|
188
|
+
dynamic_weight = pi_resize_patch_embed(
|
|
189
|
+
dynamic_weight, (kernel_size, kernel_size)
|
|
190
|
+
) # 512, 9, p, p
|
|
191
|
+
else:
|
|
192
|
+
kernel_size = self.kernel_size
|
|
193
|
+
dynamic_weight = (
|
|
194
|
+
dynamic_weight.permute([1, 2, 3, 0])
|
|
195
|
+
.contiguous()
|
|
196
|
+
.view(-1, self.decoder_embed)
|
|
197
|
+
) # 9*p*p,512
|
|
198
|
+
|
|
199
|
+
weights = dynamic_weight * self.scaler
|
|
200
|
+
|
|
201
|
+
dynamic_out = F.linear(img_feat, weights, bias=None)
|
|
202
|
+
x = dynamic_out
|
|
203
|
+
return x
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class Dynamic_Patch_Embed(nn.Module):
|
|
207
|
+
"""Input: channels of wavelength (normalized): List -> List
|
|
208
|
+
kernel size of the depth-wise convolution: kernel_size, default 3x3
|
|
209
|
+
wv_planes
|
|
210
|
+
inplanes
|
|
211
|
+
"""
|
|
212
|
+
|
|
213
|
+
def __init__(self, wv_planes, inter_dim=128, kernel_size=3, embed_dim=1024):
|
|
214
|
+
super().__init__()
|
|
215
|
+
self.kernel_size = kernel_size
|
|
216
|
+
self.wv_planes = wv_planes
|
|
217
|
+
self.embed_dim = embed_dim
|
|
218
|
+
self.kernel_size = kernel_size
|
|
219
|
+
self.patch_size = (kernel_size, kernel_size)
|
|
220
|
+
self.weight2 = nn.Parameter(
|
|
221
|
+
torch.empty([embed_dim, 2, kernel_size, kernel_size])
|
|
222
|
+
)
|
|
223
|
+
self.bias2 = nn.Parameter(torch.empty([embed_dim]))
|
|
224
|
+
self.weight3 = nn.Parameter(
|
|
225
|
+
torch.empty([embed_dim, 3, kernel_size, kernel_size])
|
|
226
|
+
)
|
|
227
|
+
self.bias3 = nn.Parameter(torch.empty([embed_dim]))
|
|
228
|
+
self.weight4 = nn.Parameter(
|
|
229
|
+
torch.empty([embed_dim, 4, kernel_size, kernel_size])
|
|
230
|
+
)
|
|
231
|
+
self.bias4 = nn.Parameter(torch.empty([embed_dim]))
|
|
232
|
+
self.weight9 = nn.Parameter(
|
|
233
|
+
torch.empty([embed_dim, 9, kernel_size, kernel_size])
|
|
234
|
+
)
|
|
235
|
+
self.bias9 = nn.Parameter(torch.empty([embed_dim]))
|
|
236
|
+
self.weight70 = nn.Parameter(
|
|
237
|
+
torch.empty([embed_dim, 70, kernel_size, kernel_size])
|
|
238
|
+
)
|
|
239
|
+
self.bias70 = nn.Parameter(torch.empty([embed_dim]))
|
|
240
|
+
self.weights = {
|
|
241
|
+
2: self.weight2,
|
|
242
|
+
3: self.weight3,
|
|
243
|
+
4: self.weight4,
|
|
244
|
+
9: self.weight9,
|
|
245
|
+
70: self.weight70,
|
|
246
|
+
}
|
|
247
|
+
self.biass = {
|
|
248
|
+
2: self.bias2,
|
|
249
|
+
3: self.bias3,
|
|
250
|
+
4: self.bias4,
|
|
251
|
+
9: self.bias9,
|
|
252
|
+
70: self.bias70,
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
def forward(self, img_feat, waves):
|
|
256
|
+
inplanes = waves.size(0)
|
|
257
|
+
# wv_feats: 9,128 -> 9, 3x3x3
|
|
258
|
+
weights = self.weights[inplanes]
|
|
259
|
+
bias = self.biass[inplanes]
|
|
260
|
+
|
|
261
|
+
dynamic_out = F.conv2d(
|
|
262
|
+
img_feat, weights, bias=bias, stride=self.kernel_size, padding=1, dilation=1
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
x = dynamic_out
|
|
266
|
+
x = x.flatten(2).transpose(1, 2)
|
|
267
|
+
|
|
268
|
+
return x
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class Dynamic_MLP_OFA(nn.Module):
|
|
272
|
+
"""Input: channels of wavelength (normalized): List -> List
|
|
273
|
+
kernel size of the depth-wise convolution: kernel_size, default 3x3
|
|
274
|
+
wv_planes
|
|
275
|
+
inplanes
|
|
276
|
+
"""
|
|
277
|
+
|
|
278
|
+
def __init__(self, wv_planes, inter_dim=128, kernel_size=3, embed_dim=1024):
|
|
279
|
+
super().__init__()
|
|
280
|
+
self.kernel_size = kernel_size
|
|
281
|
+
self.wv_planes = wv_planes
|
|
282
|
+
self.embed_dim = embed_dim
|
|
283
|
+
self.kernel_size = kernel_size
|
|
284
|
+
self._num_kernel = self.kernel_size * self.kernel_size * self.embed_dim
|
|
285
|
+
self.inter_dim = inter_dim
|
|
286
|
+
self.patch_size = (kernel_size, kernel_size)
|
|
287
|
+
self.num_patches = -1
|
|
288
|
+
|
|
289
|
+
self.weight_generator = TransformerWeightGenerator(
|
|
290
|
+
wv_planes, self._num_kernel, embed_dim
|
|
291
|
+
)
|
|
292
|
+
self.scaler = 0.01
|
|
293
|
+
|
|
294
|
+
self.fclayer = FCResLayer(wv_planes)
|
|
295
|
+
|
|
296
|
+
self._init_weights()
|
|
297
|
+
|
|
298
|
+
def _get_weights(self, waves):
|
|
299
|
+
dynamic_weights = self.weight_generator(waves)
|
|
300
|
+
return dynamic_weights
|
|
301
|
+
|
|
302
|
+
def weight_init(self, m):
|
|
303
|
+
if isinstance(m, nn.Linear):
|
|
304
|
+
init.xavier_uniform_(m.weight)
|
|
305
|
+
m.bias.data.fill_(0.01)
|
|
306
|
+
|
|
307
|
+
def _init_weights(self):
|
|
308
|
+
"""Initialize the base weights and dynamic mlp weights"""
|
|
309
|
+
self.weight_generator.apply(self.weight_init)
|
|
310
|
+
self.fclayer.apply(self.weight_init)
|
|
311
|
+
|
|
312
|
+
def forward(self, img_feat, wvs):
|
|
313
|
+
inplanes = wvs.size(0)
|
|
314
|
+
# wv_feats: 9,128 -> 9, 3x3x3
|
|
315
|
+
waves = get_1d_sincos_pos_embed_from_grid_torch(self.wv_planes, wvs * 1000)
|
|
316
|
+
waves = self.fclayer(waves)
|
|
317
|
+
weight, bias = self._get_weights(waves) # 3x3x3
|
|
318
|
+
# bias = None
|
|
319
|
+
|
|
320
|
+
# dynamic_weight = weight.view(self.embed_dim, inplanes, self.kernel_size, self.kernel_size) #3xoutdx16x16
|
|
321
|
+
dynamic_weight = weight.view(
|
|
322
|
+
inplanes, self.kernel_size, self.kernel_size, self.embed_dim
|
|
323
|
+
)
|
|
324
|
+
dynamic_weight = dynamic_weight.permute([3, 0, 1, 2])
|
|
325
|
+
if bias is not None:
|
|
326
|
+
bias = bias.view([self.embed_dim]) * self.scaler
|
|
327
|
+
|
|
328
|
+
weights = dynamic_weight * self.scaler
|
|
329
|
+
|
|
330
|
+
dynamic_out = F.conv2d(
|
|
331
|
+
img_feat, weights, bias=bias, stride=self.kernel_size, padding=1, dilation=1
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
x = dynamic_out
|
|
335
|
+
x = x.flatten(2).transpose(1, 2)
|
|
336
|
+
|
|
337
|
+
return x, waves
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
class Dynamic_MLP_OFA_spectral(nn.Module):
|
|
341
|
+
"""Input: channels of wavelength and bandwidth (normalized): List -> List
|
|
342
|
+
kernel size of the depth-wise convolution: kernel_size, default 3x3
|
|
343
|
+
wv_planes
|
|
344
|
+
inplanes
|
|
345
|
+
"""
|
|
346
|
+
|
|
347
|
+
def __init__(self, wv_planes, inter_dim=128, kernel_size=3, embed_dim=1024):
|
|
348
|
+
super().__init__()
|
|
349
|
+
self.kernel_size = kernel_size
|
|
350
|
+
self.wv_planes = wv_planes
|
|
351
|
+
self.embed_dim = embed_dim
|
|
352
|
+
self.kernel_size = kernel_size
|
|
353
|
+
self._num_kernel = self.kernel_size * self.kernel_size * self.embed_dim
|
|
354
|
+
self.inter_dim = inter_dim
|
|
355
|
+
self.patch_size = (kernel_size, kernel_size)
|
|
356
|
+
self.num_patches = -1
|
|
357
|
+
|
|
358
|
+
## CopernicusFM: fourier embedding for wavelength and bandwidth
|
|
359
|
+
# min wavelength: ultraviolet light (100 nm)
|
|
360
|
+
# max wavelength: radio waves (1 m)
|
|
361
|
+
self.spectrum_central_expansion = FourierExpansion(100, 1e9)
|
|
362
|
+
# min bandwidth: s2 ~ 10nm
|
|
363
|
+
# max bandwidth: s1 ~ 1m
|
|
364
|
+
self.spectrum_bandwidth_expansion = FourierExpansion(1, 1e9)
|
|
365
|
+
|
|
366
|
+
self.weight_generator = TransformerWeightGenerator(
|
|
367
|
+
wv_planes, self._num_kernel, embed_dim
|
|
368
|
+
)
|
|
369
|
+
self.scaler = 0.01
|
|
370
|
+
|
|
371
|
+
self.fclayer = FCResLayer(wv_planes)
|
|
372
|
+
|
|
373
|
+
self._init_weights()
|
|
374
|
+
|
|
375
|
+
def _get_weights(self, waves):
|
|
376
|
+
dynamic_weights = self.weight_generator(waves)
|
|
377
|
+
|
|
378
|
+
return dynamic_weights
|
|
379
|
+
|
|
380
|
+
def weight_init(self, m):
|
|
381
|
+
if isinstance(m, nn.Linear):
|
|
382
|
+
init.xavier_uniform_(m.weight)
|
|
383
|
+
m.bias.data.fill_(0.01)
|
|
384
|
+
|
|
385
|
+
def _init_weights(self):
|
|
386
|
+
"""Initialize the base weights and dynamic mlp weights"""
|
|
387
|
+
self.weight_generator.apply(self.weight_init)
|
|
388
|
+
self.fclayer.apply(self.weight_init)
|
|
389
|
+
|
|
390
|
+
def forward(self, img_feat, wvs, bandwidths, kernel_size=None):
|
|
391
|
+
"""wvs: nm
|
|
392
|
+
bandwidths: nm
|
|
393
|
+
"""
|
|
394
|
+
inplanes = wvs.size(0)
|
|
395
|
+
# wv_feats: 9,128 -> 9, 3x3x3
|
|
396
|
+
# waves = get_1d_sincos_pos_embed_from_grid_torch(self.wv_planes, wvs * 1000) # dofa: fixed sincos pos embedding
|
|
397
|
+
# waves = get_1d_fourier_pos_embed_from_grid_torch(self.wv_planes, wvs * 1000) # new: fourier pos embedding
|
|
398
|
+
emb_central = self.spectrum_central_expansion(wvs, self.wv_planes)
|
|
399
|
+
emb_bandwidth = self.spectrum_bandwidth_expansion(bandwidths, self.wv_planes)
|
|
400
|
+
waves = (
|
|
401
|
+
emb_central + emb_bandwidth
|
|
402
|
+
) # simply add two embeddings, can be more complex later
|
|
403
|
+
|
|
404
|
+
waves = self.fclayer(waves)
|
|
405
|
+
weight, bias = self._get_weights(waves) # 3x3x3
|
|
406
|
+
|
|
407
|
+
# Fix bug
|
|
408
|
+
dynamic_weight = weight.view(
|
|
409
|
+
inplanes, self.kernel_size, self.kernel_size, self.embed_dim
|
|
410
|
+
) # 9, 3, 3, 1024
|
|
411
|
+
dynamic_weight = dynamic_weight.permute([3, 0, 1, 2]) # 1024, 9, 3, 3
|
|
412
|
+
# resize the weight to match different preferred kernel sizes
|
|
413
|
+
if kernel_size != None and self.kernel_size != kernel_size:
|
|
414
|
+
dynamic_weight = pi_resize_patch_embed(
|
|
415
|
+
dynamic_weight, (kernel_size, kernel_size)
|
|
416
|
+
)
|
|
417
|
+
else:
|
|
418
|
+
kernel_size = self.kernel_size
|
|
419
|
+
|
|
420
|
+
if bias is not None:
|
|
421
|
+
bias = bias.view([self.embed_dim]) * self.scaler
|
|
422
|
+
|
|
423
|
+
weights = dynamic_weight * self.scaler
|
|
424
|
+
|
|
425
|
+
dynamic_out = F.conv2d(
|
|
426
|
+
img_feat, weights, bias=bias, stride=kernel_size, padding=1, dilation=1
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
x = dynamic_out
|
|
430
|
+
x = x.flatten(2).transpose(1, 2)
|
|
431
|
+
|
|
432
|
+
return x, waves
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
class Dynamic_MLP_OFA_variable(nn.Module):
|
|
436
|
+
"""Input: language embedding of variable name: Pytorch tensor
|
|
437
|
+
kernel size of the depth-wise convolution: kernel_size, default 3x3
|
|
438
|
+
wv_planes
|
|
439
|
+
inplanes
|
|
440
|
+
"""
|
|
441
|
+
|
|
442
|
+
def __init__(self, wv_planes, inter_dim=128, kernel_size=3, embed_dim=1024):
|
|
443
|
+
super().__init__()
|
|
444
|
+
self.kernel_size = kernel_size
|
|
445
|
+
self.wv_planes = wv_planes
|
|
446
|
+
self.embed_dim = embed_dim
|
|
447
|
+
self.kernel_size = kernel_size
|
|
448
|
+
self._num_kernel = self.kernel_size * self.kernel_size * self.embed_dim
|
|
449
|
+
self.inter_dim = inter_dim
|
|
450
|
+
self.patch_size = (kernel_size, kernel_size)
|
|
451
|
+
self.num_patches = -1
|
|
452
|
+
|
|
453
|
+
self.language_proj = nn.Linear(
|
|
454
|
+
2048, self.wv_planes
|
|
455
|
+
) # project to the same dimension as wv_planes
|
|
456
|
+
|
|
457
|
+
self.weight_generator = TransformerWeightGenerator(
|
|
458
|
+
wv_planes, self._num_kernel, embed_dim
|
|
459
|
+
)
|
|
460
|
+
self.scaler = 0.01
|
|
461
|
+
|
|
462
|
+
self.fclayer = FCResLayer(wv_planes)
|
|
463
|
+
|
|
464
|
+
self._init_weights()
|
|
465
|
+
|
|
466
|
+
def _get_weights(self, waves):
|
|
467
|
+
dynamic_weights = self.weight_generator(waves)
|
|
468
|
+
|
|
469
|
+
return dynamic_weights
|
|
470
|
+
|
|
471
|
+
def weight_init(self, m):
|
|
472
|
+
if isinstance(m, nn.Linear):
|
|
473
|
+
init.xavier_uniform_(m.weight)
|
|
474
|
+
m.bias.data.fill_(0.01)
|
|
475
|
+
|
|
476
|
+
def _init_weights(self):
|
|
477
|
+
"""Initialize the base weights and dynamic mlp weights"""
|
|
478
|
+
self.weight_generator.apply(self.weight_init)
|
|
479
|
+
self.fclayer.apply(self.weight_init)
|
|
480
|
+
|
|
481
|
+
def forward(self, img_feat, language_embed, kernel_size=None):
|
|
482
|
+
"""wvs: nm
|
|
483
|
+
bandwidths: nm
|
|
484
|
+
"""
|
|
485
|
+
# wv_feats: 9,128 -> 9, 3x3x3
|
|
486
|
+
emb_language = language_embed.unsqueeze(0)
|
|
487
|
+
waves = self.language_proj(emb_language)
|
|
488
|
+
# print(waves.size())
|
|
489
|
+
|
|
490
|
+
waves = self.fclayer(waves)
|
|
491
|
+
# print(waves.size())
|
|
492
|
+
weight, bias = self._get_weights(waves) # 3x3x3
|
|
493
|
+
|
|
494
|
+
# inplanes = wvs.size(0)
|
|
495
|
+
inplanes = waves.size(0)
|
|
496
|
+
# print(inplanes)
|
|
497
|
+
# Fix bug
|
|
498
|
+
dynamic_weight = weight.view(
|
|
499
|
+
inplanes, self.kernel_size, self.kernel_size, self.embed_dim
|
|
500
|
+
) # 9, 3, 3, 1024
|
|
501
|
+
dynamic_weight = dynamic_weight.permute([3, 0, 1, 2]) # 1024, 9, 3, 3
|
|
502
|
+
|
|
503
|
+
# resize the weight to match different preferred kernel sizes
|
|
504
|
+
if kernel_size != None and self.kernel_size != kernel_size:
|
|
505
|
+
dynamic_weight = pi_resize_patch_embed(
|
|
506
|
+
dynamic_weight, (kernel_size, kernel_size)
|
|
507
|
+
)
|
|
508
|
+
else:
|
|
509
|
+
kernel_size = self.kernel_size
|
|
510
|
+
|
|
511
|
+
if bias is not None:
|
|
512
|
+
bias = bias.view([self.embed_dim]) * self.scaler
|
|
513
|
+
|
|
514
|
+
weights = dynamic_weight * self.scaler
|
|
515
|
+
|
|
516
|
+
dynamic_out = F.conv2d(
|
|
517
|
+
img_feat, weights, bias=bias, stride=kernel_size, padding=1, dilation=1
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
x = dynamic_out
|
|
521
|
+
x = x.flatten(2).transpose(1, 2)
|
|
522
|
+
|
|
523
|
+
return x, waves
|