rslearn 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,134 @@
1
+ # type: ignore
2
+ """Copyright (c) Microsoft Corporation. Licensed under the MIT license."""
3
+
4
+ import math
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from .area import area, radius_earth
11
+
12
+ __all__ = [
13
+ "FourierExpansion",
14
+ "pos_expansion",
15
+ "scale_expansion",
16
+ "lead_time_expansion",
17
+ "levels_expansion",
18
+ "absolute_time_expansion",
19
+ ]
20
+
21
+
22
+ class FourierExpansion(nn.Module):
23
+ """A Fourier series-style expansion into a high-dimensional space.
24
+
25
+ Attributes:
26
+ lower (float): Lower wavelength.
27
+ upper (float): Upper wavelength.
28
+ assert_range (bool): Assert that the encoded tensor is within the specified wavelength
29
+ range.
30
+ """
31
+
32
+ def __init__(self, lower: float, upper: float, assert_range: bool = True) -> None:
33
+ """Initialise.
34
+
35
+ Args:
36
+ lower (float): Lower wavelength.
37
+ upper (float): Upper wavelength.
38
+ assert_range (bool, optional): Assert that the encoded tensor is within the specified
39
+ wavelength range. Defaults to `True`.
40
+ """
41
+ super().__init__()
42
+ self.lower = lower
43
+ self.upper = upper
44
+ self.assert_range = assert_range
45
+
46
+ def forward(self, x: torch.Tensor, d: int) -> torch.Tensor:
47
+ """Perform the expansion.
48
+
49
+ Adds a dimension of length `d` to the end of the shape of `x`.
50
+
51
+ Args:
52
+ x (:class:`torch.Tensor`): Input to expand of shape `(..., n)`. All elements of `x` must
53
+ lie within `[self.lower, self.upper]` if `self.assert_range` is `True`.
54
+ d (int): Dimensionality. Must be a multiple of two.
55
+
56
+ Raises:
57
+ AssertionError: If `self.assert_range` is `True` and not all elements of `x` are not
58
+ within `[self.lower, self.upper]`.
59
+ ValueError: If `d` is not a multiple of two.
60
+
61
+ Returns:
62
+ torch.Tensor: Fourier series-style expansion of `x` of shape `(..., n, d)`.
63
+ """
64
+ # If the input is not within the configured range, the embedding might be ambiguous!
65
+ in_range = torch.logical_and(
66
+ self.lower <= x.abs(), torch.all(x.abs() <= self.upper)
67
+ )
68
+ in_range_or_zero = torch.all(
69
+ torch.logical_or(in_range, x == 0)
70
+ ) # Allow zeros to pass through.
71
+ if self.assert_range and not in_range_or_zero:
72
+ raise AssertionError(
73
+ f"The input tensor is not within the configured range"
74
+ f" `[{self.lower}, {self.upper}]`."
75
+ )
76
+
77
+ # We will use half of the dimensionality for `sin` and the other half for `cos`.
78
+ if not (d % 2 == 0):
79
+ raise ValueError("The dimensionality must be a multiple of two.")
80
+
81
+ # Always perform the expansion with `float64`s to avoid numerical accuracy shenanigans.
82
+ x = x.double()
83
+
84
+ wavelengths = torch.logspace(
85
+ math.log10(self.lower),
86
+ math.log10(self.upper),
87
+ d // 2,
88
+ base=10,
89
+ device=x.device,
90
+ dtype=x.dtype,
91
+ )
92
+ prod = torch.einsum("...i,j->...ij", x, 2 * np.pi / wavelengths)
93
+ encoding = torch.cat((torch.sin(prod), torch.cos(prod)), dim=-1)
94
+
95
+ return encoding.float() # Cast to `float32` to avoid incompatibilities.
96
+
97
+
98
+ # Determine a reasonable smallest value for the scale embedding by assuming a smallest delta in
99
+ # latitudes and longitudes.
100
+ _delta = 0.01 # Reasonable smallest delta in latitude and longitude
101
+ _min_patch_area: float = area(
102
+ torch.tensor(
103
+ [
104
+ # The smallest patches will be at the poles. Just use the north pole.
105
+ [90, 0],
106
+ [90, _delta],
107
+ [90 - _delta, _delta],
108
+ [90 - _delta, 0],
109
+ ],
110
+ dtype=torch.float64,
111
+ )
112
+ ).item()
113
+ _area_earth = 4 * np.pi * radius_earth * radius_earth
114
+
115
+ pos_expansion = FourierExpansion(_delta, 720)
116
+
117
+
118
+ scale_expansion = FourierExpansion(_min_patch_area, _area_earth)
119
+
120
+
121
+ lead_time_expansion = FourierExpansion(1 / 60, 24 * 7 * 3)
122
+
123
+ levels_expansion = FourierExpansion(0.01, 1e5)
124
+
125
+ absolute_time_expansion = FourierExpansion(1, 24 * 365.25, assert_range=False)
126
+
127
+ ### new for SSL4EO-S ###
128
+ # min wavelength: ultraviolet light (100 nm)
129
+ # max wavelength: radio waves (1 m)
130
+ spectrum_central_expansion = FourierExpansion(1e-7, 1)
131
+
132
+ # min bandwidth: 10nm
133
+ # max bandwidth: 1m
134
+ spectrum_width_expansion = FourierExpansion(1e-7, 1)
@@ -0,0 +1,523 @@
1
+ # mypy: ignore-errors
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.nn.init as init
7
+
8
+ # CopernicusFM: meta encoding (follow aurora)
9
+ from .aurora.fourier import FourierExpansion
10
+
11
+ # CopernicusFM: dynamic patch size (follow flexivit)
12
+ from .flexivit.patch_embed import pi_resize_patch_embed
13
+ from .util.pos_embed import get_1d_sincos_pos_embed_from_grid_torch
14
+
15
+
16
+ class TransformerWeightGenerator(nn.Module):
17
+ def __init__(self, input_dim, output_dim, embed_dim, num_heads=4, num_layers=1):
18
+ super(TransformerWeightGenerator, self).__init__()
19
+ encoder_layer = nn.TransformerEncoderLayer(
20
+ d_model=input_dim,
21
+ nhead=num_heads,
22
+ activation="gelu",
23
+ norm_first=False,
24
+ batch_first=False,
25
+ dropout=False,
26
+ )
27
+ self.transformer_encoder = nn.TransformerEncoder(
28
+ encoder_layer, num_layers=num_layers, enable_nested_tensor=False
29
+ )
30
+
31
+ # Linear layer to map transformer output to desired weight shape
32
+ self.fc_weight = nn.Linear(input_dim, output_dim)
33
+ self.fc_bias = nn.Linear(input_dim, embed_dim)
34
+ self.wt_num = 128
35
+ self.weight_tokens = nn.Parameter(torch.empty([self.wt_num, input_dim]))
36
+ self.bias_token = nn.Parameter(torch.empty([1, input_dim]))
37
+
38
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
39
+ torch.nn.init.normal_(self.weight_tokens, std=0.02)
40
+ torch.nn.init.normal_(self.bias_token, std=0.02)
41
+
42
+ def forward(self, x):
43
+ # x should have shape [seq_len, batch, input_dim]
44
+ pos_wave = x
45
+ x = torch.cat([self.weight_tokens, pos_wave], dim=0)
46
+ x = torch.cat([x, self.bias_token], dim=0)
47
+ transformer_output = self.transformer_encoder(x)
48
+ weights = self.fc_weight(transformer_output[self.wt_num : -1] + pos_wave)
49
+ bias = self.fc_bias(
50
+ transformer_output[-1]
51
+ ) # Using the last output to generate bias
52
+ return weights, bias
53
+
54
+
55
+ class GaussianFourierFeatureTransform(torch.nn.Module):
56
+ """An implementation of Gaussian Fourier feature mapping.
57
+
58
+ "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains":
59
+ https://arxiv.org/abs/2006.10739
60
+ https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html
61
+
62
+ Given an input of size [batches, num_input_channels, width, height],
63
+ returns a tensor of size [batches, mapping_size*2, width, height].
64
+ """
65
+
66
+ def __init__(self, num_input_channels, mapping_size=256, scale=10):
67
+ super().__init__()
68
+
69
+ self._num_input_channels = num_input_channels
70
+ self._mapping_size = mapping_size
71
+ torch.manual_seed(42)
72
+ self._B = torch.randn((num_input_channels, mapping_size)) * scale
73
+
74
+ def forward(self, x):
75
+ assert x.dim() == 4, f"Expected 4D input (got {x.dim()}D input)"
76
+
77
+ batches, channels, width, height = x.shape
78
+
79
+ assert (
80
+ channels == self._num_input_channels
81
+ ), f"Expected input to have {self._num_input_channels} channels (got {channels} channels)"
82
+
83
+ # Make shape compatible for matmul with _B.
84
+ # From [B, C, W, H] to [(B*W*H), C].
85
+ x = x.permute(0, 2, 3, 1).reshape(batches * width * height, channels)
86
+
87
+ x = x @ self._B.to(x.device)
88
+
89
+ # From [(B*W*H), C] to [B, W, H, C]
90
+ x = x.view(batches, width, height, self._mapping_size)
91
+ # From [B, W, H, C] to [B, C, W, H]
92
+ x = x.permute(0, 3, 1, 2)
93
+
94
+ x = 2 * np.pi * x
95
+ return torch.cat([torch.sin(x), torch.cos(x)], dim=1)
96
+
97
+
98
+ class Basic1d(nn.Module):
99
+ def __init__(self, in_channels, out_channels, bias=True):
100
+ super().__init__()
101
+ conv = nn.Linear(in_channels, out_channels, bias)
102
+ self.conv = nn.Sequential(
103
+ conv,
104
+ )
105
+ if not bias:
106
+ self.conv.add_module("ln", nn.LayerNorm(out_channels))
107
+ self.conv.add_module("relu", nn.ReLU(inplace=True))
108
+
109
+ def forward(self, x):
110
+ out = self.conv(x)
111
+ return out
112
+
113
+
114
+ class FCResLayer(nn.Module):
115
+ def __init__(self, linear_size=128):
116
+ super(FCResLayer, self).__init__()
117
+ self.l_size = linear_size
118
+ self.nonlin1 = nn.ReLU(inplace=True)
119
+ self.nonlin2 = nn.ReLU(inplace=True)
120
+ # self.dropout1 = nn.Dropout()
121
+ self.w1 = nn.Linear(self.l_size, self.l_size)
122
+ self.w2 = nn.Linear(self.l_size, self.l_size)
123
+
124
+ def forward(self, x):
125
+ y = self.w1(x)
126
+ y = self.nonlin1(y)
127
+ # y = self.dropout1(y)
128
+ y = self.w2(y)
129
+ y = self.nonlin2(y)
130
+ out = x + y
131
+ return out
132
+
133
+
134
+ class Dynamic_MLP_Decoder(nn.Module):
135
+ def __init__(self, wv_planes, inter_dim=128, kernel_size=16, decoder_embed=512):
136
+ super().__init__()
137
+ self.kernel_size = kernel_size
138
+ self.wv_planes = wv_planes
139
+ self.inter_dim = inter_dim
140
+ self.decoder_embed = decoder_embed
141
+ self._num_kernel = self.kernel_size * self.kernel_size * self.decoder_embed
142
+
143
+ # self.weight_generator = nn.Sequential(Basic1d(wv_planes, self.inter_dim, bias=True),
144
+ # nn.Linear(self.inter_dim, self._num_kernel))
145
+ self.weight_generator = TransformerWeightGenerator(
146
+ wv_planes, self._num_kernel, decoder_embed
147
+ )
148
+ self.scaler = 0.01
149
+
150
+ self._init_weights()
151
+
152
+ def _get_weights(self, waves, batch=True):
153
+ dweights = []
154
+ dynamic_weights = None
155
+ if batch:
156
+ dynamic_weights = self.weight_generator(waves)
157
+ else:
158
+ for i in range(waves.size(0)):
159
+ dweights.append(self.weight_generator(waves[i]))
160
+ dynamic_weights = torch.stack(dweights, dim=0)
161
+
162
+ return dynamic_weights
163
+
164
+ def weight_init(self, m):
165
+ if isinstance(m, nn.Linear):
166
+ init.xavier_uniform_(m.weight)
167
+ m.bias.data.fill_(0.01)
168
+
169
+ def _init_weights(self):
170
+ """Initialize the base weights and dynamic mlp weights"""
171
+ self.weight_generator.apply(self.weight_init)
172
+
173
+ def forward(self, img_feat, waves, kernel_size=None):
174
+ inplanes = waves.size(0)
175
+ # wv_feats: 9,128 -> 9*16*16,512
176
+ weight, bias = self._get_weights(waves) # 9,16*16*512
177
+ # dynamic_weight = weight.view(
178
+ # inplanes * self.kernel_size * self.kernel_size, self.decoder_embed
179
+ # ) # 9*16*16,512
180
+
181
+ # CopernicusFM: dynamic patch size
182
+ dynamic_weight = weight.view(
183
+ inplanes, self.kernel_size, self.kernel_size, self.decoder_embed
184
+ )
185
+ dynamic_weight = dynamic_weight.permute([3, 0, 1, 2])
186
+ # resize the weight to match different preferred kernel sizes
187
+ if kernel_size != None and self.kernel_size != kernel_size:
188
+ dynamic_weight = pi_resize_patch_embed(
189
+ dynamic_weight, (kernel_size, kernel_size)
190
+ ) # 512, 9, p, p
191
+ else:
192
+ kernel_size = self.kernel_size
193
+ dynamic_weight = (
194
+ dynamic_weight.permute([1, 2, 3, 0])
195
+ .contiguous()
196
+ .view(-1, self.decoder_embed)
197
+ ) # 9*p*p,512
198
+
199
+ weights = dynamic_weight * self.scaler
200
+
201
+ dynamic_out = F.linear(img_feat, weights, bias=None)
202
+ x = dynamic_out
203
+ return x
204
+
205
+
206
+ class Dynamic_Patch_Embed(nn.Module):
207
+ """Input: channels of wavelength (normalized): List -> List
208
+ kernel size of the depth-wise convolution: kernel_size, default 3x3
209
+ wv_planes
210
+ inplanes
211
+ """
212
+
213
+ def __init__(self, wv_planes, inter_dim=128, kernel_size=3, embed_dim=1024):
214
+ super().__init__()
215
+ self.kernel_size = kernel_size
216
+ self.wv_planes = wv_planes
217
+ self.embed_dim = embed_dim
218
+ self.kernel_size = kernel_size
219
+ self.patch_size = (kernel_size, kernel_size)
220
+ self.weight2 = nn.Parameter(
221
+ torch.empty([embed_dim, 2, kernel_size, kernel_size])
222
+ )
223
+ self.bias2 = nn.Parameter(torch.empty([embed_dim]))
224
+ self.weight3 = nn.Parameter(
225
+ torch.empty([embed_dim, 3, kernel_size, kernel_size])
226
+ )
227
+ self.bias3 = nn.Parameter(torch.empty([embed_dim]))
228
+ self.weight4 = nn.Parameter(
229
+ torch.empty([embed_dim, 4, kernel_size, kernel_size])
230
+ )
231
+ self.bias4 = nn.Parameter(torch.empty([embed_dim]))
232
+ self.weight9 = nn.Parameter(
233
+ torch.empty([embed_dim, 9, kernel_size, kernel_size])
234
+ )
235
+ self.bias9 = nn.Parameter(torch.empty([embed_dim]))
236
+ self.weight70 = nn.Parameter(
237
+ torch.empty([embed_dim, 70, kernel_size, kernel_size])
238
+ )
239
+ self.bias70 = nn.Parameter(torch.empty([embed_dim]))
240
+ self.weights = {
241
+ 2: self.weight2,
242
+ 3: self.weight3,
243
+ 4: self.weight4,
244
+ 9: self.weight9,
245
+ 70: self.weight70,
246
+ }
247
+ self.biass = {
248
+ 2: self.bias2,
249
+ 3: self.bias3,
250
+ 4: self.bias4,
251
+ 9: self.bias9,
252
+ 70: self.bias70,
253
+ }
254
+
255
+ def forward(self, img_feat, waves):
256
+ inplanes = waves.size(0)
257
+ # wv_feats: 9,128 -> 9, 3x3x3
258
+ weights = self.weights[inplanes]
259
+ bias = self.biass[inplanes]
260
+
261
+ dynamic_out = F.conv2d(
262
+ img_feat, weights, bias=bias, stride=self.kernel_size, padding=1, dilation=1
263
+ )
264
+
265
+ x = dynamic_out
266
+ x = x.flatten(2).transpose(1, 2)
267
+
268
+ return x
269
+
270
+
271
+ class Dynamic_MLP_OFA(nn.Module):
272
+ """Input: channels of wavelength (normalized): List -> List
273
+ kernel size of the depth-wise convolution: kernel_size, default 3x3
274
+ wv_planes
275
+ inplanes
276
+ """
277
+
278
+ def __init__(self, wv_planes, inter_dim=128, kernel_size=3, embed_dim=1024):
279
+ super().__init__()
280
+ self.kernel_size = kernel_size
281
+ self.wv_planes = wv_planes
282
+ self.embed_dim = embed_dim
283
+ self.kernel_size = kernel_size
284
+ self._num_kernel = self.kernel_size * self.kernel_size * self.embed_dim
285
+ self.inter_dim = inter_dim
286
+ self.patch_size = (kernel_size, kernel_size)
287
+ self.num_patches = -1
288
+
289
+ self.weight_generator = TransformerWeightGenerator(
290
+ wv_planes, self._num_kernel, embed_dim
291
+ )
292
+ self.scaler = 0.01
293
+
294
+ self.fclayer = FCResLayer(wv_planes)
295
+
296
+ self._init_weights()
297
+
298
+ def _get_weights(self, waves):
299
+ dynamic_weights = self.weight_generator(waves)
300
+ return dynamic_weights
301
+
302
+ def weight_init(self, m):
303
+ if isinstance(m, nn.Linear):
304
+ init.xavier_uniform_(m.weight)
305
+ m.bias.data.fill_(0.01)
306
+
307
+ def _init_weights(self):
308
+ """Initialize the base weights and dynamic mlp weights"""
309
+ self.weight_generator.apply(self.weight_init)
310
+ self.fclayer.apply(self.weight_init)
311
+
312
+ def forward(self, img_feat, wvs):
313
+ inplanes = wvs.size(0)
314
+ # wv_feats: 9,128 -> 9, 3x3x3
315
+ waves = get_1d_sincos_pos_embed_from_grid_torch(self.wv_planes, wvs * 1000)
316
+ waves = self.fclayer(waves)
317
+ weight, bias = self._get_weights(waves) # 3x3x3
318
+ # bias = None
319
+
320
+ # dynamic_weight = weight.view(self.embed_dim, inplanes, self.kernel_size, self.kernel_size) #3xoutdx16x16
321
+ dynamic_weight = weight.view(
322
+ inplanes, self.kernel_size, self.kernel_size, self.embed_dim
323
+ )
324
+ dynamic_weight = dynamic_weight.permute([3, 0, 1, 2])
325
+ if bias is not None:
326
+ bias = bias.view([self.embed_dim]) * self.scaler
327
+
328
+ weights = dynamic_weight * self.scaler
329
+
330
+ dynamic_out = F.conv2d(
331
+ img_feat, weights, bias=bias, stride=self.kernel_size, padding=1, dilation=1
332
+ )
333
+
334
+ x = dynamic_out
335
+ x = x.flatten(2).transpose(1, 2)
336
+
337
+ return x, waves
338
+
339
+
340
+ class Dynamic_MLP_OFA_spectral(nn.Module):
341
+ """Input: channels of wavelength and bandwidth (normalized): List -> List
342
+ kernel size of the depth-wise convolution: kernel_size, default 3x3
343
+ wv_planes
344
+ inplanes
345
+ """
346
+
347
+ def __init__(self, wv_planes, inter_dim=128, kernel_size=3, embed_dim=1024):
348
+ super().__init__()
349
+ self.kernel_size = kernel_size
350
+ self.wv_planes = wv_planes
351
+ self.embed_dim = embed_dim
352
+ self.kernel_size = kernel_size
353
+ self._num_kernel = self.kernel_size * self.kernel_size * self.embed_dim
354
+ self.inter_dim = inter_dim
355
+ self.patch_size = (kernel_size, kernel_size)
356
+ self.num_patches = -1
357
+
358
+ ## CopernicusFM: fourier embedding for wavelength and bandwidth
359
+ # min wavelength: ultraviolet light (100 nm)
360
+ # max wavelength: radio waves (1 m)
361
+ self.spectrum_central_expansion = FourierExpansion(100, 1e9)
362
+ # min bandwidth: s2 ~ 10nm
363
+ # max bandwidth: s1 ~ 1m
364
+ self.spectrum_bandwidth_expansion = FourierExpansion(1, 1e9)
365
+
366
+ self.weight_generator = TransformerWeightGenerator(
367
+ wv_planes, self._num_kernel, embed_dim
368
+ )
369
+ self.scaler = 0.01
370
+
371
+ self.fclayer = FCResLayer(wv_planes)
372
+
373
+ self._init_weights()
374
+
375
+ def _get_weights(self, waves):
376
+ dynamic_weights = self.weight_generator(waves)
377
+
378
+ return dynamic_weights
379
+
380
+ def weight_init(self, m):
381
+ if isinstance(m, nn.Linear):
382
+ init.xavier_uniform_(m.weight)
383
+ m.bias.data.fill_(0.01)
384
+
385
+ def _init_weights(self):
386
+ """Initialize the base weights and dynamic mlp weights"""
387
+ self.weight_generator.apply(self.weight_init)
388
+ self.fclayer.apply(self.weight_init)
389
+
390
+ def forward(self, img_feat, wvs, bandwidths, kernel_size=None):
391
+ """wvs: nm
392
+ bandwidths: nm
393
+ """
394
+ inplanes = wvs.size(0)
395
+ # wv_feats: 9,128 -> 9, 3x3x3
396
+ # waves = get_1d_sincos_pos_embed_from_grid_torch(self.wv_planes, wvs * 1000) # dofa: fixed sincos pos embedding
397
+ # waves = get_1d_fourier_pos_embed_from_grid_torch(self.wv_planes, wvs * 1000) # new: fourier pos embedding
398
+ emb_central = self.spectrum_central_expansion(wvs, self.wv_planes)
399
+ emb_bandwidth = self.spectrum_bandwidth_expansion(bandwidths, self.wv_planes)
400
+ waves = (
401
+ emb_central + emb_bandwidth
402
+ ) # simply add two embeddings, can be more complex later
403
+
404
+ waves = self.fclayer(waves)
405
+ weight, bias = self._get_weights(waves) # 3x3x3
406
+
407
+ # Fix bug
408
+ dynamic_weight = weight.view(
409
+ inplanes, self.kernel_size, self.kernel_size, self.embed_dim
410
+ ) # 9, 3, 3, 1024
411
+ dynamic_weight = dynamic_weight.permute([3, 0, 1, 2]) # 1024, 9, 3, 3
412
+ # resize the weight to match different preferred kernel sizes
413
+ if kernel_size != None and self.kernel_size != kernel_size:
414
+ dynamic_weight = pi_resize_patch_embed(
415
+ dynamic_weight, (kernel_size, kernel_size)
416
+ )
417
+ else:
418
+ kernel_size = self.kernel_size
419
+
420
+ if bias is not None:
421
+ bias = bias.view([self.embed_dim]) * self.scaler
422
+
423
+ weights = dynamic_weight * self.scaler
424
+
425
+ dynamic_out = F.conv2d(
426
+ img_feat, weights, bias=bias, stride=kernel_size, padding=1, dilation=1
427
+ )
428
+
429
+ x = dynamic_out
430
+ x = x.flatten(2).transpose(1, 2)
431
+
432
+ return x, waves
433
+
434
+
435
+ class Dynamic_MLP_OFA_variable(nn.Module):
436
+ """Input: language embedding of variable name: Pytorch tensor
437
+ kernel size of the depth-wise convolution: kernel_size, default 3x3
438
+ wv_planes
439
+ inplanes
440
+ """
441
+
442
+ def __init__(self, wv_planes, inter_dim=128, kernel_size=3, embed_dim=1024):
443
+ super().__init__()
444
+ self.kernel_size = kernel_size
445
+ self.wv_planes = wv_planes
446
+ self.embed_dim = embed_dim
447
+ self.kernel_size = kernel_size
448
+ self._num_kernel = self.kernel_size * self.kernel_size * self.embed_dim
449
+ self.inter_dim = inter_dim
450
+ self.patch_size = (kernel_size, kernel_size)
451
+ self.num_patches = -1
452
+
453
+ self.language_proj = nn.Linear(
454
+ 2048, self.wv_planes
455
+ ) # project to the same dimension as wv_planes
456
+
457
+ self.weight_generator = TransformerWeightGenerator(
458
+ wv_planes, self._num_kernel, embed_dim
459
+ )
460
+ self.scaler = 0.01
461
+
462
+ self.fclayer = FCResLayer(wv_planes)
463
+
464
+ self._init_weights()
465
+
466
+ def _get_weights(self, waves):
467
+ dynamic_weights = self.weight_generator(waves)
468
+
469
+ return dynamic_weights
470
+
471
+ def weight_init(self, m):
472
+ if isinstance(m, nn.Linear):
473
+ init.xavier_uniform_(m.weight)
474
+ m.bias.data.fill_(0.01)
475
+
476
+ def _init_weights(self):
477
+ """Initialize the base weights and dynamic mlp weights"""
478
+ self.weight_generator.apply(self.weight_init)
479
+ self.fclayer.apply(self.weight_init)
480
+
481
+ def forward(self, img_feat, language_embed, kernel_size=None):
482
+ """wvs: nm
483
+ bandwidths: nm
484
+ """
485
+ # wv_feats: 9,128 -> 9, 3x3x3
486
+ emb_language = language_embed.unsqueeze(0)
487
+ waves = self.language_proj(emb_language)
488
+ # print(waves.size())
489
+
490
+ waves = self.fclayer(waves)
491
+ # print(waves.size())
492
+ weight, bias = self._get_weights(waves) # 3x3x3
493
+
494
+ # inplanes = wvs.size(0)
495
+ inplanes = waves.size(0)
496
+ # print(inplanes)
497
+ # Fix bug
498
+ dynamic_weight = weight.view(
499
+ inplanes, self.kernel_size, self.kernel_size, self.embed_dim
500
+ ) # 9, 3, 3, 1024
501
+ dynamic_weight = dynamic_weight.permute([3, 0, 1, 2]) # 1024, 9, 3, 3
502
+
503
+ # resize the weight to match different preferred kernel sizes
504
+ if kernel_size != None and self.kernel_size != kernel_size:
505
+ dynamic_weight = pi_resize_patch_embed(
506
+ dynamic_weight, (kernel_size, kernel_size)
507
+ )
508
+ else:
509
+ kernel_size = self.kernel_size
510
+
511
+ if bias is not None:
512
+ bias = bias.view([self.embed_dim]) * self.scaler
513
+
514
+ weights = dynamic_weight * self.scaler
515
+
516
+ dynamic_out = F.conv2d(
517
+ img_feat, weights, bias=bias, stride=kernel_size, padding=1, dilation=1
518
+ )
519
+
520
+ x = dynamic_out
521
+ x = x.flatten(2).transpose(1, 2)
522
+
523
+ return x, waves