rslearn 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. rslearn/config/dataset.py +23 -4
  2. rslearn/data_sources/planetary_computer.py +52 -0
  3. rslearn/dataset/handler_summaries.py +1 -0
  4. rslearn/dataset/manage.py +16 -2
  5. rslearn/models/anysat.py +5 -1
  6. rslearn/models/dinov3.py +6 -1
  7. rslearn/models/feature_center_crop.py +50 -0
  8. rslearn/models/olmoearth_pretrain/model.py +88 -27
  9. rslearn/models/prithvi.py +9 -1
  10. rslearn/train/lightning_module.py +0 -3
  11. rslearn/train/prediction_writer.py +25 -8
  12. rslearn/train/tasks/classification.py +2 -2
  13. rslearn/train/tasks/detection.py +5 -5
  14. rslearn/train/tasks/embedding.py +116 -0
  15. rslearn/train/tasks/per_pixel_regression.py +5 -4
  16. rslearn/train/tasks/regression.py +5 -5
  17. rslearn/train/transforms/pad.py +3 -3
  18. rslearn/utils/raster_format.py +38 -0
  19. {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/METADATA +3 -2
  20. {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/RECORD +25 -31
  21. rslearn-0.0.13.dist-info/licenses/NOTICE +115 -0
  22. rslearn/models/copernicusfm.py +0 -228
  23. rslearn/models/copernicusfm_src/__init__.py +0 -1
  24. rslearn/models/copernicusfm_src/aurora/area.py +0 -50
  25. rslearn/models/copernicusfm_src/aurora/fourier.py +0 -134
  26. rslearn/models/copernicusfm_src/dynamic_hypernetwork.py +0 -523
  27. rslearn/models/copernicusfm_src/flexivit/patch_embed.py +0 -260
  28. rslearn/models/copernicusfm_src/flexivit/utils.py +0 -69
  29. rslearn/models/copernicusfm_src/model_vit.py +0 -348
  30. rslearn/models/copernicusfm_src/util/pos_embed.py +0 -216
  31. {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/WHEEL +0 -0
  32. {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/entry_points.txt +0 -0
  33. {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/licenses/LICENSE +0 -0
  34. {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/top_level.txt +0 -0
@@ -1,260 +0,0 @@
1
- # type: ignore
2
- from collections.abc import Sequence
3
-
4
- import numpy as np
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from einops import rearrange
9
- from functorch import vmap
10
- from torch import Tensor
11
-
12
- from .utils import to_2tuple
13
-
14
-
15
- def pi_resize_patch_embed(
16
- patch_embed: Tensor,
17
- new_patch_size: tuple[int, int],
18
- interpolation: str = "bicubic",
19
- antialias: bool = True,
20
- ):
21
- """Resample patch embedding weights to a target resolution via pseudo-inverse
22
- resizing.
23
-
24
- Based on:
25
- https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py
26
- https://arxiv.org/abs/2212.08013
27
-
28
- Args:
29
- patch_embed: Patch embedding parameters of size [d, c, h, w]
30
- new_patch_size: Target [height, width] of embedding
31
- interpolation: Resize interpolation type
32
- antialias: Whether to apply antialiasing resizing
33
- Returns:
34
- Resized pos_embed of size [d, c h', w']
35
- """
36
- assert len(patch_embed.shape) == 4, "Patch embed kernel should be a 4D tensor"
37
- assert len(new_patch_size) == 2, "New patch size should only be (height, width)"
38
-
39
- old_patch_size = tuple(patch_embed.shape[2:])
40
-
41
- # Return original kernel if no resize is necessary
42
- if old_patch_size == new_patch_size:
43
- return patch_embed
44
-
45
- def resize(x: Tensor, shape: tuple[int, int]):
46
- x_resized = F.interpolate(
47
- x[None, None, ...],
48
- shape,
49
- mode=interpolation,
50
- antialias=antialias,
51
- )
52
- return x_resized[0, 0, ...]
53
-
54
- def calculate_pinv(old_shape: tuple[int, int], new_shape: tuple[int, int]):
55
- mat = []
56
- for i in range(np.prod(old_shape)):
57
- basis_vec = torch.zeros(old_shape)
58
- basis_vec[np.unravel_index(i, old_shape)] = 1.0
59
- mat.append(resize(basis_vec, new_shape).reshape(-1))
60
- resize_matrix = torch.stack(mat)
61
- return torch.linalg.pinv(resize_matrix)
62
-
63
- # Calculate pseudo-inverse of resize matrix
64
- resize_matrix_pinv = calculate_pinv(old_patch_size, new_patch_size)
65
- resize_matrix_pinv = resize_matrix_pinv.to(patch_embed.device)
66
-
67
- def resample_patch_embed(patch_embed: Tensor):
68
- h, w = new_patch_size
69
- resampled_kernel = resize_matrix_pinv @ patch_embed.reshape(-1)
70
- return rearrange(resampled_kernel, "(h w) -> h w", h=h, w=w)
71
-
72
- v_resample_patch_embed = vmap(vmap(resample_patch_embed, 0, 0), 1, 1)
73
-
74
- return v_resample_patch_embed(patch_embed)
75
-
76
-
77
- def interpolate_resize_patch_embed(
78
- patch_embed: Tensor,
79
- new_patch_size: tuple[int, int],
80
- interpolation: str = "bicubic",
81
- antialias: bool = True,
82
- ):
83
- """Resample patch embedding weights to a target resolution via interpolation
84
-
85
- Args:
86
- patch_embed: Patch embedding parameters of size [d, c, h, w]
87
- new_patch_size: Target [height, width] of embedding
88
- interpolation: Resize interpolation type
89
- antialias: Whether to apply antialiasing resizing
90
- Returns:
91
- Resized pos_embed of size [d, c h', w']
92
- """
93
- assert len(patch_embed.shape) == 4, "Patch embed kernel should be a 4D tensor"
94
- assert len(new_patch_size) == 2, "New patch size should only be (height, width)"
95
-
96
- patch_embed = F.interpolate(
97
- patch_embed, new_patch_size, mode=interpolation, antialias=antialias
98
- )
99
-
100
- return patch_embed
101
-
102
-
103
- class FlexiPatchEmbed(nn.Module):
104
- def __init__(
105
- self,
106
- img_size: int | tuple[int, int] = 240,
107
- patch_size: int | tuple[int, int] = 32,
108
- grid_size: int | tuple[int, int] = 7,
109
- in_chans: int = 3,
110
- embed_dim: int = 768,
111
- norm_layer: nn.Module | None = None,
112
- flatten: bool = True,
113
- bias: bool = True,
114
- patch_size_seq: Sequence[int] = (8, 10, 12, 15, 16, 20, 24, 30, 40, 48),
115
- patch_size_probs: Sequence[float] | None = None,
116
- interpolation: str = "bicubic",
117
- antialias: bool = True,
118
- ) -> None:
119
- """2D image to patch embedding w/ flexible patch sizes
120
- Extended from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/patch_embed.py#L24
121
-
122
- Args:
123
- img_size: Input image size
124
- patch_size: Base patch size. i.e the size of the parameter buffer
125
- grid_size: Size of pos_embed buffer
126
- in_chans: Number of input image channels
127
- embed_dim: Network embedding dimension size
128
- norm_layer: Optional normalization layer
129
- flatten: Whether to flatten the spatial dimensions of the output
130
- bias: Whether to use bias in convolution
131
- patch_size_seq: List of patch sizes to randomly sample from
132
- patch_size_probs: Optional list of probabilities to sample corresponding
133
- patch_size_seq elements. If None, then uniform distribution is used
134
- interpolation: Resize interpolation type
135
- antialias: Whether to apply antialiasing resizing
136
- """
137
- super().__init__()
138
-
139
- self.img_size = to_2tuple(img_size)
140
- self.patch_size = to_2tuple(patch_size)
141
- self.grid_size = to_2tuple(grid_size)
142
- self.num_patches = self.grid_size[0] * self.grid_size[1]
143
-
144
- self.flatten = flatten
145
- self.proj = nn.Conv2d(
146
- in_chans,
147
- embed_dim,
148
- kernel_size=self.patch_size,
149
- stride=self.patch_size,
150
- bias=bias,
151
- )
152
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
153
-
154
- # Flexi specific attributes
155
- self.interpolation = interpolation
156
- self.antialias = antialias
157
-
158
- self.patch_size_seq = patch_size_seq
159
-
160
- if self.patch_size_seq:
161
- if not patch_size_probs:
162
- n = len(self.patch_size_seq)
163
- self.patch_size_probs = [1.0 / n] * n
164
- else:
165
- self.patch_size_probs = [
166
- p / sum(patch_size_probs) for p in patch_size_probs
167
- ]
168
- else:
169
- self.patch_size_probs = []
170
-
171
- # Pre-calculate pinvs
172
- self.pinvs = self._cache_pinvs()
173
-
174
- def _cache_pinvs(self) -> dict:
175
- """Pre-calculate all pinv matrices"""
176
- pinvs = {}
177
- for ps in self.patch_size_seq:
178
- ps = to_2tuple(ps)
179
- pinvs[ps] = self._calculate_pinv(self.patch_size, ps)
180
- return pinvs
181
-
182
- def _resize(self, x: Tensor, shape: tuple[int, int]) -> Tensor:
183
- x_resized = F.interpolate(
184
- x[None, None, ...],
185
- shape,
186
- mode=self.interpolation,
187
- antialias=self.antialias,
188
- )
189
- return x_resized[0, 0, ...]
190
-
191
- def _calculate_pinv(
192
- self, old_shape: tuple[int, int], new_shape: tuple[int, int]
193
- ) -> Tensor:
194
- mat = []
195
- for i in range(np.prod(old_shape)):
196
- basis_vec = torch.zeros(old_shape)
197
- basis_vec[np.unravel_index(i, old_shape)] = 1.0
198
- mat.append(self._resize(basis_vec, new_shape).reshape(-1))
199
- resize_matrix = torch.stack(mat)
200
- return torch.linalg.pinv(resize_matrix)
201
-
202
- def resize_patch_embed(self, patch_embed: Tensor, new_patch_size: tuple[int, int]):
203
- """Resize patch_embed to target resolution via pseudo-inverse resizing"""
204
- # Return original kernel if no resize is necessary
205
- if self.patch_size == new_patch_size:
206
- return patch_embed
207
-
208
- # Calculate pseudo-inverse of resize matrix
209
- if new_patch_size not in self.pinvs:
210
- self.pinvs[new_patch_size] = self._calculate_pinv(
211
- self.patch_size, new_patch_size
212
- )
213
- pinv = self.pinvs[new_patch_size]
214
- pinv = pinv.to(patch_embed.device)
215
-
216
- def resample_patch_embed(patch_embed: Tensor):
217
- h, w = new_patch_size
218
- resampled_kernel = pinv @ patch_embed.reshape(-1)
219
- return rearrange(resampled_kernel, "(h w) -> h w", h=h, w=w)
220
-
221
- v_resample_patch_embed = vmap(vmap(resample_patch_embed, 0, 0), 1, 1)
222
-
223
- return v_resample_patch_embed(patch_embed)
224
-
225
- def forward(
226
- self,
227
- x: Tensor,
228
- patch_size: int | tuple[int, int] | None = None,
229
- return_patch_size: bool = False,
230
- ) -> Tensor | tuple[Tensor, tuple[int, int]]:
231
- if not patch_size and not self.training:
232
- # During evaluation use base patch size if not specified
233
- patch_size = self.patch_size
234
- elif not patch_size:
235
- # During training choose uniformly at random if not specified
236
- assert self.patch_size_seq, (
237
- "No patch size specified during forward and no patch_size_seq given to FlexiPatchEmbed"
238
- )
239
- patch_size = np.random.choice(self.patch_size_seq, p=self.patch_size_probs)
240
-
241
- patch_size = to_2tuple(patch_size)
242
-
243
- # Resize conv weights
244
- if patch_size == self.patch_size:
245
- weight = self.proj.weight
246
- else:
247
- weight = self.resize_patch_embed(self.proj.weight, patch_size)
248
-
249
- # Apply conv with resized weights
250
- x = F.conv2d(x, weight, bias=self.proj.bias, stride=patch_size)
251
-
252
- if self.flatten:
253
- x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
254
-
255
- x = self.norm(x)
256
-
257
- if return_patch_size:
258
- return x, patch_size
259
-
260
- return x
@@ -1,69 +0,0 @@
1
- # type: ignore
2
- import collections.abc
3
- import math
4
- from itertools import repeat
5
- from typing import Any
6
-
7
- import torch
8
- import torch.nn.functional as F
9
-
10
-
11
- def to_2tuple(x: Any) -> tuple:
12
- if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
13
- return tuple(x)
14
- return tuple(repeat(x, 2))
15
-
16
-
17
- def resize_abs_pos_embed(
18
- pos_embed: torch.Tensor,
19
- new_size: tuple[int, int],
20
- old_size: int | tuple[int, int] | None = None,
21
- num_prefix_tokens: int = 1,
22
- interpolation: str = "bicubic",
23
- antialias: bool = True,
24
- ) -> torch.Tensor:
25
- """Resize absolute position embeddings to a target resolution via interpolation
26
-
27
- Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/pos_embed.py
28
-
29
- Args:
30
- pos_embed: Position embeddings tensor of size [b, n, d]
31
- new_size: Target [height, width] of embedding
32
- old_size: Original [height, width] of embedding
33
- num_prefix_tokens: Number of non-spatial prefix tokens (eg. cls)
34
- interpolation: Resize interpolation type
35
- antialias: Whether to apply antialiasing resizing
36
- Returns:
37
- Resized pos_embed of size [b, n', d]
38
- """
39
- new_size = to_2tuple(new_size)
40
- new_ntok = new_size[0] * new_size[1]
41
-
42
- if not old_size:
43
- old_size = int(math.sqrt(pos_embed.shape[1] - num_prefix_tokens)) # type:ignore
44
- old_size = to_2tuple(old_size)
45
-
46
- # Return if no resize necessary
47
- if new_size == old_size:
48
- return pos_embed
49
-
50
- if num_prefix_tokens:
51
- posemb_prefix, pos_embed = (
52
- pos_embed[:, :num_prefix_tokens],
53
- pos_embed[:, num_prefix_tokens:],
54
- )
55
- else:
56
- posemb_prefix, pos_embed = None, pos_embed
57
-
58
- # Interpolate position embedding
59
- pos_embed = pos_embed.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
60
- pos_embed = F.interpolate(
61
- pos_embed, size=new_size, mode=interpolation, antialias=antialias
62
- )
63
- pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(1, new_ntok, -1)
64
-
65
- # Add back extra prefix tokens
66
- if posemb_prefix is not None:
67
- pos_embed = torch.cat([posemb_prefix, pos_embed], dim=1)
68
-
69
- return pos_embed
@@ -1,348 +0,0 @@
1
- # mypy: ignore-errors
2
- # Copyright (c) Meta Platforms, Inc. and affiliates.
3
- # All rights reserved.
4
-
5
- # This source code is licensed under the license found in the
6
- # LICENSE file in the root directory of this source tree.
7
- # --------------------------------------------------------
8
- # References:
9
- # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
10
- # DeiT: https://github.com/facebookresearch/deit
11
- # --------------------------------------------------------
12
-
13
- import math
14
- from functools import partial
15
-
16
- import torch
17
- import torch.nn as nn
18
- from timm.models.vision_transformer import Block
19
-
20
- from .aurora.fourier import FourierExpansion
21
- from .dynamic_hypernetwork import Dynamic_MLP_OFA_spectral, Dynamic_MLP_OFA_variable
22
- from .flexivit.utils import resize_abs_pos_embed
23
-
24
-
25
- class CopernicusFMViT(nn.Module):
26
- """CopernicusFM: VisionTransformer backbone"""
27
-
28
- def __init__(
29
- self,
30
- img_size=224,
31
- patch_size=16,
32
- drop_rate=0.0,
33
- embed_dim=1024,
34
- depth=24,
35
- num_heads=16,
36
- wv_planes=128,
37
- num_classes=0,
38
- global_pool=True,
39
- mlp_ratio=4.0,
40
- norm_layer=nn.LayerNorm,
41
- loc_option="lonlat",
42
- return_intermediate=False,
43
- intermediate_indices=None,
44
- ):
45
- super().__init__()
46
-
47
- self.wv_planes = wv_planes
48
- self.global_pool = global_pool
49
- if self.global_pool:
50
- norm_layer = norm_layer
51
- embed_dim = embed_dim
52
- self.fc_norm = norm_layer(embed_dim)
53
- else:
54
- self.norm = norm_layer(embed_dim)
55
-
56
- self.patch_embed_spectral = Dynamic_MLP_OFA_spectral(
57
- wv_planes=128, inter_dim=128, kernel_size=16, embed_dim=embed_dim
58
- )
59
- self.patch_embed_variable = Dynamic_MLP_OFA_variable(
60
- wv_planes=128, inter_dim=128, kernel_size=16, embed_dim=embed_dim
61
- )
62
-
63
- self.num_patches = (img_size // patch_size) ** 2
64
-
65
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
66
- # ---------------------------------------------------------------------------
67
-
68
- self.pos_embed = nn.Parameter(
69
- torch.zeros(1, self.num_patches + 1, embed_dim), requires_grad=False
70
- ) # fixed sin-cos embedding
71
-
72
- self.loc_option = loc_option
73
- if loc_option == "cartesian":
74
- self.coord_expansion = FourierExpansion(1e-7, 2)
75
- elif loc_option == "lonlat":
76
- self.coord_expansion = FourierExpansion(0.0001, 720)
77
-
78
- self.scale_expansion = FourierExpansion(0.001, 5.1e8) # 1m2 to 5.1e8 km2
79
- self.time_expansion = FourierExpansion(
80
- 1, 365.25, assert_range=False
81
- ) # 1 to 365.25 days, enable more than 1 year
82
- self.coord_fc = nn.Linear(embed_dim, embed_dim)
83
- self.scale_fc = nn.Linear(embed_dim, embed_dim)
84
- self.time_fc = nn.Linear(embed_dim, embed_dim)
85
- # if meta info is not available, set to a learned parameter
86
- self.coord_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
87
- self.scale_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
88
- self.time_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
89
-
90
- self.blocks = nn.ModuleList(
91
- [
92
- Block(
93
- embed_dim,
94
- num_heads,
95
- mlp_ratio,
96
- qkv_bias=True,
97
- norm_layer=norm_layer,
98
- )
99
- for i in range(depth)
100
- ]
101
- )
102
-
103
- self.head_drop = nn.Dropout(drop_rate)
104
- self.head = (
105
- nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
106
- )
107
-
108
- self.return_intermediate = return_intermediate
109
- self.intermediate_indices = intermediate_indices
110
-
111
- def get_coord_pos_embed(self, lons, lats, embed_dim):
112
- if self.loc_option == "cartesian":
113
- # convert to spherical coordinates
114
- spherical_x = (
115
- torch.cos(lons * math.pi / 180) * torch.cos(lats * math.pi / 180)
116
- + 1
117
- + 1e-7
118
- )
119
- spherical_y = (
120
- torch.sin(lons * math.pi / 180) * torch.cos(lats * math.pi / 180)
121
- + 1
122
- + 1e-7
123
- )
124
- spherical_z = torch.sin(lats * math.pi / 180) + 1 + 1e-7
125
- coord_embed_spherical_x = self.coord_expansion(spherical_x, embed_dim // 3)
126
- coord_embed_spherical_y = self.coord_expansion(spherical_y, embed_dim // 3)
127
- coord_embed_spherical_z = self.coord_expansion(spherical_z, embed_dim // 3)
128
- coord_embed = torch.cat(
129
- [
130
- coord_embed_spherical_x,
131
- coord_embed_spherical_y,
132
- coord_embed_spherical_z,
133
- ],
134
- dim=-1,
135
- ) # [B,D]
136
- elif self.loc_option == "lonlat":
137
- coord_embed_lon = self.coord_expansion(lons + 180, embed_dim // 2)
138
- coord_embed_lat = self.coord_expansion(lats + 90, embed_dim // 2)
139
- coord_embed = torch.cat([coord_embed_lon, coord_embed_lat], dim=-1)
140
-
141
- if coord_embed.shape[-1] < embed_dim:
142
- # pad zeros
143
- coord_embed = torch.cat(
144
- (
145
- coord_embed,
146
- torch.zeros(
147
- coord_embed.shape[0],
148
- embed_dim - coord_embed.shape[-1],
149
- device=coord_embed.device,
150
- ),
151
- ),
152
- dim=-1,
153
- )
154
-
155
- return coord_embed.unsqueeze(1) # [B,1,D]
156
-
157
- def get_area_pos_embed(self, areas, embed_dim):
158
- scale_embed = self.scale_expansion(areas, embed_dim) # B, D
159
- return scale_embed.unsqueeze(1) # [B,1,D]
160
-
161
- def get_time_pos_embed(self, times, embed_dim):
162
- time_embed = self.time_expansion(times, embed_dim) # B, D
163
- return time_embed.unsqueeze(1) # [B,1,D]
164
-
165
- def forward_features(
166
- self,
167
- x,
168
- meta_info,
169
- wave_list,
170
- bandwidth,
171
- language_embed,
172
- input_mode,
173
- kernel_size=None,
174
- ):
175
- # embed patches
176
- if input_mode == "spectral":
177
- wavelist = torch.tensor(wave_list, device=x.device).float()
178
- bandwidths = torch.tensor(bandwidth, device=x.device).float()
179
- self.waves = wavelist
180
- x, _ = self.patch_embed_spectral(x, self.waves, bandwidths, kernel_size)
181
- elif input_mode == "variable":
182
- x, _ = self.patch_embed_variable(x, language_embed, kernel_size)
183
-
184
- # resize pos embed
185
- num_patches = x.size(1)
186
- num_patches_sqrt = int(math.sqrt(num_patches))
187
- num_patches_sqrt_origin = int(math.sqrt(self.num_patches))
188
- pos_embed = resize_abs_pos_embed(
189
- self.pos_embed,
190
- num_patches_sqrt,
191
- (num_patches_sqrt_origin, num_patches_sqrt_origin),
192
- num_prefix_tokens=1,
193
- )
194
-
195
- # coord, scale and time pos embed
196
- lons, lats, times, areas = (
197
- meta_info[:, 0],
198
- meta_info[:, 1],
199
- meta_info[:, 2],
200
- meta_info[:, 3],
201
- )
202
- embed_dim = pos_embed.shape[-1]
203
- if torch.isnan(lons).any() or torch.isnan(lats).any():
204
- coord_embed = self.coord_token
205
- else:
206
- coord_embed = self.get_coord_pos_embed(lons, lats, embed_dim)
207
- coord_embed = self.coord_fc(coord_embed)
208
- if torch.isnan(areas).any():
209
- area_embed = self.scale_token
210
- else:
211
- area_embed = self.get_area_pos_embed(areas, embed_dim)
212
- area_embed = self.scale_fc(area_embed)
213
- if torch.isnan(times).any():
214
- time_embed = self.time_token
215
- else:
216
- time_embed = self.get_time_pos_embed(times, embed_dim)
217
- time_embed = self.time_fc(time_embed)
218
- pos_embed = pos_embed + coord_embed + area_embed + time_embed
219
-
220
- # add pos embed w/o cls token
221
- x = x + pos_embed[:, 1:, :]
222
-
223
- # append cls token
224
- cls_token = self.cls_token + pos_embed[:, :1, :]
225
- cls_tokens = cls_token.expand(x.shape[0], -1, -1)
226
- x = torch.cat((cls_tokens, x), dim=1)
227
-
228
- intermediate_features = []
229
- hw = num_patches_sqrt
230
- hw_shape = (hw, hw)
231
-
232
- # apply Transformer blocks
233
- for i, block in enumerate(self.blocks):
234
- x = block(x)
235
- if self.return_intermediate and (i in self.intermediate_indices):
236
- out = x[:, 1:]
237
- B, _, C = out.shape
238
- out = (
239
- out.reshape(B, hw_shape[0], hw_shape[1], C)
240
- .permute(0, 3, 1, 2)
241
- .contiguous()
242
- )
243
- intermediate_features.append(out)
244
-
245
- # if self.global_pool:
246
- # x = x[:, 1:, :].mean(dim=1) # global pool without cls token
247
- # outcome = self.fc_norm(x)
248
- # else:
249
- # x = self.norm(x)
250
- # outcome = x[:, 0]
251
-
252
- # if self.return_intermediate:
253
- # return outcome, intermediate_features
254
-
255
- # for segmentation tasks, ignore the norm
256
- # https://github.com/zhu-xlab/Copernicus-FM/blob/main/Copernicus-Bench/src/foundation_models/CopernicusFM/models_dwv_seg.py
257
- # for classification, we will apply the fc_norm in the wrapper
258
- return x[:, 1:, :]
259
-
260
- def forward_head(self, x, pre_logits=False):
261
- x = self.head_drop(x)
262
- return x if pre_logits else self.head(x)
263
-
264
- def forward(
265
- self,
266
- x,
267
- meta_info,
268
- wave_list,
269
- bandwidth,
270
- language_embed,
271
- input_mode,
272
- kernel_size=None,
273
- ):
274
- if self.return_intermediate:
275
- x, intermediate_features = self.forward_features(
276
- x,
277
- meta_info,
278
- wave_list,
279
- bandwidth,
280
- language_embed,
281
- input_mode,
282
- kernel_size,
283
- )
284
- return x, intermediate_features
285
- else:
286
- fx = self.forward_features(
287
- x,
288
- meta_info,
289
- wave_list,
290
- bandwidth,
291
- language_embed,
292
- input_mode,
293
- kernel_size,
294
- )
295
- x = self.forward_head(fx)
296
- return x, fx
297
-
298
-
299
- def vit_small_patch16(**kwargs):
300
- model = CopernicusFMViT(
301
- patch_size=16,
302
- embed_dim=384,
303
- depth=12,
304
- num_heads=6,
305
- mlp_ratio=4,
306
- norm_layer=partial(nn.LayerNorm, eps=1e-6),
307
- **kwargs,
308
- )
309
- return model
310
-
311
-
312
- def vit_base_patch16(**kwargs):
313
- model = CopernicusFMViT(
314
- patch_size=16,
315
- embed_dim=768,
316
- depth=12,
317
- num_heads=12,
318
- mlp_ratio=4,
319
- norm_layer=partial(nn.LayerNorm, eps=1e-6),
320
- **kwargs,
321
- )
322
- return model
323
-
324
-
325
- def vit_large_patch16(**kwargs):
326
- model = CopernicusFMViT(
327
- patch_size=16,
328
- embed_dim=1024,
329
- depth=24,
330
- num_heads=16,
331
- mlp_ratio=4,
332
- norm_layer=partial(nn.LayerNorm, eps=1e-6),
333
- **kwargs,
334
- )
335
- return model
336
-
337
-
338
- def vit_huge_patch14(**kwargs):
339
- model = CopernicusFMViT(
340
- patch_size=14,
341
- embed_dim=1280,
342
- depth=32,
343
- num_heads=16,
344
- mlp_ratio=4,
345
- norm_layer=partial(nn.LayerNorm, eps=1e-6),
346
- **kwargs,
347
- )
348
- return model