rslearn 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/dataset.py +23 -4
- rslearn/data_sources/planetary_computer.py +52 -0
- rslearn/dataset/handler_summaries.py +1 -0
- rslearn/dataset/manage.py +16 -2
- rslearn/models/anysat.py +5 -1
- rslearn/models/dinov3.py +6 -1
- rslearn/models/feature_center_crop.py +50 -0
- rslearn/models/olmoearth_pretrain/model.py +88 -27
- rslearn/models/prithvi.py +9 -1
- rslearn/train/lightning_module.py +0 -3
- rslearn/train/prediction_writer.py +25 -8
- rslearn/train/tasks/classification.py +2 -2
- rslearn/train/tasks/detection.py +5 -5
- rslearn/train/tasks/embedding.py +116 -0
- rslearn/train/tasks/per_pixel_regression.py +5 -4
- rslearn/train/tasks/regression.py +5 -5
- rslearn/train/transforms/pad.py +3 -3
- rslearn/utils/raster_format.py +38 -0
- {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/METADATA +3 -2
- {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/RECORD +25 -31
- rslearn-0.0.13.dist-info/licenses/NOTICE +115 -0
- rslearn/models/copernicusfm.py +0 -228
- rslearn/models/copernicusfm_src/__init__.py +0 -1
- rslearn/models/copernicusfm_src/aurora/area.py +0 -50
- rslearn/models/copernicusfm_src/aurora/fourier.py +0 -134
- rslearn/models/copernicusfm_src/dynamic_hypernetwork.py +0 -523
- rslearn/models/copernicusfm_src/flexivit/patch_embed.py +0 -260
- rslearn/models/copernicusfm_src/flexivit/utils.py +0 -69
- rslearn/models/copernicusfm_src/model_vit.py +0 -348
- rslearn/models/copernicusfm_src/util/pos_embed.py +0 -216
- {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/WHEEL +0 -0
- {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/top_level.txt +0 -0
|
@@ -1,260 +0,0 @@
|
|
|
1
|
-
# type: ignore
|
|
2
|
-
from collections.abc import Sequence
|
|
3
|
-
|
|
4
|
-
import numpy as np
|
|
5
|
-
import torch
|
|
6
|
-
import torch.nn as nn
|
|
7
|
-
import torch.nn.functional as F
|
|
8
|
-
from einops import rearrange
|
|
9
|
-
from functorch import vmap
|
|
10
|
-
from torch import Tensor
|
|
11
|
-
|
|
12
|
-
from .utils import to_2tuple
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def pi_resize_patch_embed(
|
|
16
|
-
patch_embed: Tensor,
|
|
17
|
-
new_patch_size: tuple[int, int],
|
|
18
|
-
interpolation: str = "bicubic",
|
|
19
|
-
antialias: bool = True,
|
|
20
|
-
):
|
|
21
|
-
"""Resample patch embedding weights to a target resolution via pseudo-inverse
|
|
22
|
-
resizing.
|
|
23
|
-
|
|
24
|
-
Based on:
|
|
25
|
-
https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py
|
|
26
|
-
https://arxiv.org/abs/2212.08013
|
|
27
|
-
|
|
28
|
-
Args:
|
|
29
|
-
patch_embed: Patch embedding parameters of size [d, c, h, w]
|
|
30
|
-
new_patch_size: Target [height, width] of embedding
|
|
31
|
-
interpolation: Resize interpolation type
|
|
32
|
-
antialias: Whether to apply antialiasing resizing
|
|
33
|
-
Returns:
|
|
34
|
-
Resized pos_embed of size [d, c h', w']
|
|
35
|
-
"""
|
|
36
|
-
assert len(patch_embed.shape) == 4, "Patch embed kernel should be a 4D tensor"
|
|
37
|
-
assert len(new_patch_size) == 2, "New patch size should only be (height, width)"
|
|
38
|
-
|
|
39
|
-
old_patch_size = tuple(patch_embed.shape[2:])
|
|
40
|
-
|
|
41
|
-
# Return original kernel if no resize is necessary
|
|
42
|
-
if old_patch_size == new_patch_size:
|
|
43
|
-
return patch_embed
|
|
44
|
-
|
|
45
|
-
def resize(x: Tensor, shape: tuple[int, int]):
|
|
46
|
-
x_resized = F.interpolate(
|
|
47
|
-
x[None, None, ...],
|
|
48
|
-
shape,
|
|
49
|
-
mode=interpolation,
|
|
50
|
-
antialias=antialias,
|
|
51
|
-
)
|
|
52
|
-
return x_resized[0, 0, ...]
|
|
53
|
-
|
|
54
|
-
def calculate_pinv(old_shape: tuple[int, int], new_shape: tuple[int, int]):
|
|
55
|
-
mat = []
|
|
56
|
-
for i in range(np.prod(old_shape)):
|
|
57
|
-
basis_vec = torch.zeros(old_shape)
|
|
58
|
-
basis_vec[np.unravel_index(i, old_shape)] = 1.0
|
|
59
|
-
mat.append(resize(basis_vec, new_shape).reshape(-1))
|
|
60
|
-
resize_matrix = torch.stack(mat)
|
|
61
|
-
return torch.linalg.pinv(resize_matrix)
|
|
62
|
-
|
|
63
|
-
# Calculate pseudo-inverse of resize matrix
|
|
64
|
-
resize_matrix_pinv = calculate_pinv(old_patch_size, new_patch_size)
|
|
65
|
-
resize_matrix_pinv = resize_matrix_pinv.to(patch_embed.device)
|
|
66
|
-
|
|
67
|
-
def resample_patch_embed(patch_embed: Tensor):
|
|
68
|
-
h, w = new_patch_size
|
|
69
|
-
resampled_kernel = resize_matrix_pinv @ patch_embed.reshape(-1)
|
|
70
|
-
return rearrange(resampled_kernel, "(h w) -> h w", h=h, w=w)
|
|
71
|
-
|
|
72
|
-
v_resample_patch_embed = vmap(vmap(resample_patch_embed, 0, 0), 1, 1)
|
|
73
|
-
|
|
74
|
-
return v_resample_patch_embed(patch_embed)
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def interpolate_resize_patch_embed(
|
|
78
|
-
patch_embed: Tensor,
|
|
79
|
-
new_patch_size: tuple[int, int],
|
|
80
|
-
interpolation: str = "bicubic",
|
|
81
|
-
antialias: bool = True,
|
|
82
|
-
):
|
|
83
|
-
"""Resample patch embedding weights to a target resolution via interpolation
|
|
84
|
-
|
|
85
|
-
Args:
|
|
86
|
-
patch_embed: Patch embedding parameters of size [d, c, h, w]
|
|
87
|
-
new_patch_size: Target [height, width] of embedding
|
|
88
|
-
interpolation: Resize interpolation type
|
|
89
|
-
antialias: Whether to apply antialiasing resizing
|
|
90
|
-
Returns:
|
|
91
|
-
Resized pos_embed of size [d, c h', w']
|
|
92
|
-
"""
|
|
93
|
-
assert len(patch_embed.shape) == 4, "Patch embed kernel should be a 4D tensor"
|
|
94
|
-
assert len(new_patch_size) == 2, "New patch size should only be (height, width)"
|
|
95
|
-
|
|
96
|
-
patch_embed = F.interpolate(
|
|
97
|
-
patch_embed, new_patch_size, mode=interpolation, antialias=antialias
|
|
98
|
-
)
|
|
99
|
-
|
|
100
|
-
return patch_embed
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
class FlexiPatchEmbed(nn.Module):
|
|
104
|
-
def __init__(
|
|
105
|
-
self,
|
|
106
|
-
img_size: int | tuple[int, int] = 240,
|
|
107
|
-
patch_size: int | tuple[int, int] = 32,
|
|
108
|
-
grid_size: int | tuple[int, int] = 7,
|
|
109
|
-
in_chans: int = 3,
|
|
110
|
-
embed_dim: int = 768,
|
|
111
|
-
norm_layer: nn.Module | None = None,
|
|
112
|
-
flatten: bool = True,
|
|
113
|
-
bias: bool = True,
|
|
114
|
-
patch_size_seq: Sequence[int] = (8, 10, 12, 15, 16, 20, 24, 30, 40, 48),
|
|
115
|
-
patch_size_probs: Sequence[float] | None = None,
|
|
116
|
-
interpolation: str = "bicubic",
|
|
117
|
-
antialias: bool = True,
|
|
118
|
-
) -> None:
|
|
119
|
-
"""2D image to patch embedding w/ flexible patch sizes
|
|
120
|
-
Extended from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/patch_embed.py#L24
|
|
121
|
-
|
|
122
|
-
Args:
|
|
123
|
-
img_size: Input image size
|
|
124
|
-
patch_size: Base patch size. i.e the size of the parameter buffer
|
|
125
|
-
grid_size: Size of pos_embed buffer
|
|
126
|
-
in_chans: Number of input image channels
|
|
127
|
-
embed_dim: Network embedding dimension size
|
|
128
|
-
norm_layer: Optional normalization layer
|
|
129
|
-
flatten: Whether to flatten the spatial dimensions of the output
|
|
130
|
-
bias: Whether to use bias in convolution
|
|
131
|
-
patch_size_seq: List of patch sizes to randomly sample from
|
|
132
|
-
patch_size_probs: Optional list of probabilities to sample corresponding
|
|
133
|
-
patch_size_seq elements. If None, then uniform distribution is used
|
|
134
|
-
interpolation: Resize interpolation type
|
|
135
|
-
antialias: Whether to apply antialiasing resizing
|
|
136
|
-
"""
|
|
137
|
-
super().__init__()
|
|
138
|
-
|
|
139
|
-
self.img_size = to_2tuple(img_size)
|
|
140
|
-
self.patch_size = to_2tuple(patch_size)
|
|
141
|
-
self.grid_size = to_2tuple(grid_size)
|
|
142
|
-
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
|
143
|
-
|
|
144
|
-
self.flatten = flatten
|
|
145
|
-
self.proj = nn.Conv2d(
|
|
146
|
-
in_chans,
|
|
147
|
-
embed_dim,
|
|
148
|
-
kernel_size=self.patch_size,
|
|
149
|
-
stride=self.patch_size,
|
|
150
|
-
bias=bias,
|
|
151
|
-
)
|
|
152
|
-
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
|
153
|
-
|
|
154
|
-
# Flexi specific attributes
|
|
155
|
-
self.interpolation = interpolation
|
|
156
|
-
self.antialias = antialias
|
|
157
|
-
|
|
158
|
-
self.patch_size_seq = patch_size_seq
|
|
159
|
-
|
|
160
|
-
if self.patch_size_seq:
|
|
161
|
-
if not patch_size_probs:
|
|
162
|
-
n = len(self.patch_size_seq)
|
|
163
|
-
self.patch_size_probs = [1.0 / n] * n
|
|
164
|
-
else:
|
|
165
|
-
self.patch_size_probs = [
|
|
166
|
-
p / sum(patch_size_probs) for p in patch_size_probs
|
|
167
|
-
]
|
|
168
|
-
else:
|
|
169
|
-
self.patch_size_probs = []
|
|
170
|
-
|
|
171
|
-
# Pre-calculate pinvs
|
|
172
|
-
self.pinvs = self._cache_pinvs()
|
|
173
|
-
|
|
174
|
-
def _cache_pinvs(self) -> dict:
|
|
175
|
-
"""Pre-calculate all pinv matrices"""
|
|
176
|
-
pinvs = {}
|
|
177
|
-
for ps in self.patch_size_seq:
|
|
178
|
-
ps = to_2tuple(ps)
|
|
179
|
-
pinvs[ps] = self._calculate_pinv(self.patch_size, ps)
|
|
180
|
-
return pinvs
|
|
181
|
-
|
|
182
|
-
def _resize(self, x: Tensor, shape: tuple[int, int]) -> Tensor:
|
|
183
|
-
x_resized = F.interpolate(
|
|
184
|
-
x[None, None, ...],
|
|
185
|
-
shape,
|
|
186
|
-
mode=self.interpolation,
|
|
187
|
-
antialias=self.antialias,
|
|
188
|
-
)
|
|
189
|
-
return x_resized[0, 0, ...]
|
|
190
|
-
|
|
191
|
-
def _calculate_pinv(
|
|
192
|
-
self, old_shape: tuple[int, int], new_shape: tuple[int, int]
|
|
193
|
-
) -> Tensor:
|
|
194
|
-
mat = []
|
|
195
|
-
for i in range(np.prod(old_shape)):
|
|
196
|
-
basis_vec = torch.zeros(old_shape)
|
|
197
|
-
basis_vec[np.unravel_index(i, old_shape)] = 1.0
|
|
198
|
-
mat.append(self._resize(basis_vec, new_shape).reshape(-1))
|
|
199
|
-
resize_matrix = torch.stack(mat)
|
|
200
|
-
return torch.linalg.pinv(resize_matrix)
|
|
201
|
-
|
|
202
|
-
def resize_patch_embed(self, patch_embed: Tensor, new_patch_size: tuple[int, int]):
|
|
203
|
-
"""Resize patch_embed to target resolution via pseudo-inverse resizing"""
|
|
204
|
-
# Return original kernel if no resize is necessary
|
|
205
|
-
if self.patch_size == new_patch_size:
|
|
206
|
-
return patch_embed
|
|
207
|
-
|
|
208
|
-
# Calculate pseudo-inverse of resize matrix
|
|
209
|
-
if new_patch_size not in self.pinvs:
|
|
210
|
-
self.pinvs[new_patch_size] = self._calculate_pinv(
|
|
211
|
-
self.patch_size, new_patch_size
|
|
212
|
-
)
|
|
213
|
-
pinv = self.pinvs[new_patch_size]
|
|
214
|
-
pinv = pinv.to(patch_embed.device)
|
|
215
|
-
|
|
216
|
-
def resample_patch_embed(patch_embed: Tensor):
|
|
217
|
-
h, w = new_patch_size
|
|
218
|
-
resampled_kernel = pinv @ patch_embed.reshape(-1)
|
|
219
|
-
return rearrange(resampled_kernel, "(h w) -> h w", h=h, w=w)
|
|
220
|
-
|
|
221
|
-
v_resample_patch_embed = vmap(vmap(resample_patch_embed, 0, 0), 1, 1)
|
|
222
|
-
|
|
223
|
-
return v_resample_patch_embed(patch_embed)
|
|
224
|
-
|
|
225
|
-
def forward(
|
|
226
|
-
self,
|
|
227
|
-
x: Tensor,
|
|
228
|
-
patch_size: int | tuple[int, int] | None = None,
|
|
229
|
-
return_patch_size: bool = False,
|
|
230
|
-
) -> Tensor | tuple[Tensor, tuple[int, int]]:
|
|
231
|
-
if not patch_size and not self.training:
|
|
232
|
-
# During evaluation use base patch size if not specified
|
|
233
|
-
patch_size = self.patch_size
|
|
234
|
-
elif not patch_size:
|
|
235
|
-
# During training choose uniformly at random if not specified
|
|
236
|
-
assert self.patch_size_seq, (
|
|
237
|
-
"No patch size specified during forward and no patch_size_seq given to FlexiPatchEmbed"
|
|
238
|
-
)
|
|
239
|
-
patch_size = np.random.choice(self.patch_size_seq, p=self.patch_size_probs)
|
|
240
|
-
|
|
241
|
-
patch_size = to_2tuple(patch_size)
|
|
242
|
-
|
|
243
|
-
# Resize conv weights
|
|
244
|
-
if patch_size == self.patch_size:
|
|
245
|
-
weight = self.proj.weight
|
|
246
|
-
else:
|
|
247
|
-
weight = self.resize_patch_embed(self.proj.weight, patch_size)
|
|
248
|
-
|
|
249
|
-
# Apply conv with resized weights
|
|
250
|
-
x = F.conv2d(x, weight, bias=self.proj.bias, stride=patch_size)
|
|
251
|
-
|
|
252
|
-
if self.flatten:
|
|
253
|
-
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
|
254
|
-
|
|
255
|
-
x = self.norm(x)
|
|
256
|
-
|
|
257
|
-
if return_patch_size:
|
|
258
|
-
return x, patch_size
|
|
259
|
-
|
|
260
|
-
return x
|
|
@@ -1,69 +0,0 @@
|
|
|
1
|
-
# type: ignore
|
|
2
|
-
import collections.abc
|
|
3
|
-
import math
|
|
4
|
-
from itertools import repeat
|
|
5
|
-
from typing import Any
|
|
6
|
-
|
|
7
|
-
import torch
|
|
8
|
-
import torch.nn.functional as F
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def to_2tuple(x: Any) -> tuple:
|
|
12
|
-
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
|
13
|
-
return tuple(x)
|
|
14
|
-
return tuple(repeat(x, 2))
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def resize_abs_pos_embed(
|
|
18
|
-
pos_embed: torch.Tensor,
|
|
19
|
-
new_size: tuple[int, int],
|
|
20
|
-
old_size: int | tuple[int, int] | None = None,
|
|
21
|
-
num_prefix_tokens: int = 1,
|
|
22
|
-
interpolation: str = "bicubic",
|
|
23
|
-
antialias: bool = True,
|
|
24
|
-
) -> torch.Tensor:
|
|
25
|
-
"""Resize absolute position embeddings to a target resolution via interpolation
|
|
26
|
-
|
|
27
|
-
Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/pos_embed.py
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
pos_embed: Position embeddings tensor of size [b, n, d]
|
|
31
|
-
new_size: Target [height, width] of embedding
|
|
32
|
-
old_size: Original [height, width] of embedding
|
|
33
|
-
num_prefix_tokens: Number of non-spatial prefix tokens (eg. cls)
|
|
34
|
-
interpolation: Resize interpolation type
|
|
35
|
-
antialias: Whether to apply antialiasing resizing
|
|
36
|
-
Returns:
|
|
37
|
-
Resized pos_embed of size [b, n', d]
|
|
38
|
-
"""
|
|
39
|
-
new_size = to_2tuple(new_size)
|
|
40
|
-
new_ntok = new_size[0] * new_size[1]
|
|
41
|
-
|
|
42
|
-
if not old_size:
|
|
43
|
-
old_size = int(math.sqrt(pos_embed.shape[1] - num_prefix_tokens)) # type:ignore
|
|
44
|
-
old_size = to_2tuple(old_size)
|
|
45
|
-
|
|
46
|
-
# Return if no resize necessary
|
|
47
|
-
if new_size == old_size:
|
|
48
|
-
return pos_embed
|
|
49
|
-
|
|
50
|
-
if num_prefix_tokens:
|
|
51
|
-
posemb_prefix, pos_embed = (
|
|
52
|
-
pos_embed[:, :num_prefix_tokens],
|
|
53
|
-
pos_embed[:, num_prefix_tokens:],
|
|
54
|
-
)
|
|
55
|
-
else:
|
|
56
|
-
posemb_prefix, pos_embed = None, pos_embed
|
|
57
|
-
|
|
58
|
-
# Interpolate position embedding
|
|
59
|
-
pos_embed = pos_embed.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
|
|
60
|
-
pos_embed = F.interpolate(
|
|
61
|
-
pos_embed, size=new_size, mode=interpolation, antialias=antialias
|
|
62
|
-
)
|
|
63
|
-
pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(1, new_ntok, -1)
|
|
64
|
-
|
|
65
|
-
# Add back extra prefix tokens
|
|
66
|
-
if posemb_prefix is not None:
|
|
67
|
-
pos_embed = torch.cat([posemb_prefix, pos_embed], dim=1)
|
|
68
|
-
|
|
69
|
-
return pos_embed
|
|
@@ -1,348 +0,0 @@
|
|
|
1
|
-
# mypy: ignore-errors
|
|
2
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
3
|
-
# All rights reserved.
|
|
4
|
-
|
|
5
|
-
# This source code is licensed under the license found in the
|
|
6
|
-
# LICENSE file in the root directory of this source tree.
|
|
7
|
-
# --------------------------------------------------------
|
|
8
|
-
# References:
|
|
9
|
-
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
|
10
|
-
# DeiT: https://github.com/facebookresearch/deit
|
|
11
|
-
# --------------------------------------------------------
|
|
12
|
-
|
|
13
|
-
import math
|
|
14
|
-
from functools import partial
|
|
15
|
-
|
|
16
|
-
import torch
|
|
17
|
-
import torch.nn as nn
|
|
18
|
-
from timm.models.vision_transformer import Block
|
|
19
|
-
|
|
20
|
-
from .aurora.fourier import FourierExpansion
|
|
21
|
-
from .dynamic_hypernetwork import Dynamic_MLP_OFA_spectral, Dynamic_MLP_OFA_variable
|
|
22
|
-
from .flexivit.utils import resize_abs_pos_embed
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class CopernicusFMViT(nn.Module):
|
|
26
|
-
"""CopernicusFM: VisionTransformer backbone"""
|
|
27
|
-
|
|
28
|
-
def __init__(
|
|
29
|
-
self,
|
|
30
|
-
img_size=224,
|
|
31
|
-
patch_size=16,
|
|
32
|
-
drop_rate=0.0,
|
|
33
|
-
embed_dim=1024,
|
|
34
|
-
depth=24,
|
|
35
|
-
num_heads=16,
|
|
36
|
-
wv_planes=128,
|
|
37
|
-
num_classes=0,
|
|
38
|
-
global_pool=True,
|
|
39
|
-
mlp_ratio=4.0,
|
|
40
|
-
norm_layer=nn.LayerNorm,
|
|
41
|
-
loc_option="lonlat",
|
|
42
|
-
return_intermediate=False,
|
|
43
|
-
intermediate_indices=None,
|
|
44
|
-
):
|
|
45
|
-
super().__init__()
|
|
46
|
-
|
|
47
|
-
self.wv_planes = wv_planes
|
|
48
|
-
self.global_pool = global_pool
|
|
49
|
-
if self.global_pool:
|
|
50
|
-
norm_layer = norm_layer
|
|
51
|
-
embed_dim = embed_dim
|
|
52
|
-
self.fc_norm = norm_layer(embed_dim)
|
|
53
|
-
else:
|
|
54
|
-
self.norm = norm_layer(embed_dim)
|
|
55
|
-
|
|
56
|
-
self.patch_embed_spectral = Dynamic_MLP_OFA_spectral(
|
|
57
|
-
wv_planes=128, inter_dim=128, kernel_size=16, embed_dim=embed_dim
|
|
58
|
-
)
|
|
59
|
-
self.patch_embed_variable = Dynamic_MLP_OFA_variable(
|
|
60
|
-
wv_planes=128, inter_dim=128, kernel_size=16, embed_dim=embed_dim
|
|
61
|
-
)
|
|
62
|
-
|
|
63
|
-
self.num_patches = (img_size // patch_size) ** 2
|
|
64
|
-
|
|
65
|
-
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
66
|
-
# ---------------------------------------------------------------------------
|
|
67
|
-
|
|
68
|
-
self.pos_embed = nn.Parameter(
|
|
69
|
-
torch.zeros(1, self.num_patches + 1, embed_dim), requires_grad=False
|
|
70
|
-
) # fixed sin-cos embedding
|
|
71
|
-
|
|
72
|
-
self.loc_option = loc_option
|
|
73
|
-
if loc_option == "cartesian":
|
|
74
|
-
self.coord_expansion = FourierExpansion(1e-7, 2)
|
|
75
|
-
elif loc_option == "lonlat":
|
|
76
|
-
self.coord_expansion = FourierExpansion(0.0001, 720)
|
|
77
|
-
|
|
78
|
-
self.scale_expansion = FourierExpansion(0.001, 5.1e8) # 1m2 to 5.1e8 km2
|
|
79
|
-
self.time_expansion = FourierExpansion(
|
|
80
|
-
1, 365.25, assert_range=False
|
|
81
|
-
) # 1 to 365.25 days, enable more than 1 year
|
|
82
|
-
self.coord_fc = nn.Linear(embed_dim, embed_dim)
|
|
83
|
-
self.scale_fc = nn.Linear(embed_dim, embed_dim)
|
|
84
|
-
self.time_fc = nn.Linear(embed_dim, embed_dim)
|
|
85
|
-
# if meta info is not available, set to a learned parameter
|
|
86
|
-
self.coord_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
87
|
-
self.scale_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
88
|
-
self.time_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
89
|
-
|
|
90
|
-
self.blocks = nn.ModuleList(
|
|
91
|
-
[
|
|
92
|
-
Block(
|
|
93
|
-
embed_dim,
|
|
94
|
-
num_heads,
|
|
95
|
-
mlp_ratio,
|
|
96
|
-
qkv_bias=True,
|
|
97
|
-
norm_layer=norm_layer,
|
|
98
|
-
)
|
|
99
|
-
for i in range(depth)
|
|
100
|
-
]
|
|
101
|
-
)
|
|
102
|
-
|
|
103
|
-
self.head_drop = nn.Dropout(drop_rate)
|
|
104
|
-
self.head = (
|
|
105
|
-
nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
self.return_intermediate = return_intermediate
|
|
109
|
-
self.intermediate_indices = intermediate_indices
|
|
110
|
-
|
|
111
|
-
def get_coord_pos_embed(self, lons, lats, embed_dim):
|
|
112
|
-
if self.loc_option == "cartesian":
|
|
113
|
-
# convert to spherical coordinates
|
|
114
|
-
spherical_x = (
|
|
115
|
-
torch.cos(lons * math.pi / 180) * torch.cos(lats * math.pi / 180)
|
|
116
|
-
+ 1
|
|
117
|
-
+ 1e-7
|
|
118
|
-
)
|
|
119
|
-
spherical_y = (
|
|
120
|
-
torch.sin(lons * math.pi / 180) * torch.cos(lats * math.pi / 180)
|
|
121
|
-
+ 1
|
|
122
|
-
+ 1e-7
|
|
123
|
-
)
|
|
124
|
-
spherical_z = torch.sin(lats * math.pi / 180) + 1 + 1e-7
|
|
125
|
-
coord_embed_spherical_x = self.coord_expansion(spherical_x, embed_dim // 3)
|
|
126
|
-
coord_embed_spherical_y = self.coord_expansion(spherical_y, embed_dim // 3)
|
|
127
|
-
coord_embed_spherical_z = self.coord_expansion(spherical_z, embed_dim // 3)
|
|
128
|
-
coord_embed = torch.cat(
|
|
129
|
-
[
|
|
130
|
-
coord_embed_spherical_x,
|
|
131
|
-
coord_embed_spherical_y,
|
|
132
|
-
coord_embed_spherical_z,
|
|
133
|
-
],
|
|
134
|
-
dim=-1,
|
|
135
|
-
) # [B,D]
|
|
136
|
-
elif self.loc_option == "lonlat":
|
|
137
|
-
coord_embed_lon = self.coord_expansion(lons + 180, embed_dim // 2)
|
|
138
|
-
coord_embed_lat = self.coord_expansion(lats + 90, embed_dim // 2)
|
|
139
|
-
coord_embed = torch.cat([coord_embed_lon, coord_embed_lat], dim=-1)
|
|
140
|
-
|
|
141
|
-
if coord_embed.shape[-1] < embed_dim:
|
|
142
|
-
# pad zeros
|
|
143
|
-
coord_embed = torch.cat(
|
|
144
|
-
(
|
|
145
|
-
coord_embed,
|
|
146
|
-
torch.zeros(
|
|
147
|
-
coord_embed.shape[0],
|
|
148
|
-
embed_dim - coord_embed.shape[-1],
|
|
149
|
-
device=coord_embed.device,
|
|
150
|
-
),
|
|
151
|
-
),
|
|
152
|
-
dim=-1,
|
|
153
|
-
)
|
|
154
|
-
|
|
155
|
-
return coord_embed.unsqueeze(1) # [B,1,D]
|
|
156
|
-
|
|
157
|
-
def get_area_pos_embed(self, areas, embed_dim):
|
|
158
|
-
scale_embed = self.scale_expansion(areas, embed_dim) # B, D
|
|
159
|
-
return scale_embed.unsqueeze(1) # [B,1,D]
|
|
160
|
-
|
|
161
|
-
def get_time_pos_embed(self, times, embed_dim):
|
|
162
|
-
time_embed = self.time_expansion(times, embed_dim) # B, D
|
|
163
|
-
return time_embed.unsqueeze(1) # [B,1,D]
|
|
164
|
-
|
|
165
|
-
def forward_features(
|
|
166
|
-
self,
|
|
167
|
-
x,
|
|
168
|
-
meta_info,
|
|
169
|
-
wave_list,
|
|
170
|
-
bandwidth,
|
|
171
|
-
language_embed,
|
|
172
|
-
input_mode,
|
|
173
|
-
kernel_size=None,
|
|
174
|
-
):
|
|
175
|
-
# embed patches
|
|
176
|
-
if input_mode == "spectral":
|
|
177
|
-
wavelist = torch.tensor(wave_list, device=x.device).float()
|
|
178
|
-
bandwidths = torch.tensor(bandwidth, device=x.device).float()
|
|
179
|
-
self.waves = wavelist
|
|
180
|
-
x, _ = self.patch_embed_spectral(x, self.waves, bandwidths, kernel_size)
|
|
181
|
-
elif input_mode == "variable":
|
|
182
|
-
x, _ = self.patch_embed_variable(x, language_embed, kernel_size)
|
|
183
|
-
|
|
184
|
-
# resize pos embed
|
|
185
|
-
num_patches = x.size(1)
|
|
186
|
-
num_patches_sqrt = int(math.sqrt(num_patches))
|
|
187
|
-
num_patches_sqrt_origin = int(math.sqrt(self.num_patches))
|
|
188
|
-
pos_embed = resize_abs_pos_embed(
|
|
189
|
-
self.pos_embed,
|
|
190
|
-
num_patches_sqrt,
|
|
191
|
-
(num_patches_sqrt_origin, num_patches_sqrt_origin),
|
|
192
|
-
num_prefix_tokens=1,
|
|
193
|
-
)
|
|
194
|
-
|
|
195
|
-
# coord, scale and time pos embed
|
|
196
|
-
lons, lats, times, areas = (
|
|
197
|
-
meta_info[:, 0],
|
|
198
|
-
meta_info[:, 1],
|
|
199
|
-
meta_info[:, 2],
|
|
200
|
-
meta_info[:, 3],
|
|
201
|
-
)
|
|
202
|
-
embed_dim = pos_embed.shape[-1]
|
|
203
|
-
if torch.isnan(lons).any() or torch.isnan(lats).any():
|
|
204
|
-
coord_embed = self.coord_token
|
|
205
|
-
else:
|
|
206
|
-
coord_embed = self.get_coord_pos_embed(lons, lats, embed_dim)
|
|
207
|
-
coord_embed = self.coord_fc(coord_embed)
|
|
208
|
-
if torch.isnan(areas).any():
|
|
209
|
-
area_embed = self.scale_token
|
|
210
|
-
else:
|
|
211
|
-
area_embed = self.get_area_pos_embed(areas, embed_dim)
|
|
212
|
-
area_embed = self.scale_fc(area_embed)
|
|
213
|
-
if torch.isnan(times).any():
|
|
214
|
-
time_embed = self.time_token
|
|
215
|
-
else:
|
|
216
|
-
time_embed = self.get_time_pos_embed(times, embed_dim)
|
|
217
|
-
time_embed = self.time_fc(time_embed)
|
|
218
|
-
pos_embed = pos_embed + coord_embed + area_embed + time_embed
|
|
219
|
-
|
|
220
|
-
# add pos embed w/o cls token
|
|
221
|
-
x = x + pos_embed[:, 1:, :]
|
|
222
|
-
|
|
223
|
-
# append cls token
|
|
224
|
-
cls_token = self.cls_token + pos_embed[:, :1, :]
|
|
225
|
-
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
|
226
|
-
x = torch.cat((cls_tokens, x), dim=1)
|
|
227
|
-
|
|
228
|
-
intermediate_features = []
|
|
229
|
-
hw = num_patches_sqrt
|
|
230
|
-
hw_shape = (hw, hw)
|
|
231
|
-
|
|
232
|
-
# apply Transformer blocks
|
|
233
|
-
for i, block in enumerate(self.blocks):
|
|
234
|
-
x = block(x)
|
|
235
|
-
if self.return_intermediate and (i in self.intermediate_indices):
|
|
236
|
-
out = x[:, 1:]
|
|
237
|
-
B, _, C = out.shape
|
|
238
|
-
out = (
|
|
239
|
-
out.reshape(B, hw_shape[0], hw_shape[1], C)
|
|
240
|
-
.permute(0, 3, 1, 2)
|
|
241
|
-
.contiguous()
|
|
242
|
-
)
|
|
243
|
-
intermediate_features.append(out)
|
|
244
|
-
|
|
245
|
-
# if self.global_pool:
|
|
246
|
-
# x = x[:, 1:, :].mean(dim=1) # global pool without cls token
|
|
247
|
-
# outcome = self.fc_norm(x)
|
|
248
|
-
# else:
|
|
249
|
-
# x = self.norm(x)
|
|
250
|
-
# outcome = x[:, 0]
|
|
251
|
-
|
|
252
|
-
# if self.return_intermediate:
|
|
253
|
-
# return outcome, intermediate_features
|
|
254
|
-
|
|
255
|
-
# for segmentation tasks, ignore the norm
|
|
256
|
-
# https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/foundation_models/CopernicusFM/models_dwv_seg.py
|
|
257
|
-
# for classification, we will apply the fc_norm in the wrapper
|
|
258
|
-
return x[:, 1:, :]
|
|
259
|
-
|
|
260
|
-
def forward_head(self, x, pre_logits=False):
|
|
261
|
-
x = self.head_drop(x)
|
|
262
|
-
return x if pre_logits else self.head(x)
|
|
263
|
-
|
|
264
|
-
def forward(
|
|
265
|
-
self,
|
|
266
|
-
x,
|
|
267
|
-
meta_info,
|
|
268
|
-
wave_list,
|
|
269
|
-
bandwidth,
|
|
270
|
-
language_embed,
|
|
271
|
-
input_mode,
|
|
272
|
-
kernel_size=None,
|
|
273
|
-
):
|
|
274
|
-
if self.return_intermediate:
|
|
275
|
-
x, intermediate_features = self.forward_features(
|
|
276
|
-
x,
|
|
277
|
-
meta_info,
|
|
278
|
-
wave_list,
|
|
279
|
-
bandwidth,
|
|
280
|
-
language_embed,
|
|
281
|
-
input_mode,
|
|
282
|
-
kernel_size,
|
|
283
|
-
)
|
|
284
|
-
return x, intermediate_features
|
|
285
|
-
else:
|
|
286
|
-
fx = self.forward_features(
|
|
287
|
-
x,
|
|
288
|
-
meta_info,
|
|
289
|
-
wave_list,
|
|
290
|
-
bandwidth,
|
|
291
|
-
language_embed,
|
|
292
|
-
input_mode,
|
|
293
|
-
kernel_size,
|
|
294
|
-
)
|
|
295
|
-
x = self.forward_head(fx)
|
|
296
|
-
return x, fx
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
def vit_small_patch16(**kwargs):
|
|
300
|
-
model = CopernicusFMViT(
|
|
301
|
-
patch_size=16,
|
|
302
|
-
embed_dim=384,
|
|
303
|
-
depth=12,
|
|
304
|
-
num_heads=6,
|
|
305
|
-
mlp_ratio=4,
|
|
306
|
-
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
|
307
|
-
**kwargs,
|
|
308
|
-
)
|
|
309
|
-
return model
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
def vit_base_patch16(**kwargs):
|
|
313
|
-
model = CopernicusFMViT(
|
|
314
|
-
patch_size=16,
|
|
315
|
-
embed_dim=768,
|
|
316
|
-
depth=12,
|
|
317
|
-
num_heads=12,
|
|
318
|
-
mlp_ratio=4,
|
|
319
|
-
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
|
320
|
-
**kwargs,
|
|
321
|
-
)
|
|
322
|
-
return model
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
def vit_large_patch16(**kwargs):
|
|
326
|
-
model = CopernicusFMViT(
|
|
327
|
-
patch_size=16,
|
|
328
|
-
embed_dim=1024,
|
|
329
|
-
depth=24,
|
|
330
|
-
num_heads=16,
|
|
331
|
-
mlp_ratio=4,
|
|
332
|
-
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
|
333
|
-
**kwargs,
|
|
334
|
-
)
|
|
335
|
-
return model
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
def vit_huge_patch14(**kwargs):
|
|
339
|
-
model = CopernicusFMViT(
|
|
340
|
-
patch_size=14,
|
|
341
|
-
embed_dim=1280,
|
|
342
|
-
depth=32,
|
|
343
|
-
num_heads=16,
|
|
344
|
-
mlp_ratio=4,
|
|
345
|
-
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
|
346
|
-
**kwargs,
|
|
347
|
-
)
|
|
348
|
-
return model
|